| package main |
| |
| import ( |
| "context" |
| "fmt" |
| "strings" |
| |
| "golang.org/x/net/trace" |
| "google.golang.org/grpc" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/credentials" |
| "google.golang.org/grpc/peer" |
| "google.golang.org/grpc/status" |
| ) |
| |
| type clientPKIInfo struct { |
| realm string |
| principal string |
| job string |
| } |
| |
| func (c *clientPKIInfo) String() string { |
| return fmt.Sprintf("job=%q, principal=%q, realm=%q", c.job, c.principal, c.realm) |
| } |
| |
| func parseClientName(realm, name string) (*clientPKIInfo, error) { |
| if !strings.HasSuffix(name, "."+realm) { |
| return nil, fmt.Errorf("invalid realm") |
| } |
| service := strings.TrimSuffix(name, "."+realm) |
| parts := strings.Split(service, ".") |
| if len(parts) != 2 { |
| return nil, fmt.Errorf("invalid service") |
| } |
| return &clientPKIInfo{ |
| realm: realm, |
| principal: parts[1], |
| job: parts[0], |
| }, nil |
| } |
| |
| const ( |
| ctxKeyPKIInfo = "hscloud-pki-info" |
| ) |
| |
| func withPKIInfo(ctx context.Context, c *clientPKIInfo) context.Context { |
| tr, ok := trace.FromContext(ctx) |
| if ok { |
| tr.LazyPrintf("PKI Peer: %s", c.String()) |
| } |
| return context.WithValue(ctx, ctxKeyPKIInfo, c) |
| } |
| |
| func (s *server) unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { |
| peer, ok := peer.FromContext(ctx) |
| if !ok { |
| s.trace(ctx, "Could not establish identity of peer.") |
| return nil, status.Error(codes.InvalidArgument, "no peer info") |
| } |
| |
| authInfo, ok := peer.AuthInfo.(credentials.TLSInfo) |
| if !ok { |
| s.trace(ctx, "Could not establish TLS identity of peer.") |
| return nil, status.Error(codes.InvalidArgument, "no TLS peer info") |
| } |
| |
| chains := authInfo.State.VerifiedChains |
| if len(chains) != 1 { |
| s.trace(ctx, "No trusted chain found.") |
| return nil, status.Error(codes.InvalidArgument, "invalid TLS certificate") |
| } |
| chain := chains[0] |
| |
| certDNs := make([]string, len(chain)) |
| for i, cert := range chain { |
| certDNs[i] = cert.Subject.String() |
| } |
| s.trace(ctx, "TLS chain: %s", strings.Join(certDNs, ", ")) |
| |
| clientInfo, err := parseClientName(s.opts.pkiRealm, chain[0].Subject.CommonName) |
| if err != nil { |
| s.trace(ctx, "Could not parse certificate DN: %v", err) |
| return nil, status.Error(codes.InvalidArgument, "invalid TLS CommonName") |
| } |
| ctx = withPKIInfo(ctx, clientInfo) |
| |
| return handler(ctx, req) |
| } |