#!/usr/bin/env python

from builtins import object

import datetime
from io import BytesIO
import logging
import os
import tempfile
import subprocess
import sys

from cryptography import x509
from cryptography.hazmat.backends import default_backend
import fabric

import secretstore


cluster = 'k0.hswaw.net'
remote_root = '/opt/hscloud'
local_root = os.getenv('hscloud_root')

if local_root is None:
    raise Exception("Please source env.sh")

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())


def decrypt(base):
    src = os.path.join(local_root, 'cluster/secrets/cipher', base)
    dst = os.path.join(local_root, 'cluster/secrets/plain', base)
    secretstore.decrypt(src, dst)


class PKI(object):
    def __init__(self):
        self.cacert = os.path.join(local_root, 'cluster/certs/ca.crt')
        self.cakey = os.path.join(local_root, 'cluster/secrets/plain/ca.key')

        if not os.path.exists(self.cakey):
            decrypt('ca.key')

    def sign(self, csr, crt, conf, days=365):
        logger.info('pki: signing {} for {} days'.format(csr, days))
        subprocess.check_call([
            'openssl', 'x509', '-req',
            '-in', csr,
            '-CA', self.cacert,
            '-CAkey', self.cakey,
            '-out', crt,
            '-days', str(days),
        ] + ([
            '-extensions', 'SAN', '-extfile', conf,
        ] if conf else []))


class Subject(object):
    hswaw = "Stowarzyszenie Warszawski Hackerspace"
    def __init__(self, o, ou, cn):
        self.c = 'PL'
        self.st = 'Mazowieckie'
        self.l = 'Warszawa'
        self.o = o
        self.ou = ou
        self.cn = cn

    @property
    def parts(self):
        return {
            'C': self.c,
            'ST': self.st,
            'L': self.l,
            'O': self.o,
            'OU': self.ou,
            'CN': self.cn,
        }

    def __str__(self):
        parts = self.parts
        res = []
        for p in ['C', 'ST', 'L', 'O', 'OU', 'CN']:
            res.append('/{}={}'.format(p, parts[p]))
        return ''.join(res)

def _file_exists(c, filename):
    res = c.run('stat "{}"'.format(filename), warn=True, hide=True)
    return res.exited == 0

def openssl_config(san):
    with open(os.path.join(local_root, 'cluster/openssl.cnf'), 'rb') as f:
        config = BytesIO(f.read())

    if san:
        config.seek(0, 2)
        config.write(b'\n[SAN]\n')
        config.write(b'subjectAltName = @alt_names\n')
        config.write(b'basicConstraints = CA:FALSE\nkeyUsage = nonRepudiation, digitalSignature, keyEncipherment\n')
        config.write(b'[alt_names]\n')

        ipcnt = 1
        dnscnt = 1
        for s in san:
            parts = s.split(':')
            if s.startswith('DNS'):
                config.write('DNS.{} = {}\n'.format(dnscnt, parts[1]).encode())
                dnscnt += 1
            elif s.startswith('IP'):
                config.write('IP.{} = {}\n'.format(ipcnt, parts[1]).encode())
                ipcnt += 1

    f = tempfile.NamedTemporaryFile(delete=False)
    path = f.name
    f.write(config.getvalue())
    f.close()

    return path

