package hkp

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"net/http"
	"time"
)

// TODO(lb5tr): provide as flag
var keyServers = []string{
	"http://pool.sks-keyservers.net",
	"http://keys.gnupg.net",
}

var (
	PerServerTimeLimit  = 5 * time.Second
	PerServerRetryCount = 3
)

var ErrKeyNotFound = errors.New("not found on hkp servers")

const startMarker string = "-----BEGIN PGP PUBLIC KEY BLOCK-----"
const endMarker string = "-----END PGP PUBLIC KEY BLOCK-----"

type Client interface {
	GetKeyRing(ctx context.Context, keyID []byte) ([]byte, error)
}

type transport interface {
	get(ctx context.Context, path string) ([]byte, error)
}

type httpTransport struct {
}

type HKP struct {
	transport transport
}

func NewClient() Client {
	client := HKP{
		transport: httpTransport{},
	}
	return client
}

func (hkp HKP) GetKeyRing(ctx context.Context, keyID []byte) ([]byte, error) {
	key := fmt.Sprintf("0x%x", keyID)
	output := make(chan []byte)
	errors := make(chan error)

	go func() {
		var lastError error
		for _, server := range keyServers {
			url := server + "/pks/lookup?op=get&search=" + key
			for i := 0; i < PerServerRetryCount; i++ {
				localCtx, cancel := context.WithTimeout(context.Background(), PerServerTimeLimit)
				keyData, err := hkp.transport.get(localCtx, url)
				cancel()

				// ErrKeyNotFound is retriable. I've seen cases where upon retry
				// server responds with key just fine

				switch err {
				case nil:
					output <- keyData
					return
				case ctx.Err():
					errors <- err
					return
				default:
					lastError = err
				}
			}
		}

		errors <- lastError
	}()

	select {
	case <-ctx.Done():
		return nil, ctx.Err()
	case finalError := <-errors:
		return nil, finalError
	case result := <-output:
		return result, nil
	}
}

func (httpTransport) get(ctx context.Context, url string) ([]byte, error) {
	localCtx, cancel := context.WithTimeout(ctx, PerServerTimeLimit)
	defer cancel()

	req, err := http.NewRequest("GET", url, nil)
	if err != nil {
		return nil, fmt.Errorf("http.NewRequest(GET, %q): %v", url, err)
	}

	req = req.WithContext(localCtx)
	client := http.DefaultClient
	res, err := client.Do(req)

	if err != nil {
		return nil, fmt.Errorf("client.Do(%v): %v", req, err)
	}

	defer res.Body.Close()

	if res.StatusCode != 200 {
		if res.StatusCode == 404 {
			return nil, ErrKeyNotFound
		}

		return nil, fmt.Errorf("got status code %d", res.StatusCode)
	}

	buf := bytes.NewBuffer([]byte{})
	buf.ReadFrom(res.Body)
	response := buf.Bytes()

	start := bytes.Index(response, []byte(startMarker))
	end := bytes.Index(response, []byte(endMarker))

	if start == -1 || end == -1 {
		return nil, fmt.Errorf("failed to read")
	}

	data := response[start : end+len(endMarker)]
	return data, nil
}
