cluster/identd/ident: add basic ident protocol server

This adds an ident protocol server and tests for it.

Change-Id: I830f85faa7dce4220bd7001635b20e88b4a8b417
diff --git a/cluster/identd/ident/BUILD.bazel b/cluster/identd/ident/BUILD.bazel
index b672c92..382900e 100644
--- a/cluster/identd/ident/BUILD.bazel
+++ b/cluster/identd/ident/BUILD.bazel
@@ -6,6 +6,7 @@
         "client.go",
         "request.go",
         "response.go",
+        "server.go",
     ],
     importpath = "code.hackerspace.pl/hscloud/cluster/identd/ident",
     visibility = ["//visibility:public"],
@@ -17,6 +18,7 @@
     srcs = [
         "request_test.go",
         "response_test.go",
+        "server_test.go",
     ],
     embed = [":go_default_library"],
     deps = ["@com_github_go_test_deep//:go_default_library"],
diff --git a/cluster/identd/ident/request.go b/cluster/identd/ident/request.go
index 9727893..9c94e75 100644
--- a/cluster/identd/ident/request.go
+++ b/cluster/identd/ident/request.go
@@ -2,6 +2,9 @@
 
 import (
 	"fmt"
+	"net"
+	"regexp"
+	"strconv"
 )
 
 // Request is an ident protocol request, as seen by the client or server.
@@ -12,9 +15,44 @@
 	// ServerPort is the port number on the server side of the ident protocol,
 	// ie. the port local to the ident server.
 	ServerPort uint16
+
+	// ClientAddress is the address of the ident protocol client. This is set
+	// by the ident Server before invoking handlers, and is ignored by the
+	// ident protocol Client.
+	// In handlers this can be used to ensure that responses are only returned
+	// to clients who are running on the remote side of the connection that
+	// they are querying about.
+	ClientAddress net.Addr
 }
 
 // encode encodes ths Request as per RFC1413, including the terminating \r\n.
 func (r *Request) encode() []byte {
 	return []byte(fmt.Sprintf("%d,%d\r\n", r.ServerPort, r.ClientPort))
 }
+
+var (
+	// reRequest matches request from RFC1413, but allows extra whitespace
+	// between significant tokens.
+	reRequest = regexp.MustCompile(`^\s*(\d{1,5})\s*,\s*(\d{1,5})\s*$`)
+)
+
+// decodeRequest parses the given bytes as an ident request. The data must be
+// stripped of the trailing \r\n.
+func decodeRequest(data []byte) (*Request, error) {
+	match := reRequest.FindStringSubmatch(string(data))
+	if match == nil {
+		return nil, fmt.Errorf("unparseable request")
+	}
+	serverPort, err := strconv.ParseUint(match[1], 10, 16)
+	if err != nil {
+		return nil, fmt.Errorf("invalid server port: %w", err)
+	}
+	clientPort, err := strconv.ParseUint(match[2], 10, 16)
+	if err != nil {
+		return nil, fmt.Errorf("invalid client port: %w", err)
+	}
+	return &Request{
+		ClientPort: uint16(clientPort),
+		ServerPort: uint16(serverPort),
+	}, nil
+}
diff --git a/cluster/identd/ident/response.go b/cluster/identd/ident/response.go
index 5eab431..f54b728 100644
--- a/cluster/identd/ident/response.go
+++ b/cluster/identd/ident/response.go
@@ -7,17 +7,6 @@
 	"strings"
 )
 
-var (
-	// reErrorReply matches error-reply from RFC1413, but also allows extra
-	// whitespace between significant tokens. It does not ensure that the
-	// error-type is one of the standardized values.
-	reErrorReply = regexp.MustCompile(`^\s*(\d{1,5})\s*,\s*(\d{1,5})\s*:\s*ERROR\s*:\s*(.+)$`)
-	// reIdentReply matches ident-reply from RFC1413, but also allows extra
-	// whitespace between significant tokens. It does not ensure that that
-	// opsys-field and user-id parts are RFC compliant.
-	reIdentReply = regexp.MustCompile(`^\s*(\d{1,5})\s*,\s*(\d{1,5})\s*:\s*USERID\s*:\s*([^:,]+)(,([^:]+))?\s*:(.+)$`)
-)
-
 // Response is an ident protocol response, as seen by the client or server.
 type Response struct {
 	// ClientPort is the port number on the client side of the indent protocol,
@@ -100,6 +89,47 @@
 	UserID string
 }
 
