| package main |
| |
| import ( |
| "context" |
| "flag" |
| "io" |
| "time" |
| |
| "code.hackerspace.pl/hscloud/bgpwtf/cccampix/pgpencryptor/gpg" |
| "code.hackerspace.pl/hscloud/bgpwtf/cccampix/pgpencryptor/hkp" |
| "code.hackerspace.pl/hscloud/bgpwtf/cccampix/pgpencryptor/model" |
| pb "code.hackerspace.pl/hscloud/bgpwtf/cccampix/proto" |
| "code.hackerspace.pl/hscloud/go/mirko" |
| "github.com/golang/glog" |
| "github.com/lib/pq" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/status" |
| ) |
| |
| var ( |
| flagMaxClients int |
| flagChunkSize int |
| flagHkpMaxWaitTime time.Duration |
| flagDSN string |
| ) |
| |
| type service struct { |
| hkpClient hkp.Client |
| encryptorFactory gpg.EncryptorFactory |
| clients chan (struct{}) |
| model model.Model |
| } |
| |
| func (s *service) KeyInfo(ctx context.Context, req *pb.KeyInfoRequest) (*pb.KeyInfoResponse, error) { |
| var data []byte |
| var err error |
| |
| switch req.Caching { |
| case pb.KeyInfoRequest_CACHING_AUTO: |
| _, err := s.model.GetKey(ctx, req.Fingerprint) |
| if err != nil { |
| data, err = s.hkpClient.GetKeyRing(ctx, req.Fingerprint) |
| switch err { |
| case nil: |
| break |
| case hkp.ErrKeyNotFound: |
| return nil, status.Errorf(codes.NotFound, "key not found: %v", err) |
| default: |
| return nil, status.Errorf(codes.Unavailable, "failed to get key from HKP servers: %v", err) |
| } |
| } |
| case pb.KeyInfoRequest_CACHING_FORCE_REMOTE: |
| data, err = s.hkpClient.GetKeyRing(ctx, req.Fingerprint) |
| switch err { |
| case nil: |
| break |
| case hkp.ErrKeyNotFound: |
| return nil, status.Errorf(codes.NotFound, "key not found: %v", err) |
| default: |
| return nil, status.Errorf(codes.Unavailable, "failed to get key from HKP servers: %v", err) |
| } |
| case pb.KeyInfoRequest_CACHING_FORCE_LOCAL: |
| _, err := s.model.GetKey(ctx, req.Fingerprint) |
| switch err { |
| case nil: |
| break |
| case model.ErrKeyNotFound: |
| return nil, status.Errorf(codes.NotFound, "key not found: %v", err) |
| default: |
| return nil, status.Errorf(codes.Unavailable, "failed to read key from local db: %v", err) |
| } |
| default: |
| return nil, status.Errorf(codes.InvalidArgument, "caching field value is invalid") |
| } |
| |
| // successfully read fresh key from hkp, update db |
| if data != nil { |
| err := s.model.PutKey(ctx, &model.PgpKey{ |
| Fingerprint: req.Fingerprint, |
| KeyData: data, |
| Okay: true, |
| }) |
| |
| if err != nil { |
| return nil, status.Errorf(codes.Unavailable, "failed to cache key received from HKP: %v", err) |
| } |
| } |
| |
| return &pb.KeyInfoResponse{}, nil |
| } |
| |
| func (s *service) Encrypt(stream pb.PGPEncryptor_EncryptServer) error { |
| select { |
| case s.clients <- struct{}{}: |
| break |
| case <-stream.Context().Done(): |
| return status.Errorf(codes.ResourceExhausted, "PGPEncryptor to many sessions running, try again later") |
| } |
| |
| defer func() { |
| <-s.clients |
| }() |
| |
| ctx, _ := context.WithTimeout(context.Background(), flagHkpMaxWaitTime) |
| initialMessage, err := stream.Recv() |
| if err != nil { |
| return status.Errorf(codes.Canceled, "failed to read data from the client: %v", err) |
| } |
| |
| key, err := s.model.GetKey(ctx, initialMessage.Fingerprint) |
| if err != nil { |
| if err != nil { |
| switch err { |
| case model.ErrKeyNotFound: |
| return status.Errorf(codes.NotFound, "recipient key not found: %v", err) |
| default: |
| return status.Errorf(codes.Unavailable, "error when getting keyring: %v", err) |
| } |
| } |
| } |
| |
| if key.Okay == false { |
| return status.Errorf(codes.InvalidArgument, "cached key is invalid, call PGPEncryptor.KeyInfo with caching=FORCE_REMOTE to refresh") |
| } |
| |
| senderDone := make(chan struct{}) |
| errors := make(chan error) |
| enc, err := s.encryptorFactory.Get(key.Fingerprint, key.KeyData) |
| |
| if err != nil { |
| return status.Errorf(codes.Unavailable, "PGPEncryptor error while creating encryptor: %v", err) |
| } |
| defer enc.Close() |
| |
| err = enc.WritePlainText(initialMessage.Data) |
| if err != nil { |
| return err |
| } |
| |
| // keep reading incoming messages and pass them to the encryptor |
| go func() { |
| defer enc.Finish() |
| |
| for { |
| in, err := stream.Recv() |
| |
| if err != nil { |
| if err == io.EOF { |
| return |
| } |
| |
| errors <- status.Errorf(codes.Unavailable, "PGPEncryptor error while receiving message: %v", err) |
| return |
| } |
| |
| err = enc.WritePlainText(in.Data) |
| if err != nil { |
| errors <- status.Errorf(codes.Unavailable, "PGPEncryptor error while writing data to encryptor: %v", err) |
| return |
| } |
| |
| if in.Info == pb.EncryptRequest_CHUNK_LAST { |
| return |
| } |
| } |
| }() |
| |
| // start sender routine |
| go func() { |
| defer close(senderDone) |
| |
| for { |
| data, err := enc.ReadCipherText(flagChunkSize) |
| if err != nil && err != io.EOF { |
| errors <- status.Errorf(codes.Unavailable, "PGPEncryptor error while reading cipher stream: %v", err) |
| return |
| } |
| |
| info := pb.EncryptResponse_CHUNK_INFO_MORE |
| if err == io.EOF { |
| info = pb.EncryptResponse_CHUNK_LAST |
| } |
| |
| res := &pb.EncryptResponse{ |
| Data: data, |
| Info: info, |
| } |
| |
| err = stream.Send(res) |
| if err != nil { |
| errors <- status.Errorf(codes.Unavailable, "PGPEncryptor error while sending data to client: %v", err) |
| return |
| } |
| |
| if info == pb.EncryptResponse_CHUNK_LAST { |
| return |
| } |
| } |
| }() |
| |
| // sync with sender routine |
| select { |
| case <-senderDone: |
| return nil |
| case err := <-errors: |
| return err |
| } |
| } |
| |
| func main() { |
| flag.IntVar(&flagMaxClients, "maxClients", 20, "maximum number of concurrent encryption sessions") |
| flag.IntVar(&flagChunkSize, "chunkSize", 1024*8, "maximum size of chunk sent back to client") |
| flag.DurationVar(&flagHkpMaxWaitTime, "hkpMaxWaitTime", 10*time.Second, "maximum time awaiting reply from HKP") |
| flag.StringVar(&flagDSN, "dsn", "", "PostrgreSQL connection string") |
| |
| flag.DurationVar(&gpg.ExecutionTimeLimit, "gpgExecutionTimeLimit", 30*time.Second, "execution time limit for gpg commands") |
| flag.StringVar(&gpg.BinaryPath, "gpgPath", "gpg", "path to gpg binary") |
| |
| flag.DurationVar(&hkp.PerServerTimeLimit, "hkpPerServerTimeLimit", 5*time.Second, "time for HKP server to reply with key") |
| flag.IntVar(&hkp.PerServerRetryCount, "hkpPerServerRetryCount", 3, "retry count per HKP server") |
| flag.Parse() |
| |
| // Picking an existing postgres-like driver for sqlx.BindType to work |
| // See: https://github.com/jmoiron/sqlx/blob/ed7c52c43ee1e12a35efbcfea8dbae2d62a90370/bind.go#L24 |
| mirko.TraceSQL(&pq.Driver{}, "pgx") |
| mi := mirko.New() |
| |
| m, err := model.Connect(mi.Context(), "pgx", flagDSN) |
| if err != nil { |
| glog.Exitf("Failed to create model: %v", err) |
| } |
| |
| err = m.MigrateUp() |
| if err != nil { |
| glog.Exitf("Failed to migrate up: %v", err) |
| } |
| |
| if err := mi.Listen(); err != nil { |
| glog.Exitf("Listen failed: %v", err) |
| } |
| |
| s := &service{ |
| hkpClient: hkp.NewClient(), |
| encryptorFactory: gpg.CLIEncryptorFactory{}, |
| clients: make(chan struct{}, flagMaxClients), |
| model: m, |
| } |
| pb.RegisterPGPEncryptorServer(mi.GRPC(), s) |
| |
| if err := mi.Serve(); err != nil { |
| glog.Exitf("Serve failed: %v", err) |
| } |
| |
| <-mi.Done() |
| } |