def remote_cert(pki, c, fqdn, cert_name, subj, san=[], days=365):
    logger.info("{}/{}: remote cert".format(fqdn, cert_name))

    remote_key = os.path.join(remote_root, '{}.key'.format(cert_name))
    remote_cert = os.path.join(remote_root, '{}.crt'.format(cert_name))
    remote_csr = os.path.join(remote_root, '{}.csr'.format(cert_name))
    remote_config = os.path.join(remote_root, 'openssl.cnf')

    generate_cert = False
    if not _file_exists(c, remote_key):
        logger.info("{}/{}: generating key".format(fqdn, cert_name))
        c.run('openssl genrsa -out "{}" 4096'.format(remote_key), hide=True)
        genereate_cert = True

    b = BytesIO()
    try:
        c.get(local=b, remote=remote_cert)
        cert = x509.load_pem_x509_certificate(b.getvalue(), default_backend())
        delta = cert.not_valid_after - datetime.datetime.now()
        logger.info("{}/{}: existing cert expiry: {}".format(fqdn, cert_name, delta))
        if delta.total_seconds() < 3600 * 24 * 60:
            logger.info("{}/{}: expires soon, regenerating".format(fqdn, cert_name))
            generate_cert = True
    except (FileNotFoundError, ValueError):
        generate_cert = True

    if not generate_cert:
        return False


    local_config = openssl_config(san)
    c.put(local=local_config, remote=remote_config)

    c.run("""
        nix-shell -p openssl --command "openssl req -new -key {remote_key} -out {remote_csr} -subj '{subj}' -config {remote_config} -reqexts SAN"
    """.format(remote_key=remote_key, remote_csr=remote_csr, subj=str(subj), remote_config=remote_config))

    local_csr_f = tempfile.NamedTemporaryFile(delete=False)
    local_csr = local_csr_f.name
    local_csr_f.close()

    local_cert = os.path.join(local_root, 'cluster/certs', '{}-{}.crt'.format(fqdn, cert_name))

    c.get(local=local_csr, remote=remote_csr)

    pki.sign(local_csr, local_cert, local_config, days)

    c.put(local=local_cert, remote=remote_cert)

    os.remove(local_csr)
    os.remove(local_config)

    return True


def shared_cert(pki, c, fqdn, cert_name, subj, san=[], days=365):
    logger.info("{}/{}: shared cert".format(fqdn, cert_name))

    local_key = os.path.join(local_root, 'cluster/secrets/plain', '{}.key'.format(cert_name))
    local_cert = os.path.join(local_root, 'cluster/certs', '{}.crt'.format(cert_name))
    remote_key = os.path.join(remote_root, '{}.key'.format(cert_name))
    remote_cert = os.path.join(remote_root, '{}.crt'.format(cert_name))

    generate_cert = False
    if not os.path.exists(local_key):
        try:
            decrypt('{}.key'.format(cert_name))
        except subprocess.CalledProcessError:
            logger.info("{}/{}: generating key".format(fqdn, cert_name))
            subprocess.check_call([
                'openssl', 'genrsa', '-out', local_key, '4096',
            ])
            generate_cert = True

    if os.path.exists(local_cert):
        with open(local_cert, 'rb') as f:
            b = f.read()
            cert = x509.load_pem_x509_certificate(b, default_backend())
            delta = cert.not_valid_after - datetime.datetime.now()
            logger.info("{}/{}: existing cert expiry: {}".format(fqdn, cert_name, delta))
            if delta.total_seconds() < 3600 * 24 * 60:
                logger.info("{}/{}: expires soon, regenerating".format(fqdn, cert_name))
                generate_cert = True
    else:
        generate_cert = True

    if generate_cert:
        local_csr_f = tempfile.NamedTemporaryFile(delete=False)
        local_csr = local_csr_f.name
        local_csr_f.close()

        local_config = openssl_config(san)

        subprocess.check_call([
            'openssl', 'req', '-new',
            '-key', local_key,
            '-out', local_csr,
            '-subj', str(subj),
            '-config', local_config,
        ] + ([
            '-reqexts', 'SAN',
        ] if san else []))

        pki.sign(local_csr, local_cert, local_config if san else None, days)
        os.remove(local_csr)
        os.remove(local_config)

    c.put(local=local_key, remote=remote_key)
    c.put(local=local_cert, remote=remote_cert)

    return True


def configure_k8s(username, ca, cert, key):
    subprocess.check_call([
        'kubectl', 'config',
        'set-cluster', cluster,
        '--certificate-authority=' + ca,
        '--embed-certs=true',
        '--server=https://' + cluster + ':4001',
    ])
    subprocess.check_call([
        'kubectl', 'config',
        'set-credentials', username,
        '--client-certificate=' + cert,
        '--client-key=' + key,
        '--embed-certs=true',
    ])
    subprocess.check_call([
        'kubectl', 'config',
        'set-context', cluster,
        '--cluster=' + cluster,
        '--user=' + username,
    ])
    subprocess.check_call([
        'kubectl', 'config',
        'use-context', cluster,
    ])

