openvpn-endpoint-server/server/app/services/vpn_sync_service.py

183 lines
6.5 KiB
Python

"""VPN Sync Service for synchronizing VPN connection state with database."""
from datetime import datetime
from sqlalchemy.orm import Session
from ..models.vpn_server import VPNServer
from ..models.vpn_profile import VPNProfile
from ..models.vpn_connection_log import VPNConnectionLog
from ..models.gateway import Gateway
from .vpn_server_service import VPNServerService
class VPNSyncService:
"""Service to sync VPN connections with gateway status and logs."""
def __init__(self, db: Session):
self.db = db
self.vpn_service = VPNServerService(db)
def sync_all_connections(self) -> dict:
"""
Sync all VPN connections across all servers.
Returns summary of changes.
"""
result = {
"servers_checked": 0,
"clients_found": 0,
"gateways_online": 0,
"gateways_offline": 0,
"new_connections": 0,
"closed_connections": 0,
"errors": []
}
# Get all active VPN servers
servers = self.db.query(VPNServer).filter(VPNServer.is_active == True).all()
# Collect all currently connected CNs
connected_cns = set()
for server in servers:
result["servers_checked"] += 1
try:
clients = self.vpn_service.get_connected_clients(server)
result["clients_found"] += len(clients)
for client in clients:
cn = client.get("common_name")
if cn:
connected_cns.add(cn)
self._handle_connected_client(server, client, result)
except Exception as e:
result["errors"].append(f"Server {server.name}: {str(e)}")
# Mark disconnected profiles/gateways
self._handle_disconnected_profiles(connected_cns, result)
self.db.commit()
return result
def _handle_connected_client(self, server: VPNServer, client: dict, result: dict):
"""Handle a connected VPN client - update profile, gateway, and log."""
cn = client.get("common_name")
real_address = client.get("real_address")
bytes_rx = client.get("bytes_received", 0)
bytes_tx = client.get("bytes_sent", 0)
connected_since = client.get("connected_since")
# Find profile by CN
profile = self.db.query(VPNProfile).filter(
VPNProfile.cert_cn == cn,
VPNProfile.vpn_server_id == server.id
).first()
if not profile:
# Try finding by CN across all servers (in case of migration)
profile = self.db.query(VPNProfile).filter(
VPNProfile.cert_cn == cn
).first()
if not profile:
return # Unknown client, skip
gateway = profile.gateway
# Update gateway status
if not gateway.is_online:
gateway.is_online = True
gateway.last_seen = datetime.utcnow()
result["gateways_online"] += 1
# Update profile last connection
profile.last_connection = datetime.utcnow()
# Check for existing active connection log
active_log = self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.vpn_profile_id == profile.id,
VPNConnectionLog.disconnected_at.is_(None)
).first()
if not active_log:
# Create new connection log
log = VPNConnectionLog(
vpn_profile_id=profile.id,
vpn_server_id=server.id,
gateway_id=gateway.id,
common_name=cn,
real_address=real_address,
connected_at=datetime.utcnow(),
bytes_received=bytes_rx,
bytes_sent=bytes_tx
)
self.db.add(log)
result["new_connections"] += 1
else:
# Update traffic stats
active_log.bytes_received = bytes_rx
active_log.bytes_sent = bytes_tx
active_log.real_address = real_address
def _handle_disconnected_profiles(self, connected_cns: set, result: dict):
"""Mark profiles as disconnected if their CN is no longer connected."""
# Find all active connection logs
active_logs = self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.disconnected_at.is_(None)
).all()
for log in active_logs:
if log.common_name not in connected_cns:
# Mark as disconnected
log.disconnected_at = datetime.utcnow()
result["closed_connections"] += 1
# Update gateway status
gateway = log.gateway
# Check if gateway has any other active connections
other_active = self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.gateway_id == gateway.id,
VPNConnectionLog.id != log.id,
VPNConnectionLog.disconnected_at.is_(None)
).first()
if not other_active:
gateway.is_online = False
result["gateways_offline"] += 1
def get_profile_connection_logs(
self,
profile_id: int,
limit: int = 50
) -> list[VPNConnectionLog]:
"""Get connection logs for a specific profile."""
return self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.vpn_profile_id == profile_id
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
def get_gateway_connection_logs(
self,
gateway_id: int,
limit: int = 50
) -> list[VPNConnectionLog]:
"""Get connection logs for a gateway (all profiles)."""
return self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.gateway_id == gateway_id
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
def get_server_connection_logs(
self,
server_id: int,
limit: int = 50
) -> list[VPNConnectionLog]:
"""Get connection logs for a VPN server."""
return self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.vpn_server_id == server_id
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
def get_active_connections(self) -> list[VPNConnectionLog]:
"""Get all currently active VPN connections."""
return self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.disconnected_at.is_(None)
).order_by(VPNConnectionLog.connected_at.desc()).all()