+// encode encodes the given Response. If the Response is unencodable/malformed,
+// nil is returned.
+func (r *Response) encode() []byte {
+	// Both Error and Ident cannot be set at once.
+	if r.Error != "" && r.Ident != nil {
+		return nil
+	}
+
+	if r.Error != "" {
+		if !r.Error.IsError() {
+			return nil
+		}
+		return []byte(fmt.Sprintf("%d,%d:ERROR:%s\r\n", r.ServerPort, r.ClientPort, r.Error))
+	}
+	if r.Ident != nil {
+		id := r.Ident
+		os := id.OperatingSystem
+		if os == "" {
+			return nil
+		}
+		// For compatibility, do not set US-ASCII explicitly.
+		if id.CharacterSet != "" && id.CharacterSet != "US-ASCII" {
+			os += "," + id.CharacterSet
+		}
+		return []byte(fmt.Sprintf("%d,%d:USERID:%s:%s\r\n", r.ServerPort, r.ClientPort, os, id.UserID))
+	}
+	// Malformed response, return nil.
+	return nil
+}
+
+var (
+	// reErrorReply matches error-reply from RFC1413, but also allows extra
+	// whitespace between significant tokens. It does not ensure that the
+	// error-type is one of the standardized values.
+	reErrorReply = regexp.MustCompile(`^\s*(\d{1,5})\s*,\s*(\d{1,5})\s*:\s*ERROR\s*:\s*(.+)$`)
+	// reIdentReply matches ident-reply from RFC1413, but also allows extra
+	// whitespace between significant tokens. It does not ensure that that
+	// opsys-field and user-id parts are RFC compliant.
+	reIdentReply = regexp.MustCompile(`^\s*(\d{1,5})\s*,\s*(\d{1,5})\s*:\s*USERID\s*:\s*([^:,]+)(,([^:]+))?\s*:(.+)$`)
+)
+
 // decodeResponse parses the given bytes as an ident response. The data must be
 // stripped of the trailing \r\n.
 func decodeResponse(data []byte) (*Response, error) {
diff --git a/cluster/identd/ident/response_test.go b/cluster/identd/ident/response_test.go
index e768d46..3ef5137 100644
--- a/cluster/identd/ident/response_test.go
+++ b/cluster/identd/ident/response_test.go
@@ -59,3 +59,84 @@
 		}
 	}
 }
