tools/secretstore: add sync command, re-encrypt

This kills two birds with one stone:

 - update the secretstore tool to be slightly smarter about secrets, to
   the point where we can now just point it at a secret directory and
   ask it to 'sync' all secrets in there
 - runs the new fancy sync command on all keys to update them, which
   is a follow up to gerrit/328.

Change-Id: I0eec4a3e8afcd9481b0b248154983aac25657c40
diff --git a/tools/secretstore.py b/tools/secretstore.py
index 8ff5e2a..85a3164 100644
--- a/tools/secretstore.py
+++ b/tools/secretstore.py
@@ -1,12 +1,47 @@
 #!/usr/bin/env python3
 
-# A little tool to encrypt/decrypt git secrets. Kinda like password-store, but more purpose specific and portable.
+"""
+A little tool to encrypt/decrypt git secrets. Kinda like password-store, but
+more purpose specific and portable.
 
+It generally expects to work with directory structures as follows:
+
+    foo/bar/secrets/plain:  plaintext files
+                   /cipher: ciphertext files, with names corresponding to
+                            plaintext files
+
+Note: currently all plaintext/cipher files are at a single level, ie.: there
+cannot be any subdirectory within a /plain or /cipher directory.
+
+There are multiple secret 'roots' like this in hscloud, notably:
+
+ - cluster/secrets
+ - hswaw/kube/secrets
+
+In the future, some repository-based configuration might exist to specify these
+roots in a nicer way, possibly with different target keys per root.
+
+This tool a bit of a swiss army knife, and can be used in the following ways:
+
+ - as a CLI tool to encrypt/decrypt files directly
+ - as a library for its encryption/decryption methods, and for a SecretStore
+   API, which allows for basic programmatic access to secrets, decrypting
+   things if necessary
+ - as a CLI tool to 'synchronize' a directory containing plain/cipher files,
+   which means encrypting every new plaintext file (or new ciphertext file),
+   and re-encrypting all files whose keys are different from the keys list
+   defined in this file.
+
+"""
+
+import argparse
 import logging
 import os
 import sys
 import subprocess
+import tempfile
 
+# Keys that are to be used to encrypt all secret roots.
 keys = [
     "63DFE737F078657CC8A51C00C29ADD73B3563D82", # q3k
     "482FF104C29294AD1CAF827BA43890A3DE74ECC7", # inf
@@ -15,7 +50,14 @@
 ]
 
 
-logger = logging.getLogger(__name__)
+_logger_name = __name__
+if _logger_name == '__main__':
+    _logger_name = 'secretstore'
+logger = logging.getLogger(_logger_name)
+
+
+class CLIException(Exception):
+    pass
 
 
 def encrypt(src, dst):
@@ -26,9 +68,289 @@
     cmd.append(src)
     subprocess.check_call(cmd)
 
+
 def decrypt(src, dst):
     cmd = ['gpg', '--decrypt', '--batch', '--yes', '--output', dst, src]
