blob: 18a139e9c0a4d39c09296bf807a7a6494f1797b0 [file] [log] [blame]
vuko3cd087d2021-12-28 13:19:40 +01001from at.dhcp import DhcpdUpdater, DhcpLease
2from pathlib import Path
3import yaml
4import grpc
5import json
6import re
7import subprocess
8import logging
9from concurrent import futures
10from datetime import datetime
11
12from .tracker_pb2 import DhcpClient, DhcpClients, HwAddrResponse
13from .tracker_pb2_grpc import DhcpTrackerServicer, add_DhcpTrackerServicer_to_server
14
15import argparse
16parser = argparse.ArgumentParser()
17parser.add_argument("--verbose", help="output more info", action="store_true")
18parser.add_argument("config", type=Path, help="input file")
19
20logging.basicConfig(level=logging.INFO)
21
22def lease_to_client(lease: DhcpLease) -> DhcpClient:
23 return DhcpClient(
24 hw_address = bytes.fromhex(lease.hwaddr.replace(':', '')),
25 last_seen = datetime.utcfromtimestamp(lease.atime).isoformat(),
26 client_hostname = lease.name,
27 ip_address = lease.ip
28 )
29
30class DhcpTrackerServicer(DhcpTrackerServicer):
31 def __init__(self, tracker: DhcpdUpdater, *args, **kwargs):
32 self._tracker = tracker
33 super().__init__(*args, **kwargs)
34
35 def _authorize(self, context):
36 auth = context.auth_context()
37 ctype = auth.get('transport_security_type', 'local')
38 print(ctype)
39 if ctype == [b'ssl']:
40 if b'at.hackerspace.pl' not in context.peer_identities():
41 context.abort(
42 grpc.StatusCode.PERMISSION_DENIED,
43 (
44 "Only at.hackespace.pl is allowed to access raw "
45 "clients addresses"
46 )
47 )
48 elif ctype == 'local':
49 # connection from local unix socket is trusted by default
50 pass
51 else:
52 context.abort(
53 grpc.StatusCode.PERMISSION_DENIED,
54 f"Unknown transport type: {ctype}"
55 )
56
57 def GetClients(self, request, context):
58 self._authorize(context)
59
60 clients = [
61 lease_to_client(c) for c in self._tracker.get_active_devices().values()]
62 return DhcpClients(clients = clients)
63
64 def GetHwAddr(self, request, context):
65 self._authorize(context)
66 ip_address = str(request.ip_address)
67 if not re.fullmatch('[0-9a-fA-F:.]*', ip_address):
68 raise ValueError(f'Invalid ip address: {ip_address!r}')
69 logging.info(f'running ip neigh on {ip_address}')
70 r = subprocess.run(['ip', '-json', 'neigh', 'show', ip_address], check=True, capture_output=True)
71 neighs = json.loads(r.stdout)
72 if neighs:
73 return HwAddrResponse(hw_address=bytes.fromhex(neighs[0]['lladdr'].replace(':', '')))
74 return HwAddrResponse(hw_address=None)
75
76def server():
77 args = parser.parse_args()
78
79 config = yaml.safe_load(args.config.read_text())
80 tracker = DhcpdUpdater(config['LEASE_FILE'], config['TIMEOUT'])
81 tracker.start()
82
83 server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
84 add_DhcpTrackerServicer_to_server(DhcpTrackerServicer(tracker), server)
85
86
87 tls_address = config.get("GRPC_TLS_ADDRESS", None)
88 if tls_address:
89 cert_dir = Path(config.get('GRPC_TLS_CERT_DIR', 'cert'))
90 ca_cert = Path(config.get('GRPC_TLS_CA_CERT', 'ca.pem')).read_bytes()
91
92 server_credentials = grpc.ssl_server_credentials(
93 private_key_certificate_chain_pairs = ((
94 cert_dir.joinpath('key.pem').read_bytes(),
95 cert_dir.joinpath('cert.pem').read_bytes()
96 ),),
97 root_certificates = ca_cert,
98 require_client_auth = True
99 )
100
101 server.add_secure_port(config.get('GRPC_TLS_ADDRESS', '[::]:2847'), server_credentials)
102
103 unix_socket = config.get('GRPC_UNIX_SOCKET', False)
104 if unix_socket:
105 server.add_insecure_port(f'unix://{unix_socket}')
106
107 if tls_address or unix_socket:
108 print('starting grpc server ...')
109 server.start()
110 server.wait_for_termination()