blob: 30e41bf1659cf7bf8f0049e5c210d3acc6cb2339 [file] [log] [blame]
package main
import (
"context"
"crypto/sha1"
"encoding/hex"
"flag"
"fmt"
"io"
"os"
"regexp"
"strings"
"sync"
"time"
"code.hackerspace.pl/hscloud/go/mirko"
"github.com/golang/glog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"code.hackerspace.pl/hscloud/games/factorio/modproxy/modportal"
pb "code.hackerspace.pl/hscloud/games/factorio/modproxy/proto"
)
func init() {
flag.Set("logtostderr", "true")
}
var (
flagCASDirectory string
)
func main() {
flag.StringVar(&flagCASDirectory, "cas_directory", "cas", "directory in which to store cached files")
flag.Parse()
m := mirko.New()
if err := m.Listen(); err != nil {
glog.Exitf("Listen(): %v", err)
}
srv := &service{
cache: make(map[string]*cacheEntry),
}
pb.RegisterModProxyServer(m.GRPC(), srv)
if err := m.Serve(); err != nil {
glog.Exitf("Serve(): %v", err)
}
<-m.Done()
}
var (
reSha1 = regexp.MustCompile(`^[a-f0-9]+$`)
)
func casPath(sha1 string) string {
sha1 = strings.ToLower(sha1)
if !reSha1.MatchString(sha1) {
return ""
}
return fmt.Sprintf("%s/%s", flagCASDirectory, sha1)
}
type service struct {
mu sync.Mutex
// cache of sha1 -> cache entry
cache map[string]*cacheEntry
}
type cacheEntry struct {
expires *time.Time
modName string
// found means that this is an entry confirmed on the mod portal
found bool
// mirrored means we are ready to serve this file to users
mirrored bool
}
func (s *service) Mirror(ctx context.Context, req *pb.MirrorRequest) (*pb.MirrorResponse, error) {
// build map of sha1->modName for needed downloads
modNames := make(map[string]string)
s.mu.Lock()
for sha, e := range s.cache {
if e == nil {
continue
}
if e.found == false {
continue
}
if e.mirrored == true {
continue
}
modNames[sha] = e.modName
}
s.mu.Unlock()
okays := make(map[string]bool)
errors := make(map[string]error)
for sha, modName := range modNames {
k := fmt.Sprintf("%s/%s", modName, sha)
mod, err := modportal.GetMod(ctx, modName)
if err != nil {
errors[k] = err
continue
}
release := mod.ReleaseBySHA1(sha)
if release == nil {
errors[k] = fmt.Errorf("could not find sha1 in modportal - deleted?")
continue
}
r, err := release.Download(ctx, req.Username, req.Token)
if err != nil {
errors[k] = fmt.Errorf("could not download: %v", err)
continue
}
path := casPath(sha)
pathIncoming := path + ".incoming"
out, err := os.Create(pathIncoming)
if err != nil {
errors[k] = fmt.Errorf("could not create file: %v", err)
continue
}
_, err = io.Copy(out, r)
if err != nil {
errors[k] = fmt.Errorf("could not save: %v", err)
continue
}
err = os.Rename(pathIncoming, path)
if err != nil {
errors[k] = fmt.Errorf("could not commit file: %v", err)
continue
}
okays[k] = true
s.cacheFeed(sha, modName, nil, true, true)
}
res := &pb.MirrorResponse{
ModsErrors: make(map[string]string),
}
for m, _ := range okays {
glog.Infof("Downloaded %q", m)
res.ModsOkay = append(res.ModsOkay, m)
}
for m, err := range errors {
glog.Errorf("Could not download %q: %v", m, err)
res.ModsErrors[m] = fmt.Sprintf("%v", err)
}
return res, nil
}
func (s *service) cacheGet(sha1 string) (hit, found, mirrored bool) {
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.cache[sha1]
if !ok || entry == nil {
return
}
if entry.expires != nil && time.Now().Before(*entry.expires) {
delete(s.cache, sha1)
return
}
hit = true
found = entry.found
mirrored = entry.mirrored
return
}
func (s *service) cacheFeed(sha1, modName string, expires *time.Time, found, mirrored bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.cache[sha1] = &cacheEntry{
expires: expires,
modName: modName,
found: found,
mirrored: mirrored,
}
}
func (s *service) serve(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
cas := casPath(req.FileSha1)
if cas == "" {
// Invalid sha1? Fail.
return status.Error(codes.Aborted, "invalid sha1")
}
file, err := os.Open(cas)
if err != nil {
// not in CAS, update cache and fail
s.cacheFeed(req.FileSha1, req.ModName, nil, true, false)
return srv.Send(&pb.DownloadResponse{
Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
})
}
defer file.Close()
err = srv.Send(&pb.DownloadResponse{
Status: pb.DownloadResponse_STATUS_OKAY,
})
if err != nil {
return err
}
buf := make([]byte, 1024*1024)
hash := sha1.New()
for {
n, err := file.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return status.Errorf(codes.Unavailable, "error reading file: %v", err)
}
hash.Write(buf[:n])
err = srv.Send(&pb.DownloadResponse{
Chunk: buf[:n],
})
if err != nil {
return err
}
}
// entire file send, double-check shasum
sum := hex.EncodeToString(hash.Sum(nil))
if sum != req.FileSha1 {
glog.Errorf("CAS corruption: wanted %q, got %q", req.FileSha1, sum)
return status.Error(codes.Aborted, "CAS corruption")
}
return nil
}
func (s *service) Download(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
ctx := srv.Context()
modName := req.ModName
if modName == "" {
return status.Error(codes.InvalidArgument, "mod name must be set")
}
sha1 := req.FileSha1
if sha1 == "" {
return status.Error(codes.InvalidArgument, "sha1 must be set")
}
sha1 = strings.ToLower(sha1)
req.FileSha1 = sha1
cacheHit, found, mirrored := s.cacheGet(sha1)
if cacheHit {
if !found {
return status.Error(codes.NotFound, "sha1 not found for mod")
}
if !mirrored {
return srv.Send(&pb.DownloadResponse{
Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
})
}
// we have the file, serve it
return s.serve(req, srv)
}
// cache not hit, check mod portal
mod, err := modportal.GetMod(ctx, modName)
if err != nil {
return err
}
release := mod.ReleaseBySHA1(sha1)
// release not found in mod portal, cache and answer
if release == nil {
expires := time.Now().Add(1 * time.Minute)
s.cacheFeed(sha1, modName, &expires, false, false)
return status.Error(codes.InvalidArgument, "sha1 not found for mod")
}
// we assume it's mirrored - the first cas serve will prove us wrong otherwise and
// update the cache.
s.cacheFeed(sha1, modName, nil, true, true)
// call ourselves again now that the cache is fed. computers - it's like magic!
return s.Download(req, srv)
}