openvpn-endpoint-server/server/app/services/vpn_server_service.py

345 lines
12 KiB
Python

"""VPN Server management service."""
import socket
from datetime import datetime
from typing import Optional
from sqlalchemy.orm import Session
from ..models.vpn_server import VPNServer, VPNServerStatus, VPNProtocol, VPNCipher, VPNAuth, VPNCompression
from ..models.certificate_authority import CertificateAuthority
from .certificate_service import CertificateService
from ..config import get_settings
settings = get_settings()
class VPNServerService:
"""Service for managing VPN server instances."""
def __init__(self, db: Session):
self.db = db
self.cert_service = CertificateService(db)
def create_server(
self,
name: str,
hostname: str,
ca_id: int,
port: int = 1194,
protocol: VPNProtocol = VPNProtocol.UDP,
vpn_network: str = "10.8.0.0",
vpn_netmask: str = "255.255.255.0",
cipher: VPNCipher = VPNCipher.AES_256_GCM,
auth: VPNAuth = VPNAuth.SHA256,
tls_version_min: str = "1.2",
compression: VPNCompression = VPNCompression.NONE,
max_clients: int = 100,
keepalive_interval: int = 10,
keepalive_timeout: int = 60,
management_port: int = 7505,
is_primary: bool = False,
tenant_id: Optional[int] = None,
description: Optional[str] = None
) -> VPNServer:
"""Create a new VPN server instance."""
# Get CA
ca = self.db.query(CertificateAuthority).filter(
CertificateAuthority.id == ca_id
).first()
if not ca:
raise ValueError(f"CA with id {ca_id} not found")
if not ca.is_ready:
raise ValueError("CA is not ready (DH parameters may still be generating)")
# Generate container name
existing_count = self.db.query(VPNServer).count()
container_name = f"mguard-openvpn-{existing_count + 1}"
# If setting as primary, unset other primaries
if is_primary:
self.db.query(VPNServer).filter(
VPNServer.tenant_id == tenant_id,
VPNServer.is_primary == True
).update({"is_primary": False})
# Create server record
server = VPNServer(
name=name,
description=description,
hostname=hostname,
port=port,
protocol=protocol,
vpn_network=vpn_network,
vpn_netmask=vpn_netmask,
cipher=cipher,
auth=auth,
tls_version_min=tls_version_min,
compression=compression,
max_clients=max_clients,
keepalive_interval=keepalive_interval,
keepalive_timeout=keepalive_timeout,
management_port=management_port,
docker_container_name=container_name,
ca_id=ca_id,
tenant_id=tenant_id,
is_primary=is_primary,
status=VPNServerStatus.PENDING
)
self.db.add(server)
self.db.commit()
self.db.refresh(server)
# Generate server certificate
self._generate_server_cert(server)
return server
def _generate_server_cert(self, server: VPNServer):
"""Generate server certificate and TA key."""
ca = server.certificate_authority
# Generate server certificate
cert_data = self.cert_service.generate_server_certificate(
ca=ca,
common_name=f"{server.name}-server"
)
server.server_cert = cert_data["cert"]
server.server_key = cert_data["key"]
# Generate TLS-Auth key
server.ta_key = self.cert_service.generate_ta_key()
self.db.commit()
def generate_server_config(self, server: VPNServer) -> str:
"""Generate OpenVPN server configuration file."""
if not server.certificate_authority:
raise ValueError("Server has no CA assigned")
config_lines = [
"# OpenVPN Server Configuration",
f"# Server: {server.name}",
f"# Generated: {datetime.utcnow().isoformat()}",
"",
"# Basic settings",
f"port {server.port}",
f"proto {server.protocol.value}",
"dev tun",
"",
"# Certificates",
"ca /etc/openvpn/ca.crt",
"cert /etc/openvpn/server.crt",
"key /etc/openvpn/server.key",
"dh /etc/openvpn/dh.pem",
"crl-verify /etc/openvpn/crl.pem",
"",
"# TLS Auth",
]
if server.tls_auth_enabled and server.ta_key:
config_lines.append("tls-auth /etc/openvpn/ta.key 0")
config_lines.extend([
f"tls-version-min {server.tls_version_min}",
"",
"# Network",
f"server {server.vpn_network} {server.vpn_netmask}",
"topology subnet",
"",
"# Routing",
'push "redirect-gateway def1 bypass-dhcp"',
'push "dhcp-option DNS 8.8.8.8"',
'push "dhcp-option DNS 8.8.4.4"',
"",
"# Security",
f"cipher {server.cipher.value}",
f"auth {server.auth.value}",
"",
"# Performance",
f"keepalive {server.keepalive_interval} {server.keepalive_timeout}",
f"max-clients {server.max_clients}",
])
# Compression handling
# Note: OpenVPN 2.5+ deprecates compression due to VORACLE attack
if server.compression != VPNCompression.NONE:
config_lines.append(f"compress {server.compression.value}")
config_lines.append("allow-compression yes")
# If compression is NONE, don't add any directive - OpenVPN defaults to no compression
config_lines.extend([
"",
"# Persistence",
"persist-key",
"persist-tun",
"",
"# Logging",
f"status /var/log/openvpn/status-{server.id}.log",
f"log-append /var/log/openvpn/server-{server.id}.log",
"verb 3",
"",
"# Management interface",
f"management 0.0.0.0 {server.management_port}",
"",
"# User/Group (for Linux)",
"user nobody",
"group nogroup",
"",
"# Client config directory",
"client-config-dir /etc/openvpn/ccd",
"",
"# Scripts",
"script-security 2",
"client-connect /etc/openvpn/scripts/client-connect.sh",
"client-disconnect /etc/openvpn/scripts/client-disconnect.sh",
])
return "\n".join(config_lines)
def get_connected_clients(self, server: VPNServer) -> list[dict]:
"""Get list of currently connected VPN clients."""
try:
response = self._send_management_command(server, "status")
clients = []
if "ERROR" in response:
return clients
# Parse status output
in_client_list = False
for line in response.split('\n'):
if line.startswith("ROUTING TABLE"):
in_client_list = False
elif line.startswith("Common Name"):
in_client_list = True
continue
elif in_client_list and ',' in line:
parts = line.split(',')
if len(parts) >= 5:
clients.append({
"common_name": parts[0],
"real_address": parts[1],
"bytes_received": int(parts[2]) if parts[2].isdigit() else 0,
"bytes_sent": int(parts[3]) if parts[3].isdigit() else 0,
"connected_since": parts[4]
})
# Update connected count
server.connected_clients = len(clients)
server.last_status_check = datetime.utcnow()
self.db.commit()
return clients
except Exception as e:
return []
def disconnect_client(self, server: VPNServer, common_name: str) -> bool:
"""Disconnect a specific VPN client."""
response = self._send_management_command(server, f"kill {common_name}")
return "SUCCESS" in response
def get_server_status(self, server: VPNServer) -> dict:
"""Get detailed server status."""
try:
# Try to connect to management interface
response = self._send_management_command(server, "state")
if "ERROR" in response:
server.status = VPNServerStatus.ERROR
self.db.commit()
return {
"status": "error",
"message": response
}
# Parse state
for line in response.split('\n'):
if 'CONNECTED' in line:
server.status = VPNServerStatus.RUNNING
break
else:
server.status = VPNServerStatus.STOPPED
server.last_status_check = datetime.utcnow()
self.db.commit()
clients = self.get_connected_clients(server)
return {
"status": server.status.value,
"connected_clients": len(clients),
"clients": clients,
"last_check": server.last_status_check.isoformat() if server.last_status_check else None
}
except Exception as e:
return {
"status": "unknown",
"message": str(e)
}
def _send_management_command(self, server: VPNServer, command: str) -> str:
"""Send command to OpenVPN management interface."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
# Connect to management interface
# OpenVPN runs in host network mode, so we connect via host.docker.internal
# This resolves to the Docker host from within containers
host = "host.docker.internal"
sock.connect((host, server.management_port))
# Read welcome message
sock.recv(1024)
# Send command
sock.send(f"{command}\n".encode())
# Read response
response = b""
while True:
data = sock.recv(4096)
if not data:
break
response += data
if b"END" in data or b"SUCCESS" in data or b"ERROR" in data:
break
sock.close()
return response.decode()
except Exception as e:
return f"ERROR: {str(e)}"
def update_all_server_status(self):
"""Update status for all active servers."""
servers = self.db.query(VPNServer).filter(
VPNServer.is_active == True
).all()
for server in servers:
self.get_server_status(server)
def get_server_by_id(self, server_id: int) -> Optional[VPNServer]:
"""Get a VPN server by ID."""
return self.db.query(VPNServer).filter(
VPNServer.id == server_id
).first()
def get_servers_for_tenant(self, tenant_id: Optional[int] = None) -> list[VPNServer]:
"""Get all VPN servers for a tenant."""
query = self.db.query(VPNServer)
if tenant_id:
query = query.filter(
(VPNServer.tenant_id == tenant_id) | (VPNServer.tenant_id == None)
)
else:
query = query.filter(VPNServer.tenant_id == None)
return query.order_by(VPNServer.name).all()