"""VPN Sync Service for synchronizing VPN connection state with database.""" from datetime import datetime from sqlalchemy.orm import Session from ..models.vpn_server import VPNServer from ..models.vpn_profile import VPNProfile from ..models.vpn_connection_log import VPNConnectionLog from ..models.gateway import Gateway from .vpn_server_service import VPNServerService class VPNSyncService: """Service to sync VPN connections with gateway status and logs.""" def __init__(self, db: Session): self.db = db self.vpn_service = VPNServerService(db) def sync_all_connections(self) -> dict: """ Sync all VPN connections across all servers. Returns summary of changes. """ result = { "servers_checked": 0, "clients_found": 0, "gateways_online": 0, "gateways_offline": 0, "new_connections": 0, "closed_connections": 0, "errors": [] } # Get all active VPN servers servers = self.db.query(VPNServer).filter(VPNServer.is_active == True).all() # Collect all currently connected CNs connected_cns = set() for server in servers: result["servers_checked"] += 1 try: clients = self.vpn_service.get_connected_clients(server) result["clients_found"] += len(clients) for client in clients: cn = client.get("common_name") if cn: connected_cns.add(cn) self._handle_connected_client(server, client, result) except Exception as e: result["errors"].append(f"Server {server.name}: {str(e)}") # Mark disconnected profiles/gateways self._handle_disconnected_profiles(connected_cns, result) self.db.commit() return result def _handle_connected_client(self, server: VPNServer, client: dict, result: dict): """Handle a connected VPN client - update profile, gateway, and log.""" cn = client.get("common_name") real_address = client.get("real_address") bytes_rx = client.get("bytes_received", 0) bytes_tx = client.get("bytes_sent", 0) connected_since = client.get("connected_since") # Find profile by CN profile = self.db.query(VPNProfile).filter( VPNProfile.cert_cn == cn, VPNProfile.vpn_server_id == server.id ).first() if not profile: # Try finding by CN across all servers (in case of migration) profile = self.db.query(VPNProfile).filter( VPNProfile.cert_cn == cn ).first() if not profile: return # Unknown client, skip gateway = profile.gateway # Update gateway status if not gateway.is_online: gateway.is_online = True gateway.last_seen = datetime.utcnow() result["gateways_online"] += 1 # Update profile last connection profile.last_connection = datetime.utcnow() # Check for existing active connection log active_log = self.db.query(VPNConnectionLog).filter( VPNConnectionLog.vpn_profile_id == profile.id, VPNConnectionLog.disconnected_at.is_(None) ).first() if not active_log: # Create new connection log log = VPNConnectionLog( vpn_profile_id=profile.id, vpn_server_id=server.id, gateway_id=gateway.id, common_name=cn, real_address=real_address, connected_at=datetime.utcnow(), bytes_received=bytes_rx, bytes_sent=bytes_tx ) self.db.add(log) result["new_connections"] += 1 else: # Update traffic stats active_log.bytes_received = bytes_rx active_log.bytes_sent = bytes_tx active_log.real_address = real_address def _handle_disconnected_profiles(self, connected_cns: set, result: dict): """Mark profiles as disconnected if their CN is no longer connected.""" # Find all active connection logs active_logs = self.db.query(VPNConnectionLog).filter( VPNConnectionLog.disconnected_at.is_(None) ).all() for log in active_logs: if log.common_name not in connected_cns: # Mark as disconnected log.disconnected_at = datetime.utcnow() result["closed_connections"] += 1 # Update gateway status gateway = log.gateway # Check if gateway has any other active connections other_active = self.db.query(VPNConnectionLog).filter( VPNConnectionLog.gateway_id == gateway.id, VPNConnectionLog.id != log.id, VPNConnectionLog.disconnected_at.is_(None) ).first() if not other_active: gateway.is_online = False result["gateways_offline"] += 1 def get_profile_connection_logs( self, profile_id: int, limit: int = 50 ) -> list[VPNConnectionLog]: """Get connection logs for a specific profile.""" return self.db.query(VPNConnectionLog).filter( VPNConnectionLog.vpn_profile_id == profile_id ).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all() def get_gateway_connection_logs( self, gateway_id: int, limit: int = 50 ) -> list[VPNConnectionLog]: """Get connection logs for a gateway (all profiles).""" return self.db.query(VPNConnectionLog).filter( VPNConnectionLog.gateway_id == gateway_id ).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all() def get_server_connection_logs( self, server_id: int, limit: int = 50 ) -> list[VPNConnectionLog]: """Get connection logs for a VPN server.""" return self.db.query(VPNConnectionLog).filter( VPNConnectionLog.vpn_server_id == server_id ).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all() def get_active_connections(self) -> list[VPNConnectionLog]: """Get all currently active VPN connections.""" return self.db.query(VPNConnectionLog).filter( VPNConnectionLog.disconnected_at.is_(None) ).order_by(VPNConnectionLog.connected_at.desc()).all()