# encoding: utf-8
import json
import logging
import os
from six import StringIO
import subprocess


logger = logging.getLogger(__name__)


_std_subj = {
    "C": "PL",
    "ST": "Mazowieckie",
    "L": "Warsaw",
    "O": "Warsaw Hackerspace",
    "OU": "clustercfg",
}

_ca_csr = {
  "CN": "Prototype Test Certificate Authority",
  "key": {
    "algo": "rsa",
    "size": 2048
  },
  "names": [ _std_subj ],
}

_ca_config = {
  "signing": {
    "default": {
      "expiry": "168h"
    },
    "profiles": {
      "server": {
        "expiry": "8760h",
        "usages": [
          "signing",
          "key encipherment",
          "server auth"
        ]
      },
      "client": {
        "expiry": "8760h",
        "usages": [
          "signing",
          "key encipherment",
          "client auth"
        ]
      },
      "client-server": {
        "expiry": "8760h",
        "usages": [
          "signing",
          "key encipherment",
          "server auth",
          "client auth"
        ]
      }
    }
  }
}


class CAException(Exception):
    pass


class CA(object):
    def __init__(self, secretstore, certdir, short, cn):
        self.ss = secretstore
        self.cdir = certdir
        self.short = short
        self.cn = cn
        self._init_ca()

    def __str__(self):
        return 'CN={} ({})'.format(self.cn, self.short)

    @property
    def _secret_key(self):
        return 'ca-{}.key'.format(self.short)

    @property
    def _cert(self):
        return os.path.join(self.cdir, 'ca-{}.crt'.format(self.short))

    @property
    def cert_data(self):
        with open(self._cert) as f:
            return f.read()

    def _cfssl_call(self, args, obj=None, stdin=None):
        p = subprocess.Popen(['cfssl'] + args,
                             stdin=subprocess.PIPE, stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE)
        if obj is not None:
            stdin = json.dumps(obj)

        outs, errs = p.communicate(stdin.encode())
        if p.returncode != 0:
            raise Exception(
                'cfssl failed. stderr: %r, stdout: %r, code: %r' % (
                    errs, outs, p.returncode))

        out = json.loads(outs)
        return out

    def _init_ca(self):
        if self.ss.exists(self._secret_key):
            return

        ca_csr = dict(_ca_csr)
        ca_csr['CN'] = self.cn

        logger.info("{}: Generating CA...".format(self))
        out = self._cfssl_call(['gencert', '-initca', '-'], obj=ca_csr)

        f = self.ss.open(self._secret_key, 'w')
        f.write(out['key'])
        f.close()

        f = open(self._cert, 'w')
        f.write(out['cert'])
        f.close()

    def gen_key(self, hosts, o=_std_subj['O'], ou=_std_subj['OU'], save=None):
        """お元気ですか？"""
        cfg = {
            "CN": hosts[0],
            "hosts": hosts,
            "key": {
                "algo": "rsa",
                "size": 4096,
            },
            "names": [
                {
                    "C": _std_subj["C"],
                    "ST": _std_subj["ST"],
                    "L": _std_subj["L"],
                    "O": o,
                    "OU": ou,
                },
            ],
        }
        cfg.update(_ca_config)
        logger.info("{}: Generating key/CSR for {}".format(self, hosts))
        out = self._cfssl_call(['genkey', '-'], obj=cfg)

        key, csr = out['key'], out['csr']
        if save is not None:
            logging.info("{}: Saving new key to secret {}".format(self, save))
            f = self.ss.open(save, 'w')
            f.write(key)
            f.close()

        return key, csr

    def sign(self, csr, save=None):
        logging.info("{}: Signing CSR".format(self))
        ca = self._cert
        cakey = self.ss.plaintext(self._secret_key)
        out = self._cfssl_call(['sign', '-ca=' + ca, '-ca-key=' + cakey,
                                '-profile=client-server', '-'], stdin=csr)
        cert = out['cert']
        if save is not None:
            name = os.path.join(self.cdir, save)
            logging.info("{}: Saving new certificate to {}".format(self, name))
            f = open(name, 'w')
            f.write(cert)
            f.close()

        return cert

    def upload(self, c, remote_cert):
        logger.info("Uploading CA {} to {}".format(self, remote_cert))
        c.put(local=self._cert, remote=remote_cert)

    def make_cert(self, *a, **kw):
        return ManagedCertificate(self, *a, **kw)


class ManagedCertificate(object):
    def __init__(self, ca, name, hosts, o=None, ou=None):
        self.ca = ca

        self.hosts = hosts
        self.name = name
        self.key = '{}.key'.format(name)
        self.cert = '{}.cert'.format(name)
        self.o = o
        self.ou = ou

        self.ensure()

    def __str__(self):
        return '{}'.format(self.name)

    @property
    def key_exists(self):
        return self.ca.ss.exists(self.key)

    @property
    def key_data(self):
        f = open(self.ca.ss.open(self.key))
        d = f.read()
        f.close()
        return d

    @property
    def key_path(self):
        return self.ca.ss.plaintext(self.key)

    @property
    def cert_path(self):
        return os.path.join(self.ca.cdir, self.cert)

    @property
    def cert_exists(self):
        return os.path.exists(self.cert_path)

    @property
    def cert_data(self):
        with open(self.cert_path) as f:
            return f.read()

    def ensure(self):
        if self.key_exists and self.cert_exists:
            return

        logger.info("{}: Generating...".format(self))
        key, csr = self.ca.gen_key(self.hosts, o=self.o, ou=self.ou, save=self.key)
        self.ca.sign(csr, save=self.cert)

    def upload(self, c, remote_cert, remote_key, concat_ca=False):
        logger.info("Uploading Cert {} to {} & {}".format(self, remote_cert, remote_key))
        if concat_ca:
            f = StringIO(self.cert_data + self.ca.cert_data)
            c.put(local=f, remote=remote_cert)
        else:
            c.put(local=self.cert_path, remote=remote_cert)
        c.put(local=self.key_path, remote=remote_key)

    def upload_pki(self, c, pki, concat_ca=False):
        self.upload(c, pki['cert'], pki['key'], concat_ca)