-    subprocess.check_call(cmd)
+    # catch stdout to make this code less chatty.
+    subprocess.check_output(cmd, stderr=subprocess.STDOUT)
+
+
+def _encryption_key_for_fingerprint(fp):
+    """
+    Returns the encryption key ID for a given GPG fingerprint (eg. one from the
+    'keys' list.
+    """
+    cmd = ['gpg', '-k', '--keyid-format', 'long', fp]
+    res = subprocess.check_output(cmd).decode()
+
+    # Sample output:
+    #   pub   rsa4096/70FD60197E195C26 2014-02-22 [SC] [expires: 2021-02-05]
+    #         0879F9FCA1C836677BB808C870FD60197E195C26
+    #   uid                 [ultimate] Bartosz Stebel <bartoszstebel@gmail.com>
+    #   uid                 [ultimate] Bartosz Stebel <implr@hackerspace.pl>
+    #   sub   rsa4096/E203C94E5CEBB3EF 2014-02-22 [E] [expires: 2021-02-05]
+    #
+    # We want to extract the 'sub' key with the [E] tag.
+    for line in res.split('\n'):
+        line = line.strip()
+        if not line:
+            continue
+        parts = line.split()
+        if len(parts) < 4:
+            continue
+        if parts[0] != 'sub':
+            continue
+
+        if not parts[3].startswith('[') or not parts[3].endswith(']'):
+            continue
+        usages = parts[3].strip('[]')
+        if 'E' not in usages:
+            continue
+
+        # Okay, we found the encryption key.
+        return parts[1].split('/')[1]
+
+    raise Exception("Could not find encryption key ID for fingerprint {}".format(fp))
+
+
+_encryption_keys_cache = None
+def encryption_keys():
+    """
+    Return all encryption keys associated with the keys array.
+    """
+    global _encryption_keys_cache
+    if _encryption_keys_cache is None:
+        _encryption_keys_cache = [_encryption_key_for_fingerprint(fp) for fp in keys]
+
+    return _encryption_keys_cache
+
+
+def encrypted_for(path):
+    """
+    Return for which encryption keys is a given GPG ciphertext file encrypted.
+    """
+    cmd = ['gpg', '--pinentry-mode', 'cancel', '--list-packets', path]
+    res = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode()
+
+    # Sample output:
+    #  gpg: encrypted with 4096-bit RSA key, ID E203C94E5CEBB3EF, created 2014-02-22
+    #        "Bartosz Stebel <bartoszstebel@gmail.com>"
+    #  gpg: encrypted with 2048-bit RSA key, ID 5C1B6B69E9F5EABE, created 2013-01-29
+    #        "Piotr Dobrowolski <piotr.tytus.dobrowolski@gmail.com>"
+    #  gpg: encrypted with 2048-bit RSA key, ID 386E893E110BC55B, created 2012-01-10
+    #        "Sergiusz Bazanski (Low Latency Consulting) <serge@lowlatency.ie>"
+    #  gpg: public key decryption failed: Operation cancelled
+    #  gpg: decryption failed: No secret key
+    #  # off=0 ctb=85 tag=1 hlen=3 plen=268
+    #  :pubkey enc packet: version 3, algo 1, keyid 386E893E110BC55B
+    #  	data: [2047 bits]
+    #  # off=271 ctb=85 tag=1 hlen=3 plen=268
+    #  :pubkey enc packet: version 3, algo 1, keyid 5C1B6B69E9F5EABE
+    #  	data: [2048 bits]
+    #  # off=542 ctb=85 tag=1 hlen=3 plen=524
+    #  :pubkey enc packet: version 3, algo 1, keyid E203C94E5CEBB3EF
+    #  	data: [4095 bits]
+    #  # off=1069 ctb=d2 tag=18 hlen=2 plen=121 new-ctb
+    #  :encrypted data packet:
+    #  	length: 121
+    #  	mdc_method: 2
+
+    keys = []
+    for line in res.split('\n'):
+        line = line.strip()
+        if not line:
+            continue
+
+        parts = line.split()
+        if len(parts) < 9:
+            continue
+
+
+        if parts[:4] != [':pubkey', 'enc', 'packet:', 'version']:
+            continue
+
+        if parts[7] != 'keyid':
+            continue
+
+        keys.append(parts[8])
+
+    # Make unique.
+    return list(set(keys))
+
+
+class SyncAction:
+    """
+    SyncAction is a possible action taken to synchronize some secrets.
+
+    An action is some sort of side-effect bearing OS action (ie execution of
+    script or file move, or...) that can also 'describe' that it's acting - ie,
+    just return a human readable string of what it would be doing. These
+    describe descriptions are used for dry-runs of the secretstore sync
+    functionality.
+    """
+    def describe(self):
+        return ""
+
+    def act(self):
+        pass
+
+class SyncActionEncrypt:
+    def __init__(self, src, dst, reason):
+        self.src = src
+        self.dst = dst
+        self.reason = reason
+
+    def describe(self):
+        return f'Encrypting {os.path.split(self.src)[-1]} ({self.reason})'
+
+    def act(self):
+        return encrypt(self.src, self.dst)
+
+
+class SyncActionDecrypt:
+    def __init__(self, src, dst, reason):
+        self.src = src
+        self.dst = dst
+        self.reason = reason
+
+    def describe(self):
+        return f'Decrypting {os.path.split(self.src)[-1]} ({self.reason})'
+
+    def act(self):
+        return encrypt(self.src, self.dst)
+
+
+def sync(path: str, dry: bool):
+    """Synchronize (decrypt and encrypt what's needed) a given secrets directory."""
+
+    # Turn the path into an absolute path just to make things safer.
+    path = os.path.abspath(path)
+    # Trim all trailing slashes to canonicalize.
+    path = path.rstrip('/')
+
+    plain_path = os.path.join(path, "plain")
+    cipher_path = os.path.join(path, "cipher")
+
+    # Ensure that at least one of the plain/cipher paths exist.
+    plain_exists = os.path.exists(plain_path)
+    cipher_exists = os.path.exists(cipher_path)
+    if not plain_exists and not cipher_exists:
+        raise CLIException('Given directory must contain a plain/ or cipher/ subdirectory.')
+
+    # Make missing directories.
+    if not plain_exists:
+        os.mkdir(plain_path)
+    if not cipher_exists:
+        os.mkdir(cipher_path)
+
+    # List files on both sides:
+    plain_files = [f for f in os.listdir(plain_path) if f != '.gitignore' and os.path.isfile(os.path.join(plain_path, f))]
+    cipher_files = [f for f in os.listdir(cipher_path) if os.path.isfile(os.path.join(cipher_path, f))]
+
+    # Helper function to turn a short filename within a directory to a pair
+    # of plain/cipher full paths.
+    def pc(p):
+        return os.path.join(plain_path, p), os.path.join(cipher_path, p)
+
+    # Make a set of all file names - no matter if only available as plain, as
+    # cipher, or as both.
+    all_files = set(plain_files + cipher_files)
+
+    # We'll be making a list of actions to perform to bring up given directory
+    # pair to a stable state.
+    actions = []  # type: List[SyncAction]
+
+    # First, for every possible file (either encrypted or decrypted), figure
+    # out which side is fresher based on file presence and mtime.
+    fresher = {}  # type: Dict[str, str]
+    for p in all_files:
+        # Handle the easy case when the file only exists on one side.
+        if p not in cipher_files:
+            fresher[p] = 'plain'
+            continue
+        if p not in plain_files:
+            fresher[p] = 'cipher'
+            continue
+
+        plain, cipher = pc(p)
+
+        # Otherwise, we have both the cipher and plain version.
+        # Check if the decrypted version matches the plaintext version. If so,
+        # they're both equal.
+
+        f = tempfile.NamedTemporaryFile(delete=False)
+        f.close()
+        decrypt(cipher, f.name)
+
+        with open(f.name, 'rb') as fd:
+            decrypted_data = fd.read()
+        with open(plain, 'rb') as fc:
+            current_data = fc.read()
+
+        if decrypted_data == current_data:
+            fresher[p] = 'equal'
+            os.unlink(f.name)
+            continue
+
+        os.unlink(f.name)
+
+        # The plain and cipher versions differ. Let's choose based on mtime.
+        mtime_plain = os.path.getmtime(plain)
+        mtime_cipher = os.path.getmtime(cipher)
+
+        if mtime_plain > mtime_cipher:
+            fresher[p] = 'plain'
+        elif mtime_cipher > mtime_plain:
+            fresher[p] = 'cipher'
+        else:
+            raise CLIException(f'cipher/plain stalemate on {p}: contents differ, but files have same mtime')
+
+    # Find all files that need to be re-encrypted for changed keys.
+    reencrypt = set()
+    for p in cipher_files:
+        _, cipher = pc(p)
+        current = set(encrypted_for(cipher))
+        want = set(encryption_keys())
+
+        if current != want:
+            reencrypt.add(p)
+
+    # Okay, now actually construct a list of actions.
+    # First, all fresh==cipher keys need to be decrypted.
+    for p, v in fresher.items():
+        if v != 'cipher':
+            continue
+
+        plain, cipher = pc(p)
+        actions.append(SyncActionDecrypt(cipher, plain, "cipher version is newer"))
+
+    encrypted = set()
+    # Then, encrypt all fresh==plain files, and make note of what those
+    # are.
+    for p, v in fresher.items():
+        if v != 'plain':
+            continue
+
+        plain, cipher = pc(p)
+        actions.append(SyncActionEncrypt(plain, cipher, "plain version is newer"))
+        encrypted.add(p)
+
+    # Finally, re-encrypt all the files that aren't already being encrypted.
+    for p in reencrypt.difference(encrypted):
+        plain, cipher = pc(p)
+        actions.append(SyncActionEncrypt(plain, cipher, "needs to be re-encrypted for different keys"))
+
+    if len(actions) == 0:
+        logger.info('Nothing to do!')
+    else:
+        if dry:
+            logger.info('Would perform the following:')
+        else:
+            logger.info('Running actions...')
+    for a in actions:
+        logger.info(a.describe())
+        if not dry:
+            a.act()
 
 
 class SecretStoreMissing(Exception):
