345 lines
12 KiB
Python
345 lines
12 KiB
Python
"""VPN Server management service."""
|
|
|
|
import socket
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
from sqlalchemy.orm import Session
|
|
|
|
from ..models.vpn_server import VPNServer, VPNServerStatus, VPNProtocol, VPNCipher, VPNAuth, VPNCompression
|
|
from ..models.certificate_authority import CertificateAuthority
|
|
from .certificate_service import CertificateService
|
|
from ..config import get_settings
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
class VPNServerService:
|
|
"""Service for managing VPN server instances."""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.cert_service = CertificateService(db)
|
|
|
|
def create_server(
|
|
self,
|
|
name: str,
|
|
hostname: str,
|
|
ca_id: int,
|
|
port: int = 1194,
|
|
protocol: VPNProtocol = VPNProtocol.UDP,
|
|
vpn_network: str = "10.8.0.0",
|
|
vpn_netmask: str = "255.255.255.0",
|
|
cipher: VPNCipher = VPNCipher.AES_256_GCM,
|
|
auth: VPNAuth = VPNAuth.SHA256,
|
|
tls_version_min: str = "1.2",
|
|
compression: VPNCompression = VPNCompression.NONE,
|
|
max_clients: int = 100,
|
|
keepalive_interval: int = 10,
|
|
keepalive_timeout: int = 60,
|
|
management_port: int = 7505,
|
|
is_primary: bool = False,
|
|
tenant_id: Optional[int] = None,
|
|
description: Optional[str] = None
|
|
) -> VPNServer:
|
|
"""Create a new VPN server instance."""
|
|
|
|
# Get CA
|
|
ca = self.db.query(CertificateAuthority).filter(
|
|
CertificateAuthority.id == ca_id
|
|
).first()
|
|
|
|
if not ca:
|
|
raise ValueError(f"CA with id {ca_id} not found")
|
|
|
|
if not ca.is_ready:
|
|
raise ValueError("CA is not ready (DH parameters may still be generating)")
|
|
|
|
# Generate container name
|
|
existing_count = self.db.query(VPNServer).count()
|
|
container_name = f"mguard-openvpn-{existing_count + 1}"
|
|
|
|
# If setting as primary, unset other primaries
|
|
if is_primary:
|
|
self.db.query(VPNServer).filter(
|
|
VPNServer.tenant_id == tenant_id,
|
|
VPNServer.is_primary == True
|
|
).update({"is_primary": False})
|
|
|
|
# Create server record
|
|
server = VPNServer(
|
|
name=name,
|
|
description=description,
|
|
hostname=hostname,
|
|
port=port,
|
|
protocol=protocol,
|
|
vpn_network=vpn_network,
|
|
vpn_netmask=vpn_netmask,
|
|
cipher=cipher,
|
|
auth=auth,
|
|
tls_version_min=tls_version_min,
|
|
compression=compression,
|
|
max_clients=max_clients,
|
|
keepalive_interval=keepalive_interval,
|
|
keepalive_timeout=keepalive_timeout,
|
|
management_port=management_port,
|
|
docker_container_name=container_name,
|
|
ca_id=ca_id,
|
|
tenant_id=tenant_id,
|
|
is_primary=is_primary,
|
|
status=VPNServerStatus.PENDING
|
|
)
|
|
|
|
self.db.add(server)
|
|
self.db.commit()
|
|
self.db.refresh(server)
|
|
|
|
# Generate server certificate
|
|
self._generate_server_cert(server)
|
|
|
|
return server
|
|
|
|
def _generate_server_cert(self, server: VPNServer):
|
|
"""Generate server certificate and TA key."""
|
|
ca = server.certificate_authority
|
|
|
|
# Generate server certificate
|
|
cert_data = self.cert_service.generate_server_certificate(
|
|
ca=ca,
|
|
common_name=f"{server.name}-server"
|
|
)
|
|
|
|
server.server_cert = cert_data["cert"]
|
|
server.server_key = cert_data["key"]
|
|
|
|
# Generate TLS-Auth key
|
|
server.ta_key = self.cert_service.generate_ta_key()
|
|
|
|
self.db.commit()
|
|
|
|
def generate_server_config(self, server: VPNServer) -> str:
|
|
"""Generate OpenVPN server configuration file."""
|
|
if not server.certificate_authority:
|
|
raise ValueError("Server has no CA assigned")
|
|
|
|
config_lines = [
|
|
"# OpenVPN Server Configuration",
|
|
f"# Server: {server.name}",
|
|
f"# Generated: {datetime.utcnow().isoformat()}",
|
|
"",
|
|
"# Basic settings",
|
|
f"port {server.port}",
|
|
f"proto {server.protocol.value}",
|
|
"dev tun",
|
|
"",
|
|
"# Certificates",
|
|
"ca /etc/openvpn/ca.crt",
|
|
"cert /etc/openvpn/server.crt",
|
|
"key /etc/openvpn/server.key",
|
|
"dh /etc/openvpn/dh.pem",
|
|
"crl-verify /etc/openvpn/crl.pem",
|
|
"",
|
|
"# TLS Auth",
|
|
]
|
|
|
|
if server.tls_auth_enabled and server.ta_key:
|
|
config_lines.append("tls-auth /etc/openvpn/ta.key 0")
|
|
|
|
config_lines.extend([
|
|
f"tls-version-min {server.tls_version_min}",
|
|
"",
|
|
"# Network",
|
|
f"server {server.vpn_network} {server.vpn_netmask}",
|
|
"topology subnet",
|
|
"",
|
|
"# Routing",
|
|
'push "redirect-gateway def1 bypass-dhcp"',
|
|
'push "dhcp-option DNS 8.8.8.8"',
|
|
'push "dhcp-option DNS 8.8.4.4"',
|
|
"",
|
|
"# Security",
|
|
f"cipher {server.cipher.value}",
|
|
f"auth {server.auth.value}",
|
|
"",
|
|
"# Performance",
|
|
f"keepalive {server.keepalive_interval} {server.keepalive_timeout}",
|
|
f"max-clients {server.max_clients}",
|
|
])
|
|
|
|
# Compression handling
|
|
# Note: OpenVPN 2.5+ deprecates compression due to VORACLE attack
|
|
if server.compression != VPNCompression.NONE:
|
|
config_lines.append(f"compress {server.compression.value}")
|
|
config_lines.append("allow-compression yes")
|
|
# If compression is NONE, don't add any directive - OpenVPN defaults to no compression
|
|
|
|
config_lines.extend([
|
|
"",
|
|
"# Persistence",
|
|
"persist-key",
|
|
"persist-tun",
|
|
"",
|
|
"# Logging",
|
|
f"status /var/log/openvpn/status-{server.id}.log",
|
|
f"log-append /var/log/openvpn/server-{server.id}.log",
|
|
"verb 3",
|
|
"",
|
|
"# Management interface",
|
|
f"management 0.0.0.0 {server.management_port}",
|
|
"",
|
|
"# User/Group (for Linux)",
|
|
"user nobody",
|
|
"group nogroup",
|
|
"",
|
|
"# Client config directory",
|
|
"client-config-dir /etc/openvpn/ccd",
|
|
"",
|
|
"# Scripts",
|
|
"script-security 2",
|
|
"client-connect /etc/openvpn/scripts/client-connect.sh",
|
|
"client-disconnect /etc/openvpn/scripts/client-disconnect.sh",
|
|
])
|
|
|
|
return "\n".join(config_lines)
|
|
|
|
def get_connected_clients(self, server: VPNServer) -> list[dict]:
|
|
"""Get list of currently connected VPN clients."""
|
|
try:
|
|
response = self._send_management_command(server, "status")
|
|
clients = []
|
|
|
|
if "ERROR" in response:
|
|
return clients
|
|
|
|
# Parse status output
|
|
in_client_list = False
|
|
for line in response.split('\n'):
|
|
if line.startswith("ROUTING TABLE"):
|
|
in_client_list = False
|
|
elif line.startswith("Common Name"):
|
|
in_client_list = True
|
|
continue
|
|
elif in_client_list and ',' in line:
|
|
parts = line.split(',')
|
|
if len(parts) >= 5:
|
|
clients.append({
|
|
"common_name": parts[0],
|
|
"real_address": parts[1],
|
|
"bytes_received": int(parts[2]) if parts[2].isdigit() else 0,
|
|
"bytes_sent": int(parts[3]) if parts[3].isdigit() else 0,
|
|
"connected_since": parts[4]
|
|
})
|
|
|
|
# Update connected count
|
|
server.connected_clients = len(clients)
|
|
server.last_status_check = datetime.utcnow()
|
|
self.db.commit()
|
|
|
|
return clients
|
|
except Exception as e:
|
|
return []
|
|
|
|
def disconnect_client(self, server: VPNServer, common_name: str) -> bool:
|
|
"""Disconnect a specific VPN client."""
|
|
response = self._send_management_command(server, f"kill {common_name}")
|
|
return "SUCCESS" in response
|
|
|
|
def get_server_status(self, server: VPNServer) -> dict:
|
|
"""Get detailed server status."""
|
|
try:
|
|
# Try to connect to management interface
|
|
response = self._send_management_command(server, "state")
|
|
|
|
if "ERROR" in response:
|
|
server.status = VPNServerStatus.ERROR
|
|
self.db.commit()
|
|
return {
|
|
"status": "error",
|
|
"message": response
|
|
}
|
|
|
|
# Parse state
|
|
for line in response.split('\n'):
|
|
if 'CONNECTED' in line:
|
|
server.status = VPNServerStatus.RUNNING
|
|
break
|
|
else:
|
|
server.status = VPNServerStatus.STOPPED
|
|
|
|
server.last_status_check = datetime.utcnow()
|
|
self.db.commit()
|
|
|
|
clients = self.get_connected_clients(server)
|
|
|
|
return {
|
|
"status": server.status.value,
|
|
"connected_clients": len(clients),
|
|
"clients": clients,
|
|
"last_check": server.last_status_check.isoformat() if server.last_status_check else None
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"status": "unknown",
|
|
"message": str(e)
|
|
}
|
|
|
|
def _send_management_command(self, server: VPNServer, command: str) -> str:
|
|
"""Send command to OpenVPN management interface."""
|
|
try:
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.settimeout(5)
|
|
|
|
# Connect to management interface
|
|
# OpenVPN runs in host network mode, so we connect via host.docker.internal
|
|
# This resolves to the Docker host from within containers
|
|
host = "host.docker.internal"
|
|
sock.connect((host, server.management_port))
|
|
|
|
# Read welcome message
|
|
sock.recv(1024)
|
|
|
|
# Send command
|
|
sock.send(f"{command}\n".encode())
|
|
|
|
# Read response
|
|
response = b""
|
|
while True:
|
|
data = sock.recv(4096)
|
|
if not data:
|
|
break
|
|
response += data
|
|
if b"END" in data or b"SUCCESS" in data or b"ERROR" in data:
|
|
break
|
|
|
|
sock.close()
|
|
return response.decode()
|
|
except Exception as e:
|
|
return f"ERROR: {str(e)}"
|
|
|
|
def update_all_server_status(self):
|
|
"""Update status for all active servers."""
|
|
servers = self.db.query(VPNServer).filter(
|
|
VPNServer.is_active == True
|
|
).all()
|
|
|
|
for server in servers:
|
|
self.get_server_status(server)
|
|
|
|
def get_server_by_id(self, server_id: int) -> Optional[VPNServer]:
|
|
"""Get a VPN server by ID."""
|
|
return self.db.query(VPNServer).filter(
|
|
VPNServer.id == server_id
|
|
).first()
|
|
|
|
def get_servers_for_tenant(self, tenant_id: Optional[int] = None) -> list[VPNServer]:
|
|
"""Get all VPN servers for a tenant."""
|
|
query = self.db.query(VPNServer)
|
|
|
|
if tenant_id:
|
|
query = query.filter(
|
|
(VPNServer.tenant_id == tenant_id) | (VPNServer.tenant_id == None)
|
|
)
|
|
else:
|
|
query = query.filter(VPNServer.tenant_id == None)
|
|
|
|
return query.order_by(VPNServer.name).all()
|