games/factorio: add modproxy

This adds a mod proxy system, called, well, modproxy.

It sits between Factorio server instances and the Factorio mod portal,
allowing for arbitrary mod download without needing the servers to know
Factorio credentials.

Change-Id: I7bc405a25b6f9559cae1f23295249f186761f212
diff --git a/games/factorio/modproxy/main.go b/games/factorio/modproxy/main.go
new file mode 100644
index 0000000..30e41bf
--- /dev/null
+++ b/games/factorio/modproxy/main.go
@@ -0,0 +1,298 @@
+package main
+
+import (
+	"context"
+	"crypto/sha1"
+	"encoding/hex"
+	"flag"
+	"fmt"
+	"io"
+	"os"
+	"regexp"
+	"strings"
+	"sync"
+	"time"
+
+	"code.hackerspace.pl/hscloud/go/mirko"
+	"github.com/golang/glog"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
+
+	"code.hackerspace.pl/hscloud/games/factorio/modproxy/modportal"
+	pb "code.hackerspace.pl/hscloud/games/factorio/modproxy/proto"
+)
+
+func init() {
+	flag.Set("logtostderr", "true")
+}
+
+var (
+	flagCASDirectory string
+)
+
+func main() {
+	flag.StringVar(&flagCASDirectory, "cas_directory", "cas", "directory in which to store cached files")
+	flag.Parse()
+	m := mirko.New()
+	if err := m.Listen(); err != nil {
+		glog.Exitf("Listen(): %v", err)
+	}
+
+	srv := &service{
+		cache: make(map[string]*cacheEntry),
+	}
+
+	pb.RegisterModProxyServer(m.GRPC(), srv)
+
+	if err := m.Serve(); err != nil {
+		glog.Exitf("Serve(): %v", err)
+	}
+
+	<-m.Done()
+}
+
+var (
+	reSha1 = regexp.MustCompile(`^[a-f0-9]+$`)
+)
+
+func casPath(sha1 string) string {
+	sha1 = strings.ToLower(sha1)
+	if !reSha1.MatchString(sha1) {
+		return ""
+	}
+	return fmt.Sprintf("%s/%s", flagCASDirectory, sha1)
+}
+
+type service struct {
+	mu sync.Mutex
+
+	// cache of sha1 -> cache entry
+	cache map[string]*cacheEntry
+}
+
+type cacheEntry struct {
+	expires *time.Time
+	modName string
+
+	// found means that this is an entry confirmed on the mod portal
+	found bool
+	// mirrored means we are ready to serve this file to users
+	mirrored bool
+}
+
+func (s *service) Mirror(ctx context.Context, req *pb.MirrorRequest) (*pb.MirrorResponse, error) {
+
+	// build map of sha1->modName for needed downloads
+	modNames := make(map[string]string)
+	s.mu.Lock()
+	for sha, e := range s.cache {
+		if e == nil {
+			continue
+		}
+		if e.found == false {
+			continue
+		}
+		if e.mirrored == true {
+			continue
+		}
+
+		modNames[sha] = e.modName
+	}
+	s.mu.Unlock()
+
+	okays := make(map[string]bool)
+	errors := make(map[string]error)
+
+	for sha, modName := range modNames {
+		k := fmt.Sprintf("%s/%s", modName, sha)
+		mod, err := modportal.GetMod(ctx, modName)
+		if err != nil {
+			errors[k] = err
+			continue
+		}
+		release := mod.ReleaseBySHA1(sha)
+		if release == nil {
+			errors[k] = fmt.Errorf("could not find sha1 in modportal - deleted?")
+			continue
+		}
+
+		r, err := release.Download(ctx, req.Username, req.Token)
+		if err != nil {
+			errors[k] = fmt.Errorf("could not download: %v", err)
+			continue
+		}
+
+		path := casPath(sha)
+		pathIncoming := path + ".incoming"
+
+		out, err := os.Create(pathIncoming)
+		if err != nil {
+			errors[k] = fmt.Errorf("could not create file: %v", err)
+			continue
+		}
+		_, err = io.Copy(out, r)
+		if err != nil {
+			errors[k] = fmt.Errorf("could not save: %v", err)
+			continue
+		}
+		err = os.Rename(pathIncoming, path)
+		if err != nil {
+			errors[k] = fmt.Errorf("could not commit file: %v", err)
+			continue
+		}
+
+		okays[k] = true
+		s.cacheFeed(sha, modName, nil, true, true)
+	}
+
+	res := &pb.MirrorResponse{
+		ModsErrors: make(map[string]string),
+	}
+	for m, _ := range okays {
+		glog.Infof("Downloaded %q", m)
+		res.ModsOkay = append(res.ModsOkay, m)
+	}
+	for m, err := range errors {
+		glog.Errorf("Could not download %q: %v", m, err)
+		res.ModsErrors[m] = fmt.Sprintf("%v", err)
+	}
+
+	return res, nil
+}
+
+func (s *service) cacheGet(sha1 string) (hit, found, mirrored bool) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	entry, ok := s.cache[sha1]
+	if !ok || entry == nil {
+		return
+	}
+
+	if entry.expires != nil && time.Now().Before(*entry.expires) {
+		delete(s.cache, sha1)
+		return
+	}
+
+	hit = true
+	found = entry.found
+	mirrored = entry.mirrored
+
+	return
+}
+
+func (s *service) cacheFeed(sha1, modName string, expires *time.Time, found, mirrored bool) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	s.cache[sha1] = &cacheEntry{
+		expires:  expires,
+		modName:  modName,
+		found:    found,
+		mirrored: mirrored,
+	}
+}
+
+func (s *service) serve(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
+	cas := casPath(req.FileSha1)
+	if cas == "" {
+		// Invalid sha1? Fail.
+		return status.Error(codes.Aborted, "invalid sha1")
+	}
+
+	file, err := os.Open(cas)
+	if err != nil {
+		// not in CAS, update cache and fail
+		s.cacheFeed(req.FileSha1, req.ModName, nil, true, false)
+		return srv.Send(&pb.DownloadResponse{
+			Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
+		})
+	}
+	defer file.Close()
+
+	err = srv.Send(&pb.DownloadResponse{
+		Status: pb.DownloadResponse_STATUS_OKAY,
+	})
+	if err != nil {
+		return err
+	}
+
+	buf := make([]byte, 1024*1024)
+	hash := sha1.New()
+
+	for {
+		n, err := file.Read(buf)
+		if err == io.EOF {
+			break
+		}
+		if err != nil {
+			return status.Errorf(codes.Unavailable, "error reading file: %v", err)
+		}
+		hash.Write(buf[:n])
+		err = srv.Send(&pb.DownloadResponse{
+			Chunk: buf[:n],
+		})
+		if err != nil {
+			return err
+		}
+	}
+
+	// entire file send, double-check shasum
+	sum := hex.EncodeToString(hash.Sum(nil))
+	if sum != req.FileSha1 {
+		glog.Errorf("CAS corruption: wanted %q, got %q", req.FileSha1, sum)
+		return status.Error(codes.Aborted, "CAS corruption")
+	}
+
+	return nil
+}
+
+func (s *service) Download(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
+	ctx := srv.Context()
+
+	modName := req.ModName
+	if modName == "" {
+		return status.Error(codes.InvalidArgument, "mod name must be set")
+	}
+	sha1 := req.FileSha1
+	if sha1 == "" {
+		return status.Error(codes.InvalidArgument, "sha1 must be set")
+	}
+	sha1 = strings.ToLower(sha1)
+	req.FileSha1 = sha1
+
+	cacheHit, found, mirrored := s.cacheGet(sha1)
+	if cacheHit {
+		if !found {
+			return status.Error(codes.NotFound, "sha1 not found for mod")
+		}
+		if !mirrored {
+			return srv.Send(&pb.DownloadResponse{
+				Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
+			})
+		}
+
+		// we have the file, serve it
+		return s.serve(req, srv)
+	}
+
+	// cache not hit, check mod portal
+	mod, err := modportal.GetMod(ctx, modName)
+	if err != nil {
+		return err
+	}
+	release := mod.ReleaseBySHA1(sha1)
+
+	// release not found in mod portal, cache and answer
+	if release == nil {
+		expires := time.Now().Add(1 * time.Minute)
+		s.cacheFeed(sha1, modName, &expires, false, false)
+		return status.Error(codes.InvalidArgument, "sha1 not found for mod")
+	}
+
+	// we assume it's mirrored - the first cas serve will prove us wrong otherwise and
+	// update the cache.
+	s.cacheFeed(sha1, modName, nil, true, true)
+	// call ourselves again now that the cache is fed. computers - it's like magic!
+	return s.Download(req, srv)
+}