bgpwtf/cccampix/pgpencryptor: implement service

TODO:
  * tests

Change-Id: I5d0506542070236a8ee879fcb54bc9518e23b5e3
diff --git a/bgpwtf/cccampix/pgpencryptor/main.go b/bgpwtf/cccampix/pgpencryptor/main.go
index d8e410a..3d73e01 100644
--- a/bgpwtf/cccampix/pgpencryptor/main.go
+++ b/bgpwtf/cccampix/pgpencryptor/main.go
@@ -3,35 +3,248 @@
 import (
 	"context"
 	"flag"
-
 	"github.com/golang/glog"
+	"github.com/lib/pq"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
+	"io"
+	"time"
 
+	"code.hackerspace.pl/hscloud/bgpwtf/cccampix/pgpencryptor/gpg"
+	"code.hackerspace.pl/hscloud/bgpwtf/cccampix/pgpencryptor/hkp"
+	"code.hackerspace.pl/hscloud/bgpwtf/cccampix/pgpencryptor/model"
 	pb "code.hackerspace.pl/hscloud/bgpwtf/cccampix/proto"
 	"code.hackerspace.pl/hscloud/go/mirko"
 )
 
+var (
+	flagMaxClients     int
+	flagChunkSize      int
+	flagHkpMaxWaitTime time.Duration
+	flagDSN            string
+)
+
 type service struct {
+	hkpClient        hkp.Client
+	encryptorFactory gpg.EncryptorFactory
+	clients          chan (struct{})
+	model            model.Model
 }
 
 func (s *service) KeyInfo(ctx context.Context, req *pb.KeyInfoRequest) (*pb.KeyInfoResponse, error) {
-	return nil, status.Error(codes.Unimplemented, "not implemented yet")
+	var data []byte
+	var err error
+
+	switch req.Caching {
+	case pb.KeyInfoRequest_CACHING_AUTO:
+		_, err := s.model.GetKey(ctx, req.Fingerprint)
+		if err != nil {
+			data, err = s.hkpClient.GetKeyRing(ctx, req.Fingerprint)
+			switch err {
+			case nil:
+				break
+			case hkp.ErrKeyNotFound:
+				return nil, status.Errorf(codes.NotFound, "key not found: %v", err)
+			default:
+				return nil, status.Errorf(codes.Unavailable, "failed to get key from HKP servers: %v", err)
+			}
+		}
+	case pb.KeyInfoRequest_CACHING_FORCE_REMOTE:
+		data, err = s.hkpClient.GetKeyRing(ctx, req.Fingerprint)
+		switch err {
+		case nil:
+			break
+		case hkp.ErrKeyNotFound:
+			return nil, status.Errorf(codes.NotFound, "key not found: %v", err)
+		default:
+			return nil, status.Errorf(codes.Unavailable, "failed to get key from HKP servers: %v", err)
+		}
+	case pb.KeyInfoRequest_CACHING_FORCE_LOCAL:
+		_, err := s.model.GetKey(ctx, req.Fingerprint)
+		switch err {
+		case nil:
+			break
+		case model.ErrKeyNotFound:
+			return nil, status.Errorf(codes.NotFound, "key not found: %v", err)
+		default:
+			return nil, status.Errorf(codes.Unavailable, "failed to read key from local db: %v", err)
+		}
+	default:
+		return nil, status.Errorf(codes.InvalidArgument, "caching field value is invalid")
+	}
+
+	// successfully read fresh key from hkp, update db
+	if data != nil {
+		err := s.model.PutKey(ctx, &model.PgpKey{
+			Fingerprint: req.Fingerprint,
+			KeyData:     data,
+			Okay:        true,
+		})
+
+		if err != nil {
+			return nil, status.Errorf(codes.Unavailable, "failed to cache key received from HKP: %v", err)
+		}
+	}
+
+	return &pb.KeyInfoResponse{}, nil
 }
 
 func (s *service) Encrypt(stream pb.PGPEncryptor_EncryptServer) error {
-	return status.Error(codes.Unimplemented, "not implemented yet")
+	select {
+	case s.clients <- struct{}{}:
+		break
+	case <-stream.Context().Done():
+		return status.Errorf(codes.ResourceExhausted, "PGPEncryptor to many sessions running, try again later")
+	}
+
+	defer func() {
+		<-s.clients
+	}()
+
+	ctx, _ := context.WithTimeout(context.Background(), flagHkpMaxWaitTime)
+	initialMessage, err := stream.Recv()
+	if err != nil {
+		return status.Errorf(codes.Canceled, "failed to read data from the client: %v", err)
+	}
+
+	key, err := s.model.GetKey(ctx, initialMessage.Fingerprint)
+	if err != nil {
+		if err != nil {
+			switch err {
+			case model.ErrKeyNotFound:
+				return status.Errorf(codes.NotFound, "recipient key not found: %v", err)
+			default:
+				return status.Errorf(codes.Unavailable, "error when getting keyring: %v", err)
+			}
+		}
+	}
+
+	if key.Okay == false {
+		return status.Errorf(codes.InvalidArgument, "cached key is invalid, call PGPEncryptor.KeyInfo with caching=FORCE_REMOTE to refresh")
+	}
+
+	senderDone := make(chan struct{})
+	errors := make(chan error)
+	enc, err := s.encryptorFactory.Get(key.Fingerprint, key.KeyData)
+	defer enc.Close()
+
+	if err != nil {
+		return status.Errorf(codes.Unavailable, "PGPEncryptor error while creating encryptor: %v", err)
+	}
+
+	err = enc.WritePlainText(initialMessage.Data)
+	if err != nil {
+		return err
+	}
+
+	// keep reading incoming messages and pass them to the encryptor
+	go func() {
+		defer enc.Finish()
+
+		for {
+			in, err := stream.Recv()
+
+			if err != nil {
+				if err == io.EOF {
+					return
+				}
+
+				errors <- status.Errorf(codes.Unavailable, "PGPEncryptor error while receiving message: %v", err)
+				return
+			}
+
+			err = enc.WritePlainText(in.Data)
+			if err != nil {
+				errors <- status.Errorf(codes.Unavailable, "PGPEncryptor error while writing data to encryptor: %v", err)
+				return
+			}
+
+			if in.Info == pb.EncryptRequest_CHUNK_LAST {
+				return
+			}
+		}
+	}()
+
+	// start sender routine
+	go func() {
+		defer close(senderDone)
+
+		for {
+			data, err := enc.ReadCipherText(flagChunkSize)
+			if err != nil && err != io.EOF {
+				errors <- status.Errorf(codes.Unavailable, "PGPEncryptor error while reading cipher stream: %v", err)
+				return
+			}
+
+			info := pb.EncryptResponse_CHUNK_INFO_MORE
+			if err == io.EOF {
+				info = pb.EncryptResponse_CHUNK_LAST
+			}
+
+			res := &pb.EncryptResponse{
+				Data: data,
+				Info: info,
+			}
+
+			err = stream.Send(res)
+			if err != nil {
+				errors <- status.Errorf(codes.Unavailable, "PGPEncryptor error while sending data to client: %v", err)
+				return
+			}
+
+			if info == pb.EncryptResponse_CHUNK_LAST {
+				return
+			}
+		}
+	}()
+
+	// sync with sender routine
+	select {
+	case <-senderDone:
+		return nil
+	case err := <-errors:
+		return err
+	}
 }
 
 func main() {
+	flag.IntVar(&flagMaxClients, "maxClients", 20, "maximum number of concurrent encryption sessions")
+	flag.IntVar(&flagChunkSize, "chunkSize", 1024*8, "maximum size of chunk sent back to client")
+	flag.DurationVar(&flagHkpMaxWaitTime, "hkpMaxWaitTime", 10*time.Second, "maximum time awaiting reply from HKP")
+	flag.StringVar(&flagDSN, "dsn", "", "PostrgreSQL connection string")
+
+	flag.DurationVar(&gpg.ExecutionTimeLimit, "gpgExecutionTimeLimit", 30*time.Second, "execution time limit for gpg commands")
+	flag.StringVar(&gpg.BinaryPath, "gpgPath", "gpg", "path to gpg binary")
+
+	flag.DurationVar(&hkp.PerServerTimeLimit, "hkpPerServerTimeLimit", 5*time.Second, "time for HKP server to reply with key")
+	flag.IntVar(&hkp.PerServerRetryCount, "hkpPerServerRetryCount", 3, "retry count per HKP server")
 	flag.Parse()
+
+	// Picking an existing postgres-like driver for sqlx.BindType to work
+	// See: https://github.com/jmoiron/sqlx/blob/ed7c52c43ee1e12a35efbcfea8dbae2d62a90370/bind.go#L24
+	mirko.TraceSQL(&pq.Driver{}, "pgx")
 	mi := mirko.New()
 
+	m, err := model.Connect(mi.Context(), "pgx", flagDSN)
+	if err != nil {
+		glog.Exitf("Failed to create model: %v", err)
+	}
+
+	err = m.MigrateUp()
+	if err != nil {
+		glog.Exitf("Failed to migrate up: %v", err)
+	}
+
 	if err := mi.Listen(); err != nil {
 		glog.Exitf("Listen failed: %v", err)
 	}
 
-	s := &service{}
+	s := &service{
+		hkpClient:        hkp.NewClient(),
+		encryptorFactory: gpg.CLIEncryptorFactory{},
+		clients:          make(chan struct{}, flagMaxClients),
+		model:            m,
+	}
 	pb.RegisterPGPEncryptorServer(mi.GRPC(), s)
 
 	if err := mi.Serve(); err != nil {