"""Client VPN Profile management service for user/technician VPN connections.""" from datetime import datetime from typing import Optional from sqlalchemy.orm import Session from ..models.client_vpn_profile import ClientVPNProfile, ClientVPNProfileStatus from ..models.vpn_server import VPNServer from ..models.user import User from ..models.gateway import Gateway from ..models.access import UserGatewayAccess from ..models.certificate_authority import CertificateAuthority from .certificate_service import CertificateService from ..config import get_settings import ipaddress settings = get_settings() class ClientVPNProfileService: """Service for managing VPN profiles for users/technicians.""" def __init__(self, db: Session): self.db = db self.cert_service = CertificateService(db) def get_or_create_profile(self, user: User) -> Optional[ClientVPNProfile]: """Get existing profile for user or create a new one.""" # Check if user already has a profile if user.client_vpn_profile and user.client_vpn_profile.is_ready: return user.client_vpn_profile # Find active VPN server (prefer primary, then any active) server = self.db.query(VPNServer).filter( VPNServer.is_active == True, VPNServer.is_primary == True ).first() if not server: server = self.db.query(VPNServer).filter( VPNServer.is_active == True ).first() if not server: return None if not server.certificate_authority or not server.certificate_authority.is_ready: return None # Create new profile return self.create_profile(user, server.id) def create_profile( self, user: User, vpn_server_id: int ) -> ClientVPNProfile: """Create a new VPN profile for a user.""" # Get VPN server server = self.db.query(VPNServer).filter( VPNServer.id == vpn_server_id ).first() if not server: raise ValueError(f"VPN Server with id {vpn_server_id} not found") if not server.certificate_authority.is_ready: raise ValueError("CA is not ready (DH parameters may still be generating)") # Generate unique common name cert_cn = f"client-{user.username}-{user.id}" # Check if profile already exists existing = self.db.query(ClientVPNProfile).filter( ClientVPNProfile.user_id == user.id ).first() if existing: # Delete old profile and create new self.db.delete(existing) self.db.commit() # Create profile profile = ClientVPNProfile( user_id=user.id, vpn_server_id=vpn_server_id, ca_id=server.ca_id, cert_cn=cert_cn, status=ClientVPNProfileStatus.PENDING ) self.db.add(profile) self.db.commit() self.db.refresh(profile) # Generate client certificate self._generate_client_cert(profile) return profile def _generate_client_cert(self, profile: ClientVPNProfile): """Generate client certificate for profile.""" ca = profile.certificate_authority cert_data = self.cert_service.generate_client_certificate( ca=ca, common_name=profile.cert_cn ) profile.client_cert = cert_data["cert"] profile.client_key = cert_data["key"] profile.valid_from = cert_data["valid_from"] profile.valid_until = cert_data["valid_until"] profile.status = ClientVPNProfileStatus.ACTIVE self.db.commit() def generate_client_config(self, profile: ClientVPNProfile, split_tunnel: bool = True) -> str: """Generate OpenVPN client configuration (.ovpn) for a profile. Args: profile: The client VPN profile split_tunnel: If True, only VPN traffic goes through tunnel (default). If False, all traffic goes through VPN (no internet). """ if not profile.is_ready: raise ValueError("Profile is not ready") server = profile.vpn_server ca = profile.certificate_authority config_lines = [ "# OpenVPN Client Configuration", f"# User: {profile.user.username}", f"# Server: {server.name}", f"# Generated: {datetime.utcnow().isoformat()}", f"# Split Tunneling: {'Enabled' if split_tunnel else 'Disabled'}", "", "client", "dev tun", f"proto {server.protocol.value}", f"remote {server.hostname} {server.port}", "", "resolv-retry infinite", "nobind", "persist-key", "persist-tun", "", "remote-cert-tls server", f"cipher {server.cipher.value}", f"auth {server.auth.value}", "", "verb 3", "", ] # Split Tunneling Configuration if split_tunnel: config_lines.extend([ "# Split Tunneling - Only VPN traffic through tunnel", "# Ignore any redirect-gateway push from server", "pull-filter ignore \"redirect-gateway\"", "", ]) # Add route for VPN network vpn_network = server.vpn_network vpn_netmask = server.vpn_netmask config_lines.append(f"# Route for VPN network") config_lines.append(f"route {vpn_network} {vpn_netmask}") config_lines.append("") # Add routes for gateway subnets the user has access to routes_added = self._get_gateway_routes(profile.user) if routes_added: config_lines.append("# Routes for accessible gateway networks") for route in routes_added: config_lines.append(f"route {route['network']} {route['netmask']} # {route['name']}") config_lines.append("") # Add CA certificate config_lines.extend([ "", ca.ca_cert.strip(), "", "", ]) # Add client certificate config_lines.extend([ "", profile.client_cert.strip(), "", "", ]) # Add client key config_lines.extend([ "", profile.client_key.strip(), "", "", ]) # Add TLS-Auth key if enabled if server.tls_auth_enabled and server.ta_key: config_lines.extend([ "key-direction 1", "", server.ta_key.strip(), "", ]) return "\n".join(config_lines) def _get_gateway_routes(self, user: User) -> list[dict]: """Get routes for all gateway subnets the user has access to.""" routes = [] seen_networks = set() # Get gateways the user has access to if user.is_admin: # Admins have access to all gateways in their tenant gateways = self.db.query(Gateway).filter( Gateway.tenant_id == user.tenant_id ).all() else: # Regular users - check access table access_entries = self.db.query(UserGatewayAccess).filter( UserGatewayAccess.user_id == user.id ).all() gateway_ids = [a.gateway_id for a in access_entries] gateways = self.db.query(Gateway).filter( Gateway.id.in_(gateway_ids) ).all() if gateway_ids else [] for gateway in gateways: # Add route for gateway's VPN subnet if defined if gateway.vpn_subnet and gateway.vpn_subnet not in seen_networks: try: network = ipaddress.ip_network(gateway.vpn_subnet, strict=False) routes.append({ "network": str(network.network_address), "netmask": str(network.netmask), "name": gateway.name }) seen_networks.add(gateway.vpn_subnet) except ValueError: pass # Invalid subnet, skip return routes def get_vpn_config_for_user(self, user: User) -> Optional[str]: """Get or create VPN config for a user.""" profile = self.get_or_create_profile(user) if not profile: return None return self.generate_client_config(profile) def revoke_profile(self, profile: ClientVPNProfile, reason: str = "unspecified"): """Revoke a user's VPN profile certificate.""" if profile.client_cert: self.cert_service.revoke_certificate( ca=profile.certificate_authority, cert_pem=profile.client_cert, reason=reason ) profile.status = ClientVPNProfileStatus.REVOKED profile.is_active = False self.db.commit() def renew_profile(self, profile: ClientVPNProfile): """Renew the certificate for a profile.""" # Revoke old certificate if profile.client_cert: self.cert_service.revoke_certificate( ca=profile.certificate_authority, cert_pem=profile.client_cert, reason="superseded" ) # Generate new certificate self._generate_client_cert(profile) self.db.commit()