blob: 30e41bf1659cf7bf8f0049e5c210d3acc6cb2339 [file] [log] [blame]
Sergiusz Bazanski0581bbf2020-05-11 03:21:32 +02001package main
2
3import (
4 "context"
5 "crypto/sha1"
6 "encoding/hex"
7 "flag"
8 "fmt"
9 "io"
10 "os"
11 "regexp"
12 "strings"
13 "sync"
14 "time"
15
16 "code.hackerspace.pl/hscloud/go/mirko"
17 "github.com/golang/glog"
18 "google.golang.org/grpc/codes"
19 "google.golang.org/grpc/status"
20
21 "code.hackerspace.pl/hscloud/games/factorio/modproxy/modportal"
22 pb "code.hackerspace.pl/hscloud/games/factorio/modproxy/proto"
23)
24
25func init() {
26 flag.Set("logtostderr", "true")
27}
28
29var (
30 flagCASDirectory string
31)
32
33func main() {
34 flag.StringVar(&flagCASDirectory, "cas_directory", "cas", "directory in which to store cached files")
35 flag.Parse()
36 m := mirko.New()
37 if err := m.Listen(); err != nil {
38 glog.Exitf("Listen(): %v", err)
39 }
40
41 srv := &service{
42 cache: make(map[string]*cacheEntry),
43 }
44
45 pb.RegisterModProxyServer(m.GRPC(), srv)
46
47 if err := m.Serve(); err != nil {
48 glog.Exitf("Serve(): %v", err)
49 }
50
51 <-m.Done()
52}
53
54var (
55 reSha1 = regexp.MustCompile(`^[a-f0-9]+$`)
56)
57
58func casPath(sha1 string) string {
59 sha1 = strings.ToLower(sha1)
60 if !reSha1.MatchString(sha1) {
61 return ""
62 }
63 return fmt.Sprintf("%s/%s", flagCASDirectory, sha1)
64}
65
66type service struct {
67 mu sync.Mutex
68
69 // cache of sha1 -> cache entry
70 cache map[string]*cacheEntry
71}
72
73type cacheEntry struct {
74 expires *time.Time
75 modName string
76
77 // found means that this is an entry confirmed on the mod portal
78 found bool
79 // mirrored means we are ready to serve this file to users
80 mirrored bool
81}
82
83func (s *service) Mirror(ctx context.Context, req *pb.MirrorRequest) (*pb.MirrorResponse, error) {
84
85 // build map of sha1->modName for needed downloads
86 modNames := make(map[string]string)
87 s.mu.Lock()
88 for sha, e := range s.cache {
89 if e == nil {
90 continue
91 }
92 if e.found == false {
93 continue
94 }
95 if e.mirrored == true {
96 continue
97 }
98
99 modNames[sha] = e.modName
100 }
101 s.mu.Unlock()
102
103 okays := make(map[string]bool)
104 errors := make(map[string]error)
105
106 for sha, modName := range modNames {
107 k := fmt.Sprintf("%s/%s", modName, sha)
108 mod, err := modportal.GetMod(ctx, modName)
109 if err != nil {
110 errors[k] = err
111 continue
112 }
113 release := mod.ReleaseBySHA1(sha)
114 if release == nil {
115 errors[k] = fmt.Errorf("could not find sha1 in modportal - deleted?")
116 continue
117 }
118
119 r, err := release.Download(ctx, req.Username, req.Token)
120 if err != nil {
121 errors[k] = fmt.Errorf("could not download: %v", err)
122 continue
123 }
124
125 path := casPath(sha)
126 pathIncoming := path + ".incoming"
127
128 out, err := os.Create(pathIncoming)
129 if err != nil {
130 errors[k] = fmt.Errorf("could not create file: %v", err)
131 continue
132 }
133 _, err = io.Copy(out, r)
134 if err != nil {
135 errors[k] = fmt.Errorf("could not save: %v", err)
136 continue
137 }
138 err = os.Rename(pathIncoming, path)
139 if err != nil {
140 errors[k] = fmt.Errorf("could not commit file: %v", err)
141 continue
142 }
143
144 okays[k] = true
145 s.cacheFeed(sha, modName, nil, true, true)
146 }
147
148 res := &pb.MirrorResponse{
149 ModsErrors: make(map[string]string),
150 }
151 for m, _ := range okays {
152 glog.Infof("Downloaded %q", m)
153 res.ModsOkay = append(res.ModsOkay, m)
154 }
155 for m, err := range errors {
156 glog.Errorf("Could not download %q: %v", m, err)
157 res.ModsErrors[m] = fmt.Sprintf("%v", err)
158 }
159
160 return res, nil
161}
162
163func (s *service) cacheGet(sha1 string) (hit, found, mirrored bool) {
164 s.mu.Lock()
165 defer s.mu.Unlock()
166
167 entry, ok := s.cache[sha1]
168 if !ok || entry == nil {
169 return
170 }
171
172 if entry.expires != nil && time.Now().Before(*entry.expires) {
173 delete(s.cache, sha1)
174 return
175 }
176
177 hit = true
178 found = entry.found
179 mirrored = entry.mirrored
180
181 return
182}
183
184func (s *service) cacheFeed(sha1, modName string, expires *time.Time, found, mirrored bool) {
185 s.mu.Lock()
186 defer s.mu.Unlock()
187
188 s.cache[sha1] = &cacheEntry{
189 expires: expires,
190 modName: modName,
191 found: found,
192 mirrored: mirrored,
193 }
194}
195
196func (s *service) serve(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
197 cas := casPath(req.FileSha1)
198 if cas == "" {
199 // Invalid sha1? Fail.
200 return status.Error(codes.Aborted, "invalid sha1")
201 }
202
203 file, err := os.Open(cas)
204 if err != nil {
205 // not in CAS, update cache and fail
206 s.cacheFeed(req.FileSha1, req.ModName, nil, true, false)
207 return srv.Send(&pb.DownloadResponse{
208 Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
209 })
210 }
211 defer file.Close()
212
213 err = srv.Send(&pb.DownloadResponse{
214 Status: pb.DownloadResponse_STATUS_OKAY,
215 })
216 if err != nil {
217 return err
218 }
219
220 buf := make([]byte, 1024*1024)
221 hash := sha1.New()
222
223 for {
224 n, err := file.Read(buf)
225 if err == io.EOF {
226 break
227 }
228 if err != nil {
229 return status.Errorf(codes.Unavailable, "error reading file: %v", err)
230 }
231 hash.Write(buf[:n])
232 err = srv.Send(&pb.DownloadResponse{
233 Chunk: buf[:n],
234 })
235 if err != nil {
236 return err
237 }
238 }
239
240 // entire file send, double-check shasum
241 sum := hex.EncodeToString(hash.Sum(nil))
242 if sum != req.FileSha1 {
243 glog.Errorf("CAS corruption: wanted %q, got %q", req.FileSha1, sum)
244 return status.Error(codes.Aborted, "CAS corruption")
245 }
246
247 return nil
248}
249
250func (s *service) Download(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
251 ctx := srv.Context()
252
253 modName := req.ModName
254 if modName == "" {
255 return status.Error(codes.InvalidArgument, "mod name must be set")
256 }
257 sha1 := req.FileSha1
258 if sha1 == "" {
259 return status.Error(codes.InvalidArgument, "sha1 must be set")
260 }
261 sha1 = strings.ToLower(sha1)
262 req.FileSha1 = sha1
263
264 cacheHit, found, mirrored := s.cacheGet(sha1)
265 if cacheHit {
266 if !found {
267 return status.Error(codes.NotFound, "sha1 not found for mod")
268 }
269 if !mirrored {
270 return srv.Send(&pb.DownloadResponse{
271 Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
272 })
273 }
274
275 // we have the file, serve it
276 return s.serve(req, srv)
277 }
278
279 // cache not hit, check mod portal
280 mod, err := modportal.GetMod(ctx, modName)
281 if err != nil {
282 return err
283 }
284 release := mod.ReleaseBySHA1(sha1)
285
286 // release not found in mod portal, cache and answer
287 if release == nil {
288 expires := time.Now().Add(1 * time.Minute)
289 s.cacheFeed(sha1, modName, &expires, false, false)
290 return status.Error(codes.InvalidArgument, "sha1 not found for mod")
291 }
292
293 // we assume it's mirrored - the first cas serve will prove us wrong otherwise and
294 // update the cache.
295 s.cacheFeed(sha1, modName, nil, true, true)
296 // call ourselves again now that the cache is fed. computers - it's like magic!
297 return s.Download(req, srv)
298}