blob: e6b38d0282f9c4a5591298bdd6a976579f1e8fe5 [file] [log] [blame]
package model
import (
"context"
"database/sql"
"fmt"
"net"
"time"
)
func (s *sqlModel) UpdateAllowedPrefixes(ctx context.Context, asn int64, prefixes []*AllowedPrefix) error {
tx := s.db.MustBeginTx(ctx, &sql.TxOptions{})
defer tx.Rollback()
timestamp := time.Now().UnixNano()
for _, prefix := range prefixes {
q := `
INSERT INTO allowed_prefixes
(peer_id, timestamp, prefix, max_length, ta)
SELECT
peers.id, :timestamp, :prefix, :max_length, :ta
FROM peers
WHERE peers.asn = :asn
ON CONFLICT (peer_id, prefix)
DO UPDATE SET
timestamp = :timestamp,
max_length = :max_length,
ta = :ta
`
ap := sqlAllowedPrefix{
Timestamp: timestamp,
Prefix: prefix.Prefix.String(),
MaxLength: prefix.MaxLength,
TA: prefix.TA,
ASN: fmt.Sprintf("%d", asn),
}
if _, err := tx.NamedExecContext(ctx, q, ap); err != nil {
return fmt.Errorf("INSERT allowed_prefixes: %v", err)
}
}
q := `
DELETE FROM allowed_prefixes
WHERE timestamp != $1
AND peer_id = (SELECT peers.id FROM peers WHERE peers.asn = $2)
`
if _, err := tx.ExecContext(ctx, q, timestamp, asn); err != nil {
return fmt.Errorf("DELETE FROM allowed_prefixes: %v", err)
}
return tx.Commit()
}
func (s *sqlModel) GetAllowedPrefixes(ctx context.Context, asn int64) ([]*AllowedPrefix, error) {
q := `
SELECT
allowed_prefixes.prefix,
allowed_prefixes.max_length,
allowed_prefixes.ta
FROM
allowed_prefixes
LEFT JOIN peers
ON peers.id = allowed_prefixes.peer_id
WHERE peers.asn = $1
`
data := []sqlAllowedPrefix{}
if err := s.db.SelectContext(ctx, &data, q, asn); err != nil {
return nil, fmt.Errorf("SELECT allowed_prefixes: %v", err)
}
res := make([]*AllowedPrefix, len(data))
for i, d := range data {
_, prefix, err := net.ParseCIDR(d.Prefix)
if err != nil {
return nil, fmt.Errorf("corrupted CIDR in database: %v", err)
}
res[i] = &AllowedPrefix{
Prefix: *prefix,
MaxLength: d.MaxLength,
TA: d.TA,
}
}
return res, nil
}