183 lines
6.5 KiB
Python
183 lines
6.5 KiB
Python
"""VPN Sync Service for synchronizing VPN connection state with database."""
|
|
|
|
from datetime import datetime
|
|
from sqlalchemy.orm import Session
|
|
from ..models.vpn_server import VPNServer
|
|
from ..models.vpn_profile import VPNProfile
|
|
from ..models.vpn_connection_log import VPNConnectionLog
|
|
from ..models.gateway import Gateway
|
|
from .vpn_server_service import VPNServerService
|
|
|
|
|
|
class VPNSyncService:
|
|
"""Service to sync VPN connections with gateway status and logs."""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.vpn_service = VPNServerService(db)
|
|
|
|
def sync_all_connections(self) -> dict:
|
|
"""
|
|
Sync all VPN connections across all servers.
|
|
Returns summary of changes.
|
|
"""
|
|
result = {
|
|
"servers_checked": 0,
|
|
"clients_found": 0,
|
|
"gateways_online": 0,
|
|
"gateways_offline": 0,
|
|
"new_connections": 0,
|
|
"closed_connections": 0,
|
|
"errors": []
|
|
}
|
|
|
|
# Get all active VPN servers
|
|
servers = self.db.query(VPNServer).filter(VPNServer.is_active == True).all()
|
|
|
|
# Collect all currently connected CNs
|
|
connected_cns = set()
|
|
|
|
for server in servers:
|
|
result["servers_checked"] += 1
|
|
try:
|
|
clients = self.vpn_service.get_connected_clients(server)
|
|
result["clients_found"] += len(clients)
|
|
|
|
for client in clients:
|
|
cn = client.get("common_name")
|
|
if cn:
|
|
connected_cns.add(cn)
|
|
self._handle_connected_client(server, client, result)
|
|
|
|
except Exception as e:
|
|
result["errors"].append(f"Server {server.name}: {str(e)}")
|
|
|
|
# Mark disconnected profiles/gateways
|
|
self._handle_disconnected_profiles(connected_cns, result)
|
|
|
|
self.db.commit()
|
|
return result
|
|
|
|
def _handle_connected_client(self, server: VPNServer, client: dict, result: dict):
|
|
"""Handle a connected VPN client - update profile, gateway, and log."""
|
|
cn = client.get("common_name")
|
|
real_address = client.get("real_address")
|
|
bytes_rx = client.get("bytes_received", 0)
|
|
bytes_tx = client.get("bytes_sent", 0)
|
|
connected_since = client.get("connected_since")
|
|
|
|
# Find profile by CN
|
|
profile = self.db.query(VPNProfile).filter(
|
|
VPNProfile.cert_cn == cn,
|
|
VPNProfile.vpn_server_id == server.id
|
|
).first()
|
|
|
|
if not profile:
|
|
# Try finding by CN across all servers (in case of migration)
|
|
profile = self.db.query(VPNProfile).filter(
|
|
VPNProfile.cert_cn == cn
|
|
).first()
|
|
|
|
if not profile:
|
|
return # Unknown client, skip
|
|
|
|
gateway = profile.gateway
|
|
|
|
# Update gateway status
|
|
if not gateway.is_online:
|
|
gateway.is_online = True
|
|
gateway.last_seen = datetime.utcnow()
|
|
result["gateways_online"] += 1
|
|
|
|
# Update profile last connection
|
|
profile.last_connection = datetime.utcnow()
|
|
|
|
# Check for existing active connection log
|
|
active_log = self.db.query(VPNConnectionLog).filter(
|
|
VPNConnectionLog.vpn_profile_id == profile.id,
|
|
VPNConnectionLog.disconnected_at.is_(None)
|
|
).first()
|
|
|
|
if not active_log:
|
|
# Create new connection log
|
|
log = VPNConnectionLog(
|
|
vpn_profile_id=profile.id,
|
|
vpn_server_id=server.id,
|
|
gateway_id=gateway.id,
|
|
common_name=cn,
|
|
real_address=real_address,
|
|
connected_at=datetime.utcnow(),
|
|
bytes_received=bytes_rx,
|
|
bytes_sent=bytes_tx
|
|
)
|
|
self.db.add(log)
|
|
result["new_connections"] += 1
|
|
else:
|
|
# Update traffic stats
|
|
active_log.bytes_received = bytes_rx
|
|
active_log.bytes_sent = bytes_tx
|
|
active_log.real_address = real_address
|
|
|
|
def _handle_disconnected_profiles(self, connected_cns: set, result: dict):
|
|
"""Mark profiles as disconnected if their CN is no longer connected."""
|
|
# Find all active connection logs
|
|
active_logs = self.db.query(VPNConnectionLog).filter(
|
|
VPNConnectionLog.disconnected_at.is_(None)
|
|
).all()
|
|
|
|
for log in active_logs:
|
|
if log.common_name not in connected_cns:
|
|
# Mark as disconnected
|
|
log.disconnected_at = datetime.utcnow()
|
|
result["closed_connections"] += 1
|
|
|
|
# Update gateway status
|
|
gateway = log.gateway
|
|
|
|
# Check if gateway has any other active connections
|
|
other_active = self.db.query(VPNConnectionLog).filter(
|
|
VPNConnectionLog.gateway_id == gateway.id,
|
|
VPNConnectionLog.id != log.id,
|
|
VPNConnectionLog.disconnected_at.is_(None)
|
|
).first()
|
|
|
|
if not other_active:
|
|
gateway.is_online = False
|
|
result["gateways_offline"] += 1
|
|
|
|
def get_profile_connection_logs(
|
|
self,
|
|
profile_id: int,
|
|
limit: int = 50
|
|
) -> list[VPNConnectionLog]:
|
|
"""Get connection logs for a specific profile."""
|
|
return self.db.query(VPNConnectionLog).filter(
|
|
VPNConnectionLog.vpn_profile_id == profile_id
|
|
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
|
|
|
|
def get_gateway_connection_logs(
|
|
self,
|
|
gateway_id: int,
|
|
limit: int = 50
|
|
) -> list[VPNConnectionLog]:
|
|
"""Get connection logs for a gateway (all profiles)."""
|
|
return self.db.query(VPNConnectionLog).filter(
|
|
VPNConnectionLog.gateway_id == gateway_id
|
|
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
|
|
|
|
def get_server_connection_logs(
|
|
self,
|
|
server_id: int,
|
|
limit: int = 50
|
|
) -> list[VPNConnectionLog]:
|
|
"""Get connection logs for a VPN server."""
|
|
return self.db.query(VPNConnectionLog).filter(
|
|
VPNConnectionLog.vpn_server_id == server_id
|
|
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
|
|
|
|
def get_active_connections(self) -> list[VPNConnectionLog]:
|
|
"""Get all currently active VPN connections."""
|
|
return self.db.query(VPNConnectionLog).filter(
|
|
VPNConnectionLog.disconnected_at.is_(None)
|
|
).order_by(VPNConnectionLog.connected_at.desc()).all()
|