@@ -75,18 +397,43 @@
 
 
 def main():
-    if len(sys.argv) < 3 or sys.argv[1] not in ('encrypt', 'decrypt'):
-        sys.stderr.write("Usage: {} encrypt/decrypt file\n".format(sys.argv[0]))
-        sys.stderr.flush()
-        return 1
+    parser = argparse.ArgumentParser(description='Manage hscloud git-based secrets.')
+    subparsers = parser.add_subparsers(dest='mode')
 
-    action = sys.argv[1]
-    src = sys.argv[2]
+    parser_decrypt = subparsers.add_parser('decrypt', help='decrypt a single secret file')
+    parser_decrypt.add_argument('input', type=str, help='encrypted file path')
+    parser_decrypt.add_argument('output', type=str, default='-', help='decrypted file path file path (or - for stdout)')
 
-    if action == 'encrypt':
-        encrypt(src, '-')
-    else:
-        decrypt(src, '-')
+    parser_encrypt = subparsers.add_parser('encrypt', help='encrypt a single secret file')
+    parser_encrypt.add_argument('input', type=str, help='plaintext file path')
+    parser_encrypt.add_argument('output', type=str, default='-', help='encrypted file path file path (or - for stdout)')
+
+    parser_sync = subparsers.add_parser('sync', help='Synchronize a canonically formatted secrets/{plain,cipher} directory')
+    parser_sync.add_argument('dir', type=str, help='Path to secrets directory to synchronize')
+    parser_sync.add_argument('--dry', dest='dry', action='store_true')
+    parser_sync.set_defaults(dry=False)
+
+    logging.basicConfig(level='INFO')
+
+    args = parser.parse_args()
+
+    if args.mode == None:
+        parser.print_help()
+        sys.exit(1)
+
+    try:
+        if args.mode == 'encrypt':
+            encrypt(args.input, args.output)
+        elif args.mode == 'decrypt':
+            decrypt(args.input, args.output)
+        elif args.mode == 'sync':
+            sync(args.dir, dry=args.dry)
+        else:
+            # ???
+            raise Exception('invalid mode {}'.format(args.mode))
+    except CLIException as e:
+        logger.error(e)
+        sys.exit(1)
 
 if __name__ == '__main__':
     sys.exit(main() or 0)