package main

import (
	"context"
	"crypto/sha1"
	"encoding/hex"
	"flag"
	"fmt"
	"io"
	"os"
	"regexp"
	"strings"
	"sync"
	"time"

	"code.hackerspace.pl/hscloud/go/mirko"
	"github.com/golang/glog"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	"code.hackerspace.pl/hscloud/games/factorio/modproxy/modportal"
	pb "code.hackerspace.pl/hscloud/games/factorio/modproxy/proto"
)

func init() {
	flag.Set("logtostderr", "true")
}

var (
	flagCASDirectory string
)

func main() {
	flag.StringVar(&flagCASDirectory, "cas_directory", "cas", "directory in which to store cached files")
	flag.Parse()
	m := mirko.New()
	if err := m.Listen(); err != nil {
		glog.Exitf("Listen(): %v", err)
	}

	srv := &service{
		cache: make(map[string]*cacheEntry),
	}

	pb.RegisterModProxyServer(m.GRPC(), srv)

	if err := m.Serve(); err != nil {
		glog.Exitf("Serve(): %v", err)
	}

	<-m.Done()
}

var (
	reSha1 = regexp.MustCompile(`^[a-f0-9]+$`)
)

func casPath(sha1 string) string {
	sha1 = strings.ToLower(sha1)
	if !reSha1.MatchString(sha1) {
		return ""
	}
	return fmt.Sprintf("%s/%s", flagCASDirectory, sha1)
}

type service struct {
	mu sync.Mutex

	// cache of sha1 -> cache entry
	cache map[string]*cacheEntry
}

type cacheEntry struct {
	expires *time.Time
	modName string

	// found means that this is an entry confirmed on the mod portal
	found bool
	// mirrored means we are ready to serve this file to users
	mirrored bool
}

func (s *service) Mirror(ctx context.Context, req *pb.MirrorRequest) (*pb.MirrorResponse, error) {

	// build map of sha1->modName for needed downloads
	modNames := make(map[string]string)
	s.mu.Lock()
	for sha, e := range s.cache {
		if e == nil {
			continue
		}
		if e.found == false {
			continue
		}
		if e.mirrored == true {
			continue
		}

		modNames[sha] = e.modName
	}
	s.mu.Unlock()

	okays := make(map[string]bool)
	errors := make(map[string]error)

	for sha, modName := range modNames {
		k := fmt.Sprintf("%s/%s", modName, sha)
		mod, err := modportal.GetMod(ctx, modName)
		if err != nil {
			errors[k] = err
			continue
		}
		release := mod.ReleaseBySHA1(sha)
		if release == nil {
			errors[k] = fmt.Errorf("could not find sha1 in modportal - deleted?")
			continue
		}

		r, err := release.Download(ctx, req.Username, req.Token)
		if err != nil {
			errors[k] = fmt.Errorf("could not download: %v", err)
			continue
		}

		path := casPath(sha)
		pathIncoming := path + ".incoming"

		out, err := os.Create(pathIncoming)
		if err != nil {
			errors[k] = fmt.Errorf("could not create file: %v", err)
			continue
		}
		_, err = io.Copy(out, r)
		if err != nil {
			errors[k] = fmt.Errorf("could not save: %v", err)
			continue
		}
		err = os.Rename(pathIncoming, path)
		if err != nil {
			errors[k] = fmt.Errorf("could not commit file: %v", err)
			continue
		}

		okays[k] = true
		s.cacheFeed(sha, modName, nil, true, true)
	}

	res := &pb.MirrorResponse{
		ModsErrors: make(map[string]string),
	}
	for m, _ := range okays {
		glog.Infof("Downloaded %q", m)
		res.ModsOkay = append(res.ModsOkay, m)
	}
	for m, err := range errors {
		glog.Errorf("Could not download %q: %v", m, err)
		res.ModsErrors[m] = fmt.Sprintf("%v", err)
	}

	return res, nil
}

func (s *service) cacheGet(sha1 string) (hit, found, mirrored bool) {
	s.mu.Lock()
	defer s.mu.Unlock()

	entry, ok := s.cache[sha1]
	if !ok || entry == nil {
		return
	}

	if entry.expires != nil && time.Now().Before(*entry.expires) {
		delete(s.cache, sha1)
		return
	}

	hit = true
	found = entry.found
	mirrored = entry.mirrored

	return
}

func (s *service) cacheFeed(sha1, modName string, expires *time.Time, found, mirrored bool) {
	s.mu.Lock()
	defer s.mu.Unlock()

	s.cache[sha1] = &cacheEntry{
		expires:  expires,
		modName:  modName,
		found:    found,
		mirrored: mirrored,
	}
}

func (s *service) serve(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
	cas := casPath(req.FileSha1)
	if cas == "" {
		// Invalid sha1? Fail.
		return status.Error(codes.Aborted, "invalid sha1")
	}

	file, err := os.Open(cas)
	if err != nil {
		// not in CAS, update cache and fail
		s.cacheFeed(req.FileSha1, req.ModName, nil, true, false)
		return srv.Send(&pb.DownloadResponse{
			Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
		})
	}
	defer file.Close()

	err = srv.Send(&pb.DownloadResponse{
		Status: pb.DownloadResponse_STATUS_OKAY,
	})
	if err != nil {
		return err
	}

	buf := make([]byte, 1024*1024)
	hash := sha1.New()

	for {
		n, err := file.Read(buf)
		if err == io.EOF {
			break
		}
		if err != nil {
			return status.Errorf(codes.Unavailable, "error reading file: %v", err)
		}
		hash.Write(buf[:n])
		err = srv.Send(&pb.DownloadResponse{
			Chunk: buf[:n],
		})
		if err != nil {
			return err
		}
	}

	// entire file send, double-check shasum
	sum := hex.EncodeToString(hash.Sum(nil))
	if sum != req.FileSha1 {
		glog.Errorf("CAS corruption: wanted %q, got %q", req.FileSha1, sum)
		return status.Error(codes.Aborted, "CAS corruption")
	}

	return nil
}

func (s *service) Download(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
	ctx := srv.Context()

	modName := req.ModName
	if modName == "" {
		return status.Error(codes.InvalidArgument, "mod name must be set")
	}
	sha1 := req.FileSha1
	if sha1 == "" {
		return status.Error(codes.InvalidArgument, "sha1 must be set")
	}
	sha1 = strings.ToLower(sha1)
	req.FileSha1 = sha1

	cacheHit, found, mirrored := s.cacheGet(sha1)
	if cacheHit {
		if !found {
			return status.Error(codes.NotFound, "sha1 not found for mod")
		}
		if !mirrored {
			return srv.Send(&pb.DownloadResponse{
				Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
			})
		}

		// we have the file, serve it
		return s.serve(req, srv)
	}

	// cache not hit, check mod portal
	mod, err := modportal.GetMod(ctx, modName)
	if err != nil {
		return err
	}
	release := mod.ReleaseBySHA1(sha1)

	// release not found in mod portal, cache and answer
	if release == nil {
		expires := time.Now().Add(1 * time.Minute)
		s.cacheFeed(sha1, modName, &expires, false, false)
		return status.Error(codes.InvalidArgument, "sha1 not found for mod")
	}

	// we assume it's mirrored - the first cas serve will prove us wrong otherwise and
	// update the cache.
	s.cacheFeed(sha1, modName, nil, true, true)
	// call ourselves again now that the cache is fed. computers - it's like magic!
	return s.Download(req, srv)
}
