blob: 5c5e280e71f641107ea2181ee6712e2e884553fc [file] [log] [blame]
Sergiusz Bazanskic315aac2020-01-02 16:43:39 +01001package main
2
3import (
4 "flag"
5 "fmt"
6 "io"
7 "io/ioutil"
8 "net/http"
9 "regexp"
10
11 "code.hackerspace.pl/hscloud/go/mirko"
12 "github.com/dgraph-io/ristretto"
13 tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api"
14 "github.com/golang/glog"
15 "github.com/ulule/limiter/v3"
16 "github.com/ulule/limiter/v3/drivers/store/memory"
17)
18
19func init() {
20 flag.Set("logtostderr", "true")
21}
22
23var (
24 flagPublicListen string
25 flagTelegramToken string
Sergiusz Bazanskieffafe92020-01-05 22:31:39 +010026 reTelegram = regexp.MustCompile(`/fileid/([a-zA-Z0-9_-]+).([a-z0-9]+)`)
Sergiusz Bazanskic315aac2020-01-02 16:43:39 +010027)
28
29type server struct {
30 cache *ristretto.Cache
31 limiter *limiter.Limiter
32 tel *tgbotapi.BotAPI
33}
34
35func main() {
36 flag.StringVar(&flagPublicListen, "public_listen", "127.0.0.1:5000", "Listen address for public HTTP handler")
37 flag.StringVar(&flagTelegramToken, "telegram_token", "", "Telegram Bot API Token")
38 flag.Parse()
39
40 if flagTelegramToken == "" {
41 glog.Exitf("telegram_token must be set")
42 }
43
44 cache, err := ristretto.NewCache(&ristretto.Config{
45 NumCounters: 1e7, // number of keys to track frequency of (10M).
46 MaxCost: 1 << 30, // maximum cost of cache (1GB).
47 BufferItems: 64, // number of keys per Get buffer.
48 })
49 if err != nil {
50 glog.Exit(err)
51 }
52
53 tel, err := tgbotapi.NewBotAPI(flagTelegramToken)
54 if err != nil {
55 glog.Exitf("Error when creating telegram bot: %v", err)
56 }
57
58 rate, err := limiter.NewRateFromFormatted("10-M")
59 if err != nil {
60 glog.Exit(err)
61 }
62
63 store := memory.NewStore()
64 instance := limiter.New(store, rate, limiter.WithTrustForwardHeader(true))
65
66 s := &server{
67 cache: cache,
68 limiter: instance,
69 tel: tel,
70 }
71
72 m := mirko.New()
73 if err := m.Listen(); err != nil {
74 glog.Exitf("Listen(): %v", err)
75 }
76
77 if err := m.Serve(); err != nil {
78 glog.Exitf("Serve(): %v", err)
79 }
80
81 publicMux := http.NewServeMux()
82 publicMux.HandleFunc("/", s.publicHandler)
83 publicSrv := http.Server{
84 Addr: flagPublicListen,
85 Handler: publicMux,
86 }
87 go func() {
88 if err := publicSrv.ListenAndServe(); err != nil {
89 glog.Exitf("public ListenAndServe: %v", err)
90 }
91 }()
92
93 <-m.Done()
94}
95
96func setMime(w http.ResponseWriter, ext string) {
97 switch ext {
98 case "jpg":
99 w.Header().Set("Content-Type", "image/jpeg")
100 case "mp4":
101 w.Header().Set("Content-Type", "video/mp4")
102 }
103}
104
105func (s *server) publicHandler(w http.ResponseWriter, r *http.Request) {
106 ctx := r.Context()
107
108 if !reTelegram.MatchString(r.URL.Path) {
109 http.NotFound(w, r)
110 return
111 }
112 parts := reTelegram.FindStringSubmatch(r.URL.Path)
113 fileid := parts[1]
114 fileext := parts[2]
115 glog.Infof("FileID: %s", fileid)
116
117 c, ok := s.cache.Get(fileid)
118 if ok {
119 glog.Infof("Get %q - cache hit", fileid)
120 // cache hit
121 setMime(w, fileext)
122 w.Write(c.([]byte))
123 return
124 }
125
126 glog.Infof("Get %q - cache miss", fileid)
127
128 limit, err := s.limiter.Get(ctx, s.limiter.GetIPKey(r))
129 if err != nil {
130 w.WriteHeader(500)
131 fmt.Fprintf(w, ":(")
132 glog.Errorf("limiter.Get(%q): %v", s.limiter.GetIPKey(r), err)
133 return
134 }
135
136 if limit.Reached {
137 w.WriteHeader(420)
138 fmt.Fprintf(w, "enhance your calm")
139 glog.Warningf("Limit reached by %q", s.limiter.GetIPKey(r))
140 return
141 }
142
143 f, err := s.tel.GetFile(tgbotapi.FileConfig{fileid})
144 if err != nil {
145 w.WriteHeader(502)
146 fmt.Fprintf(w, "telegram mumbles.")
147 glog.Errorf("tel.GetFile(%q): %v", fileid, err)
148 return
149 }
150
151 target := f.Link(flagTelegramToken)
152
153 req, err := http.NewRequest("GET", target, nil)
154 if err != nil {
155 w.WriteHeader(500)
156 fmt.Fprintf(w, ":(")
157 glog.Errorf("NewRequest(GET, %q, nil): %v", target, err)
158 return
159 }
160
161 req = req.WithContext(ctx)
162 res, err := http.DefaultClient.Do(req)
163 if err != nil {
164 w.WriteHeader(500)
165 fmt.Fprintf(w, ":(")
166 glog.Errorf("GET(%q): %v", target, err)
167 return
168 }
169 defer res.Body.Close()
170
171 if res.StatusCode != 200 {
172 // do not cache errors
173 w.WriteHeader(res.StatusCode)
174 io.Copy(w, res.Body)
175 return
176 }
177
178 b, err := ioutil.ReadAll(res.Body)
179 if err != nil {
180 w.WriteHeader(500)
181 fmt.Fprintf(w, ":(")
182 glog.Errorf("Read(%q): %v", target, err)
183 return
184 }
185
186 s.cache.Set(fileid, b, int64(len(b)))
187
188 setMime(w, fileext)
189 w.Write(b)
190}