Initial Commit
diff --git a/pki.go b/pki.go
new file mode 100644
index 0000000..fdb4e34
--- /dev/null
+++ b/pki.go
@@ -0,0 +1,88 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "golang.org/x/net/trace"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/peer"
+ "google.golang.org/grpc/status"
+)
+
+type clientPKIInfo struct {
+ realm string
+ principal string
+ job string
+}
+
+func (c *clientPKIInfo) String() string {
+ return fmt.Sprintf("job=%q, principal=%q, realm=%q", c.job, c.principal, c.realm)
+}
+
+func parseClientName(realm, name string) (*clientPKIInfo, error) {
+ if !strings.HasSuffix(name, "."+realm) {
+ return nil, fmt.Errorf("invalid realm")
+ }
+ service := strings.TrimSuffix(name, "."+realm)
+ parts := strings.Split(service, ".")
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid service")
+ }
+ return &clientPKIInfo{
+ realm: realm,
+ principal: parts[1],
+ job: parts[0],
+ }, nil
+}
+
+const (
+ ctxKeyPKIInfo = "hscloud-pki-info"
+)
+
+func withPKIInfo(ctx context.Context, c *clientPKIInfo) context.Context {
+ tr, ok := trace.FromContext(ctx)
+ if ok {
+ tr.LazyPrintf("PKI Peer: %s", c.String())
+ }
+ return context.WithValue(ctx, ctxKeyPKIInfo, c)
+}
+
+func (s *server) unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
+ peer, ok := peer.FromContext(ctx)
+ if !ok {
+ s.trace(ctx, "Could not establish identity of peer.")
+ return nil, status.Error(codes.InvalidArgument, "no peer info")
+ }
+
+ authInfo, ok := peer.AuthInfo.(credentials.TLSInfo)
+ if !ok {
+ s.trace(ctx, "Could not establish TLS identity of peer.")
+ return nil, status.Error(codes.InvalidArgument, "no TLS peer info")
+ }
+
+ chains := authInfo.State.VerifiedChains
+ if len(chains) != 1 {
+ s.trace(ctx, "No trusted chain found.")
+ return nil, status.Error(codes.InvalidArgument, "invalid TLS certificate")
+ }
+ chain := chains[0]
+
+ certDNs := make([]string, len(chain))
+ for i, cert := range chain {
+ certDNs[i] = cert.Subject.String()
+ }
+ s.trace(ctx, "TLS chain: %s", strings.Join(certDNs, ", "))
+
+ clientInfo, err := parseClientName(s.opts.pkiRealm, chain[0].Subject.CommonName)
+ if err != nil {
+ s.trace(ctx, "Could not parse certificate DN: %v", err)
+ return nil, status.Error(codes.InvalidArgument, "invalid TLS CommonName")
+ }
+ ctx = withPKIInfo(ctx, clientInfo)
+
+ return handler(ctx, req)
+}