| package main |
| |
| import ( |
| "flag" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "net/http" |
| "regexp" |
| |
| "code.hackerspace.pl/hscloud/go/mirko" |
| "github.com/dgraph-io/ristretto" |
| tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api" |
| "github.com/golang/glog" |
| "github.com/ulule/limiter/v3" |
| "github.com/ulule/limiter/v3/drivers/store/memory" |
| ) |
| |
| func init() { |
| flag.Set("logtostderr", "true") |
| } |
| |
| var ( |
| flagPublicListen string |
| flagTelegramToken string |
| reTelegram = regexp.MustCompile(`/fileid/([a-zA-Z0-9_-]+).([a-z0-9]+)`) |
| ) |
| |
| type server struct { |
| cache *ristretto.Cache |
| limiter *limiter.Limiter |
| tel *tgbotapi.BotAPI |
| } |
| |
| func main() { |
| flag.StringVar(&flagPublicListen, "public_listen", "127.0.0.1:5000", "Listen address for public HTTP handler") |
| flag.StringVar(&flagTelegramToken, "telegram_token", "", "Telegram Bot API Token") |
| flag.Parse() |
| |
| if flagTelegramToken == "" { |
| glog.Exitf("telegram_token must be set") |
| } |
| |
| cache, err := ristretto.NewCache(&ristretto.Config{ |
| NumCounters: 1e7, // number of keys to track frequency of (10M). |
| MaxCost: 1 << 30, // maximum cost of cache (1GB). |
| BufferItems: 64, // number of keys per Get buffer. |
| }) |
| if err != nil { |
| glog.Exit(err) |
| } |
| |
| tel, err := tgbotapi.NewBotAPI(flagTelegramToken) |
| if err != nil { |
| glog.Exitf("Error when creating telegram bot: %v", err) |
| } |
| |
| rate, err := limiter.NewRateFromFormatted("10-M") |
| if err != nil { |
| glog.Exit(err) |
| } |
| |
| store := memory.NewStore() |
| instance := limiter.New(store, rate, limiter.WithTrustForwardHeader(true)) |
| |
| s := &server{ |
| cache: cache, |
| limiter: instance, |
| tel: tel, |
| } |
| |
| m := mirko.New() |
| if err := m.Listen(); err != nil { |
| glog.Exitf("Listen(): %v", err) |
| } |
| |
| if err := m.Serve(); err != nil { |
| glog.Exitf("Serve(): %v", err) |
| } |
| |
| publicMux := http.NewServeMux() |
| publicMux.HandleFunc("/", s.publicHandler) |
| publicSrv := http.Server{ |
| Addr: flagPublicListen, |
| Handler: publicMux, |
| } |
| go func() { |
| if err := publicSrv.ListenAndServe(); err != nil { |
| glog.Exitf("public ListenAndServe: %v", err) |
| } |
| }() |
| |
| <-m.Done() |
| } |
| |
| func setMime(w http.ResponseWriter, ext string) { |
| switch ext { |
| case "jpg": |
| w.Header().Set("Content-Type", "image/jpeg") |
| case "mp4": |
| w.Header().Set("Content-Type", "video/mp4") |
| } |
| } |
| |
| func (s *server) publicHandler(w http.ResponseWriter, r *http.Request) { |
| ctx := r.Context() |
| |
| if !reTelegram.MatchString(r.URL.Path) { |
| http.NotFound(w, r) |
| return |
| } |
| parts := reTelegram.FindStringSubmatch(r.URL.Path) |
| fileid := parts[1] |
| fileext := parts[2] |
| glog.Infof("FileID: %s", fileid) |
| |
| c, ok := s.cache.Get(fileid) |
| if ok { |
| glog.Infof("Get %q - cache hit", fileid) |
| // cache hit |
| setMime(w, fileext) |
| w.Write(c.([]byte)) |
| return |
| } |
| |
| glog.Infof("Get %q - cache miss", fileid) |
| |
| limit, err := s.limiter.Get(ctx, s.limiter.GetIPKey(r)) |
| if err != nil { |
| w.WriteHeader(500) |
| fmt.Fprintf(w, ":(") |
| glog.Errorf("limiter.Get(%q): %v", s.limiter.GetIPKey(r), err) |
| return |
| } |
| |
| if limit.Reached { |
| w.WriteHeader(420) |
| fmt.Fprintf(w, "enhance your calm") |
| glog.Warningf("Limit reached by %q", s.limiter.GetIPKey(r)) |
| return |
| } |
| |
| f, err := s.tel.GetFile(tgbotapi.FileConfig{fileid}) |
| if err != nil { |
| w.WriteHeader(502) |
| fmt.Fprintf(w, "telegram mumbles.") |
| glog.Errorf("tel.GetFile(%q): %v", fileid, err) |
| return |
| } |
| |
| target := f.Link(flagTelegramToken) |
| |
| req, err := http.NewRequest("GET", target, nil) |
| if err != nil { |
| w.WriteHeader(500) |
| fmt.Fprintf(w, ":(") |
| glog.Errorf("NewRequest(GET, %q, nil): %v", target, err) |
| return |
| } |
| |
| req = req.WithContext(ctx) |
| res, err := http.DefaultClient.Do(req) |
| if err != nil { |
| w.WriteHeader(500) |
| fmt.Fprintf(w, ":(") |
| glog.Errorf("GET(%q): %v", target, err) |
| return |
| } |
| defer res.Body.Close() |
| |
| if res.StatusCode != 200 { |
| // do not cache errors |
| w.WriteHeader(res.StatusCode) |
| io.Copy(w, res.Body) |
| return |
| } |
| |
| b, err := ioutil.ReadAll(res.Body) |
| if err != nil { |
| w.WriteHeader(500) |
| fmt.Fprintf(w, ":(") |
| glog.Errorf("Read(%q): %v", target, err) |
| return |
| } |
| |
| s.cache.Set(fileid, b, int64(len(b))) |
| |
| setMime(w, fileext) |
| w.Write(b) |
| } |