package main

import (
	"context"
	"database/sql"
	"fmt"
	"strconv"
	"time"

	"github.com/golang/glog"
	"github.com/golang/protobuf/proto"
	_ "github.com/mattn/go-sqlite3"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	pb "code.hackerspace.pl/hscloud/bgpwtf/invoice/proto"
)

type model struct {
	db *sql.DB
}

func newModel(dsn string) (*model, error) {
	db, err := sql.Open("sqlite3", dsn)
	if err != nil {
		return nil, err
	}
	return &model{
		db: db,
	}, nil
}

func (m *model) init() error {
	_, err := m.db.Exec(`
		create table invoice (
			id integer primary key not null,
			created_time integer not null,
			proto blob not null
		);
		create table invoice_seal (
			id integer primary key not null,
			invoice_id integer not null,
			final_uid text not null unique,
			sealed_time integer not null,
			foreign key (invoice_id) references invoice(id)
		);
		create table invoice_blob (
			id integer primary key not null,
			invoice_id integer not null,
			pdf blob not null,
			foreign key (invoice_id) references invoice(id)
		);
	`)
	return err
}

func (m *model) sealInvoice(ctx context.Context, uid, language string, useProformaTime bool) error {
	id, err := strconv.Atoi(uid)
	if err != nil {
		return status.Error(codes.InvalidArgument, "invalid uid")
	}

	invoice, err := m.getInvoice(ctx, uid)
	if err != nil {
		return err
	}

	tx, err := m.db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}

	q := `
		insert into invoice_seal (
			invoice_id, final_uid, sealed_time
		) values (
			?,
			( select printf("%04d", ifnull( (select final_uid as v from invoice_seal order by final_uid desc limit 1), 20000) + 1 )),
			?
		)
	`

	sealTime := time.Now()
	if useProformaTime {
		sealTime = time.Unix(0, invoice.Date)
	}
	res, err := tx.Exec(q, id, sealTime.UnixNano())
	if err != nil {
		return err
	}

	lastInvoiceSealId, err := res.LastInsertId()
	if err != nil {
		return err
	}

	q = `
		select final_uid from invoice_seal where id = ?
	`

	var finalUid string
	if err := tx.QueryRow(q, lastInvoiceSealId).Scan(&finalUid); err != nil {
		return err
	}

	invoice.State = pb.Invoice_STATE_SEALED
	// TODO(q3k): this should be configurable.
	invoice.FinalUid = fmt.Sprintf("FV/%s", finalUid)
	invoice.Date = sealTime.UnixNano()
	calculateInvoiceData(invoice)

	pdfBlob, err := renderInvoicePDF(invoice, language)
	if err != nil {
		return err
	}

	q = `
		insert into invoice_blob (
			invoice_id, pdf
		) values (
			?, ?
		)
	`

	if _, err := tx.Exec(q, id, pdfBlob); err != nil {
		return err
	}

	if err := tx.Commit(); err != nil {
		return err
	}

	return nil
}

func (m *model) createInvoice(ctx context.Context, id *pb.InvoiceData) (string, error) {
	data, err := proto.Marshal(id)
	if err != nil {
		return "", err
	}

	sql := `
		insert into invoice (
			proto, created_time
		) values (
			?, ?
		)
	`

	t := time.Now()
	if id.Date != 0 {
		t = time.Unix(0, id.Date)
	}

	res, err := m.db.Exec(sql, data, t.UnixNano())
	if err != nil {
		return "", err
	}
	uid, err := res.LastInsertId()
	if err != nil {
		return "", err
	}

	glog.Infof("%+v", uid)
	return fmt.Sprintf("%d", uid), nil
}

func (m *model) getRendered(ctx context.Context, uid string) ([]byte, error) {
	id, err := strconv.Atoi(uid)
	if err != nil {
		return nil, status.Error(codes.InvalidArgument, "invalid uid")
	}

	q := `
		select invoice_blob.pdf from invoice_blob where invoice_blob.invoice_id = ?
	`
	res := m.db.QueryRow(q, id)

	data := []byte{}
	if err := res.Scan(&data); err != nil {
		if err == sql.ErrNoRows {
			return nil, status.Error(codes.InvalidArgument, "no such invoice")
		}
		return nil, err
	}
	return data, nil
}

func (m *model) getSealedUid(ctx context.Context, uid string) (string, error) {
	id, err := strconv.Atoi(uid)
	if err != nil {
		return "", status.Error(codes.InvalidArgument, "invalid uid")
	}

	q := `
		select invoice_seal.final_uid from invoice_seal where invoice_seal.invoice_id = ?
	`
	res := m.db.QueryRow(q, id)
	finalUid := ""
	if err := res.Scan(&finalUid); err != nil {
		if err == sql.ErrNoRows {
			return "", nil
		}
		return "", err
	}
	return finalUid, nil
}

type sqlInvoiceSealRow struct {
	proto       []byte
	createdTime int64
	sealedTime  sql.NullInt64
	finalUid    sql.NullString
	uid         int64
}

func (s *sqlInvoiceSealRow) Proto() (*pb.Invoice, error) {
	data := &pb.InvoiceData{}
	if err := proto.Unmarshal(s.proto, data); err != nil {
		return nil, err
	}

	p := &pb.Invoice{
		Uid:  fmt.Sprintf("%d", s.uid),
		Data: data,
	}
	if s.finalUid.Valid {
		p.State = pb.Invoice_STATE_SEALED
		p.FinalUid = fmt.Sprintf("FV/%s", s.finalUid.String)
		p.Date = s.sealedTime.Int64
	} else {
		p.State = pb.Invoice_STATE_PROFORMA
		p.FinalUid = fmt.Sprintf("PROFORMA/%d", s.uid)
		p.Date = s.createdTime
	}
	calculateInvoiceData(p)
	return p, nil
}

func (m *model) getInvoice(ctx context.Context, uid string) (*pb.Invoice, error) {
	id, err := strconv.Atoi(uid)
	if err != nil {
		return nil, status.Error(codes.InvalidArgument, "invalid uid")
	}

	q := `
		select
				invoice.id, invoice.proto, invoice.created_time, invoice_seal.sealed_time, invoice_seal.final_uid
		from invoice
		left join invoice_seal
		on invoice_seal.invoice_id = invoice.id
		where invoice.id = ?
	`
	res := m.db.QueryRow(q, id)
	row := sqlInvoiceSealRow{}
	if err := res.Scan(&row.uid, &row.proto, &row.createdTime, &row.sealedTime, &row.finalUid); err != nil {
		if err == sql.ErrNoRows {
			return nil, status.Error(codes.NotFound, "no such invoice")
		}
		return nil, err
	}

	return row.Proto()
}

func (m *model) getInvoices(ctx context.Context) ([]*pb.Invoice, error) {
	q := `
		select
				invoice.id, invoice.proto, invoice.created_time, invoice_seal.sealed_time, invoice_seal.final_uid
		from invoice
		left join invoice_seal
		on invoice_seal.invoice_id = invoice.id
	`
	rows, err := m.db.QueryContext(ctx, q)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	res := []*pb.Invoice{}

	for rows.Next() {
		row := sqlInvoiceSealRow{}
		if err := rows.Scan(&row.uid, &row.proto, &row.createdTime, &row.sealedTime, &row.finalUid); err != nil {
			return nil, err
		}
		p, err := row.Proto()
		if err != nil {
			return nil, err
		}
		res = append(res, p)
	}

	return res, nil
}
