package model

import (
	"context"
	"database/sql"
	"fmt"
	"strconv"
	"strings"

	pb "code.hackerspace.pl/hscloud/bgpwtf/cccampix/proto"
	"github.com/golang/glog"
	"github.com/jmoiron/sqlx"
)

func (m *sqlModel) RecordPeeringDBPeers(ctx context.Context, members []*pb.PeeringDBMember) error {
	tx := m.db.MustBeginTx(ctx, &sql.TxOptions{})
	defer tx.Rollback()

	wanted := make(map[string]*pb.PeeringDBMember)
	for _, member := range members {
		wanted[fmt.Sprintf("%d", member.Asn)] = member
	}

	toDelete := make(map[string]bool)
	toAdd := make(map[string]bool)
	toUpdate := make(map[string]bool)

	existing := []sqlPeer{}
	existingMap := make(map[string]*sqlPeer)

	q := `
		SELECT peers.id, peers.asn, peers.name, peers.source
		FROM peers
	`
	if err := tx.SelectContext(ctx, &existing, q); err != nil {
		return fmt.Errorf("SELECT peers: %v", err)
	}

	// Mark ASs to delete and note existing ASs
	for _, ex := range existing {
		ex := ex
		if wanted[ex.ASN] == nil && ex.Source == "from-peeringdb" {
			toDelete[ex.ASN] = true
		}
		existingMap[ex.ASN] = &ex
	}

	// Mark ASs to add
	for k, _ := range wanted {
		if existingMap[k] == nil {
			toAdd[k] = true
		}
	}

	// Mark ASs to update
	for k, wd := range wanted {
		if existingMap[k] == nil {
			continue
		}
		if existingMap[k].Source != "from-peeringdb" {
			continue
		}
		if wd.Name != existingMap[k].Name {
			toUpdate[k] = true
			continue
		}
	}

	if len(toAdd) > 0 {
		glog.Infof("RecordPeeringDBPeers: adding %v", toAdd)
	}
	if len(toDelete) > 0 {
		glog.Infof("RecordPeeringDBPeers: deleting %v", toDelete)
	}
	if len(toUpdate) > 0 {
		glog.Infof("RecordPeeringDBPeers: updating %v", toUpdate)
	}

	// Run INSERT to add new ASNs
	if len(toAdd) > 0 {
		q = `
			INSERT INTO peers
				(asn, name, source)
			VALUES
				(:asn, :name, :source)
		`

		add := make([]*sqlPeer, len(toAdd))
		i := 0
		for ta, _ := range toAdd {
			add[i] = &sqlPeer{
				ASN:    ta,
				Name:   wanted[ta].Name,
				Source: "from-peeringdb",
			}
			i += 1
		}

		if _, err := tx.NamedExecContext(ctx, q, add); err != nil {
			return fmt.Errorf("INSERT peers: %v", err)
		}
	}

	// Run DELETE to remove nonexistent ASNs
	if len(toDelete) > 0 {
		deleteIds := make([]string, len(toDelete))
		i := 0
		for td, _ := range toDelete {
			deleteIds[i] = existingMap[td].ID
			i += 1
		}
		query, args, err := sqlx.In("DELETE FROM peers WHERE id IN (?)", deleteIds)
		if err != nil {
			return fmt.Errorf("DELETE peers: %v", err)
		}
		query = tx.Rebind(query)
		_, err = tx.ExecContext(ctx, query, args...)
		if err != nil {
			return fmt.Errorf("DELETE peers: %v", err)
		}
	}

	// Run UPDATE to update existing ASNs
	for k, _ := range toUpdate {
		want := wanted[k]
		got := existingMap[k]

		fields := []string{}
		args := []interface{}{}
		if want.Name != got.Name {
			fields = append(fields, "name = ?")
			args = append(args, want.Name)
		}

		q = fmt.Sprintf(`
			UPDATE peers
			SET
				%s
			WHERE
				id = ?
		`, strings.Join(fields, ",\n"))
		q = tx.Rebind(q)
		args = append(args, got.ID)
		_, err := tx.ExecContext(ctx, q, args...)
		if err != nil {
			return fmt.Errorf("UPDATE peers: %v", err)
		}
	}

	return tx.Commit()
}

func (s *sqlModel) GetPeeringDBPeer(ctx context.Context, asn int64) (*pb.PeeringDBMember, error) {
	data := []struct {
		sqlPeer       `db:"peers"`
		sqlPeerRouter `db:"peer_routers"`
	}{}
	q := `
		SELECT
			peers.id "peers.id",
			peers.asn "peers.asn",
			peers.name "peers.name",

			peer_routers.peer_id "peer_routers.peer_id",
			peer_routers.v6 "peer_routers.v6",
			peer_routers.v4 "peer_routers.v4"
		FROM peers
		LEFT JOIN peer_routers
		ON peer_routers.peer_id = peers.id
		WHERE peers.asn = $1
	`
	if err := s.db.SelectContext(ctx, &data, q, asn); err != nil {
		return nil, fmt.Errorf("SELECT peers/peerRouters: %v", err)
	}

	res := &pb.PeeringDBMember{}

	for i, row := range data {
		if res.Routers == nil {
			asn, err := strconv.ParseInt(row.sqlPeer.ASN, 10, 64)
			if err != nil {
				return nil, fmt.Errorf("data corruption: invalid ASN %q", row.sqlPeer.ASN)
			}
			res.Asn = asn
			res.Name = row.sqlPeer.Name
			res.Routers = make([]*pb.PeeringDBMember_Router, len(data))
		}

		res.Routers[i] = &pb.PeeringDBMember_Router{}
		if row.sqlPeerRouter.V6.Valid {
			res.Routers[i].Ipv6 = row.sqlPeerRouter.V6.String
		}
		if row.sqlPeerRouter.V4.Valid {
			res.Routers[i].Ipv4 = row.sqlPeerRouter.V4.String
		}
	}

	return res, nil
}
