hswaw/smsgw: implement

The SMS gateway service allows consumers to subscribe to SMS messages
received by a Twilio phone number.

This is useful for receiving SMS auth messages.

Change-Id: Ib02a4306ad0d856dd10c7ca9241d9163809e7084
diff --git a/hswaw/smsgw/BUILD.bazel b/hswaw/smsgw/BUILD.bazel
new file mode 100644
index 0000000..35b36f1
--- /dev/null
+++ b/hswaw/smsgw/BUILD.bazel
@@ -0,0 +1,31 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library", "go_test")
+
+go_library(
+    name = "go_default_library",
+    srcs = [
+        "dispatcher.go",
+        "main.go",
+        "twilio.go",
+    ],
+    importpath = "code.hackerspace.pl/hscloud/hswaw/smsgw",
+    visibility = ["//visibility:private"],
+    deps = [
+        "//go/mirko:go_default_library",
+        "//hswaw/smsgw/proto:go_default_library",
+        "@com_github_golang_glog//:go_default_library",
+        "@org_golang_google_grpc//codes:go_default_library",
+        "@org_golang_google_grpc//status:go_default_library",
+    ],
+)
+
+go_binary(
+    name = "smsgw",
+    embed = [":go_default_library"],
+    visibility = ["//visibility:public"],
+)
+
+go_test(
+    name = "go_default_test",
+    srcs = ["dispatcher_test.go"],
+    embed = [":go_default_library"],
+)
diff --git a/hswaw/smsgw/dispatcher.go b/hswaw/smsgw/dispatcher.go
new file mode 100644
index 0000000..60fba32
--- /dev/null
+++ b/hswaw/smsgw/dispatcher.go
@@ -0,0 +1,110 @@
+package main
+
+import (
+	"context"
+	"regexp"
+	"time"
+
+	"github.com/golang/glog"
+)
+
+// dispatcher is responsible for dispatching incoming SMS messages to subscribers
+// that have chosen to receive them, filtering accordingly.
+type dispatcher struct {
+	// New SMS messages to be dispatched.
+	incoming chan *sms
+	// New subscribers to send messages to.
+	subscribers chan *subscriber
+}
+
+// newDispatcher creates a new dispatcher.
+func newDispatcher() *dispatcher {
+	return &dispatcher{
+		incoming:    make(chan *sms),
+		subscribers: make(chan *subscriber),
+	}
+}
+
+// sms received from the upstream provider.
+type sms struct {
+	from      string
+	body      string
+	timestamp time.Time
+}
+
+// subscriber that wants to receive messages with a given body filter.
+type subscriber struct {
+	// regexp to filter message body by
+	re *regexp.Regexp
+	// channel to which messages will be sent, must be emptied regularly by the
+	// subscriber.
+	data chan *sms
+	// channel that needs to be closed when the subscriber doesn't want to receive
+	// any more messages.
+	cancel chan struct{}
+}
+
+func (p *dispatcher) publish(msg *sms) {
+	p.incoming <- msg
+}
+
+func (p *dispatcher) subscribe(sub *subscriber) {
+	p.subscribers <- sub
+}
+
+func (p *dispatcher) run(ctx context.Context) {
+	// Map of internal IDs to subscribers. Internal IDs are used to remove
+	// canceled subscribers easily.
+	subscriberMap := make(map[int64]*subscriber)
+	// Internal channel that will emit SIDs of subscribers that needs to be
+	// removed.
+	subscriberCancel := make(chan int64)
+
+	for {
+		select {
+
+		// Should the processor close?
+		case <-ctx.Done():
+			return
+
+		// Do we need to remove a given subscriber?
+		case sid := <-subscriberCancel:
+			delete(subscriberMap, sid)
+
+		// Do we have a new subscriber?
+		case sub := <-p.subscribers:
+			// Generate a SID. A UNIX nanosecond timestamp is enough, since
+			// we're not running in parallel.
+			sid := time.Now().UnixNano()
+			glog.V(5).Infof("New subscriber %x, regexp %v", sid, sub.re)
+
+			// Add to subscriber map.
+			subscriberMap[sid] = sub
+
+			// On sub.cancel closed, emit info that we need to delete that
+			// subscriber.
+			go func() {
+				_, _ = <-sub.cancel
+				subscriberCancel <- sid
+			}()
+
+		// Do we have a new message to dispatch?
+		case in := <-p.incoming:
+			for sid, s := range subscriberMap {
+				glog.V(10).Infof("Considering %x", sid)
+				// If this subscriber doesn't care, ignore.
+				if !s.re.MatchString(in.body) {
+					continue
+				}
+
+				// Send, non-blocking, to subscriber. This ensures that we
+				// don't get stuck if a subscriber doesn't drain fast enough.
+				go func(to *subscriber, sid int64) {
+					glog.V(10).Infof("Dispatching to %x, %v", sid, to.data)
+					to.data <- in
+					glog.V(10).Infof("Dispatched to %x", sid)
+				}(s, sid)
+			}
+		}
+	}
+}
diff --git a/hswaw/smsgw/dispatcher_test.go b/hswaw/smsgw/dispatcher_test.go
new file mode 100644
index 0000000..a5f3977
--- /dev/null
+++ b/hswaw/smsgw/dispatcher_test.go
@@ -0,0 +1,206 @@
+package main
+
+import (
+	"context"
+	"regexp"
+	"testing"
+	"time"
+)
+
+func makeDut() (*dispatcher, context.CancelFunc, context.Context) {
+	dut := newDispatcher()
+
+	ctx := context.Background()
+	ctxC, cancelCtx := context.WithCancel(ctx)
+	go dut.run(ctxC)
+
+	return dut, cancelCtx, ctx
+}
+
+func expectReceived(t *testing.T, s *sms, data chan *sms) {
+	ticker := time.NewTicker(100 * time.Millisecond)
+	defer ticker.Stop()
+	select {
+	case d := <-data:
+		if d.from != s.from {
+			t.Errorf("Received SMS from %q, wanted %q", d.from, s.from)
+		}
+		if d.body != s.body {
+			t.Errorf("Received SMS body %q, wanted %q", d.body, s.body)
+		}
+		if d.timestamp != s.timestamp {
+			t.Errorf("Received SMS timestamp %v, wanted %v", d.timestamp, s.timestamp)
+		}
+	case <-ticker.C:
+		t.Fatalf("Timed out waiting for message")
+	}
+}
+
+func expectEmpty(t *testing.T, data chan *sms) {
+	ticker := time.NewTicker(1 * time.Millisecond)
+	defer ticker.Stop()
+	select {
+	case <-data:
+		t.Fatalf("Received unwanted message")
+	case <-ticker.C:
+	}
+}
+
+func TestDispatcher(t *testing.T) {
+	dut, cancelDut, _ := makeDut()
+	defer cancelDut()
+
+	data := make(chan *sms)
+	cancel := make(chan struct{})
+
+	dut.subscribe(&subscriber{
+		re:     regexp.MustCompile(".*"),
+		data:   data,
+		cancel: cancel,
+	})
+
+	in := &sms{
+		from:      "+4821372137",
+		body:      "foo",
+		timestamp: time.Now(),
+	}
+	dut.publish(in)
+
+	// Make sure we ge the message.
+	expectReceived(t, in, data)
+
+	// Make sure we don't receive the message again.
+	expectEmpty(t, data)
+
+	// Publish a new message, but this time close our subscriber.
+	close(cancel)
+	// Hack: yield.
+	time.Sleep(1 * time.Millisecond)
+
+	dut.publish(in)
+	expectEmpty(t, data)
+}
+
+type testSubscriber struct {
+	re     *regexp.Regexp
+	data   chan *sms
+	cancel chan struct{}
+}
+
+func TestDispatcherFilters(t *testing.T) {
+	dut, cancelDut, _ := makeDut()
+	defer cancelDut()
+
+	subscribers := []*testSubscriber{
+		{re: regexp.MustCompile(".*")},
+		{re: regexp.MustCompile("foo")},
+		{re: regexp.MustCompile("bar")},
+	}
+
+	for _, s := range subscribers {
+		s.data = make(chan *sms)
+		s.cancel = make(chan struct{})
+		dut.subscribe(&subscriber{
+			re:     s.re,
+			data:   s.data,
+			cancel: s.cancel,
+		})
+		defer func(c chan struct{}) {
+			close(c)
+		}(s.cancel)
+	}
+
+	in := &sms{
+		from:      "+4821372137",
+		body:      "foo",
+		timestamp: time.Now(),
+	}
+	dut.publish(in)
+	expectReceived(t, in, subscribers[0].data)
+	expectReceived(t, in, subscribers[1].data)
+	expectEmpty(t, subscribers[2].data)
+
+	in = &sms{
+		from:      "+4821372137",
+		body:      "bar",
+		timestamp: time.Now(),
+	}
+	dut.publish(in)
+	expectReceived(t, in, subscribers[0].data)
+	expectEmpty(t, subscribers[1].data)
+	expectReceived(t, in, subscribers[2].data)
+
+	in = &sms{
+		from:      "+4821372137",
+		body:      "foobar",
+		timestamp: time.Now(),
+	}
+	dut.publish(in)
+	expectReceived(t, in, subscribers[0].data)
+	expectReceived(t, in, subscribers[1].data)
+	expectReceived(t, in, subscribers[2].data)
+}
+
+func TestDispatcherMany(t *testing.T) {
+	dut, cancelDut, _ := makeDut()
+	defer cancelDut()
+
+	subscribers := make([]*testSubscriber, 10000)
+
+	for i, _ := range subscribers {
+		s := &testSubscriber{
+			re:     regexp.MustCompile(".*"),
+			data:   make(chan *sms),
+			cancel: make(chan struct{}),
+		}
+		subscribers[i] = s
+		dut.subscribe(&subscriber{
+			re:     s.re,
+			data:   s.data,
+			cancel: s.cancel,
+		})
+		defer func(c chan struct{}) {
+			close(c)
+		}(s.cancel)
+	}
+
+	in := &sms{
+		from:      "+4821372137",
+		body:      "foo",
+		timestamp: time.Now(),
+	}
+	dut.publish(in)
+
+	for _, s := range subscribers {
+		expectReceived(t, in, s.data)
+	}
+}
+
+func TestDispatcherHammer(t *testing.T) {
+	dut, cancelDut, _ := makeDut()
+	defer cancelDut()
+
+	for i := 0; i < 1000000; i += 1 {
+		s := &testSubscriber{
+			re:     regexp.MustCompile(".*"),
+			data:   make(chan *sms),
+			cancel: make(chan struct{}),
+		}
+
+		dut.subscribe(&subscriber{
+			re:     s.re,
+			data:   s.data,
+			cancel: s.cancel,
+		})
+
+		in := &sms{
+			from:      "+4821372137",
+			body:      "foo",
+			timestamp: time.Now(),
+		}
+		dut.publish(in)
+		expectReceived(t, in, s.data)
+
+		close(s.cancel)
+	}
+}
diff --git a/hswaw/smsgw/main.go b/hswaw/smsgw/main.go
new file mode 100644
index 0000000..a0a6a07
--- /dev/null
+++ b/hswaw/smsgw/main.go
@@ -0,0 +1,226 @@
+package main
+
+import (
+	"context"
+	"flag"
+	"fmt"
+	"net/http"
+	"regexp"
+	"strings"
+	"time"
+
+	"code.hackerspace.pl/hscloud/go/mirko"
+	"github.com/golang/glog"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
+
+	pb "code.hackerspace.pl/hscloud/hswaw/smsgw/proto"
+)
+
+var (
+	flagTwilioSID           string
+	flagTwilioToken         string
+	flagTwilioFriendlyPhone string
+
+	flagWebhookListen string
+	flagWebhookPublic string
+)
+
+func init() {
+	flag.Set("logtostderr", "true")
+}
+
+type server struct {
+	dispatcher *dispatcher
+}
+
+func ourPhoneNumber(ctx context.Context, t *twilio, friendly string) (*incomingPhoneNumber, error) {
+	ipn, err := t.getIncomingPhoneNumbers(ctx)
+	if err != nil {
+		return nil, err
+	}
+
+	for _, pn := range ipn {
+		if pn.FriendlyName == friendly {
+			return &pn, nil
+		}
+	}
+
+	return nil, fmt.Errorf("requested phone number %q not in list", friendly)
+}
+
+func ensureWebhook(ctx context.Context, t *twilio) {
+	pn, err := ourPhoneNumber(ctx, t, flagTwilioFriendlyPhone)
+	if err != nil {
+		glog.Exitf("could not get our phone number: %v", err)
+	}
+
+	url := fmt.Sprintf("%ssms", flagWebhookPublic)
+
+	// first setup.
+	if pn.SMSMethod != "POST" || pn.SMSURL != url {
+		glog.Infof("Updating webhook (is %s %q, want %s %q)", pn.SMSMethod, pn.SMSURL, "POST", url)
+		if err := t.updateIncomingPhoneNumberSMSWebhook(ctx, pn.SID, "POST", url); err != nil {
+			glog.Exitf("could not set webhook: %v")
+		}
+
+		// try again to check that it's actually set
+		for {
+			pn, err = ourPhoneNumber(ctx, t, flagTwilioFriendlyPhone)
+			if err != nil {
+				glog.Exitf("could not get our phone number: %v", err)
+			}
+			if pn.SMSMethod == "POST" || pn.SMSURL == url {
+				break
+			}
+			glog.Infof("Webhook not yet ready, currently %s %q", pn.SMSMethod, pn.SMSURL)
+			time.Sleep(5 * time.Second)
+		}
+		glog.Infof("Webhook verifier")
+	} else {
+		glog.Infof("Webhook up to date")
+	}
+
+	// now keep checking to make sure that nobody takes over our webhook
+	tick := time.NewTicker(30 * time.Second)
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		case <-tick.C:
+			pn, err = ourPhoneNumber(ctx, t, flagTwilioFriendlyPhone)
+			if err != nil {
+				glog.Exitf("could not get our phone number: %v", err)
+			}
+			if pn.SMSMethod != "POST" || pn.SMSURL != url {
+				glog.Exitf("Webhook got deconfigured, not %s %q", pn.SMSMethod, pn.SMSURL)
+			}
+		}
+	}
+}
+
+func (s *server) webhookHandler(w http.ResponseWriter, r *http.Request) {
+	if err := r.ParseForm(); err != nil {
+		glog.Errorf("webhook body parse error: %v", err)
+		return
+	}
+
+	accountSID := r.PostForm.Get("AccountSid")
+	if accountSID != flagTwilioSID {
+		glog.Errorf("webhook got wrong account sid, got %q, wanted %q", accountSID, flagTwilioSID)
+		return
+	}
+
+	body := r.PostForm.Get("Body")
+	if body == "" {
+		return
+	}
+
+	from := r.PostForm.Get("From")
+
+	glog.Infof("Got SMS from %q, body %q", from, body)
+
+	s.dispatcher.publish(&sms{
+		from:      from,
+		body:      body,
+		timestamp: time.Now(),
+	})
+
+	w.WriteHeader(200)
+}
+
+func main() {
+	flag.StringVar(&flagTwilioSID, "twilio_sid", "", "Twilio account SID")
+	flag.StringVar(&flagTwilioToken, "twilio_token", "", "Twilio auth token")
+	flag.StringVar(&flagTwilioFriendlyPhone, "twilio_friendly_phone", "", "Twilio friendly phone number")
+
+	flag.StringVar(&flagWebhookListen, "webhook_listen", "127.0.0.1:5000", "Listen address for webhook handler")
+	flag.StringVar(&flagWebhookPublic, "webhook_public", "", "Public address for webhook handler (wg. http://proxy.q3k.org/smsgw/)")
+	flag.Parse()
+
+	if flagTwilioSID == "" || flagTwilioToken == "" {
+		glog.Exitf("twilio_sid and twilio_token must be set")
+	}
+
+	if flagTwilioFriendlyPhone == "" {
+		glog.Exitf("twilio_friendly_phone must be set")
+	}
+
+	if flagWebhookPublic == "" {
+		glog.Exitf("webhook_public must be set")
+	}
+
+	if !strings.HasSuffix(flagWebhookPublic, "/") {
+		flagWebhookPublic += "/"
+	}
+
+	s := &server{
+		dispatcher: newDispatcher(),
+	}
+
+	m := mirko.New()
+	if err := m.Listen(); err != nil {
+		glog.Exitf("Listen(): %v", err)
+	}
+
+	webhookMux := http.NewServeMux()
+	webhookMux.HandleFunc("/sms", s.webhookHandler)
+	webhookSrv := http.Server{
+		Addr:    flagWebhookListen,
+		Handler: webhookMux,
+	}
+	go func() {
+		if err := webhookSrv.ListenAndServe(); err != nil {
+			glog.Exitf("webhook ListenAndServe: %v", err)
+		}
+	}()
+
+	t := &twilio{
+		accountSID:   flagTwilioSID,
+		accountToken: flagTwilioToken,
+	}
+	go ensureWebhook(m.Context(), t)
+	go s.dispatcher.run(m.Context())
+
+	pb.RegisterSMSGatewayServer(m.GRPC(), s)
+
+	if err := m.Serve(); err != nil {
+		glog.Exitf("Serve(): %v", err)
+	}
+
+	<-m.Done()
+}
+
+func (s *server) Messages(req *pb.MessagesRequest, stream pb.SMSGateway_MessagesServer) error {
+	re := regexp.MustCompile(".*")
+	if req.FilterBody != "" {
+		var err error
+		re, err = regexp.Compile(req.FilterBody)
+		if err != nil {
+			return status.Errorf(codes.InvalidArgument, "filter regexp error: %v", err)
+		}
+	}
+
+	data := make(chan *sms)
+	cancel := make(chan struct{})
+	defer func() {
+		close(cancel)
+		close(data)
+	}()
+
+	s.dispatcher.subscribe(&subscriber{
+		re:     re,
+		data:   data,
+		cancel: cancel,
+	})
+
+	for d := range data {
+		stream.Send(&pb.MessagesResponse{
+			Sender:    d.from,
+			Body:      d.body,
+			Timestamp: d.timestamp.UnixNano(),
+		})
+	}
+
+	return nil
+}
diff --git a/hswaw/smsgw/proto/BUILD.bazel b/hswaw/smsgw/proto/BUILD.bazel
new file mode 100644
index 0000000..be3451f
--- /dev/null
+++ b/hswaw/smsgw/proto/BUILD.bazel
@@ -0,0 +1,23 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+
+proto_library(
+    name = "proto_proto",
+    srcs = ["smsgw.proto"],
+    visibility = ["//visibility:public"],
+)
+
+go_proto_library(
+    name = "proto_go_proto",
+    compilers = ["@io_bazel_rules_go//proto:go_grpc"],
+    importpath = "code.hackerspace.pl/hscloud/hswaw/smsgw/proto",
+    proto = ":proto_proto",
+    visibility = ["//visibility:public"],
+)
+
+go_library(
+    name = "go_default_library",
+    embed = [":proto_go_proto"],
+    importpath = "code.hackerspace.pl/hscloud/hswaw/smsgw/proto",
+    visibility = ["//visibility:public"],
+)
diff --git a/hswaw/smsgw/proto/smsgw.proto b/hswaw/smsgw/proto/smsgw.proto
new file mode 100644
index 0000000..2a95308
--- /dev/null
+++ b/hswaw/smsgw/proto/smsgw.proto
@@ -0,0 +1,17 @@
+syntax = "proto3";
+package proto;
+option go_package = "code.hackerspace.pl/hscloud/hswaw/smsgw/proto";
+
+message MessagesRequest {
+    string filter_body = 1;
+}
+
+message MessagesResponse {
+    string sender = 1;
+    string body = 3;
+    int64 timestamp = 4;
+}
+
+service SMSGateway {
+    rpc Messages(MessagesRequest) returns (stream MessagesResponse);
+}
diff --git a/hswaw/smsgw/twilio.go b/hswaw/smsgw/twilio.go
new file mode 100644
index 0000000..cdc0255
--- /dev/null
+++ b/hswaw/smsgw/twilio.go
@@ -0,0 +1,79 @@
+package main
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"net/url"
+	"strings"
+)
+
+type twilio struct {
+	accountSID   string
+	accountToken string
+}
+
+type incomingPhoneNumber struct {
+	FriendlyName string `json:"friendly_name"`
+	SMSMethod    string `json:"sms_method"`
+	SMSURL       string `json:"sms_url"`
+	SID          string `json:"sid"`
+}
+
+func (t *twilio) getIncomingPhoneNumbers(ctx context.Context) ([]incomingPhoneNumber, error) {
+	url := fmt.Sprintf("https://api.twilio.com/2010-04-01/Accounts/%s/IncomingPhoneNumbers.json", t.accountSID)
+	req, err := http.NewRequest("GET", url, nil)
+	if err != nil {
+		return nil, err
+	}
+	req = req.WithContext(ctx)
+	req.SetBasicAuth(t.accountSID, t.accountToken)
+	res, err := http.DefaultClient.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer res.Body.Close()
+
+	result := struct {
+		Message string                `json:"message"`
+		Status  int64                 `json:"status"`
+		IPN     []incomingPhoneNumber `json:"incoming_phone_numbers"`
+	}{}
+
+	if err := json.NewDecoder(res.Body).Decode(&result); err != nil {
+		return nil, err
+	}
+
+	if result.Message != "" {
+		return nil, fmt.Errorf("REST response error, status: %v, message: %q", result.Status, result.Message)
+	}
+
+	return result.IPN, nil
+}
+
+func (t *twilio) updateIncomingPhoneNumberSMSWebhook(ctx context.Context, sid, method, whurl string) error {
+	turl := fmt.Sprintf("https://api.twilio.com/2010-04-01/Accounts/%s/IncomingPhoneNumbers/%s.json", t.accountSID, sid)
+
+	data := url.Values{}
+	data.Set("SmsMethod", method)
+	data.Set("SmsUrl", whurl)
+
+	req, err := http.NewRequest("POST", turl, strings.NewReader(data.Encode()))
+	if err != nil {
+		return err
+	}
+	req = req.WithContext(ctx)
+	req.SetBasicAuth(t.accountSID, t.accountToken)
+	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+	res, err := http.DefaultClient.Do(req)
+	if err != nil {
+		return err
+	}
+	defer res.Body.Close()
+
+	if res.StatusCode != 200 {
+		return fmt.Errorf("status code: %v", res.StatusCode)
+	}
+	return nil
+}