| from at.dhcp import DhcpdUpdater, DhcpLease |
| from pathlib import Path |
| import yaml |
| import grpc |
| import json |
| import re |
| import subprocess |
| import logging |
| from concurrent import futures |
| from datetime import datetime, timezone |
| |
| from .tracker_pb2 import DhcpClient, DhcpClients, HwAddrResponse |
| from .tracker_pb2_grpc import DhcpTrackerServicer, add_DhcpTrackerServicer_to_server |
| |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--verbose", help="output more info", action="store_true") |
| parser.add_argument("config", type=Path, help="input file") |
| |
| logging.basicConfig(level=logging.INFO) |
| |
| def lease_to_client(lease: DhcpLease) -> DhcpClient: |
| return DhcpClient( |
| hw_address = bytes.fromhex(lease.hwaddr.replace(':', '')), |
| last_seen = datetime.utcfromtimestamp(lease.atime).replace( |
| tzinfo=timezone.utc).isoformat(), |
| client_hostname = lease.name, |
| ip_address = lease.ip |
| ) |
| |
| class DhcpTrackerServicer(DhcpTrackerServicer): |
| def __init__(self, tracker: DhcpdUpdater, *args, **kwargs): |
| self._tracker = tracker |
| super().__init__(*args, **kwargs) |
| |
| def _authorize(self, context): |
| auth = context.auth_context() |
| ctype = auth.get('transport_security_type', 'local') |
| print(ctype) |
| if ctype == [b'ssl']: |
| if b'at.hackerspace.pl' not in context.peer_identities(): |
| context.abort( |
| grpc.StatusCode.PERMISSION_DENIED, |
| ( |
| "Only at.hackespace.pl is allowed to access raw " |
| "clients addresses" |
| ) |
| ) |
| elif ctype == 'local': |
| # connection from local unix socket is trusted by default |
| pass |
| else: |
| context.abort( |
| grpc.StatusCode.PERMISSION_DENIED, |
| f"Unknown transport type: {ctype}" |
| ) |
| |
| def GetClients(self, request, context): |
| self._authorize(context) |
| |
| clients = [ |
| lease_to_client(c) for c in self._tracker.get_active_devices().values()] |
| return DhcpClients(clients = clients) |
| |
| def GetHwAddr(self, request, context): |
| self._authorize(context) |
| ip_address = str(request.ip_address) |
| if not re.fullmatch('[0-9a-fA-F:.]*', ip_address): |
| raise ValueError(f'Invalid ip address: {ip_address!r}') |
| logging.info(f'running ip neigh on {ip_address}') |
| r = subprocess.run(['ip', '-json', 'neigh', 'show', ip_address], check=True, capture_output=True) |
| neighs = json.loads(r.stdout) |
| if neighs: |
| return HwAddrResponse(hw_address=bytes.fromhex(neighs[0]['lladdr'].replace(':', ''))) |
| return HwAddrResponse(hw_address=None) |
| |
| def server(): |
| args = parser.parse_args() |
| |
| config = yaml.safe_load(args.config.read_text()) |
| tracker = DhcpdUpdater(config['LEASE_FILE'], config['TIMEOUT']) |
| tracker.start() |
| |
| server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) |
| add_DhcpTrackerServicer_to_server(DhcpTrackerServicer(tracker), server) |
| |
| |
| tls_address = config.get("GRPC_TLS_ADDRESS", None) |
| if tls_address: |
| cert_dir = Path(config.get('GRPC_TLS_CERT_DIR', 'cert')) |
| ca_cert = Path(config.get('GRPC_TLS_CA_CERT', 'ca.pem')).read_bytes() |
| |
| server_credentials = grpc.ssl_server_credentials( |
| private_key_certificate_chain_pairs = (( |
| cert_dir.joinpath('key.pem').read_bytes(), |
| cert_dir.joinpath('cert.pem').read_bytes() |
| ),), |
| root_certificates = ca_cert, |
| require_client_auth = True |
| ) |
| |
| server.add_secure_port(config.get('GRPC_TLS_ADDRESS', '[::]:2847'), server_credentials) |
| |
| unix_socket = config.get('GRPC_UNIX_SOCKET', False) |
| if unix_socket: |
| server.add_insecure_port(f'unix://{unix_socket}') |
| |
| if tls_address or unix_socket: |
| print('starting grpc server ...') |
| server.start() |
| server.wait_for_termination() |