Serge Bazanski | 9f0e1e8 | 2023-03-31 22:36:54 +0000 | [diff] [blame] | 1 | package certs |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "crypto" |
| 6 | "crypto/ed25519" |
| 7 | "crypto/rand" |
| 8 | "crypto/x509" |
| 9 | "encoding/pem" |
| 10 | "errors" |
| 11 | "fmt" |
| 12 | "log" |
| 13 | "math/big" |
| 14 | "net" |
| 15 | "os" |
| 16 | "path/filepath" |
| 17 | "time" |
| 18 | ) |
| 19 | |
| 20 | // Certificate is a higher-level descriptor of an intent to generate a |
| 21 | // certificate and corresponding Ed25519 keypair on disk. |
| 22 | type Certificate struct { |
| 23 | // uniquer name for this cert, used to calculate filesystem paths. |
| 24 | name string |
| 25 | // root directory where all certs are stored. |
| 26 | root string |
| 27 | // duration used to determine TimeAfter. If not set, the certificate will |
| 28 | // never expire. |
| 29 | duration time.Duration |
| 30 | |
| 31 | kind certificateKind |
| 32 | |
| 33 | // cn is the subject common name that's going to be produced in the X.509 |
| 34 | // certificate. |
| 35 | cn string |
| 36 | // o is the subject organziation that's going to be produced in the X.509 |
| 37 | // certificate. |
| 38 | o string |
| 39 | // san are the DNS alternate names that are going to be produced in the |
| 40 | // X.509 certificate. |
| 41 | san []string |
| 42 | // ips are the IP alternate names that are going to be produced in the |
| 43 | // X.509 certificate. |
| 44 | ips []net.IP |
| 45 | |
| 46 | // issuer, if set, is the certificate that will sign this certificate. If |
| 47 | // not set, the certificate will be self-signed. |
| 48 | issuer *Certificate |
| 49 | } |
| 50 | |
| 51 | // Paths returns local filesystem paths to the CA certificate, certificate and |
| 52 | // key respectively. If the certificate is self signed, the CA path returned |
| 53 | // will be empty. These files might or might not live on the file system - you |
| 54 | // should first call Ensure to make sure they do. |
| 55 | func (c *Certificate) Paths() (caPath, certPath, keyPath string) { |
| 56 | if c.issuer != nil { |
| 57 | caPath = c.issuer.path(fileKindCert) |
| 58 | } |
| 59 | certPath = c.path(fileKindCert) |
| 60 | keyPath = c.path(fileKindKey) |
| 61 | return |
| 62 | } |
| 63 | |
| 64 | type certificateKind string |
| 65 | |
| 66 | const ( |
| 67 | kindServer certificateKind = "server" |
| 68 | kindClient certificateKind = "client" |
| 69 | kindClientServer certificateKind = "client-server" |
| 70 | kindCA certificateKind = "ca" |
| 71 | kindProdvider certificateKind = "prodvider" |
| 72 | ) |
| 73 | |
| 74 | type fileKind string |
| 75 | |
| 76 | const ( |
| 77 | fileKindKey fileKind = "key" |
| 78 | fileKindKeyEncrypted fileKind = "key-encrypted" |
| 79 | fileKindCert fileKind = "cert" |
| 80 | ) |
| 81 | |
| 82 | // path returns the path to the generated fileKind for this Certificate. |
| 83 | func (c *Certificate) path(k fileKind) string { |
| 84 | switch k { |
| 85 | case fileKindKeyEncrypted: |
| 86 | return filepath.Join(c.root, "secrets", "cipher", c.name+".key") |
| 87 | case fileKindKey: |
| 88 | return filepath.Join(c.root, "secrets", "plain", c.name+".key") |
| 89 | case fileKindCert: |
| 90 | // clustercfg.py compat: CA certs end in .crt, non-CA certs end in .cert. |
| 91 | // We're keeping this accidental convention to avoid spurious nix rebuilds |
| 92 | // when migrating. |
| 93 | // |
| 94 | // Feel free to fix it if it annoys you. |
| 95 | extension := ".cert" |
| 96 | if c.kind == kindCA { |
| 97 | extension = ".crt" |
| 98 | } |
| 99 | return filepath.Join(c.root, "certs", c.name+extension) |
| 100 | default: |
| 101 | panic("unexpected file kind type " + k) |
| 102 | } |
| 103 | } |
| 104 | |
| 105 | // ensureKey loads or generates-then-saves the private key for this |
| 106 | // Certificate. |
| 107 | func (c *Certificate) ensureKey() (crypto.Signer, error) { |
| 108 | path := c.path(fileKindKey) |
| 109 | _, err := os.Stat(path) |
| 110 | switch { |
| 111 | case err == nil: |
| 112 | return c.loadKey() |
| 113 | case errors.Is(err, os.ErrNotExist): |
| 114 | epath := c.path(fileKindKeyEncrypted) |
| 115 | if _, err = os.Stat(epath); err == nil { |
| 116 | return nil, fmt.Errorf("plaintext key at %q not found, but exists encrypted at %q - please decrypt using secretstore", path, epath) |
| 117 | } |
| 118 | return c.generateKey() |
| 119 | default: |
| 120 | return nil, fmt.Errorf("could not read key: %w", err) |
| 121 | } |
| 122 | } |
| 123 | |
| 124 | func (c *Certificate) loadKey() (crypto.Signer, error) { |
| 125 | path := c.path(fileKindKey) |
| 126 | bytes, err := os.ReadFile(path) |
| 127 | if err != nil { |
| 128 | return nil, err |
| 129 | } |
| 130 | block, _ := pem.Decode(bytes) |
| 131 | if block == nil { |
| 132 | return nil, fmt.Errorf("no PEM block found") |
| 133 | } |
| 134 | if block.Type != "PRIVATE KEY" { |
| 135 | return nil, fmt.Errorf("unexpected PEM block: %q", block.Type) |
| 136 | } |
| 137 | key, err := x509.ParsePKCS8PrivateKey(block.Bytes) |
| 138 | if err != nil { |
| 139 | return nil, err |
| 140 | } |
| 141 | if k, ok := key.(ed25519.PrivateKey); ok { |
| 142 | return k, nil |
| 143 | } |
| 144 | return nil, fmt.Errorf("not an ED25519 key") |
| 145 | } |
| 146 | |
| 147 | func (c *Certificate) generateKey() (crypto.Signer, error) { |
| 148 | _, priv, err := ed25519.GenerateKey(rand.Reader) |
| 149 | if err != nil { |
| 150 | return nil, err |
| 151 | } |
| 152 | |
| 153 | pkcs8, err := x509.MarshalPKCS8PrivateKey(priv) |
| 154 | if err != nil { |
| 155 | return nil, err |
| 156 | } |
| 157 | |
| 158 | block := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: pkcs8}) |
| 159 | path := c.path(fileKindKey) |
| 160 | os.MkdirAll(filepath.Dir(path), 0700) |
| 161 | log.Printf("Saving %s key to %s ...", c.name, path) |
| 162 | if err := os.WriteFile(path, block, 0600); err != nil { |
| 163 | return nil, err |
| 164 | } |
| 165 | |
| 166 | return priv, nil |
| 167 | } |
| 168 | |
| 169 | // ensureCert loads or generates-then-saves the X.509 certificate for the |
| 170 | // Certificate. |
| 171 | func (c *Certificate) ensureCert() (*x509.Certificate, error) { |
| 172 | path := c.path(fileKindCert) |
| 173 | _, err := os.Stat(path) |
| 174 | switch { |
| 175 | case err == nil: |
| 176 | cert, err := c.loadCert() |
| 177 | switch err { |
| 178 | case nil: |
| 179 | return cert, nil |
| 180 | case errExpired: |
| 181 | return c.generateCert() |
| 182 | default: |
| 183 | return nil, err |
| 184 | } |
| 185 | case errors.Is(err, os.ErrNotExist): |
| 186 | return c.generateCert() |
| 187 | default: |
| 188 | return nil, fmt.Errorf("could not read cert: %w", err) |
| 189 | } |
| 190 | } |
| 191 | |
| 192 | func (c *Certificate) generateCert() (*x509.Certificate, error) { |
| 193 | serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 127) |
| 194 | serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) |
| 195 | if err != nil { |
| 196 | return nil, err |
| 197 | } |
| 198 | |
| 199 | notAfter := unknownNotAfter |
| 200 | if c.duration != 0 { |
| 201 | notAfter = time.Now().Add(c.duration) |
| 202 | } |
| 203 | template := c.template() |
| 204 | template.SerialNumber = serialNumber |
| 205 | template.NotBefore = time.Now() |
| 206 | template.NotAfter = notAfter |
| 207 | |
| 208 | parent := template |
| 209 | skey, err := c.ensureKey() |
| 210 | if err != nil { |
| 211 | return nil, fmt.Errorf("when ensuring key: %w", err) |
| 212 | } |
| 213 | pkey := skey.Public() |
| 214 | caskey := skey |
| 215 | if c.issuer != nil { |
| 216 | caskey, err = c.issuer.ensureKey() |
| 217 | if err != nil { |
| 218 | return nil, fmt.Errorf("when ensuring CA key: %w", err) |
| 219 | } |
| 220 | cacert, err := c.issuer.ensureCert() |
| 221 | if err != nil { |
| 222 | return nil, fmt.Errorf("when ensuring CA cert: %w", err) |
| 223 | } |
| 224 | parent = cacert |
| 225 | } |
| 226 | |
| 227 | bytes, err := x509.CreateCertificate(rand.Reader, template, parent, pkey, caskey) |
| 228 | if err != nil { |
| 229 | return nil, fmt.Errorf("issuing certificate failed: %w", err) |
| 230 | } |
| 231 | |
| 232 | block := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: bytes}) |
| 233 | path := c.path(fileKindCert) |
| 234 | os.MkdirAll(filepath.Dir(path), 0700) |
| 235 | log.Printf("Saving %s cert to %s ...", c.name, path) |
| 236 | if err := os.WriteFile(path, block, 0600); err != nil { |
| 237 | return nil, err |
| 238 | } |
| 239 | |
| 240 | return x509.ParseCertificate(bytes) |
| 241 | } |
| 242 | |
| 243 | // errExpired is returned if the cert exists on disk but has (nearly) expired. |
| 244 | var errExpired = errors.New("certificate expired") |
| 245 | |
| 246 | func (c *Certificate) loadCert() (*x509.Certificate, error) { |
| 247 | path := c.path(fileKindCert) |
| 248 | b, err := os.ReadFile(path) |
| 249 | if err != nil { |
| 250 | return nil, err |
| 251 | } |
| 252 | |
| 253 | block, _ := pem.Decode(b) |
| 254 | if block == nil { |
| 255 | return nil, fmt.Errorf("no PEM block found") |
| 256 | } |
| 257 | if block.Type != "CERTIFICATE" { |
| 258 | return nil, fmt.Errorf("unexpected PEM block: %q", block.Type) |
| 259 | } |
| 260 | cert, err := x509.ParseCertificate(block.Bytes) |
| 261 | if err != nil { |
| 262 | return nil, err |
| 263 | } |
| 264 | if time.Now().Add(time.Hour).After(cert.NotAfter) { |
| 265 | return nil, errExpired |
| 266 | } |
| 267 | pkey, ok := cert.PublicKey.(ed25519.PublicKey) |
| 268 | if !ok { |
| 269 | return nil, fmt.Errorf("not a ED25519 cert") |
| 270 | } |
| 271 | skey, err := c.ensureKey() |
| 272 | if err != nil { |
| 273 | return nil, fmt.Errorf("when ensuring key: %w", err) |
| 274 | } |
| 275 | if !bytes.Equal(pkey, skey.Public().(ed25519.PublicKey)) { |
| 276 | return nil, fmt.Errorf("issued for different key") |
| 277 | } |
| 278 | |
| 279 | template := c.template() |
| 280 | if err := compareCertData(template, cert); err != nil { |
| 281 | return nil, err |
| 282 | } |
| 283 | return cert, nil |
| 284 | } |
| 285 | |
| 286 | // Ensure makes sure the given Certificate (and all of its' issuers) have |
| 287 | // corresponding private keys and X.509 certificates on disk, generating things |
| 288 | // as necessary. |
| 289 | func (c *Certificate) Ensure() error { |
| 290 | cert, err := c.ensureCert() |
| 291 | if err != nil { |
| 292 | return fmt.Errorf("when ensuring cert %s: %w", c.name, err) |
| 293 | } |
| 294 | _ = cert |
| 295 | |
| 296 | return nil |
| 297 | } |