+
+// TestResponseEncode exercises the encode method of Response.
+func TestResponseEncode(t *testing.T) {
+	for i, te := range []struct {
+		r    *Response
+		want string
+	}{
+		/// Standard features
+		// 0: simple error
+		{&Response{
+			ServerPort: 123,
+			ClientPort: 234,
+			Error:      InvalidPort,
+		}, "123,234:ERROR:INVALID-PORT\r\n"},
+		// 1: simple response, stripped US-ASCII
+		{&Response{
+			ServerPort: 123,
+			ClientPort: 234,
+			Ident: &IdentResponse{
+				OperatingSystem: "UNIX",
+				UserID:          "q3k",
+			},
+		}, "123,234:USERID:UNIX:q3k\r\n"},
+		// 2: simple response, stripped US-ASCII
+		{&Response{
+			ServerPort: 123,
+			ClientPort: 234,
+			Ident: &IdentResponse{
+				OperatingSystem: "UNIX",
+				CharacterSet:    "US-ASCII",
+				UserID:          "q3k",
+			},
+		}, "123,234:USERID:UNIX:q3k\r\n"},
+
+		/// Unusual features
+		// 3: unusual response
+		{&Response{
+			ServerPort: 123,
+			ClientPort: 234,
+			Ident: &IdentResponse{
+				OperatingSystem: "SUN",
+				CharacterSet:    "PETSCII",
+				UserID:          "q3k",
+			},
+		}, "123,234:USERID:SUN,PETSCII:q3k\r\n"},
+		// 4: non-standard error code
+		{&Response{
+			ServerPort: 123,
+			ClientPort: 234,
+			Error:      ErrorResponse("XNOMANA"),
+		}, "123,234:ERROR:XNOMANA\r\n"},
+
+		/// Errors
+		// 5: invalid error code (should return nil)
+		{&Response{
+			ServerPort: 123,
+			ClientPort: 234,
+			Error:      ErrorResponse("NOT ENOUGH MANA"),
+		}, ""},
+		// 6: no error/ident set (should return nil)
+		{&Response{
+			ServerPort: 123,
+			ClientPort: 234,
+		}, ""},
+		// 7: both error and ident set (should return nil)
+		{&Response{
+			ServerPort: 123,
+			ClientPort: 234,
+			Ident: &IdentResponse{
+				OperatingSystem: "UNIX",
+				CharacterSet:    "US-ASCII",
+				UserID:          "q3k",
+			},
+			Error: InvalidPort,
+		}, ""},
+	} {
+		if want, got := te.want, string(te.r.encode()); want != got {
+			t.Errorf("%d: wanted %q, got %q", i, want, got)
+		}
+	}
+}
diff --git a/cluster/identd/ident/server.go b/cluster/identd/ident/server.go
new file mode 100644
index 0000000..ac985d9
--- /dev/null
+++ b/cluster/identd/ident/server.go
@@ -0,0 +1,282 @@
+package ident
+
+import (
+	"bufio"
+	"context"
+	"fmt"
+	"net"
+	"sync"
+	"time"
+
+	"github.com/golang/glog"
+)
+
+// NewServer returns an ident Server.
+func NewServer() *Server {
+	return &Server{
+		handler: unsetHandler,
+	}
+}
+
+// Server is an ident protocol server (per RFC1413). It must be configured with
+// a HandlerFunc before Serve is called.
+// Multiple goroutines may invoke methods on Server simultaneously, but the
+// Server can only Serve one listener at a time.
+type Server struct {
+	handler HandlerFunc
+	// mu guards stopC and stop.
+	mu sync.Mutex
+	// stopC is set if Serve() is already running. If it gets closed, Serve()
+	// will quit and set stopC to nil.
+	stopC chan struct{}
+	// stop can be set to true if Serve() is not yet running but has already
+	// been requested to Stop() (eg. if Serve() is ran in a goroutine which
+	// hasn't yet scheduled). It will be set back to false when Serve() sees it
+	// set and exits.
+	stop bool
+}
+
+// ResponseWriter is passed to HandlerFuncs and is used to signal to the Server
+// that the HandlerFunc wants to respond to the incoming Request in a certain
+// way.
+// Only the goroutine that the HandlerFunc has been started in may invoke
+// methods on the ResponseWriter.
+type ResponseWriter interface {
+	// SendError returns an ident ErrorResponse to the ident client. This can
+	// only be called once, and cannot be called after SendIdent.
+	SendError(ErrorResponse) error
+	// SendIdent returns an ident IdentResponse to the ident client. This can
+	// only be called once, and cannot be called after SendError.
+	SendIdent(*IdentResponse) error
+}
+
+// HandlerFunc is a function that will be called to serve a given ident
+// Request. Users of the Server must implement this and configure a Server to
+// use it by invoking Server.HandleFunc.
+// Each HandlerFunc will be started in its own goroutine. When HandlerFunc
+// returns, the Server will attempt to serve more incoming requests from the
+// ident client.
+// The Server does not limit the amount of concurrent ident connections that it
+// serves. If the Server user wishes to limit concurrency, she must do it
+// herself, eg. by using a semaphore. The Server will continue accepting new
+// connections and starting new HandlerFuncs, if the user code needs to push
+// back it should return as early as possible. There currently is no way to
+// make the Server refuse connections above some concurrncy limit.
+// The Server does not impose any execution timeout on handlers. If the Server
+// user wishes to impose an execution timeout, she must do it herself, eg.
+// using context.WithTimeout or time.After.
+// The passed Context will be canceled when the ident client disconnects or the
+// Server shuts down. The HandlerFunc must return as early as it can detect
+// that the context is done.
+type HandlerFunc func(ctx context.Context, w ResponseWriter, r *Request)
+
+// responseWriter implements ResponseWriter for a Server.
+type responseWriter struct {
+	conn      net.Conn
+	req       *Request
+	responded bool
+}
+
+// sendResponse sends a Response to the ident client. The Response must already
+// be fully populated.
+func (w *responseWriter) sendResponse(r *Response) error {
+	if w.responded {
+		return fmt.Errorf("handler already sent a response")
+	}
+	w.responded = true
+	data := r.encode()
+	if data == nil {
+		return fmt.Errorf("failed to encode response")
+	}
+	glog.V(3).Infof(" -> %q", data)
+	_, err := w.conn.Write(data)
+	if err != nil {
+		return fmt.Errorf("writing response failed: %w", err)
+	}
+	return nil
+}
+
+func (w *responseWriter) SendError(e ErrorResponse) error {
+	if !e.IsError() {
+		return fmt.Errorf("error response must contain a valid error")
+	}
+	return w.sendResponse(&Response{
+		ClientPort: w.req.ClientPort,
+		ServerPort: w.req.ServerPort,
+		Error:      e,
+	})
+}
+
+func (w *responseWriter) SendIdent(i *IdentResponse) error {
+	ir := *i
+	// TODO(q3k): enforce RFC1413 limits.
+	if ir.OperatingSystem == "" {
+		ir.OperatingSystem = "UNIX"
+	}
+	if ir.UserID == "" {
+		return fmt.Errorf("ident response must have UserID set")
+	}
+	return w.sendResponse(&Response{
+		ClientPort: w.req.ClientPort,
+		ServerPort: w.req.ServerPort,
+		Ident:      &ir,
+	})
+}
+
+var (
+	unsetHandlerErrorOnce sync.Once
+)
+
+// unsetHandler is the default handler that is configured for a Server. It
+// returns UNKNOWN-ERROR to the ident client and logs an error once if it's
+// called (telling the user about a misconfiguration / programming error).
+func unsetHandler(ctx context.Context, w ResponseWriter, r *Request) {
+	unsetHandlerErrorOnce.Do(func() {
+		glog.Errorf("Server with no handler configured - will always return UNKNOWN-ERROR")
+	})
+	w.SendError(UnknownError)
+}
+
+// HandleFunc sets the HandlerFunc that the server will call for every incoming
+// ident request. If a HandlerFunc is already set, it will be overwritten by
+// the given new function.
+func (s *Server) HandleFunc(fn HandlerFunc) {
+	s.handler = fn
+}
+
+// Serve runs the ident server, blocking until a transport-level error occurs
+// or Stop() is invoked. The returned error will be nil on Stop(), and will
+// wrap the underlying transport-level error otherwise.
+//
+// Only one invokation of Serve() can be run at a time, but Serve can be called
+// again after Stop() is called, and can be ran on a different Listener - no
+// state is kept in between subsequent Serve() runs.
+func (s *Server) Serve(lis net.Listener) error {
+	s.mu.Lock()
+	if s.stopC != nil {
+		s.mu.Unlock()
+		return fmt.Errorf("cannot Serve() an already serving server")
+	}
+	// Stop() has been invoked before Serve() started.
+	if s.stop == true {
+		s.stop = false
+		s.mu.Unlock()
+		return nil
+	}
+	// Set stopC to allow Stop() calls to stop this running Serve(). It will be
+	// set to nil on exit.
+	stopC := make(chan struct{})
+	s.stopC = stopC
+	s.mu.Unlock()
+
+	defer func() {
+		s.mu.Lock()
+		s.stopC = nil
+		s.mu.Unlock()
+	}()
+
+	ctx, ctxC := context.WithCancel(context.Background())
+	for {
+		lisConnC := make(chan net.Conn)
+		lisErrC := make(chan error)
+		go func() {
+			conn, err := lis.Accept()
+			select {
+			case <-stopC:
+				// Server stopped, drop the accepted connection (if any)
+				// and return.
+				glog.V(2).Infof("Accept goroutine stopping...")
+				if err == nil {
+					conn.Close()
+				}
+				return
+			default:
+			}
+			if err == nil {
+				glog.V(5).Infof("Accept ok: %v", conn.RemoteAddr())
+				lisConnC <- conn
+			} else {
+				glog.V(5).Infof("Accept err: %v", err)
+				lisErrC <- err
+			}
+		}()
+
+		select {
+		case <-stopC:
+			// Server stopped, return.
+			ctxC()
+			return nil
+		case err := <-lisErrC:
+			ctxC()
+			// Accept() failed, return error.
+			return err
+		case conn := <-lisConnC:
+			// Accept() succeeded, serve request.
+			go s.serve(ctx, conn)
+		}
+	}
+}
+
+func (s *Server) Stop() {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.stopC != nil {
+		close(s.stopC)
+	} else {
+		s.stop = true
+	}
+}
+
+func (s *Server) serve(ctx context.Context, conn net.Conn) {
+	glog.V(2).Infof("Serving connection %v", conn.RemoteAddr())
+	scanner := bufio.NewScanner(conn)
+	// The RFC does not place a limit on the request line length, only on
+	// response length. We set an arbitrary limit to 1024 bytes.
+	scanner.Buffer(nil, 1024)
+
+	for {
+		// Implement an arbitrary timeout for receiving data from client.
+		// TODO(q3k): make this configurable
+		go func() {
+			timer := time.NewTimer(10 * time.Second)
+			defer timer.Stop()
+			select {
+			case <-ctx.Done():
+				return
+			case <-timer.C:
+				glog.V(1).Infof("Connection %v: terminating on receive timeout", conn.RemoteAddr())
+				conn.Close()
+			}
+		}()
+		if !scanner.Scan() {
+			err := scanner.Err()
+			if err == nil {
+				// EOF, just return.
+				return
+			}
+			// Some other transport level error occured, or the request line
+			// was too long. We can only log this and be done.
+			glog.V(1).Infof("Connection %v: scan failed: %v", conn.RemoteAddr(), err)
+			conn.Close()
+			return
+		}
+		data := scanner.Bytes()
+		glog.V(3).Infof(" <- %q", data)
+		req, err := decodeRequest(data)
+		if err != nil {
+			glog.V(1).Infof("Connection %v: could not decode request: %v", conn.RemoteAddr(), err)
+			conn.Close()
+			return
+		}
+		req.ClientAddress = conn.RemoteAddr()
+		rw := responseWriter{
+			req:  req,
+			conn: conn,
+		}
+		s.handler(ctx, &rw, req)
+		if !rw.responded {
+			glog.Warningf("Connection %v: handler did not send response, sending UNKNOWN-ERROR", conn.RemoteAddr())
+			rw.SendError(UnknownError)
+		}
+	}
+}
diff --git a/cluster/identd/ident/server_test.go b/cluster/identd/ident/server_test.go
new file mode 100644
index 0000000..7cf53ae
--- /dev/null
+++ b/cluster/identd/ident/server_test.go
@@ -0,0 +1,158 @@
+package ident
+
+import (
+	"bufio"
+	"context"
+	"fmt"
+	"io"
+	"net"
+	"strings"
+	"testing"
+)
+
+// loopback sets up a net.Listener on any available TCP port and returns it and
+// a dialer function that returns open connections to that listener.
+func loopback(t *testing.T) (net.Listener, func() net.Conn) {
+	t.Helper()
+
+	lis, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("Listen: %v", err)
+	}
+
+	return lis, func() net.Conn {
+		t.Helper()
+		conn, err := net.Dial("tcp", lis.Addr().String())
+		if err != nil {
+			t.Fatalf("Dial: %v", err)
+		}
+		return conn
+	}
+}
+
+// dumbHandler is a handler that returns USERID:UNIX:q3k for every request.
+func dumbHandler(ctx context.Context, w ResponseWriter, r *Request) {
+	w.SendIdent(&IdentResponse{
+		UserID: "q3k",
+	})
+}
+
+// reqRessps send an ident query to the conn and expects a response with
+// USERID:UNIX:q3k on the scanner.
+func reqResp(t *testing.T, conn net.Conn, scanner *bufio.Scanner, client, server uint16) {
+	t.Helper()
+	if _, err := fmt.Fprintf(conn, "%d,%d\r\n", client, server); err != nil {
+		t.Fatalf("Write: %v", err)
+	}
+	if !scanner.Scan() {
+		t.Fatalf("Scan: %v", scanner.Err())
+	}
+	if want, got := fmt.Sprintf("%d,%d:USERID:UNIX:q3k", client, server), scanner.Text(); want != got {
+		t.Fatalf("Wanted %q, got %q", want, got)
+	}
+}
+
+// TestServeSimple exercises the basic Server functionality: responding to
+// ident requests.
+func TestServeSimple(t *testing.T) {
+	lis, dial := loopback(t)
+	defer lis.Close()
+
+	isrv := NewServer()
+	isrv.HandleFunc(dumbHandler)
+	go isrv.Serve(lis)
+
+	conn := dial()
+	defer conn.Close()
+	scanner := bufio.NewScanner(conn)
+
+	// Send a request, expect response.
+	reqResp(t, conn, scanner, 123, 234)
+	// Send another request on the same conn, expect response.
+	reqResp(t, conn, scanner, 234, 345)
+
+	// Send another request in parallel, expect response.
+	conn2 := dial()
+	defer conn2.Close()
+	scanner2 := bufio.NewScanner(conn2)
+	reqResp(t, conn2, scanner2, 345, 456)
+}
+
+// TestServeError exercises situations where the server has to deal with
+// nasty/broken clients.
+func TestServeErrors(t *testing.T) {
+	lis, dial := loopback(t)
+	defer lis.Close()
+
+	isrv := NewServer()
+	isrv.HandleFunc(dumbHandler)
+	go isrv.Serve(lis)
+
+	conn := dial()
+	defer conn.Close()
+
+	// Send something that's not ident.
+	fmt.Fprintf(conn, "GET / HTTP/1.1\r\n\r\n")
+	// Expect EOF on read.
+	data := make([]byte, 100)
+	_, err := conn.Read(data)
+	if want, got := io.EOF, err; want != got {
+		t.Fatalf("Expected %v, got %v", want, got)
+	}
+
+	conn = dial()
+	defer conn.Close()
+
+	// Send a very long request line, expect to not be served.
+	fmt.Fprintf(conn, "123,%s123\r\n", strings.Repeat(" ", 4096))
+	data = make([]byte, 100)
+	_, err = conn.Read(data)
+	// In a large write, the connection will be closed by the server before
+	// we're finished writing. That will cause the connection to be reset, not
+	// just EOF'd as above.
+	if err == nil {
+		t.Fatalf("Read did not fail")
+	}
+}
+
+// TestServerRestart ensures that the server's serve/stop logic works as expected.
+func TestServerRestart(t *testing.T) {
+	lis, dial := loopback(t)
+	defer lis.Close()
+
+	isrv := NewServer()
+	isrv.HandleFunc(dumbHandler)
+
+	// Stop the server before it's even started.
+	isrv.Stop()
+
+	// The server should now exit immediately.
+	if err := isrv.Serve(lis); err != nil {
+		t.Fatalf("Serve: %v", err)
+	}
+
+	// On a subsequent run it should, however, start and serve.
+	go isrv.Serve(lis)
+
+	conn := dial()
+	defer conn.Close()
+	scanner := bufio.NewScanner(conn)
+
+	// Send a request, expect response.
+	reqResp(t, conn, scanner, 123, 234)
+
+	// Attempting another simultaneous Serve() shoud fail.
+	if err := isrv.Serve(lis); err == nil {
+		t.Fatal("Serve() returned nil, wanted error")
+	}
+
+	// Send a request, expect response.
+	reqResp(t, conn, scanner, 234, 345)
+
+	// Stop server, restart server.
+	isrv.Stop()
+	go isrv.Serve(lis)
+
+	// Send a request, expect response.
+	reqResp(t, conn, scanner, 345, 456)
+}