def admincreds(args):
    if len(args) != 1:
        sys.stderr.write("Usage: admincreds q3k\n")
        return 1
    username = args[0]

    pki = PKI()

    local_key = os.path.join(local_root, '.kubectl/admin.key')
    local_cert = os.path.join(local_root, '.kubectl/admin.crt')
    local_csr = os.path.join(local_root, '.kubectl/admin.csr')

    kubectl = os.path.join(local_root, '.kubectl')
    if not os.path.exists(kubectl):
        os.mkdir(kubectl)

    generate_cert = False
    if not os.path.exists(local_key):
        subprocess.check_call([
            'openssl', 'genrsa', '-out', local_key, '4096',
        ])
        generate_cert = True

    if os.path.exists(local_cert):
        with open(local_cert, 'rb') as f:
            b = f.read()
            cert = x509.load_pem_x509_certificate(b, default_backend())
            delta = cert.not_valid_after - datetime.datetime.now()
            logger.info("admin: existing cert expiry: {}".format(delta))
            if delta.total_seconds() < 3600 * 24:
                logger.info("admin: expires soon, regenerating")
                generate_cert = True
    else:
        generate_cert = True

    if not generate_cert:
        return configure_k8s(username, pki.cacert, local_cert, local_key)

    subj = Subject('system:masters', "Kubernetes Admin Account for {}".format(username), username)

    subprocess.check_call([
        'openssl', 'req', '-new',
        '-key', local_key,
        '-out', local_csr,
        '-subj', str(subj),
    ])

    pki.sign(local_csr, local_cert, None, 5)

    configure_k8s(username, pki.cacert, local_cert, local_key)


def nodestrap(args):
    if len(args) != 1:
        sys.stderr.write("Usage: nodestrap bc01n01.hswaw.net\n")
        return 1
    fqdn = args[0]

    logger.info("Nodestrapping {}...".format(fqdn))

    c = fabric.Connection('root@{}'.format(fqdn))
    p = PKI()

    modified = False
    modified |= remote_cert(p, c, fqdn, "node", Subject(Subject.hswaw, 'Node Certificate', fqdn))
    modified |= remote_cert(p, c, fqdn, "kube-node", Subject('system:nodes', 'Kubelet Certificate', 'system:node:' + fqdn), san=["DNS:"+fqdn,])
    for component in ['controller-manager', 'proxy', 'scheduler']:
        o = 'system:kube-{}'.format(component)
        ou = 'Kuberneter Component {}'.format(component)
        modified |= shared_cert(p, c, fqdn, 'kube-{}'.format(component), Subject(o, ou, o))
    modified |= shared_cert(p, c, fqdn, 'kube-apiserver', Subject(Subject.hswaw, 'Kubernetes API', cluster), san=['IP:10.10.12.1', 'DNS:' + cluster])
    modified |= shared_cert(p, c, fqdn, 'kube-serviceaccounts', Subject(Subject.hswaw, 'Kubernetes Service Account Signer', 'service-accounts'))
    modified |= shared_cert(p, c, fqdn, 'kube-calico', Subject(Subject.hswaw, 'Kubernetes Calico Account', 'calico'))

    c.run('nixos-rebuild switch')

def usage():
    sys.stderr.write("Usage: {} <nodestrap|admincreds>\n".format(sys.argv[0]))

def main():
    if len(sys.argv) < 2:
        usage()
        return 1

    mode = sys.argv[1]
    if mode == "nodestrap":
        return nodestrap(sys.argv[2:])
    elif mode == "admincreds":
        return admincreds(sys.argv[2:])
    else:
        usage()
        return 1

if __name__ == '__main__':
    sys.exit(main() or 0)
