"""VPN Server management service.""" import socket from datetime import datetime from typing import Optional from sqlalchemy.orm import Session from ..models.vpn_server import VPNServer, VPNServerStatus, VPNProtocol, VPNCipher, VPNAuth, VPNCompression from ..models.certificate_authority import CertificateAuthority from ..models.vpn_profile import VPNProfile from ..models.gateway import Gateway from .certificate_service import CertificateService from ..config import get_settings import ipaddress settings = get_settings() class VPNServerService: """Service for managing VPN server instances.""" def __init__(self, db: Session): self.db = db self.cert_service = CertificateService(db) def create_server( self, name: str, hostname: str, ca_id: int, port: int = 1194, protocol: VPNProtocol = VPNProtocol.UDP, vpn_network: str = "10.8.0.0", vpn_netmask: str = "255.255.255.0", cipher: VPNCipher = VPNCipher.AES_256_GCM, auth: VPNAuth = VPNAuth.SHA256, tls_version_min: str = "1.2", compression: VPNCompression = VPNCompression.NONE, max_clients: int = 100, keepalive_interval: int = 10, keepalive_timeout: int = 60, management_port: int = 7505, is_primary: bool = False, tenant_id: Optional[int] = None, description: Optional[str] = None ) -> VPNServer: """Create a new VPN server instance.""" # Get CA ca = self.db.query(CertificateAuthority).filter( CertificateAuthority.id == ca_id ).first() if not ca: raise ValueError(f"CA with id {ca_id} not found") if not ca.is_ready: raise ValueError("CA is not ready (DH parameters may still be generating)") # Generate container name existing_count = self.db.query(VPNServer).count() container_name = f"mguard-openvpn-{existing_count + 1}" # If setting as primary, unset other primaries if is_primary: self.db.query(VPNServer).filter( VPNServer.tenant_id == tenant_id, VPNServer.is_primary == True ).update({"is_primary": False}) # Create server record server = VPNServer( name=name, description=description, hostname=hostname, port=port, protocol=protocol, vpn_network=vpn_network, vpn_netmask=vpn_netmask, cipher=cipher, auth=auth, tls_version_min=tls_version_min, compression=compression, max_clients=max_clients, keepalive_interval=keepalive_interval, keepalive_timeout=keepalive_timeout, management_port=management_port, docker_container_name=container_name, ca_id=ca_id, tenant_id=tenant_id, is_primary=is_primary, status=VPNServerStatus.PENDING ) self.db.add(server) self.db.commit() self.db.refresh(server) # Generate server certificate self._generate_server_cert(server) return server def _generate_server_cert(self, server: VPNServer): """Generate server certificate and TA key.""" ca = server.certificate_authority # Generate server certificate cert_data = self.cert_service.generate_server_certificate( ca=ca, common_name=f"{server.name}-server" ) server.server_cert = cert_data["cert"] server.server_key = cert_data["key"] # Generate TLS-Auth key server.ta_key = self.cert_service.generate_ta_key() self.db.commit() def generate_server_config(self, server: VPNServer) -> str: """Generate OpenVPN server configuration file.""" if not server.certificate_authority: raise ValueError("Server has no CA assigned") config_lines = [ "# OpenVPN Server Configuration", f"# Server: {server.name}", f"# Generated: {datetime.utcnow().isoformat()}", "", "# Basic settings", f"port {server.port}", f"proto {server.protocol.value}", "dev tun", "", "# Certificates", "ca /etc/openvpn/ca.crt", "cert /etc/openvpn/server.crt", "key /etc/openvpn/server.key", "dh /etc/openvpn/dh.pem", "crl-verify /etc/openvpn/crl.pem", "", "# TLS Auth", ] if server.tls_auth_enabled and server.ta_key: config_lines.append("tls-auth /etc/openvpn/ta.key 0") config_lines.extend([ f"tls-version-min {server.tls_version_min}", "", "# Network", f"server {server.vpn_network} {server.vpn_netmask}", "topology subnet", "", "# Routing", 'push "redirect-gateway def1 bypass-dhcp"', 'push "dhcp-option DNS 8.8.8.8"', 'push "dhcp-option DNS 8.8.4.4"', "", ]) # Routes for gateway subnets (required for iroute to work) gateway_routes = self._get_gateway_subnet_routes(server) if gateway_routes: config_lines.append("# Routes for gateway subnets") for route in gateway_routes: config_lines.append(f"route {route['network']} {route['netmask']} # {route['gateway_name']}") config_lines.append("") config_lines.extend([ "# Security", f"cipher {server.cipher.value}", f"auth {server.auth.value}", "", "# Performance", f"keepalive {server.keepalive_interval} {server.keepalive_timeout}", f"max-clients {server.max_clients}", ]) # Compression handling # Note: OpenVPN 2.5+ deprecates compression due to VORACLE attack if server.compression != VPNCompression.NONE: config_lines.append(f"compress {server.compression.value}") config_lines.append("allow-compression yes") # If compression is NONE, don't add any directive - OpenVPN defaults to no compression config_lines.extend([ "", "# Persistence", "persist-key", "persist-tun", "", "# Logging", f"status /var/log/openvpn/status-{server.id}.log", f"log-append /var/log/openvpn/server-{server.id}.log", "verb 3", "", "# Management interface", f"management 0.0.0.0 {server.management_port}", "", "# User/Group (for Linux)", "user nobody", "group nogroup", "", "# Client config directory", "client-config-dir /etc/openvpn/ccd", "", "# Scripts", "script-security 2", "client-connect /etc/openvpn/scripts/client-connect.sh", "client-disconnect /etc/openvpn/scripts/client-disconnect.sh", ]) return "\n".join(config_lines) def _get_gateway_subnet_routes(self, server: VPNServer) -> list[dict]: """Get all gateway subnets that need routes on this server.""" routes = [] seen_networks = set() # Find all VPN profiles that use this server profiles = self.db.query(VPNProfile).filter( VPNProfile.vpn_server_id == server.id ).all() for profile in profiles: gateway = profile.gateway if gateway and gateway.vpn_subnet and gateway.vpn_subnet not in seen_networks: try: network = ipaddress.ip_network(gateway.vpn_subnet, strict=False) routes.append({ "network": str(network.network_address), "netmask": str(network.netmask), "gateway_name": gateway.name }) seen_networks.add(gateway.vpn_subnet) except ValueError: pass # Invalid subnet, skip return routes def generate_ccd_file(self, profile: VPNProfile) -> str | None: """Generate CCD (Client Config Directory) file content for a gateway profile. This file tells OpenVPN that traffic for the gateway's subnet should be routed through this client. """ gateway = profile.gateway if not gateway or not gateway.vpn_subnet: return None lines = [ f"# CCD file for gateway: {gateway.name}", f"# Profile: {profile.name}", f"# Common Name: {profile.cert_cn}", "", ] try: network = ipaddress.ip_network(gateway.vpn_subnet, strict=False) # iroute tells OpenVPN to route this subnet through this client lines.append(f"iroute {network.network_address} {network.netmask}") except ValueError: return None # Optional: assign static IP to this gateway if profile.vpn_ip: lines.append(f"# Static IP assignment (if needed)") lines.append(f"# ifconfig-push {profile.vpn_ip} {profile.vpn_server.vpn_netmask}") return "\n".join(lines) def get_all_ccd_files(self, server: VPNServer) -> dict[str, str]: """Get all CCD files for a VPN server. Returns a dict mapping cert_cn to CCD file content. """ ccd_files = {} profiles = self.db.query(VPNProfile).filter( VPNProfile.vpn_server_id == server.id ).all() for profile in profiles: content = self.generate_ccd_file(profile) if content: ccd_files[profile.cert_cn] = content return ccd_files def get_connected_clients(self, server: VPNServer) -> list[dict]: """Get list of currently connected VPN clients.""" try: response = self._send_management_command(server, "status") clients = [] if "ERROR" in response: return clients # Parse status output in_client_list = False for line in response.split('\n'): if line.startswith("ROUTING TABLE"): in_client_list = False elif line.startswith("Common Name"): in_client_list = True continue elif in_client_list and ',' in line: parts = line.split(',') if len(parts) >= 5: clients.append({ "common_name": parts[0], "real_address": parts[1], "bytes_received": int(parts[2]) if parts[2].isdigit() else 0, "bytes_sent": int(parts[3]) if parts[3].isdigit() else 0, "connected_since": parts[4] }) # Update connected count server.connected_clients = len(clients) server.last_status_check = datetime.utcnow() self.db.commit() return clients except Exception as e: return [] def disconnect_client(self, server: VPNServer, common_name: str) -> bool: """Disconnect a specific VPN client.""" response = self._send_management_command(server, f"kill {common_name}") return "SUCCESS" in response def get_server_status(self, server: VPNServer) -> dict: """Get detailed server status.""" try: # Try to connect to management interface response = self._send_management_command(server, "state") if "ERROR" in response: server.status = VPNServerStatus.ERROR self.db.commit() return { "status": "error", "message": response } # Parse state for line in response.split('\n'): if 'CONNECTED' in line: server.status = VPNServerStatus.RUNNING break else: server.status = VPNServerStatus.STOPPED server.last_status_check = datetime.utcnow() self.db.commit() clients = self.get_connected_clients(server) return { "status": server.status.value, "connected_clients": len(clients), "clients": clients, "last_check": server.last_status_check.isoformat() if server.last_status_check else None } except Exception as e: return { "status": "unknown", "message": str(e) } def _send_management_command(self, server: VPNServer, command: str) -> str: """Send command to OpenVPN management interface.""" try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(5) # Connect to management interface # OpenVPN runs in host network mode, so we connect via host.docker.internal # This resolves to the Docker host from within containers host = "host.docker.internal" sock.connect((host, server.management_port)) # Read welcome message sock.recv(1024) # Send command sock.send(f"{command}\n".encode()) # Read response response = b"" while True: data = sock.recv(4096) if not data: break response += data if b"END" in data or b"SUCCESS" in data or b"ERROR" in data: break sock.close() return response.decode() except Exception as e: return f"ERROR: {str(e)}" def update_all_server_status(self): """Update status for all active servers.""" servers = self.db.query(VPNServer).filter( VPNServer.is_active == True ).all() for server in servers: self.get_server_status(server) def get_server_by_id(self, server_id: int) -> Optional[VPNServer]: """Get a VPN server by ID.""" return self.db.query(VPNServer).filter( VPNServer.id == server_id ).first() def get_servers_for_tenant(self, tenant_id: Optional[int] = None) -> list[VPNServer]: """Get all VPN servers for a tenant.""" query = self.db.query(VPNServer) if tenant_id: query = query.filter( (VPNServer.tenant_id == tenant_id) | (VPNServer.tenant_id == None) ) else: query = query.filter(VPNServer.tenant_id == None) return query.order_by(VPNServer.name).all()