blob: 062c9344b48e7e66c862c8d183af3e14aa6709aa [file] [log] [blame]
package pki
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"os"
"github.com/golang/glog"
)
// DeveloperCredentialsLocation returns the path containing HSPKI credentials
// on developer machines. These are provisioned by //cluster/prodaccess, and
// are used if available.
func DeveloperCredentialsLocation() (string, error) {
cfgDir, err := os.UserConfigDir()
if err != nil {
return "", fmt.Errorf("UserConfigDir: %w", err)
}
return fmt.Sprintf("%s/hspki", cfgDir), nil
}
// DeveloperCredentialsPrincipal returns the principal/DN for which the local
// developer credentials are provisioned.
func DeveloperCredentialsPrincipal() (string, error) {
creds, err := loadDeveloperCredentials()
if err != nil {
return "", fmt.Errorf("when loading developer credentials: %w", err)
}
pair, err := tls.X509KeyPair(creds.cert, creds.key)
if err != nil {
return "", fmt.Errorf("when loading developer client cert: %w", err)
}
cert, err := x509.ParseCertificate(pair.Certificate[0])
if err != nil {
return "", fmt.Errorf("when parsing developer client cert: %w", err)
}
return cert.Subject.CommonName, nil
}
type creds struct {
ca []byte
cert []byte
key []byte
}
func loadDeveloperCredentials() (*creds, error) {
path, err := DeveloperCredentialsLocation()
if err != nil {
return nil, fmt.Errorf("DeveloperCredentialsLocation: %w", err)
}
c := creds{}
for _, el := range []struct {
target *[]byte
path string
}{
{&c.ca, path + "/" + "ca.crt"},
{&c.cert, path + "/" + "tls.crt"},
{&c.key, path + "/" + "tls.key"},
} {
data, err := ioutil.ReadFile(el.path)
if err != nil {
return nil, fmt.Errorf("ReadFile(%q): %w", el.path, err)
}
*el.target = data
}
return &c, nil
}
func loadFlagCredentials() (*creds, error) {
c := creds{}
for _, el := range []struct {
target *[]byte
path string
}{
{&c.ca, flagCAPath},
{&c.cert, flagCertificatePath},
{&c.key, flagKeyPath},
} {
data, err := ioutil.ReadFile(el.path)
if err != nil {
return nil, fmt.Errorf("ReadFile(%q): %w", el.path, err)
}
*el.target = data
}
return &c, nil
}
func loadCredentials() (*creds, error) {
dev, err := loadDeveloperCredentials()
if err == nil {
return dev, nil
}
glog.Warningf("Could not load developer PKI credentials: %v", err)
fl, err := loadFlagCredentials()
if err == nil {
return fl, err
}
glog.Warningf("Could not load flag-defined PKI credentials: %v", err)
return nil, fmt.Errorf("could not load PKI credentials (hint: run `prodaccess` to set up developer certs or add `-- -hspki_disable` to bazel run command)")
}