blob: fdb4e34660f81c93c74b8c5c1eee5017f66534c7 [file] [log] [blame]
Sergiusz Bazanskie08e6da2018-08-27 20:40:10 +01001package main
2
3import (
4 "context"
5 "fmt"
6 "strings"
7
8 "golang.org/x/net/trace"
9 "google.golang.org/grpc"
10 "google.golang.org/grpc/codes"
11 "google.golang.org/grpc/credentials"
12 "google.golang.org/grpc/peer"
13 "google.golang.org/grpc/status"
14)
15
16type clientPKIInfo struct {
17 realm string
18 principal string
19 job string
20}
21
22func (c *clientPKIInfo) String() string {
23 return fmt.Sprintf("job=%q, principal=%q, realm=%q", c.job, c.principal, c.realm)
24}
25
26func parseClientName(realm, name string) (*clientPKIInfo, error) {
27 if !strings.HasSuffix(name, "."+realm) {
28 return nil, fmt.Errorf("invalid realm")
29 }
30 service := strings.TrimSuffix(name, "."+realm)
31 parts := strings.Split(service, ".")
32 if len(parts) != 2 {
33 return nil, fmt.Errorf("invalid service")
34 }
35 return &clientPKIInfo{
36 realm: realm,
37 principal: parts[1],
38 job: parts[0],
39 }, nil
40}
41
42const (
43 ctxKeyPKIInfo = "hscloud-pki-info"
44)
45
46func withPKIInfo(ctx context.Context, c *clientPKIInfo) context.Context {
47 tr, ok := trace.FromContext(ctx)
48 if ok {
49 tr.LazyPrintf("PKI Peer: %s", c.String())
50 }
51 return context.WithValue(ctx, ctxKeyPKIInfo, c)
52}
53
54func (s *server) unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
55 peer, ok := peer.FromContext(ctx)
56 if !ok {
57 s.trace(ctx, "Could not establish identity of peer.")
58 return nil, status.Error(codes.InvalidArgument, "no peer info")
59 }
60
61 authInfo, ok := peer.AuthInfo.(credentials.TLSInfo)
62 if !ok {
63 s.trace(ctx, "Could not establish TLS identity of peer.")
64 return nil, status.Error(codes.InvalidArgument, "no TLS peer info")
65 }
66
67 chains := authInfo.State.VerifiedChains
68 if len(chains) != 1 {
69 s.trace(ctx, "No trusted chain found.")
70 return nil, status.Error(codes.InvalidArgument, "invalid TLS certificate")
71 }
72 chain := chains[0]
73
74 certDNs := make([]string, len(chain))
75 for i, cert := range chain {
76 certDNs[i] = cert.Subject.String()
77 }
78 s.trace(ctx, "TLS chain: %s", strings.Join(certDNs, ", "))
79
80 clientInfo, err := parseClientName(s.opts.pkiRealm, chain[0].Subject.CommonName)
81 if err != nil {
82 s.trace(ctx, "Could not parse certificate DN: %v", err)
83 return nil, status.Error(codes.InvalidArgument, "invalid TLS CommonName")
84 }
85 ctx = withPKIInfo(ctx, clientInfo)
86
87 return handler(ctx, req)
88}