first commit

This commit is contained in:
Stefan Hacker
2026-02-02 09:46:35 +01:00
commit 6901dc369b
98 changed files with 13030 additions and 0 deletions
+17
View File
@@ -0,0 +1,17 @@
"""Business logic services."""
from .auth_service import AuthService
from .vpn_service import VPNService
from .firewall_service import FirewallService
from .certificate_service import CertificateService
from .vpn_server_service import VPNServerService
from .vpn_profile_service import VPNProfileService
__all__ = [
"AuthService",
"VPNService",
"FirewallService",
"CertificateService",
"VPNServerService",
"VPNProfileService",
]
+86
View File
@@ -0,0 +1,86 @@
"""Authentication service."""
from datetime import datetime
from sqlalchemy.orm import Session
from ..models.user import User
from ..schemas.user import UserCreate, Token
from ..utils.security import (
verify_password, get_password_hash,
create_access_token, create_refresh_token, decode_token
)
class AuthService:
"""Service for authentication operations."""
def __init__(self, db: Session):
self.db = db
def authenticate_user(self, username: str, password: str) -> User | None:
"""Authenticate user with username and password."""
user = self.db.query(User).filter(User.username == username).first()
if not user:
return None
if not verify_password(password, user.password_hash):
return None
if not user.is_active:
return None
# Update last login
user.last_login = datetime.utcnow()
self.db.commit()
return user
def create_tokens(self, user: User) -> Token:
"""Create access and refresh tokens for user."""
access_token = create_access_token(
user_id=user.id,
username=user.username,
role=user.role.value,
tenant_id=user.tenant_id
)
refresh_token = create_refresh_token(user_id=user.id)
return Token(
access_token=access_token,
refresh_token=refresh_token
)
def refresh_tokens(self, refresh_token: str) -> Token | None:
"""Refresh access token using refresh token."""
payload = decode_token(refresh_token)
if not payload:
return None
if payload.get("type") != "refresh":
return None
user_id = payload.get("sub")
user = self.db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
return None
return self.create_tokens(user)
def create_user(self, user_data: UserCreate) -> User:
"""Create a new user."""
user = User(
username=user_data.username,
email=user_data.email,
password_hash=get_password_hash(user_data.password),
full_name=user_data.full_name,
role=user_data.role,
tenant_id=user_data.tenant_id
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
return user
def get_user_by_id(self, user_id: int) -> User | None:
"""Get user by ID."""
return self.db.query(User).filter(User.id == user_id).first()
def get_user_by_username(self, username: str) -> User | None:
"""Get user by username."""
return self.db.query(User).filter(User.username == username).first()
+588
View File
@@ -0,0 +1,588 @@
"""Certificate management service for PKI operations."""
import os
import subprocess
import threading
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
from cryptography import x509
from cryptography.x509.oid import NameOID, ExtensionOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from sqlalchemy.orm import Session
from ..models.certificate_authority import CertificateAuthority, CAStatus
from ..models.vpn_profile import VPNProfile, VPNProfileStatus
from ..config import get_settings
settings = get_settings()
class CertificateService:
"""Service for managing certificates and CAs."""
def __init__(self, db: Session):
self.db = db
def create_ca(
self,
name: str,
key_size: int = 4096,
validity_days: int = 3650,
organization: str = "mGuard VPN",
country: str = "DE",
state: str = "NRW",
city: str = "Dortmund",
tenant_id: Optional[int] = None,
created_by_id: Optional[int] = None,
is_default: bool = False
) -> CertificateAuthority:
"""Create a new Certificate Authority."""
# Generate CA private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=key_size,
backend=default_backend()
)
# Build CA certificate
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state),
x509.NameAttribute(NameOID.LOCALITY_NAME, city),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization),
x509.NameAttribute(NameOID.COMMON_NAME, f"{name} CA"),
])
valid_from = datetime.utcnow()
valid_until = valid_from + timedelta(days=validity_days)
ca_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(valid_from)
.not_valid_after(valid_until)
.add_extension(
x509.BasicConstraints(ca=True, path_length=0),
critical=True,
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_cert_sign=True,
crl_sign=True,
key_encipherment=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()),
critical=False,
)
.sign(private_key, hashes.SHA256(), default_backend())
)
# Serialize to PEM
ca_cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
ca_key_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
).decode('utf-8')
# If this is default, unset other defaults for this tenant
if is_default:
self.db.query(CertificateAuthority).filter(
CertificateAuthority.tenant_id == tenant_id,
CertificateAuthority.is_default == True
).update({"is_default": False})
# Create CA record
ca = CertificateAuthority(
name=name,
tenant_id=tenant_id,
ca_cert=ca_cert_pem,
ca_key=ca_key_pem,
key_size=key_size,
valid_from=valid_from,
valid_until=valid_until,
is_default=is_default,
status=CAStatus.PENDING, # Will be ACTIVE after DH generation
created_by_id=created_by_id,
dh_generating=True
)
self.db.add(ca)
self.db.commit()
self.db.refresh(ca)
# Start DH parameter generation in background
self._generate_dh_async(ca.id, key_size)
return ca
def _generate_dh_async(self, ca_id: int, key_size: int):
"""Generate DH parameters in background thread."""
def generate():
try:
# Use openssl for DH generation (faster than pure Python)
result = subprocess.run(
["openssl", "dhparam", "-out", "-", str(key_size)],
capture_output=True,
text=True,
timeout=3600 # 1 hour max
)
if result.returncode == 0:
dh_pem = result.stdout
# Update CA in database (new session needed for thread)
from ..database import SessionLocal
db = SessionLocal()
try:
ca = db.query(CertificateAuthority).filter(
CertificateAuthority.id == ca_id
).first()
if ca:
ca.dh_params = dh_pem
ca.dh_generating = False
ca.status = CAStatus.ACTIVE
db.commit()
finally:
db.close()
except Exception as e:
# Log error and mark CA as failed
from ..database import SessionLocal
db = SessionLocal()
try:
ca = db.query(CertificateAuthority).filter(
CertificateAuthority.id == ca_id
).first()
if ca:
ca.dh_generating = False
ca.description = f"DH generation failed: {str(e)}"
db.commit()
finally:
db.close()
thread = threading.Thread(target=generate, daemon=True)
thread.start()
def import_ca(
self,
name: str,
ca_cert_pem: str,
ca_key_pem: str,
dh_params_pem: Optional[str] = None,
tenant_id: Optional[int] = None,
created_by_id: Optional[int] = None
) -> CertificateAuthority:
"""Import an existing CA from PEM files."""
# Parse certificate to extract metadata
ca_cert = x509.load_pem_x509_certificate(
ca_cert_pem.encode('utf-8'),
default_backend()
)
# Validate it's a CA certificate
try:
basic_constraints = ca_cert.extensions.get_extension_for_oid(
ExtensionOID.BASIC_CONSTRAINTS
)
if not basic_constraints.value.ca:
raise ValueError("Certificate is not a CA certificate")
except x509.extensions.ExtensionNotFound:
raise ValueError("Certificate missing BasicConstraints extension")
# Get key size from private key
private_key = serialization.load_pem_private_key(
ca_key_pem.encode('utf-8'),
password=None,
backend=default_backend()
)
key_size = private_key.key_size
ca = CertificateAuthority(
name=name,
tenant_id=tenant_id,
ca_cert=ca_cert_pem,
ca_key=ca_key_pem,
dh_params=dh_params_pem,
key_size=key_size,
valid_from=ca_cert.not_valid_before_utc,
valid_until=ca_cert.not_valid_after_utc,
status=CAStatus.ACTIVE if dh_params_pem else CAStatus.PENDING,
created_by_id=created_by_id,
dh_generating=dh_params_pem is None
)
self.db.add(ca)
self.db.commit()
self.db.refresh(ca)
# Generate DH if not provided
if not dh_params_pem:
self._generate_dh_async(ca.id, key_size)
return ca
def generate_server_certificate(
self,
ca: CertificateAuthority,
common_name: str,
validity_days: int = 825
) -> dict:
"""Generate a server certificate from the CA."""
if not ca.is_ready:
raise ValueError("CA is not ready for issuing certificates")
# Load CA key and cert
ca_key = serialization.load_pem_private_key(
ca.ca_key.encode('utf-8'),
password=None,
backend=default_backend()
)
ca_cert = x509.load_pem_x509_certificate(
ca.ca_cert.encode('utf-8'),
default_backend()
)
# Generate server private key
server_key = rsa.generate_private_key(
public_exponent=65537,
key_size=ca.key_size,
backend=default_backend()
)
# Build server certificate
valid_from = datetime.utcnow()
valid_until = valid_from + timedelta(days=validity_days)
subject = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
])
server_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(ca_cert.subject)
.public_key(server_key.public_key())
.serial_number(ca.next_serial)
.not_valid_before(valid_from)
.not_valid_after(valid_until)
.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_encipherment=True,
key_cert_sign=False,
crl_sign=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage([
x509.oid.ExtendedKeyUsageOID.SERVER_AUTH,
]),
critical=False,
)
.add_extension(
x509.SubjectAlternativeName([
x509.DNSName(common_name),
]),
critical=False,
)
.sign(ca_key, hashes.SHA256(), default_backend())
)
# Increment serial
ca.next_serial += 1
self.db.commit()
# Serialize to PEM
cert_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
key_pem = server_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
).decode('utf-8')
return {
"cert": cert_pem,
"key": key_pem,
"valid_from": valid_from,
"valid_until": valid_until,
"serial": ca.next_serial - 1
}
def generate_client_certificate(
self,
ca: CertificateAuthority,
common_name: str,
validity_days: int = 365
) -> dict:
"""Generate a client certificate from the CA."""
if not ca.is_ready:
raise ValueError("CA is not ready for issuing certificates")
# Load CA key and cert
ca_key = serialization.load_pem_private_key(
ca.ca_key.encode('utf-8'),
password=None,
backend=default_backend()
)
ca_cert = x509.load_pem_x509_certificate(
ca.ca_cert.encode('utf-8'),
default_backend()
)
# Generate client private key
client_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048, # Client keys can be smaller
backend=default_backend()
)
# Build client certificate
valid_from = datetime.utcnow()
valid_until = valid_from + timedelta(days=validity_days)
subject = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
])
client_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(ca_cert.subject)
.public_key(client_key.public_key())
.serial_number(ca.next_serial)
.not_valid_before(valid_from)
.not_valid_after(valid_until)
.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_encipherment=True,
key_cert_sign=False,
crl_sign=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage([
x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH,
]),
critical=False,
)
.sign(ca_key, hashes.SHA256(), default_backend())
)
# Increment serial
ca.next_serial += 1
self.db.commit()
# Serialize to PEM
cert_pem = client_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
key_pem = client_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
).decode('utf-8')
return {
"cert": cert_pem,
"key": key_pem,
"valid_from": valid_from,
"valid_until": valid_until,
"serial": ca.next_serial - 1
}
def generate_ta_key(self) -> str:
"""Generate TLS-Auth key using OpenVPN."""
try:
result = subprocess.run(
["openvpn", "--genkey", "secret", "/dev/stdout"],
capture_output=True,
text=True,
timeout=30
)
if result.returncode == 0:
return result.stdout
else:
raise RuntimeError(f"Failed to generate TA key: {result.stderr}")
except FileNotFoundError:
# OpenVPN not installed, generate a simple key
import secrets
key_data = secrets.token_hex(256)
return f"""#
# 2048 bit OpenVPN static key
#
-----BEGIN OpenVPN Static key V1-----
{key_data[:64]}
{key_data[64:128]}
{key_data[128:192]}
{key_data[192:256]}
{key_data[256:320]}
{key_data[320:384]}
{key_data[384:448]}
{key_data[448:512]}
-----END OpenVPN Static key V1-----
"""
def revoke_certificate(
self,
ca: CertificateAuthority,
cert_pem: str,
reason: str = "unspecified"
) -> bool:
"""Revoke a certificate and update CRL."""
# Parse certificate to get serial number
cert = x509.load_pem_x509_certificate(
cert_pem.encode('utf-8'),
default_backend()
)
# Load CA key
ca_key = serialization.load_pem_private_key(
ca.ca_key.encode('utf-8'),
password=None,
backend=default_backend()
)
# Load existing CRL or create new one
revoked_certs = []
if ca.crl:
try:
existing_crl = x509.load_pem_x509_crl(
ca.crl.encode('utf-8'),
default_backend()
)
for revoked in existing_crl:
revoked_certs.append(revoked)
except Exception:
pass
# Add new revoked certificate
revoked_cert = (
x509.RevokedCertificateBuilder()
.serial_number(cert.serial_number)
.revocation_date(datetime.utcnow())
.build()
)
revoked_certs.append(revoked_cert)
# Load CA cert
ca_cert = x509.load_pem_x509_certificate(
ca.ca_cert.encode('utf-8'),
default_backend()
)
# Build new CRL
crl_builder = (
x509.CertificateRevocationListBuilder()
.issuer_name(ca_cert.subject)
.last_update(datetime.utcnow())
.next_update(datetime.utcnow() + timedelta(days=180))
)
for revoked in revoked_certs:
crl_builder = crl_builder.add_revoked_certificate(revoked)
crl = crl_builder.sign(ca_key, hashes.SHA256(), default_backend())
# Update CA
ca.crl = crl.public_bytes(serialization.Encoding.PEM).decode('utf-8')
ca.crl_updated_at = datetime.utcnow()
self.db.commit()
return True
def get_crl(self, ca: CertificateAuthority) -> str:
"""Get the CRL for a CA, generating if needed."""
if not ca.crl:
# Generate empty CRL
ca_key = serialization.load_pem_private_key(
ca.ca_key.encode('utf-8'),
password=None,
backend=default_backend()
)
ca_cert = x509.load_pem_x509_certificate(
ca.ca_cert.encode('utf-8'),
default_backend()
)
crl = (
x509.CertificateRevocationListBuilder()
.issuer_name(ca_cert.subject)
.last_update(datetime.utcnow())
.next_update(datetime.utcnow() + timedelta(days=180))
.sign(ca_key, hashes.SHA256(), default_backend())
)
ca.crl = crl.public_bytes(serialization.Encoding.PEM).decode('utf-8')
ca.crl_updated_at = datetime.utcnow()
self.db.commit()
return ca.crl
def get_expiring_certificates(
self,
ca: CertificateAuthority,
days: int = 30
) -> list[VPNProfile]:
"""Get profiles with certificates expiring within given days."""
threshold = datetime.utcnow() + timedelta(days=days)
return self.db.query(VPNProfile).filter(
VPNProfile.ca_id == ca.id,
VPNProfile.valid_until <= threshold,
VPNProfile.status == VPNProfileStatus.ACTIVE
).all()
def get_default_ca(self, tenant_id: Optional[int] = None) -> Optional[CertificateAuthority]:
"""Get the default CA for a tenant (or global)."""
ca = self.db.query(CertificateAuthority).filter(
CertificateAuthority.tenant_id == tenant_id,
CertificateAuthority.is_default == True,
CertificateAuthority.status == CAStatus.ACTIVE
).first()
if not ca and tenant_id:
# Fall back to global CA
ca = self.db.query(CertificateAuthority).filter(
CertificateAuthority.tenant_id == None,
CertificateAuthority.is_default == True,
CertificateAuthority.status == CAStatus.ACTIVE
).first()
return ca
+129
View File
@@ -0,0 +1,129 @@
"""Firewall management service for dynamic iptables rules."""
import subprocess
from typing import Literal
class FirewallService:
"""Service for managing iptables firewall rules."""
# Chain name for our VPN rules
VPN_CHAIN = "MGUARD_VPN"
def __init__(self):
self._ensure_chain_exists()
def _run_iptables(self, args: list[str], check: bool = True) -> subprocess.CompletedProcess:
"""Run iptables command."""
cmd = ["iptables"] + args
return subprocess.run(cmd, capture_output=True, text=True, check=check)
def _ensure_chain_exists(self):
"""Ensure our custom chain exists."""
# Check if chain exists
result = self._run_iptables(["-L", self.VPN_CHAIN], check=False)
if result.returncode != 0:
# Create chain
self._run_iptables(["-N", self.VPN_CHAIN], check=False)
# Add jump to our chain from FORWARD
self._run_iptables(["-I", "FORWARD", "-j", self.VPN_CHAIN], check=False)
def allow_connection(
self,
client_vpn_ip: str,
gateway_vpn_ip: str,
target_ip: str,
target_port: int,
protocol: Literal["tcp", "udp"] = "tcp"
) -> bool:
"""Allow connection from client through gateway to target endpoint."""
try:
# Rule 1: Allow client to reach target through gateway
self._run_iptables([
"-A", self.VPN_CHAIN,
"-s", client_vpn_ip,
"-d", target_ip,
"-p", protocol,
"--dport", str(target_port),
"-j", "ACCEPT"
])
# Rule 2: Allow return traffic
self._run_iptables([
"-A", self.VPN_CHAIN,
"-s", target_ip,
"-d", client_vpn_ip,
"-p", protocol,
"--sport", str(target_port),
"-j", "ACCEPT"
])
# Add NAT/masquerade if needed for routing through gateway
self._run_iptables([
"-t", "nat",
"-A", "POSTROUTING",
"-s", client_vpn_ip,
"-d", target_ip,
"-j", "MASQUERADE"
], check=False)
return True
except subprocess.CalledProcessError:
return False
def revoke_connection(
self,
client_vpn_ip: str,
gateway_vpn_ip: str,
target_ip: str,
target_port: int,
protocol: Literal["tcp", "udp"] = "tcp"
) -> bool:
"""Remove firewall rules for a connection."""
try:
# Remove forward rules
self._run_iptables([
"-D", self.VPN_CHAIN,
"-s", client_vpn_ip,
"-d", target_ip,
"-p", protocol,
"--dport", str(target_port),
"-j", "ACCEPT"
], check=False)
self._run_iptables([
"-D", self.VPN_CHAIN,
"-s", target_ip,
"-d", client_vpn_ip,
"-p", protocol,
"--sport", str(target_port),
"-j", "ACCEPT"
], check=False)
# Remove NAT rule
self._run_iptables([
"-t", "nat",
"-D", "POSTROUTING",
"-s", client_vpn_ip,
"-d", target_ip,
"-j", "MASQUERADE"
], check=False)
return True
except subprocess.CalledProcessError:
return False
def list_rules(self) -> list[str]:
"""List all rules in our VPN chain."""
result = self._run_iptables(["-L", self.VPN_CHAIN, "-n", "-v"], check=False)
if result.returncode == 0:
return result.stdout.strip().split('\n')
return []
def flush_rules(self) -> bool:
"""Remove all rules from our VPN chain."""
try:
self._run_iptables(["-F", self.VPN_CHAIN])
return True
except subprocess.CalledProcessError:
return False
+282
View File
@@ -0,0 +1,282 @@
"""VPN Profile management service."""
from datetime import datetime
from typing import Optional
from sqlalchemy.orm import Session
from ..models.vpn_profile import VPNProfile, VPNProfileStatus
from ..models.vpn_server import VPNServer
from ..models.gateway import Gateway
from ..models.certificate_authority import CertificateAuthority
from .certificate_service import CertificateService
from ..config import get_settings
settings = get_settings()
class VPNProfileService:
"""Service for managing VPN profiles for gateways."""
def __init__(self, db: Session):
self.db = db
self.cert_service = CertificateService(db)
def create_profile(
self,
gateway_id: int,
vpn_server_id: int,
name: str,
priority: int = 1,
description: Optional[str] = None
) -> VPNProfile:
"""Create a new VPN profile for a gateway."""
# Get gateway
gateway = self.db.query(Gateway).filter(
Gateway.id == gateway_id
).first()
if not gateway:
raise ValueError(f"Gateway with id {gateway_id} not found")
# Get VPN server
server = self.db.query(VPNServer).filter(
VPNServer.id == vpn_server_id
).first()
if not server:
raise ValueError(f"VPN Server with id {vpn_server_id} not found")
if not server.certificate_authority.is_ready:
raise ValueError("CA is not ready (DH parameters may still be generating)")
# Generate unique common name
base_cn = f"gw-{gateway.name.lower().replace(' ', '-')}"
profile_count = self.db.query(VPNProfile).filter(
VPNProfile.gateway_id == gateway_id
).count()
cert_cn = f"{base_cn}-{profile_count + 1}"
# Create profile
profile = VPNProfile(
gateway_id=gateway_id,
vpn_server_id=vpn_server_id,
ca_id=server.ca_id,
name=name,
description=description,
cert_cn=cert_cn,
priority=priority,
status=VPNProfileStatus.PENDING
)
self.db.add(profile)
self.db.commit()
self.db.refresh(profile)
# Generate client certificate
self._generate_client_cert(profile)
return profile
def _generate_client_cert(self, profile: VPNProfile):
"""Generate client certificate for profile."""
ca = profile.certificate_authority
cert_data = self.cert_service.generate_client_certificate(
ca=ca,
common_name=profile.cert_cn
)
profile.client_cert = cert_data["cert"]
profile.client_key = cert_data["key"]
profile.valid_from = cert_data["valid_from"]
profile.valid_until = cert_data["valid_until"]
profile.status = VPNProfileStatus.ACTIVE
self.db.commit()
def generate_client_config(self, profile: VPNProfile) -> str:
"""Generate OpenVPN client configuration (.ovpn) for a profile."""
if not profile.is_ready:
raise ValueError("Profile is not ready for provisioning")
server = profile.vpn_server
ca = profile.certificate_authority
config_lines = [
"# OpenVPN Client Configuration",
f"# Profile: {profile.name}",
f"# Gateway: {profile.gateway.name}",
f"# Server: {server.name}",
f"# Generated: {datetime.utcnow().isoformat()}",
"",
"client",
"dev tun",
f"proto {server.protocol.value}",
f"remote {server.hostname} {server.port}",
"",
"resolv-retry infinite",
"nobind",
"persist-key",
"persist-tun",
"",
"remote-cert-tls server",
f"cipher {server.cipher.value}",
f"auth {server.auth.value}",
"",
"verb 3",
"",
]
# Add CA certificate
config_lines.extend([
"<ca>",
ca.ca_cert.strip(),
"</ca>",
"",
])
# Add client certificate
config_lines.extend([
"<cert>",
profile.client_cert.strip(),
"</cert>",
"",
])
# Add client key
config_lines.extend([
"<key>",
profile.client_key.strip(),
"</key>",
"",
])
# Add TLS-Auth key if enabled
if server.tls_auth_enabled and server.ta_key:
config_lines.extend([
"key-direction 1",
"<tls-auth>",
server.ta_key.strip(),
"</tls-auth>",
])
return "\n".join(config_lines)
def provision_profile(self, profile: VPNProfile) -> str:
"""Mark profile as provisioned and return config."""
config = self.generate_client_config(profile)
profile.status = VPNProfileStatus.PROVISIONED
profile.provisioned_at = datetime.utcnow()
profile.gateway.is_provisioned = True
self.db.commit()
return config
def set_priority(self, profile: VPNProfile, new_priority: int):
"""Set profile priority and reorder others if needed."""
gateway_id = profile.gateway_id
# Get all profiles for this gateway ordered by priority
profiles = self.db.query(VPNProfile).filter(
VPNProfile.gateway_id == gateway_id,
VPNProfile.id != profile.id
).order_by(VPNProfile.priority).all()
# Update priorities
current_priority = 1
for p in profiles:
if current_priority == new_priority:
current_priority += 1
p.priority = current_priority
current_priority += 1
profile.priority = new_priority
self.db.commit()
def get_profiles_for_gateway(self, gateway_id: int) -> list[VPNProfile]:
"""Get all VPN profiles for a gateway, ordered by priority."""
return self.db.query(VPNProfile).filter(
VPNProfile.gateway_id == gateway_id
).order_by(VPNProfile.priority).all()
def get_active_profiles_for_gateway(self, gateway_id: int) -> list[VPNProfile]:
"""Get active VPN profiles for a gateway."""
return self.db.query(VPNProfile).filter(
VPNProfile.gateway_id == gateway_id,
VPNProfile.is_active == True,
VPNProfile.status.in_([VPNProfileStatus.ACTIVE, VPNProfileStatus.PROVISIONED])
).order_by(VPNProfile.priority).all()
def revoke_profile(self, profile: VPNProfile, reason: str = "unspecified"):
"""Revoke a VPN profile's certificate."""
if profile.client_cert:
self.cert_service.revoke_certificate(
ca=profile.certificate_authority,
cert_pem=profile.client_cert,
reason=reason
)
profile.status = VPNProfileStatus.REVOKED
profile.is_active = False
self.db.commit()
def renew_profile(self, profile: VPNProfile):
"""Renew the certificate for a profile."""
# Revoke old certificate
if profile.client_cert:
self.cert_service.revoke_certificate(
ca=profile.certificate_authority,
cert_pem=profile.client_cert,
reason="superseded"
)
# Generate new certificate
self._generate_client_cert(profile)
# Reset provisioned status
profile.status = VPNProfileStatus.ACTIVE
profile.provisioned_at = None
self.db.commit()
def get_profile_by_id(self, profile_id: int) -> Optional[VPNProfile]:
"""Get a VPN profile by ID."""
return self.db.query(VPNProfile).filter(
VPNProfile.id == profile_id
).first()
def delete_profile(self, profile: VPNProfile):
"""Delete a VPN profile."""
# Revoke certificate first
if profile.client_cert and profile.status != VPNProfileStatus.REVOKED:
self.cert_service.revoke_certificate(
ca=profile.certificate_authority,
cert_pem=profile.client_cert,
reason="cessationOfOperation"
)
self.db.delete(profile)
self.db.commit()
def generate_all_configs_for_gateway(self, gateway_id: int) -> list[tuple[str, str]]:
"""Generate configs for all active profiles of a gateway.
Returns list of tuples: (filename, config_content)
"""
profiles = self.get_active_profiles_for_gateway(gateway_id)
configs = []
for profile in profiles:
try:
config = self.generate_client_config(profile)
filename = f"{profile.gateway.name}-{profile.name}.ovpn"
filename = filename.lower().replace(' ', '-')
configs.append((filename, config))
except Exception:
continue
return configs
+344
View File
@@ -0,0 +1,344 @@
"""VPN Server management service."""
import socket
from datetime import datetime
from typing import Optional
from sqlalchemy.orm import Session
from ..models.vpn_server import VPNServer, VPNServerStatus, VPNProtocol, VPNCipher, VPNAuth, VPNCompression
from ..models.certificate_authority import CertificateAuthority
from .certificate_service import CertificateService
from ..config import get_settings
settings = get_settings()
class VPNServerService:
"""Service for managing VPN server instances."""
def __init__(self, db: Session):
self.db = db
self.cert_service = CertificateService(db)
def create_server(
self,
name: str,
hostname: str,
ca_id: int,
port: int = 1194,
protocol: VPNProtocol = VPNProtocol.UDP,
vpn_network: str = "10.8.0.0",
vpn_netmask: str = "255.255.255.0",
cipher: VPNCipher = VPNCipher.AES_256_GCM,
auth: VPNAuth = VPNAuth.SHA256,
tls_version_min: str = "1.2",
compression: VPNCompression = VPNCompression.NONE,
max_clients: int = 100,
keepalive_interval: int = 10,
keepalive_timeout: int = 60,
management_port: int = 7505,
is_primary: bool = False,
tenant_id: Optional[int] = None,
description: Optional[str] = None
) -> VPNServer:
"""Create a new VPN server instance."""
# Get CA
ca = self.db.query(CertificateAuthority).filter(
CertificateAuthority.id == ca_id
).first()
if not ca:
raise ValueError(f"CA with id {ca_id} not found")
if not ca.is_ready:
raise ValueError("CA is not ready (DH parameters may still be generating)")
# Generate container name
existing_count = self.db.query(VPNServer).count()
container_name = f"mguard-openvpn-{existing_count + 1}"
# If setting as primary, unset other primaries
if is_primary:
self.db.query(VPNServer).filter(
VPNServer.tenant_id == tenant_id,
VPNServer.is_primary == True
).update({"is_primary": False})
# Create server record
server = VPNServer(
name=name,
description=description,
hostname=hostname,
port=port,
protocol=protocol,
vpn_network=vpn_network,
vpn_netmask=vpn_netmask,
cipher=cipher,
auth=auth,
tls_version_min=tls_version_min,
compression=compression,
max_clients=max_clients,
keepalive_interval=keepalive_interval,
keepalive_timeout=keepalive_timeout,
management_port=management_port,
docker_container_name=container_name,
ca_id=ca_id,
tenant_id=tenant_id,
is_primary=is_primary,
status=VPNServerStatus.PENDING
)
self.db.add(server)
self.db.commit()
self.db.refresh(server)
# Generate server certificate
self._generate_server_cert(server)
return server
def _generate_server_cert(self, server: VPNServer):
"""Generate server certificate and TA key."""
ca = server.certificate_authority
# Generate server certificate
cert_data = self.cert_service.generate_server_certificate(
ca=ca,
common_name=f"{server.name}-server"
)
server.server_cert = cert_data["cert"]
server.server_key = cert_data["key"]
# Generate TLS-Auth key
server.ta_key = self.cert_service.generate_ta_key()
self.db.commit()
def generate_server_config(self, server: VPNServer) -> str:
"""Generate OpenVPN server configuration file."""
if not server.certificate_authority:
raise ValueError("Server has no CA assigned")
config_lines = [
"# OpenVPN Server Configuration",
f"# Server: {server.name}",
f"# Generated: {datetime.utcnow().isoformat()}",
"",
"# Basic settings",
f"port {server.port}",
f"proto {server.protocol.value}",
"dev tun",
"",
"# Certificates",
"ca /etc/openvpn/ca.crt",
"cert /etc/openvpn/server.crt",
"key /etc/openvpn/server.key",
"dh /etc/openvpn/dh.pem",
"crl-verify /etc/openvpn/crl.pem",
"",
"# TLS Auth",
]
if server.tls_auth_enabled and server.ta_key:
config_lines.append("tls-auth /etc/openvpn/ta.key 0")
config_lines.extend([
f"tls-version-min {server.tls_version_min}",
"",
"# Network",
f"server {server.vpn_network} {server.vpn_netmask}",
"topology subnet",
"",
"# Routing",
'push "redirect-gateway def1 bypass-dhcp"',
'push "dhcp-option DNS 8.8.8.8"',
'push "dhcp-option DNS 8.8.4.4"',
"",
"# Security",
f"cipher {server.cipher.value}",
f"auth {server.auth.value}",
"",
"# Performance",
f"keepalive {server.keepalive_interval} {server.keepalive_timeout}",
f"max-clients {server.max_clients}",
])
# Compression handling
# Note: OpenVPN 2.5+ deprecates compression due to VORACLE attack
if server.compression != VPNCompression.NONE:
config_lines.append(f"compress {server.compression.value}")
config_lines.append("allow-compression yes")
# If compression is NONE, don't add any directive - OpenVPN defaults to no compression
config_lines.extend([
"",
"# Persistence",
"persist-key",
"persist-tun",
"",
"# Logging",
f"status /var/log/openvpn/status-{server.id}.log",
f"log-append /var/log/openvpn/server-{server.id}.log",
"verb 3",
"",
"# Management interface",
f"management 0.0.0.0 {server.management_port}",
"",
"# User/Group (for Linux)",
"user nobody",
"group nogroup",
"",
"# Client config directory",
"client-config-dir /etc/openvpn/ccd",
"",
"# Scripts",
"script-security 2",
"client-connect /etc/openvpn/scripts/client-connect.sh",
"client-disconnect /etc/openvpn/scripts/client-disconnect.sh",
])
return "\n".join(config_lines)
def get_connected_clients(self, server: VPNServer) -> list[dict]:
"""Get list of currently connected VPN clients."""
try:
response = self._send_management_command(server, "status")
clients = []
if "ERROR" in response:
return clients
# Parse status output
in_client_list = False
for line in response.split('\n'):
if line.startswith("ROUTING TABLE"):
in_client_list = False
elif line.startswith("Common Name"):
in_client_list = True
continue
elif in_client_list and ',' in line:
parts = line.split(',')
if len(parts) >= 5:
clients.append({
"common_name": parts[0],
"real_address": parts[1],
"bytes_received": int(parts[2]) if parts[2].isdigit() else 0,
"bytes_sent": int(parts[3]) if parts[3].isdigit() else 0,
"connected_since": parts[4]
})
# Update connected count
server.connected_clients = len(clients)
server.last_status_check = datetime.utcnow()
self.db.commit()
return clients
except Exception as e:
return []
def disconnect_client(self, server: VPNServer, common_name: str) -> bool:
"""Disconnect a specific VPN client."""
response = self._send_management_command(server, f"kill {common_name}")
return "SUCCESS" in response
def get_server_status(self, server: VPNServer) -> dict:
"""Get detailed server status."""
try:
# Try to connect to management interface
response = self._send_management_command(server, "state")
if "ERROR" in response:
server.status = VPNServerStatus.ERROR
self.db.commit()
return {
"status": "error",
"message": response
}
# Parse state
for line in response.split('\n'):
if 'CONNECTED' in line:
server.status = VPNServerStatus.RUNNING
break
else:
server.status = VPNServerStatus.STOPPED
server.last_status_check = datetime.utcnow()
self.db.commit()
clients = self.get_connected_clients(server)
return {
"status": server.status.value,
"connected_clients": len(clients),
"clients": clients,
"last_check": server.last_status_check.isoformat() if server.last_status_check else None
}
except Exception as e:
return {
"status": "unknown",
"message": str(e)
}
def _send_management_command(self, server: VPNServer, command: str) -> str:
"""Send command to OpenVPN management interface."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
# Connect to management interface
# OpenVPN runs in host network mode, so we connect via host.docker.internal
# This resolves to the Docker host from within containers
host = "host.docker.internal"
sock.connect((host, server.management_port))
# Read welcome message
sock.recv(1024)
# Send command
sock.send(f"{command}\n".encode())
# Read response
response = b""
while True:
data = sock.recv(4096)
if not data:
break
response += data
if b"END" in data or b"SUCCESS" in data or b"ERROR" in data:
break
sock.close()
return response.decode()
except Exception as e:
return f"ERROR: {str(e)}"
def update_all_server_status(self):
"""Update status for all active servers."""
servers = self.db.query(VPNServer).filter(
VPNServer.is_active == True
).all()
for server in servers:
self.get_server_status(server)
def get_server_by_id(self, server_id: int) -> Optional[VPNServer]:
"""Get a VPN server by ID."""
return self.db.query(VPNServer).filter(
VPNServer.id == server_id
).first()
def get_servers_for_tenant(self, tenant_id: Optional[int] = None) -> list[VPNServer]:
"""Get all VPN servers for a tenant."""
query = self.db.query(VPNServer)
if tenant_id:
query = query.filter(
(VPNServer.tenant_id == tenant_id) | (VPNServer.tenant_id == None)
)
else:
query = query.filter(VPNServer.tenant_id == None)
return query.order_by(VPNServer.name).all()
+95
View File
@@ -0,0 +1,95 @@
"""VPN management service for OpenVPN operations.
This service provides basic OpenVPN management interface communication.
For PKI/certificate management, use CertificateService.
For VPN server management, use VPNServerService.
"""
import socket
from ..config import get_settings
settings = get_settings()
class VPNService:
"""Service for OpenVPN management interface operations."""
def __init__(self, host: str = None, port: int = None):
self.management_host = host or settings.openvpn_management_host
self.management_port = port or settings.openvpn_management_port
def _send_management_command(self, command: str) -> str:
"""Send command to OpenVPN management interface."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
sock.connect((self.management_host, self.management_port))
# Read welcome message
sock.recv(1024)
# Send command
sock.send(f"{command}\n".encode())
# Read response
response = b""
while True:
data = sock.recv(4096)
if not data:
break
response += data
if b"END" in data or b"SUCCESS" in data or b"ERROR" in data:
break
sock.close()
return response.decode()
except Exception as e:
return f"ERROR: {str(e)}"
def get_connected_clients(self) -> list[dict]:
"""Get list of currently connected VPN clients."""
response = self._send_management_command("status")
clients = []
if "ERROR" in response:
return clients
# Parse status output
in_client_list = False
for line in response.split('\n'):
if line.startswith("ROUTING TABLE"):
in_client_list = False
elif line.startswith("Common Name"):
in_client_list = True
continue
elif in_client_list and ',' in line:
parts = line.split(',')
if len(parts) >= 5:
clients.append({
"common_name": parts[0],
"real_address": parts[1],
"bytes_received": int(parts[2]) if parts[2].isdigit() else 0,
"bytes_sent": int(parts[3]) if parts[3].isdigit() else 0,
"connected_since": parts[4]
})
return clients
def disconnect_client(self, common_name: str) -> bool:
"""Disconnect a specific VPN client."""
response = self._send_management_command(f"kill {common_name}")
return "SUCCESS" in response
def get_server_status(self) -> dict:
"""Get OpenVPN server status."""
response = self._send_management_command("state")
if "ERROR" in response:
return {"status": "error", "message": response}
# Parse state
for line in response.split('\n'):
if 'CONNECTED' in line:
return {"status": "running", "message": "Server is running"}
return {"status": "unknown", "message": response}
+182
View File
@@ -0,0 +1,182 @@
"""VPN Sync Service for synchronizing VPN connection state with database."""
from datetime import datetime
from sqlalchemy.orm import Session
from ..models.vpn_server import VPNServer
from ..models.vpn_profile import VPNProfile
from ..models.vpn_connection_log import VPNConnectionLog
from ..models.gateway import Gateway
from .vpn_server_service import VPNServerService
class VPNSyncService:
"""Service to sync VPN connections with gateway status and logs."""
def __init__(self, db: Session):
self.db = db
self.vpn_service = VPNServerService(db)
def sync_all_connections(self) -> dict:
"""
Sync all VPN connections across all servers.
Returns summary of changes.
"""
result = {
"servers_checked": 0,
"clients_found": 0,
"gateways_online": 0,
"gateways_offline": 0,
"new_connections": 0,
"closed_connections": 0,
"errors": []
}
# Get all active VPN servers
servers = self.db.query(VPNServer).filter(VPNServer.is_active == True).all()
# Collect all currently connected CNs
connected_cns = set()
for server in servers:
result["servers_checked"] += 1
try:
clients = self.vpn_service.get_connected_clients(server)
result["clients_found"] += len(clients)
for client in clients:
cn = client.get("common_name")
if cn:
connected_cns.add(cn)
self._handle_connected_client(server, client, result)
except Exception as e:
result["errors"].append(f"Server {server.name}: {str(e)}")
# Mark disconnected profiles/gateways
self._handle_disconnected_profiles(connected_cns, result)
self.db.commit()
return result
def _handle_connected_client(self, server: VPNServer, client: dict, result: dict):
"""Handle a connected VPN client - update profile, gateway, and log."""
cn = client.get("common_name")
real_address = client.get("real_address")
bytes_rx = client.get("bytes_received", 0)
bytes_tx = client.get("bytes_sent", 0)
connected_since = client.get("connected_since")
# Find profile by CN
profile = self.db.query(VPNProfile).filter(
VPNProfile.cert_cn == cn,
VPNProfile.vpn_server_id == server.id
).first()
if not profile:
# Try finding by CN across all servers (in case of migration)
profile = self.db.query(VPNProfile).filter(
VPNProfile.cert_cn == cn
).first()
if not profile:
return # Unknown client, skip
gateway = profile.gateway
# Update gateway status
if not gateway.is_online:
gateway.is_online = True
gateway.last_seen = datetime.utcnow()
result["gateways_online"] += 1
# Update profile last connection
profile.last_connection = datetime.utcnow()
# Check for existing active connection log
active_log = self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.vpn_profile_id == profile.id,
VPNConnectionLog.disconnected_at.is_(None)
).first()
if not active_log:
# Create new connection log
log = VPNConnectionLog(
vpn_profile_id=profile.id,
vpn_server_id=server.id,
gateway_id=gateway.id,
common_name=cn,
real_address=real_address,
connected_at=datetime.utcnow(),
bytes_received=bytes_rx,
bytes_sent=bytes_tx
)
self.db.add(log)
result["new_connections"] += 1
else:
# Update traffic stats
active_log.bytes_received = bytes_rx
active_log.bytes_sent = bytes_tx
active_log.real_address = real_address
def _handle_disconnected_profiles(self, connected_cns: set, result: dict):
"""Mark profiles as disconnected if their CN is no longer connected."""
# Find all active connection logs
active_logs = self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.disconnected_at.is_(None)
).all()
for log in active_logs:
if log.common_name not in connected_cns:
# Mark as disconnected
log.disconnected_at = datetime.utcnow()
result["closed_connections"] += 1
# Update gateway status
gateway = log.gateway
# Check if gateway has any other active connections
other_active = self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.gateway_id == gateway.id,
VPNConnectionLog.id != log.id,
VPNConnectionLog.disconnected_at.is_(None)
).first()
if not other_active:
gateway.is_online = False
result["gateways_offline"] += 1
def get_profile_connection_logs(
self,
profile_id: int,
limit: int = 50
) -> list[VPNConnectionLog]:
"""Get connection logs for a specific profile."""
return self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.vpn_profile_id == profile_id
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
def get_gateway_connection_logs(
self,
gateway_id: int,
limit: int = 50
) -> list[VPNConnectionLog]:
"""Get connection logs for a gateway (all profiles)."""
return self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.gateway_id == gateway_id
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
def get_server_connection_logs(
self,
server_id: int,
limit: int = 50
) -> list[VPNConnectionLog]:
"""Get connection logs for a VPN server."""
return self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.vpn_server_id == server_id
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
def get_active_connections(self) -> list[VPNConnectionLog]:
"""Get all currently active VPN connections."""
return self.db.query(VPNConnectionLog).filter(
VPNConnectionLog.disconnected_at.is_(None)
).order_by(VPNConnectionLog.connected_at.desc()).all()