blob: 7cd2c818aa5be9fc6691ef1fb2af5fcb62796732 [file] [log] [blame]
Serge Bazanski9f0e1e82023-03-31 22:36:54 +00001package certs
2
3import (
4 "bytes"
5 "crypto"
6 "crypto/ed25519"
7 "crypto/rand"
8 "crypto/x509"
9 "encoding/pem"
10 "errors"
11 "fmt"
12 "log"
13 "math/big"
14 "net"
15 "os"
16 "path/filepath"
17 "time"
18)
19
20// Certificate is a higher-level descriptor of an intent to generate a
21// certificate and corresponding Ed25519 keypair on disk.
22type Certificate struct {
23 // uniquer name for this cert, used to calculate filesystem paths.
24 name string
25 // root directory where all certs are stored.
26 root string
27 // duration used to determine TimeAfter. If not set, the certificate will
28 // never expire.
29 duration time.Duration
30
31 kind certificateKind
32
33 // cn is the subject common name that's going to be produced in the X.509
34 // certificate.
35 cn string
36 // o is the subject organziation that's going to be produced in the X.509
37 // certificate.
38 o string
39 // san are the DNS alternate names that are going to be produced in the
40 // X.509 certificate.
41 san []string
42 // ips are the IP alternate names that are going to be produced in the
43 // X.509 certificate.
44 ips []net.IP
45
46 // issuer, if set, is the certificate that will sign this certificate. If
47 // not set, the certificate will be self-signed.
48 issuer *Certificate
49}
50
51// Paths returns local filesystem paths to the CA certificate, certificate and
52// key respectively. If the certificate is self signed, the CA path returned
53// will be empty. These files might or might not live on the file system - you
54// should first call Ensure to make sure they do.
55func (c *Certificate) Paths() (caPath, certPath, keyPath string) {
56 if c.issuer != nil {
57 caPath = c.issuer.path(fileKindCert)
58 }
59 certPath = c.path(fileKindCert)
60 keyPath = c.path(fileKindKey)
61 return
62}
63
64type certificateKind string
65
66const (
67 kindServer certificateKind = "server"
68 kindClient certificateKind = "client"
69 kindClientServer certificateKind = "client-server"
70 kindCA certificateKind = "ca"
71 kindProdvider certificateKind = "prodvider"
72)
73
74type fileKind string
75
76const (
77 fileKindKey fileKind = "key"
78 fileKindKeyEncrypted fileKind = "key-encrypted"
79 fileKindCert fileKind = "cert"
80)
81
82// path returns the path to the generated fileKind for this Certificate.
83func (c *Certificate) path(k fileKind) string {
84 switch k {
85 case fileKindKeyEncrypted:
86 return filepath.Join(c.root, "secrets", "cipher", c.name+".key")
87 case fileKindKey:
88 return filepath.Join(c.root, "secrets", "plain", c.name+".key")
89 case fileKindCert:
90 // clustercfg.py compat: CA certs end in .crt, non-CA certs end in .cert.
91 // We're keeping this accidental convention to avoid spurious nix rebuilds
92 // when migrating.
93 //
94 // Feel free to fix it if it annoys you.
95 extension := ".cert"
96 if c.kind == kindCA {
97 extension = ".crt"
98 }
99 return filepath.Join(c.root, "certs", c.name+extension)
100 default:
101 panic("unexpected file kind type " + k)
102 }
103}
104
105// ensureKey loads or generates-then-saves the private key for this
106// Certificate.
107func (c *Certificate) ensureKey() (crypto.Signer, error) {
108 path := c.path(fileKindKey)
109 _, err := os.Stat(path)
110 switch {
111 case err == nil:
112 return c.loadKey()
113 case errors.Is(err, os.ErrNotExist):
114 epath := c.path(fileKindKeyEncrypted)
115 if _, err = os.Stat(epath); err == nil {
116 return nil, fmt.Errorf("plaintext key at %q not found, but exists encrypted at %q - please decrypt using secretstore", path, epath)
117 }
118 return c.generateKey()
119 default:
120 return nil, fmt.Errorf("could not read key: %w", err)
121 }
122}
123
124func (c *Certificate) loadKey() (crypto.Signer, error) {
125 path := c.path(fileKindKey)
126 bytes, err := os.ReadFile(path)
127 if err != nil {
128 return nil, err
129 }
130 block, _ := pem.Decode(bytes)
131 if block == nil {
132 return nil, fmt.Errorf("no PEM block found")
133 }
134 if block.Type != "PRIVATE KEY" {
135 return nil, fmt.Errorf("unexpected PEM block: %q", block.Type)
136 }
137 key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
138 if err != nil {
139 return nil, err
140 }
141 if k, ok := key.(ed25519.PrivateKey); ok {
142 return k, nil
143 }
144 return nil, fmt.Errorf("not an ED25519 key")
145}
146
147func (c *Certificate) generateKey() (crypto.Signer, error) {
148 _, priv, err := ed25519.GenerateKey(rand.Reader)
149 if err != nil {
150 return nil, err
151 }
152
153 pkcs8, err := x509.MarshalPKCS8PrivateKey(priv)
154 if err != nil {
155 return nil, err
156 }
157
158 block := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: pkcs8})
159 path := c.path(fileKindKey)
160 os.MkdirAll(filepath.Dir(path), 0700)
161 log.Printf("Saving %s key to %s ...", c.name, path)
162 if err := os.WriteFile(path, block, 0600); err != nil {
163 return nil, err
164 }
165
166 return priv, nil
167}
168
169// ensureCert loads or generates-then-saves the X.509 certificate for the
170// Certificate.
171func (c *Certificate) ensureCert() (*x509.Certificate, error) {
172 path := c.path(fileKindCert)
173 _, err := os.Stat(path)
174 switch {
175 case err == nil:
176 cert, err := c.loadCert()
177 switch err {
178 case nil:
179 return cert, nil
180 case errExpired:
181 return c.generateCert()
182 default:
183 return nil, err
184 }
185 case errors.Is(err, os.ErrNotExist):
186 return c.generateCert()
187 default:
188 return nil, fmt.Errorf("could not read cert: %w", err)
189 }
190}
191
192func (c *Certificate) generateCert() (*x509.Certificate, error) {
193 serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 127)
194 serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
195 if err != nil {
196 return nil, err
197 }
198
199 notAfter := unknownNotAfter
200 if c.duration != 0 {
201 notAfter = time.Now().Add(c.duration)
202 }
203 template := c.template()
204 template.SerialNumber = serialNumber
205 template.NotBefore = time.Now()
206 template.NotAfter = notAfter
207
208 parent := template
209 skey, err := c.ensureKey()
210 if err != nil {
211 return nil, fmt.Errorf("when ensuring key: %w", err)
212 }
213 pkey := skey.Public()
214 caskey := skey
215 if c.issuer != nil {
216 caskey, err = c.issuer.ensureKey()
217 if err != nil {
218 return nil, fmt.Errorf("when ensuring CA key: %w", err)
219 }
220 cacert, err := c.issuer.ensureCert()
221 if err != nil {
222 return nil, fmt.Errorf("when ensuring CA cert: %w", err)
223 }
224 parent = cacert
225 }
226
227 bytes, err := x509.CreateCertificate(rand.Reader, template, parent, pkey, caskey)
228 if err != nil {
229 return nil, fmt.Errorf("issuing certificate failed: %w", err)
230 }
231
232 block := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: bytes})
233 path := c.path(fileKindCert)
234 os.MkdirAll(filepath.Dir(path), 0700)
235 log.Printf("Saving %s cert to %s ...", c.name, path)
236 if err := os.WriteFile(path, block, 0600); err != nil {
237 return nil, err
238 }
239
240 return x509.ParseCertificate(bytes)
241}
242
243// errExpired is returned if the cert exists on disk but has (nearly) expired.
244var errExpired = errors.New("certificate expired")
245
246func (c *Certificate) loadCert() (*x509.Certificate, error) {
247 path := c.path(fileKindCert)
248 b, err := os.ReadFile(path)
249 if err != nil {
250 return nil, err
251 }
252
253 block, _ := pem.Decode(b)
254 if block == nil {
255 return nil, fmt.Errorf("no PEM block found")
256 }
257 if block.Type != "CERTIFICATE" {
258 return nil, fmt.Errorf("unexpected PEM block: %q", block.Type)
259 }
260 cert, err := x509.ParseCertificate(block.Bytes)
261 if err != nil {
262 return nil, err
263 }
264 if time.Now().Add(time.Hour).After(cert.NotAfter) {
265 return nil, errExpired
266 }
267 pkey, ok := cert.PublicKey.(ed25519.PublicKey)
268 if !ok {
269 return nil, fmt.Errorf("not a ED25519 cert")
270 }
271 skey, err := c.ensureKey()
272 if err != nil {
273 return nil, fmt.Errorf("when ensuring key: %w", err)
274 }
275 if !bytes.Equal(pkey, skey.Public().(ed25519.PublicKey)) {
276 return nil, fmt.Errorf("issued for different key")
277 }
278
279 template := c.template()
280 if err := compareCertData(template, cert); err != nil {
281 return nil, err
282 }
283 return cert, nil
284}
285
286// Ensure makes sure the given Certificate (and all of its' issuers) have
287// corresponding private keys and X.509 certificates on disk, generating things
288// as necessary.
289func (c *Certificate) Ensure() error {
290 cert, err := c.ensureCert()
291 if err != nil {
292 return fmt.Errorf("when ensuring cert %s: %w", c.name, err)
293 }
294 _ = cert
295
296 return nil
297}