hswaw/smsgw: implement
The SMS gateway service allows consumers to subscribe to SMS messages
received by a Twilio phone number.
This is useful for receiving SMS auth messages.
Change-Id: Ib02a4306ad0d856dd10c7ca9241d9163809e7084
diff --git a/hswaw/smsgw/BUILD.bazel b/hswaw/smsgw/BUILD.bazel
new file mode 100644
index 0000000..35b36f1
--- /dev/null
+++ b/hswaw/smsgw/BUILD.bazel
@@ -0,0 +1,31 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library", "go_test")
+
+go_library(
+ name = "go_default_library",
+ srcs = [
+ "dispatcher.go",
+ "main.go",
+ "twilio.go",
+ ],
+ importpath = "code.hackerspace.pl/hscloud/hswaw/smsgw",
+ visibility = ["//visibility:private"],
+ deps = [
+ "//go/mirko:go_default_library",
+ "//hswaw/smsgw/proto:go_default_library",
+ "@com_github_golang_glog//:go_default_library",
+ "@org_golang_google_grpc//codes:go_default_library",
+ "@org_golang_google_grpc//status:go_default_library",
+ ],
+)
+
+go_binary(
+ name = "smsgw",
+ embed = [":go_default_library"],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "go_default_test",
+ srcs = ["dispatcher_test.go"],
+ embed = [":go_default_library"],
+)
diff --git a/hswaw/smsgw/dispatcher.go b/hswaw/smsgw/dispatcher.go
new file mode 100644
index 0000000..60fba32
--- /dev/null
+++ b/hswaw/smsgw/dispatcher.go
@@ -0,0 +1,110 @@
+package main
+
+import (
+ "context"
+ "regexp"
+ "time"
+
+ "github.com/golang/glog"
+)
+
+// dispatcher is responsible for dispatching incoming SMS messages to subscribers
+// that have chosen to receive them, filtering accordingly.
+type dispatcher struct {
+ // New SMS messages to be dispatched.
+ incoming chan *sms
+ // New subscribers to send messages to.
+ subscribers chan *subscriber
+}
+
+// newDispatcher creates a new dispatcher.
+func newDispatcher() *dispatcher {
+ return &dispatcher{
+ incoming: make(chan *sms),
+ subscribers: make(chan *subscriber),
+ }
+}
+
+// sms received from the upstream provider.
+type sms struct {
+ from string
+ body string
+ timestamp time.Time
+}
+
+// subscriber that wants to receive messages with a given body filter.
+type subscriber struct {
+ // regexp to filter message body by
+ re *regexp.Regexp
+ // channel to which messages will be sent, must be emptied regularly by the
+ // subscriber.
+ data chan *sms
+ // channel that needs to be closed when the subscriber doesn't want to receive
+ // any more messages.
+ cancel chan struct{}
+}
+
+func (p *dispatcher) publish(msg *sms) {
+ p.incoming <- msg
+}
+
+func (p *dispatcher) subscribe(sub *subscriber) {
+ p.subscribers <- sub
+}
+
+func (p *dispatcher) run(ctx context.Context) {
+ // Map of internal IDs to subscribers. Internal IDs are used to remove
+ // canceled subscribers easily.
+ subscriberMap := make(map[int64]*subscriber)
+ // Internal channel that will emit SIDs of subscribers that needs to be
+ // removed.
+ subscriberCancel := make(chan int64)
+
+ for {
+ select {
+
+ // Should the processor close?
+ case <-ctx.Done():
+ return
+
+ // Do we need to remove a given subscriber?
+ case sid := <-subscriberCancel:
+ delete(subscriberMap, sid)
+
+ // Do we have a new subscriber?
+ case sub := <-p.subscribers:
+ // Generate a SID. A UNIX nanosecond timestamp is enough, since
+ // we're not running in parallel.
+ sid := time.Now().UnixNano()
+ glog.V(5).Infof("New subscriber %x, regexp %v", sid, sub.re)
+
+ // Add to subscriber map.
+ subscriberMap[sid] = sub
+
+ // On sub.cancel closed, emit info that we need to delete that
+ // subscriber.
+ go func() {
+ _, _ = <-sub.cancel
+ subscriberCancel <- sid
+ }()
+
+ // Do we have a new message to dispatch?
+ case in := <-p.incoming:
+ for sid, s := range subscriberMap {
+ glog.V(10).Infof("Considering %x", sid)
+ // If this subscriber doesn't care, ignore.
+ if !s.re.MatchString(in.body) {
+ continue
+ }
+
+ // Send, non-blocking, to subscriber. This ensures that we
+ // don't get stuck if a subscriber doesn't drain fast enough.
+ go func(to *subscriber, sid int64) {
+ glog.V(10).Infof("Dispatching to %x, %v", sid, to.data)
+ to.data <- in
+ glog.V(10).Infof("Dispatched to %x", sid)
+ }(s, sid)
+ }
+ }
+ }
+}
diff --git a/hswaw/smsgw/dispatcher_test.go b/hswaw/smsgw/dispatcher_test.go
new file mode 100644
index 0000000..a5f3977
--- /dev/null
+++ b/hswaw/smsgw/dispatcher_test.go
@@ -0,0 +1,206 @@
+package main
+
+import (
+ "context"
+ "regexp"
+ "testing"
+ "time"
+)
+
+func makeDut() (*dispatcher, context.CancelFunc, context.Context) {
+ dut := newDispatcher()
+
+ ctx := context.Background()
+ ctxC, cancelCtx := context.WithCancel(ctx)
+ go dut.run(ctxC)
+
+ return dut, cancelCtx, ctx
+}
+
+func expectReceived(t *testing.T, s *sms, data chan *sms) {
+ ticker := time.NewTicker(100 * time.Millisecond)
+ defer ticker.Stop()
+ select {
+ case d := <-data:
+ if d.from != s.from {
+ t.Errorf("Received SMS from %q, wanted %q", d.from, s.from)
+ }
+ if d.body != s.body {
+ t.Errorf("Received SMS body %q, wanted %q", d.body, s.body)
+ }
+ if d.timestamp != s.timestamp {
+ t.Errorf("Received SMS timestamp %v, wanted %v", d.timestamp, s.timestamp)
+ }
+ case <-ticker.C:
+ t.Fatalf("Timed out waiting for message")
+ }
+}
+
+func expectEmpty(t *testing.T, data chan *sms) {
+ ticker := time.NewTicker(1 * time.Millisecond)
+ defer ticker.Stop()
+ select {
+ case <-data:
+ t.Fatalf("Received unwanted message")
+ case <-ticker.C:
+ }
+}
+
+func TestDispatcher(t *testing.T) {
+ dut, cancelDut, _ := makeDut()
+ defer cancelDut()
+
+ data := make(chan *sms)
+ cancel := make(chan struct{})
+
+ dut.subscribe(&subscriber{
+ re: regexp.MustCompile(".*"),
+ data: data,
+ cancel: cancel,
+ })
+
+ in := &sms{
+ from: "+4821372137",
+ body: "foo",
+ timestamp: time.Now(),
+ }
+ dut.publish(in)
+
+ // Make sure we ge the message.
+ expectReceived(t, in, data)
+
+ // Make sure we don't receive the message again.
+ expectEmpty(t, data)
+
+ // Publish a new message, but this time close our subscriber.
+ close(cancel)
+ // Hack: yield.
+ time.Sleep(1 * time.Millisecond)
+
+ dut.publish(in)
+ expectEmpty(t, data)
+}
+
+type testSubscriber struct {
+ re *regexp.Regexp
+ data chan *sms
+ cancel chan struct{}
+}
+
+func TestDispatcherFilters(t *testing.T) {
+ dut, cancelDut, _ := makeDut()
+ defer cancelDut()
+
+ subscribers := []*testSubscriber{
+ {re: regexp.MustCompile(".*")},
+ {re: regexp.MustCompile("foo")},
+ {re: regexp.MustCompile("bar")},
+ }
+
+ for _, s := range subscribers {
+ s.data = make(chan *sms)
+ s.cancel = make(chan struct{})
+ dut.subscribe(&subscriber{
+ re: s.re,
+ data: s.data,
+ cancel: s.cancel,
+ })
+ defer func(c chan struct{}) {
+ close(c)
+ }(s.cancel)
+ }
+
+ in := &sms{
+ from: "+4821372137",
+ body: "foo",
+ timestamp: time.Now(),
+ }
+ dut.publish(in)
+ expectReceived(t, in, subscribers[0].data)
+ expectReceived(t, in, subscribers[1].data)
+ expectEmpty(t, subscribers[2].data)
+
+ in = &sms{
+ from: "+4821372137",
+ body: "bar",
+ timestamp: time.Now(),
+ }
+ dut.publish(in)
+ expectReceived(t, in, subscribers[0].data)
+ expectEmpty(t, subscribers[1].data)
+ expectReceived(t, in, subscribers[2].data)
+
+ in = &sms{
+ from: "+4821372137",
+ body: "foobar",
+ timestamp: time.Now(),
+ }
+ dut.publish(in)
+ expectReceived(t, in, subscribers[0].data)
+ expectReceived(t, in, subscribers[1].data)
+ expectReceived(t, in, subscribers[2].data)
+}
+
+func TestDispatcherMany(t *testing.T) {
+ dut, cancelDut, _ := makeDut()
+ defer cancelDut()
+
+ subscribers := make([]*testSubscriber, 10000)
+
+ for i, _ := range subscribers {
+ s := &testSubscriber{
+ re: regexp.MustCompile(".*"),
+ data: make(chan *sms),
+ cancel: make(chan struct{}),
+ }
+ subscribers[i] = s
+ dut.subscribe(&subscriber{
+ re: s.re,
+ data: s.data,
+ cancel: s.cancel,
+ })
+ defer func(c chan struct{}) {
+ close(c)
+ }(s.cancel)
+ }
+
+ in := &sms{
+ from: "+4821372137",
+ body: "foo",
+ timestamp: time.Now(),
+ }
+ dut.publish(in)
+
+ for _, s := range subscribers {
+ expectReceived(t, in, s.data)
+ }
+}
+
+func TestDispatcherHammer(t *testing.T) {
+ dut, cancelDut, _ := makeDut()
+ defer cancelDut()
+
+ for i := 0; i < 1000000; i += 1 {
+ s := &testSubscriber{
+ re: regexp.MustCompile(".*"),
+ data: make(chan *sms),
+ cancel: make(chan struct{}),
+ }
+
+ dut.subscribe(&subscriber{
+ re: s.re,
+ data: s.data,
+ cancel: s.cancel,
+ })
+
+ in := &sms{
+ from: "+4821372137",
+ body: "foo",
+ timestamp: time.Now(),
+ }
+ dut.publish(in)
+ expectReceived(t, in, s.data)
+
+ close(s.cancel)
+ }
+}
diff --git a/hswaw/smsgw/main.go b/hswaw/smsgw/main.go
new file mode 100644
index 0000000..a0a6a07
--- /dev/null
+++ b/hswaw/smsgw/main.go
@@ -0,0 +1,226 @@
+package main
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "net/http"
+ "regexp"
+ "strings"
+ "time"
+
+ "code.hackerspace.pl/hscloud/go/mirko"
+ "github.com/golang/glog"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ pb "code.hackerspace.pl/hscloud/hswaw/smsgw/proto"
+)
+
+var (
+ flagTwilioSID string
+ flagTwilioToken string
+ flagTwilioFriendlyPhone string
+
+ flagWebhookListen string
+ flagWebhookPublic string
+)
+
+func init() {
+ flag.Set("logtostderr", "true")
+}
+
+type server struct {
+ dispatcher *dispatcher
+}
+
+func ourPhoneNumber(ctx context.Context, t *twilio, friendly string) (*incomingPhoneNumber, error) {
+ ipn, err := t.getIncomingPhoneNumbers(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, pn := range ipn {
+ if pn.FriendlyName == friendly {
+ return &pn, nil
+ }
+ }
+
+ return nil, fmt.Errorf("requested phone number %q not in list", friendly)
+}
+
+func ensureWebhook(ctx context.Context, t *twilio) {
+ pn, err := ourPhoneNumber(ctx, t, flagTwilioFriendlyPhone)
+ if err != nil {
+ glog.Exitf("could not get our phone number: %v", err)
+ }
+
+ url := fmt.Sprintf("%ssms", flagWebhookPublic)
+
+ // first setup.
+ if pn.SMSMethod != "POST" || pn.SMSURL != url {
+ glog.Infof("Updating webhook (is %s %q, want %s %q)", pn.SMSMethod, pn.SMSURL, "POST", url)
+ if err := t.updateIncomingPhoneNumberSMSWebhook(ctx, pn.SID, "POST", url); err != nil {
+ glog.Exitf("could not set webhook: %v")
+ }
+
+ // try again to check that it's actually set
+ for {
+ pn, err = ourPhoneNumber(ctx, t, flagTwilioFriendlyPhone)
+ if err != nil {
+ glog.Exitf("could not get our phone number: %v", err)
+ }
+ if pn.SMSMethod == "POST" || pn.SMSURL == url {
+ break
+ }
+ glog.Infof("Webhook not yet ready, currently %s %q", pn.SMSMethod, pn.SMSURL)
+ time.Sleep(5 * time.Second)
+ }
+ glog.Infof("Webhook verifier")
+ } else {
+ glog.Infof("Webhook up to date")
+ }
+
+ // now keep checking to make sure that nobody takes over our webhook
+ tick := time.NewTicker(30 * time.Second)
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-tick.C:
+ pn, err = ourPhoneNumber(ctx, t, flagTwilioFriendlyPhone)
+ if err != nil {
+ glog.Exitf("could not get our phone number: %v", err)
+ }
+ if pn.SMSMethod != "POST" || pn.SMSURL != url {
+ glog.Exitf("Webhook got deconfigured, not %s %q", pn.SMSMethod, pn.SMSURL)
+ }
+ }
+ }
+}
+
+func (s *server) webhookHandler(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ glog.Errorf("webhook body parse error: %v", err)
+ return
+ }
+
+ accountSID := r.PostForm.Get("AccountSid")
+ if accountSID != flagTwilioSID {
+ glog.Errorf("webhook got wrong account sid, got %q, wanted %q", accountSID, flagTwilioSID)
+ return
+ }
+
+ body := r.PostForm.Get("Body")
+ if body == "" {
+ return
+ }
+
+ from := r.PostForm.Get("From")
+
+ glog.Infof("Got SMS from %q, body %q", from, body)
+
+ s.dispatcher.publish(&sms{
+ from: from,
+ body: body,
+ timestamp: time.Now(),
+ })
+
+ w.WriteHeader(200)
+}
+
+func main() {
+ flag.StringVar(&flagTwilioSID, "twilio_sid", "", "Twilio account SID")
+ flag.StringVar(&flagTwilioToken, "twilio_token", "", "Twilio auth token")
+ flag.StringVar(&flagTwilioFriendlyPhone, "twilio_friendly_phone", "", "Twilio friendly phone number")
+
+ flag.StringVar(&flagWebhookListen, "webhook_listen", "127.0.0.1:5000", "Listen address for webhook handler")
+ flag.StringVar(&flagWebhookPublic, "webhook_public", "", "Public address for webhook handler (wg. http://proxy.q3k.org/smsgw/)")
+ flag.Parse()
+
+ if flagTwilioSID == "" || flagTwilioToken == "" {
+ glog.Exitf("twilio_sid and twilio_token must be set")
+ }
+
+ if flagTwilioFriendlyPhone == "" {
+ glog.Exitf("twilio_friendly_phone must be set")
+ }
+
+ if flagWebhookPublic == "" {
+ glog.Exitf("webhook_public must be set")
+ }
+
+ if !strings.HasSuffix(flagWebhookPublic, "/") {
+ flagWebhookPublic += "/"
+ }
+
+ s := &server{
+ dispatcher: newDispatcher(),
+ }
+
+ m := mirko.New()
+ if err := m.Listen(); err != nil {
+ glog.Exitf("Listen(): %v", err)
+ }
+
+ webhookMux := http.NewServeMux()
+ webhookMux.HandleFunc("/sms", s.webhookHandler)
+ webhookSrv := http.Server{
+ Addr: flagWebhookListen,
+ Handler: webhookMux,
+ }
+ go func() {
+ if err := webhookSrv.ListenAndServe(); err != nil {
+ glog.Exitf("webhook ListenAndServe: %v", err)
+ }
+ }()
+
+ t := &twilio{
+ accountSID: flagTwilioSID,
+ accountToken: flagTwilioToken,
+ }
+ go ensureWebhook(m.Context(), t)
+ go s.dispatcher.run(m.Context())
+
+ pb.RegisterSMSGatewayServer(m.GRPC(), s)
+
+ if err := m.Serve(); err != nil {
+ glog.Exitf("Serve(): %v", err)
+ }
+
+ <-m.Done()
+}
+
+func (s *server) Messages(req *pb.MessagesRequest, stream pb.SMSGateway_MessagesServer) error {
+ re := regexp.MustCompile(".*")
+ if req.FilterBody != "" {
+ var err error
+ re, err = regexp.Compile(req.FilterBody)
+ if err != nil {
+ return status.Errorf(codes.InvalidArgument, "filter regexp error: %v", err)
+ }
+ }
+
+ data := make(chan *sms)
+ cancel := make(chan struct{})
+ defer func() {
+ close(cancel)
+ close(data)
+ }()
+
+ s.dispatcher.subscribe(&subscriber{
+ re: re,
+ data: data,
+ cancel: cancel,
+ })
+
+ for d := range data {
+ stream.Send(&pb.MessagesResponse{
+ Sender: d.from,
+ Body: d.body,
+ Timestamp: d.timestamp.UnixNano(),
+ })
+ }
+
+ return nil
+}
diff --git a/hswaw/smsgw/proto/BUILD.bazel b/hswaw/smsgw/proto/BUILD.bazel
new file mode 100644
index 0000000..be3451f
--- /dev/null
+++ b/hswaw/smsgw/proto/BUILD.bazel
@@ -0,0 +1,23 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+
+proto_library(
+ name = "proto_proto",
+ srcs = ["smsgw.proto"],
+ visibility = ["//visibility:public"],
+)
+
+go_proto_library(
+ name = "proto_go_proto",
+ compilers = ["@io_bazel_rules_go//proto:go_grpc"],
+ importpath = "code.hackerspace.pl/hscloud/hswaw/smsgw/proto",
+ proto = ":proto_proto",
+ visibility = ["//visibility:public"],
+)
+
+go_library(
+ name = "go_default_library",
+ embed = [":proto_go_proto"],
+ importpath = "code.hackerspace.pl/hscloud/hswaw/smsgw/proto",
+ visibility = ["//visibility:public"],
+)
diff --git a/hswaw/smsgw/proto/smsgw.proto b/hswaw/smsgw/proto/smsgw.proto
new file mode 100644
index 0000000..2a95308
--- /dev/null
+++ b/hswaw/smsgw/proto/smsgw.proto
@@ -0,0 +1,17 @@
+syntax = "proto3";
+package proto;
+option go_package = "code.hackerspace.pl/hscloud/hswaw/smsgw/proto";
+
+message MessagesRequest {
+ string filter_body = 1;
+}
+
+message MessagesResponse {
+ string sender = 1;
+ string body = 3;
+ int64 timestamp = 4;
+}
+
+service SMSGateway {
+ rpc Messages(MessagesRequest) returns (stream MessagesResponse);
+}
diff --git a/hswaw/smsgw/twilio.go b/hswaw/smsgw/twilio.go
new file mode 100644
index 0000000..cdc0255
--- /dev/null
+++ b/hswaw/smsgw/twilio.go
@@ -0,0 +1,79 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+)
+
+type twilio struct {
+ accountSID string
+ accountToken string
+}
+
+type incomingPhoneNumber struct {
+ FriendlyName string `json:"friendly_name"`
+ SMSMethod string `json:"sms_method"`
+ SMSURL string `json:"sms_url"`
+ SID string `json:"sid"`
+}
+
+func (t *twilio) getIncomingPhoneNumbers(ctx context.Context) ([]incomingPhoneNumber, error) {
+ url := fmt.Sprintf("https://api.twilio.com/2010-04-01/Accounts/%s/IncomingPhoneNumbers.json", t.accountSID)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ req = req.WithContext(ctx)
+ req.SetBasicAuth(t.accountSID, t.accountToken)
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer res.Body.Close()
+
+ result := struct {
+ Message string `json:"message"`
+ Status int64 `json:"status"`
+ IPN []incomingPhoneNumber `json:"incoming_phone_numbers"`
+ }{}
+
+ if err := json.NewDecoder(res.Body).Decode(&result); err != nil {
+ return nil, err
+ }
+
+ if result.Message != "" {
+ return nil, fmt.Errorf("REST response error, status: %v, message: %q", result.Status, result.Message)
+ }
+
+ return result.IPN, nil
+}
+
+func (t *twilio) updateIncomingPhoneNumberSMSWebhook(ctx context.Context, sid, method, whurl string) error {
+ turl := fmt.Sprintf("https://api.twilio.com/2010-04-01/Accounts/%s/IncomingPhoneNumbers/%s.json", t.accountSID, sid)
+
+ data := url.Values{}
+ data.Set("SmsMethod", method)
+ data.Set("SmsUrl", whurl)
+
+ req, err := http.NewRequest("POST", turl, strings.NewReader(data.Encode()))
+ if err != nil {
+ return err
+ }
+ req = req.WithContext(ctx)
+ req.SetBasicAuth(t.accountSID, t.accountToken)
+ req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return err
+ }
+ defer res.Body.Close()
+
+ if res.StatusCode != 200 {
+ return fmt.Errorf("status code: %v", res.StatusCode)
+ }
+ return nil
+}