first commit
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
"""Business logic services."""
|
||||
|
||||
from .auth_service import AuthService
|
||||
from .vpn_service import VPNService
|
||||
from .firewall_service import FirewallService
|
||||
from .certificate_service import CertificateService
|
||||
from .vpn_server_service import VPNServerService
|
||||
from .vpn_profile_service import VPNProfileService
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
"VPNService",
|
||||
"FirewallService",
|
||||
"CertificateService",
|
||||
"VPNServerService",
|
||||
"VPNProfileService",
|
||||
]
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Authentication service."""
|
||||
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from ..models.user import User
|
||||
from ..schemas.user import UserCreate, Token
|
||||
from ..utils.security import (
|
||||
verify_password, get_password_hash,
|
||||
create_access_token, create_refresh_token, decode_token
|
||||
)
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for authentication operations."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def authenticate_user(self, username: str, password: str) -> User | None:
|
||||
"""Authenticate user with username and password."""
|
||||
user = self.db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
if not user.is_active:
|
||||
return None
|
||||
|
||||
# Update last login
|
||||
user.last_login = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
return user
|
||||
|
||||
def create_tokens(self, user: User) -> Token:
|
||||
"""Create access and refresh tokens for user."""
|
||||
access_token = create_access_token(
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
role=user.role.value,
|
||||
tenant_id=user.tenant_id
|
||||
)
|
||||
refresh_token = create_refresh_token(user_id=user.id)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token
|
||||
)
|
||||
|
||||
def refresh_tokens(self, refresh_token: str) -> Token | None:
|
||||
"""Refresh access token using refresh token."""
|
||||
payload = decode_token(refresh_token)
|
||||
if not payload:
|
||||
return None
|
||||
if payload.get("type") != "refresh":
|
||||
return None
|
||||
|
||||
user_id = payload.get("sub")
|
||||
user = self.db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
|
||||
return self.create_tokens(user)
|
||||
|
||||
def create_user(self, user_data: UserCreate) -> User:
|
||||
"""Create a new user."""
|
||||
user = User(
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
password_hash=get_password_hash(user_data.password),
|
||||
full_name=user_data.full_name,
|
||||
role=user_data.role,
|
||||
tenant_id=user_data.tenant_id
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> User | None:
|
||||
"""Get user by ID."""
|
||||
return self.db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
def get_user_by_username(self, username: str) -> User | None:
|
||||
"""Get user by username."""
|
||||
return self.db.query(User).filter(User.username == username).first()
|
||||
@@ -0,0 +1,588 @@
|
||||
"""Certificate management service for PKI operations."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from cryptography import x509
|
||||
from cryptography.x509.oid import NameOID, ExtensionOID
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models.certificate_authority import CertificateAuthority, CAStatus
|
||||
from ..models.vpn_profile import VPNProfile, VPNProfileStatus
|
||||
from ..config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class CertificateService:
|
||||
"""Service for managing certificates and CAs."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_ca(
|
||||
self,
|
||||
name: str,
|
||||
key_size: int = 4096,
|
||||
validity_days: int = 3650,
|
||||
organization: str = "mGuard VPN",
|
||||
country: str = "DE",
|
||||
state: str = "NRW",
|
||||
city: str = "Dortmund",
|
||||
tenant_id: Optional[int] = None,
|
||||
created_by_id: Optional[int] = None,
|
||||
is_default: bool = False
|
||||
) -> CertificateAuthority:
|
||||
"""Create a new Certificate Authority."""
|
||||
|
||||
# Generate CA private key
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=key_size,
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
# Build CA certificate
|
||||
subject = issuer = x509.Name([
|
||||
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
|
||||
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state),
|
||||
x509.NameAttribute(NameOID.LOCALITY_NAME, city),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, f"{name} CA"),
|
||||
])
|
||||
|
||||
valid_from = datetime.utcnow()
|
||||
valid_until = valid_from + timedelta(days=validity_days)
|
||||
|
||||
ca_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(issuer)
|
||||
.public_key(private_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(valid_from)
|
||||
.not_valid_after(valid_until)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=0),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
key_encipherment=False,
|
||||
content_commitment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()),
|
||||
critical=False,
|
||||
)
|
||||
.sign(private_key, hashes.SHA256(), default_backend())
|
||||
)
|
||||
|
||||
# Serialize to PEM
|
||||
ca_cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
|
||||
ca_key_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
).decode('utf-8')
|
||||
|
||||
# If this is default, unset other defaults for this tenant
|
||||
if is_default:
|
||||
self.db.query(CertificateAuthority).filter(
|
||||
CertificateAuthority.tenant_id == tenant_id,
|
||||
CertificateAuthority.is_default == True
|
||||
).update({"is_default": False})
|
||||
|
||||
# Create CA record
|
||||
ca = CertificateAuthority(
|
||||
name=name,
|
||||
tenant_id=tenant_id,
|
||||
ca_cert=ca_cert_pem,
|
||||
ca_key=ca_key_pem,
|
||||
key_size=key_size,
|
||||
valid_from=valid_from,
|
||||
valid_until=valid_until,
|
||||
is_default=is_default,
|
||||
status=CAStatus.PENDING, # Will be ACTIVE after DH generation
|
||||
created_by_id=created_by_id,
|
||||
dh_generating=True
|
||||
)
|
||||
|
||||
self.db.add(ca)
|
||||
self.db.commit()
|
||||
self.db.refresh(ca)
|
||||
|
||||
# Start DH parameter generation in background
|
||||
self._generate_dh_async(ca.id, key_size)
|
||||
|
||||
return ca
|
||||
|
||||
def _generate_dh_async(self, ca_id: int, key_size: int):
|
||||
"""Generate DH parameters in background thread."""
|
||||
def generate():
|
||||
try:
|
||||
# Use openssl for DH generation (faster than pure Python)
|
||||
result = subprocess.run(
|
||||
["openssl", "dhparam", "-out", "-", str(key_size)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3600 # 1 hour max
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
dh_pem = result.stdout
|
||||
# Update CA in database (new session needed for thread)
|
||||
from ..database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ca = db.query(CertificateAuthority).filter(
|
||||
CertificateAuthority.id == ca_id
|
||||
).first()
|
||||
if ca:
|
||||
ca.dh_params = dh_pem
|
||||
ca.dh_generating = False
|
||||
ca.status = CAStatus.ACTIVE
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
# Log error and mark CA as failed
|
||||
from ..database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ca = db.query(CertificateAuthority).filter(
|
||||
CertificateAuthority.id == ca_id
|
||||
).first()
|
||||
if ca:
|
||||
ca.dh_generating = False
|
||||
ca.description = f"DH generation failed: {str(e)}"
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
thread = threading.Thread(target=generate, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def import_ca(
|
||||
self,
|
||||
name: str,
|
||||
ca_cert_pem: str,
|
||||
ca_key_pem: str,
|
||||
dh_params_pem: Optional[str] = None,
|
||||
tenant_id: Optional[int] = None,
|
||||
created_by_id: Optional[int] = None
|
||||
) -> CertificateAuthority:
|
||||
"""Import an existing CA from PEM files."""
|
||||
|
||||
# Parse certificate to extract metadata
|
||||
ca_cert = x509.load_pem_x509_certificate(
|
||||
ca_cert_pem.encode('utf-8'),
|
||||
default_backend()
|
||||
)
|
||||
|
||||
# Validate it's a CA certificate
|
||||
try:
|
||||
basic_constraints = ca_cert.extensions.get_extension_for_oid(
|
||||
ExtensionOID.BASIC_CONSTRAINTS
|
||||
)
|
||||
if not basic_constraints.value.ca:
|
||||
raise ValueError("Certificate is not a CA certificate")
|
||||
except x509.extensions.ExtensionNotFound:
|
||||
raise ValueError("Certificate missing BasicConstraints extension")
|
||||
|
||||
# Get key size from private key
|
||||
private_key = serialization.load_pem_private_key(
|
||||
ca_key_pem.encode('utf-8'),
|
||||
password=None,
|
||||
backend=default_backend()
|
||||
)
|
||||
key_size = private_key.key_size
|
||||
|
||||
ca = CertificateAuthority(
|
||||
name=name,
|
||||
tenant_id=tenant_id,
|
||||
ca_cert=ca_cert_pem,
|
||||
ca_key=ca_key_pem,
|
||||
dh_params=dh_params_pem,
|
||||
key_size=key_size,
|
||||
valid_from=ca_cert.not_valid_before_utc,
|
||||
valid_until=ca_cert.not_valid_after_utc,
|
||||
status=CAStatus.ACTIVE if dh_params_pem else CAStatus.PENDING,
|
||||
created_by_id=created_by_id,
|
||||
dh_generating=dh_params_pem is None
|
||||
)
|
||||
|
||||
self.db.add(ca)
|
||||
self.db.commit()
|
||||
self.db.refresh(ca)
|
||||
|
||||
# Generate DH if not provided
|
||||
if not dh_params_pem:
|
||||
self._generate_dh_async(ca.id, key_size)
|
||||
|
||||
return ca
|
||||
|
||||
def generate_server_certificate(
|
||||
self,
|
||||
ca: CertificateAuthority,
|
||||
common_name: str,
|
||||
validity_days: int = 825
|
||||
) -> dict:
|
||||
"""Generate a server certificate from the CA."""
|
||||
|
||||
if not ca.is_ready:
|
||||
raise ValueError("CA is not ready for issuing certificates")
|
||||
|
||||
# Load CA key and cert
|
||||
ca_key = serialization.load_pem_private_key(
|
||||
ca.ca_key.encode('utf-8'),
|
||||
password=None,
|
||||
backend=default_backend()
|
||||
)
|
||||
ca_cert = x509.load_pem_x509_certificate(
|
||||
ca.ca_cert.encode('utf-8'),
|
||||
default_backend()
|
||||
)
|
||||
|
||||
# Generate server private key
|
||||
server_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=ca.key_size,
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
# Build server certificate
|
||||
valid_from = datetime.utcnow()
|
||||
valid_until = valid_from + timedelta(days=validity_days)
|
||||
|
||||
subject = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
])
|
||||
|
||||
server_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(server_key.public_key())
|
||||
.serial_number(ca.next_serial)
|
||||
.not_valid_before(valid_from)
|
||||
.not_valid_after(valid_until)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=False, path_length=None),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
key_encipherment=True,
|
||||
key_cert_sign=False,
|
||||
crl_sign=False,
|
||||
content_commitment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([
|
||||
x509.oid.ExtendedKeyUsageOID.SERVER_AUTH,
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.DNSName(common_name),
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256(), default_backend())
|
||||
)
|
||||
|
||||
# Increment serial
|
||||
ca.next_serial += 1
|
||||
self.db.commit()
|
||||
|
||||
# Serialize to PEM
|
||||
cert_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
|
||||
key_pem = server_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
).decode('utf-8')
|
||||
|
||||
return {
|
||||
"cert": cert_pem,
|
||||
"key": key_pem,
|
||||
"valid_from": valid_from,
|
||||
"valid_until": valid_until,
|
||||
"serial": ca.next_serial - 1
|
||||
}
|
||||
|
||||
def generate_client_certificate(
|
||||
self,
|
||||
ca: CertificateAuthority,
|
||||
common_name: str,
|
||||
validity_days: int = 365
|
||||
) -> dict:
|
||||
"""Generate a client certificate from the CA."""
|
||||
|
||||
if not ca.is_ready:
|
||||
raise ValueError("CA is not ready for issuing certificates")
|
||||
|
||||
# Load CA key and cert
|
||||
ca_key = serialization.load_pem_private_key(
|
||||
ca.ca_key.encode('utf-8'),
|
||||
password=None,
|
||||
backend=default_backend()
|
||||
)
|
||||
ca_cert = x509.load_pem_x509_certificate(
|
||||
ca.ca_cert.encode('utf-8'),
|
||||
default_backend()
|
||||
)
|
||||
|
||||
# Generate client private key
|
||||
client_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048, # Client keys can be smaller
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
# Build client certificate
|
||||
valid_from = datetime.utcnow()
|
||||
valid_until = valid_from + timedelta(days=validity_days)
|
||||
|
||||
subject = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
])
|
||||
|
||||
client_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(client_key.public_key())
|
||||
.serial_number(ca.next_serial)
|
||||
.not_valid_before(valid_from)
|
||||
.not_valid_after(valid_until)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=False, path_length=None),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
key_encipherment=True,
|
||||
key_cert_sign=False,
|
||||
crl_sign=False,
|
||||
content_commitment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([
|
||||
x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH,
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256(), default_backend())
|
||||
)
|
||||
|
||||
# Increment serial
|
||||
ca.next_serial += 1
|
||||
self.db.commit()
|
||||
|
||||
# Serialize to PEM
|
||||
cert_pem = client_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
|
||||
key_pem = client_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
).decode('utf-8')
|
||||
|
||||
return {
|
||||
"cert": cert_pem,
|
||||
"key": key_pem,
|
||||
"valid_from": valid_from,
|
||||
"valid_until": valid_until,
|
||||
"serial": ca.next_serial - 1
|
||||
}
|
||||
|
||||
def generate_ta_key(self) -> str:
|
||||
"""Generate TLS-Auth key using OpenVPN."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["openvpn", "--genkey", "secret", "/dev/stdout"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
else:
|
||||
raise RuntimeError(f"Failed to generate TA key: {result.stderr}")
|
||||
except FileNotFoundError:
|
||||
# OpenVPN not installed, generate a simple key
|
||||
import secrets
|
||||
key_data = secrets.token_hex(256)
|
||||
return f"""#
|
||||
# 2048 bit OpenVPN static key
|
||||
#
|
||||
-----BEGIN OpenVPN Static key V1-----
|
||||
{key_data[:64]}
|
||||
{key_data[64:128]}
|
||||
{key_data[128:192]}
|
||||
{key_data[192:256]}
|
||||
{key_data[256:320]}
|
||||
{key_data[320:384]}
|
||||
{key_data[384:448]}
|
||||
{key_data[448:512]}
|
||||
-----END OpenVPN Static key V1-----
|
||||
"""
|
||||
|
||||
def revoke_certificate(
|
||||
self,
|
||||
ca: CertificateAuthority,
|
||||
cert_pem: str,
|
||||
reason: str = "unspecified"
|
||||
) -> bool:
|
||||
"""Revoke a certificate and update CRL."""
|
||||
# Parse certificate to get serial number
|
||||
cert = x509.load_pem_x509_certificate(
|
||||
cert_pem.encode('utf-8'),
|
||||
default_backend()
|
||||
)
|
||||
|
||||
# Load CA key
|
||||
ca_key = serialization.load_pem_private_key(
|
||||
ca.ca_key.encode('utf-8'),
|
||||
password=None,
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
# Load existing CRL or create new one
|
||||
revoked_certs = []
|
||||
if ca.crl:
|
||||
try:
|
||||
existing_crl = x509.load_pem_x509_crl(
|
||||
ca.crl.encode('utf-8'),
|
||||
default_backend()
|
||||
)
|
||||
for revoked in existing_crl:
|
||||
revoked_certs.append(revoked)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Add new revoked certificate
|
||||
revoked_cert = (
|
||||
x509.RevokedCertificateBuilder()
|
||||
.serial_number(cert.serial_number)
|
||||
.revocation_date(datetime.utcnow())
|
||||
.build()
|
||||
)
|
||||
revoked_certs.append(revoked_cert)
|
||||
|
||||
# Load CA cert
|
||||
ca_cert = x509.load_pem_x509_certificate(
|
||||
ca.ca_cert.encode('utf-8'),
|
||||
default_backend()
|
||||
)
|
||||
|
||||
# Build new CRL
|
||||
crl_builder = (
|
||||
x509.CertificateRevocationListBuilder()
|
||||
.issuer_name(ca_cert.subject)
|
||||
.last_update(datetime.utcnow())
|
||||
.next_update(datetime.utcnow() + timedelta(days=180))
|
||||
)
|
||||
|
||||
for revoked in revoked_certs:
|
||||
crl_builder = crl_builder.add_revoked_certificate(revoked)
|
||||
|
||||
crl = crl_builder.sign(ca_key, hashes.SHA256(), default_backend())
|
||||
|
||||
# Update CA
|
||||
ca.crl = crl.public_bytes(serialization.Encoding.PEM).decode('utf-8')
|
||||
ca.crl_updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
return True
|
||||
|
||||
def get_crl(self, ca: CertificateAuthority) -> str:
|
||||
"""Get the CRL for a CA, generating if needed."""
|
||||
if not ca.crl:
|
||||
# Generate empty CRL
|
||||
ca_key = serialization.load_pem_private_key(
|
||||
ca.ca_key.encode('utf-8'),
|
||||
password=None,
|
||||
backend=default_backend()
|
||||
)
|
||||
ca_cert = x509.load_pem_x509_certificate(
|
||||
ca.ca_cert.encode('utf-8'),
|
||||
default_backend()
|
||||
)
|
||||
|
||||
crl = (
|
||||
x509.CertificateRevocationListBuilder()
|
||||
.issuer_name(ca_cert.subject)
|
||||
.last_update(datetime.utcnow())
|
||||
.next_update(datetime.utcnow() + timedelta(days=180))
|
||||
.sign(ca_key, hashes.SHA256(), default_backend())
|
||||
)
|
||||
|
||||
ca.crl = crl.public_bytes(serialization.Encoding.PEM).decode('utf-8')
|
||||
ca.crl_updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
return ca.crl
|
||||
|
||||
def get_expiring_certificates(
|
||||
self,
|
||||
ca: CertificateAuthority,
|
||||
days: int = 30
|
||||
) -> list[VPNProfile]:
|
||||
"""Get profiles with certificates expiring within given days."""
|
||||
threshold = datetime.utcnow() + timedelta(days=days)
|
||||
|
||||
return self.db.query(VPNProfile).filter(
|
||||
VPNProfile.ca_id == ca.id,
|
||||
VPNProfile.valid_until <= threshold,
|
||||
VPNProfile.status == VPNProfileStatus.ACTIVE
|
||||
).all()
|
||||
|
||||
def get_default_ca(self, tenant_id: Optional[int] = None) -> Optional[CertificateAuthority]:
|
||||
"""Get the default CA for a tenant (or global)."""
|
||||
ca = self.db.query(CertificateAuthority).filter(
|
||||
CertificateAuthority.tenant_id == tenant_id,
|
||||
CertificateAuthority.is_default == True,
|
||||
CertificateAuthority.status == CAStatus.ACTIVE
|
||||
).first()
|
||||
|
||||
if not ca and tenant_id:
|
||||
# Fall back to global CA
|
||||
ca = self.db.query(CertificateAuthority).filter(
|
||||
CertificateAuthority.tenant_id == None,
|
||||
CertificateAuthority.is_default == True,
|
||||
CertificateAuthority.status == CAStatus.ACTIVE
|
||||
).first()
|
||||
|
||||
return ca
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Firewall management service for dynamic iptables rules."""
|
||||
|
||||
import subprocess
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class FirewallService:
|
||||
"""Service for managing iptables firewall rules."""
|
||||
|
||||
# Chain name for our VPN rules
|
||||
VPN_CHAIN = "MGUARD_VPN"
|
||||
|
||||
def __init__(self):
|
||||
self._ensure_chain_exists()
|
||||
|
||||
def _run_iptables(self, args: list[str], check: bool = True) -> subprocess.CompletedProcess:
|
||||
"""Run iptables command."""
|
||||
cmd = ["iptables"] + args
|
||||
return subprocess.run(cmd, capture_output=True, text=True, check=check)
|
||||
|
||||
def _ensure_chain_exists(self):
|
||||
"""Ensure our custom chain exists."""
|
||||
# Check if chain exists
|
||||
result = self._run_iptables(["-L", self.VPN_CHAIN], check=False)
|
||||
if result.returncode != 0:
|
||||
# Create chain
|
||||
self._run_iptables(["-N", self.VPN_CHAIN], check=False)
|
||||
# Add jump to our chain from FORWARD
|
||||
self._run_iptables(["-I", "FORWARD", "-j", self.VPN_CHAIN], check=False)
|
||||
|
||||
def allow_connection(
|
||||
self,
|
||||
client_vpn_ip: str,
|
||||
gateway_vpn_ip: str,
|
||||
target_ip: str,
|
||||
target_port: int,
|
||||
protocol: Literal["tcp", "udp"] = "tcp"
|
||||
) -> bool:
|
||||
"""Allow connection from client through gateway to target endpoint."""
|
||||
try:
|
||||
# Rule 1: Allow client to reach target through gateway
|
||||
self._run_iptables([
|
||||
"-A", self.VPN_CHAIN,
|
||||
"-s", client_vpn_ip,
|
||||
"-d", target_ip,
|
||||
"-p", protocol,
|
||||
"--dport", str(target_port),
|
||||
"-j", "ACCEPT"
|
||||
])
|
||||
|
||||
# Rule 2: Allow return traffic
|
||||
self._run_iptables([
|
||||
"-A", self.VPN_CHAIN,
|
||||
"-s", target_ip,
|
||||
"-d", client_vpn_ip,
|
||||
"-p", protocol,
|
||||
"--sport", str(target_port),
|
||||
"-j", "ACCEPT"
|
||||
])
|
||||
|
||||
# Add NAT/masquerade if needed for routing through gateway
|
||||
self._run_iptables([
|
||||
"-t", "nat",
|
||||
"-A", "POSTROUTING",
|
||||
"-s", client_vpn_ip,
|
||||
"-d", target_ip,
|
||||
"-j", "MASQUERADE"
|
||||
], check=False)
|
||||
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
def revoke_connection(
|
||||
self,
|
||||
client_vpn_ip: str,
|
||||
gateway_vpn_ip: str,
|
||||
target_ip: str,
|
||||
target_port: int,
|
||||
protocol: Literal["tcp", "udp"] = "tcp"
|
||||
) -> bool:
|
||||
"""Remove firewall rules for a connection."""
|
||||
try:
|
||||
# Remove forward rules
|
||||
self._run_iptables([
|
||||
"-D", self.VPN_CHAIN,
|
||||
"-s", client_vpn_ip,
|
||||
"-d", target_ip,
|
||||
"-p", protocol,
|
||||
"--dport", str(target_port),
|
||||
"-j", "ACCEPT"
|
||||
], check=False)
|
||||
|
||||
self._run_iptables([
|
||||
"-D", self.VPN_CHAIN,
|
||||
"-s", target_ip,
|
||||
"-d", client_vpn_ip,
|
||||
"-p", protocol,
|
||||
"--sport", str(target_port),
|
||||
"-j", "ACCEPT"
|
||||
], check=False)
|
||||
|
||||
# Remove NAT rule
|
||||
self._run_iptables([
|
||||
"-t", "nat",
|
||||
"-D", "POSTROUTING",
|
||||
"-s", client_vpn_ip,
|
||||
"-d", target_ip,
|
||||
"-j", "MASQUERADE"
|
||||
], check=False)
|
||||
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
def list_rules(self) -> list[str]:
|
||||
"""List all rules in our VPN chain."""
|
||||
result = self._run_iptables(["-L", self.VPN_CHAIN, "-n", "-v"], check=False)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip().split('\n')
|
||||
return []
|
||||
|
||||
def flush_rules(self) -> bool:
|
||||
"""Remove all rules from our VPN chain."""
|
||||
try:
|
||||
self._run_iptables(["-F", self.VPN_CHAIN])
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
@@ -0,0 +1,282 @@
|
||||
"""VPN Profile management service."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models.vpn_profile import VPNProfile, VPNProfileStatus
|
||||
from ..models.vpn_server import VPNServer
|
||||
from ..models.gateway import Gateway
|
||||
from ..models.certificate_authority import CertificateAuthority
|
||||
from .certificate_service import CertificateService
|
||||
from ..config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class VPNProfileService:
|
||||
"""Service for managing VPN profiles for gateways."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.cert_service = CertificateService(db)
|
||||
|
||||
def create_profile(
|
||||
self,
|
||||
gateway_id: int,
|
||||
vpn_server_id: int,
|
||||
name: str,
|
||||
priority: int = 1,
|
||||
description: Optional[str] = None
|
||||
) -> VPNProfile:
|
||||
"""Create a new VPN profile for a gateway."""
|
||||
|
||||
# Get gateway
|
||||
gateway = self.db.query(Gateway).filter(
|
||||
Gateway.id == gateway_id
|
||||
).first()
|
||||
|
||||
if not gateway:
|
||||
raise ValueError(f"Gateway with id {gateway_id} not found")
|
||||
|
||||
# Get VPN server
|
||||
server = self.db.query(VPNServer).filter(
|
||||
VPNServer.id == vpn_server_id
|
||||
).first()
|
||||
|
||||
if not server:
|
||||
raise ValueError(f"VPN Server with id {vpn_server_id} not found")
|
||||
|
||||
if not server.certificate_authority.is_ready:
|
||||
raise ValueError("CA is not ready (DH parameters may still be generating)")
|
||||
|
||||
# Generate unique common name
|
||||
base_cn = f"gw-{gateway.name.lower().replace(' ', '-')}"
|
||||
profile_count = self.db.query(VPNProfile).filter(
|
||||
VPNProfile.gateway_id == gateway_id
|
||||
).count()
|
||||
|
||||
cert_cn = f"{base_cn}-{profile_count + 1}"
|
||||
|
||||
# Create profile
|
||||
profile = VPNProfile(
|
||||
gateway_id=gateway_id,
|
||||
vpn_server_id=vpn_server_id,
|
||||
ca_id=server.ca_id,
|
||||
name=name,
|
||||
description=description,
|
||||
cert_cn=cert_cn,
|
||||
priority=priority,
|
||||
status=VPNProfileStatus.PENDING
|
||||
)
|
||||
|
||||
self.db.add(profile)
|
||||
self.db.commit()
|
||||
self.db.refresh(profile)
|
||||
|
||||
# Generate client certificate
|
||||
self._generate_client_cert(profile)
|
||||
|
||||
return profile
|
||||
|
||||
def _generate_client_cert(self, profile: VPNProfile):
|
||||
"""Generate client certificate for profile."""
|
||||
ca = profile.certificate_authority
|
||||
|
||||
cert_data = self.cert_service.generate_client_certificate(
|
||||
ca=ca,
|
||||
common_name=profile.cert_cn
|
||||
)
|
||||
|
||||
profile.client_cert = cert_data["cert"]
|
||||
profile.client_key = cert_data["key"]
|
||||
profile.valid_from = cert_data["valid_from"]
|
||||
profile.valid_until = cert_data["valid_until"]
|
||||
profile.status = VPNProfileStatus.ACTIVE
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def generate_client_config(self, profile: VPNProfile) -> str:
|
||||
"""Generate OpenVPN client configuration (.ovpn) for a profile."""
|
||||
|
||||
if not profile.is_ready:
|
||||
raise ValueError("Profile is not ready for provisioning")
|
||||
|
||||
server = profile.vpn_server
|
||||
ca = profile.certificate_authority
|
||||
|
||||
config_lines = [
|
||||
"# OpenVPN Client Configuration",
|
||||
f"# Profile: {profile.name}",
|
||||
f"# Gateway: {profile.gateway.name}",
|
||||
f"# Server: {server.name}",
|
||||
f"# Generated: {datetime.utcnow().isoformat()}",
|
||||
"",
|
||||
"client",
|
||||
"dev tun",
|
||||
f"proto {server.protocol.value}",
|
||||
f"remote {server.hostname} {server.port}",
|
||||
"",
|
||||
"resolv-retry infinite",
|
||||
"nobind",
|
||||
"persist-key",
|
||||
"persist-tun",
|
||||
"",
|
||||
"remote-cert-tls server",
|
||||
f"cipher {server.cipher.value}",
|
||||
f"auth {server.auth.value}",
|
||||
"",
|
||||
"verb 3",
|
||||
"",
|
||||
]
|
||||
|
||||
# Add CA certificate
|
||||
config_lines.extend([
|
||||
"<ca>",
|
||||
ca.ca_cert.strip(),
|
||||
"</ca>",
|
||||
"",
|
||||
])
|
||||
|
||||
# Add client certificate
|
||||
config_lines.extend([
|
||||
"<cert>",
|
||||
profile.client_cert.strip(),
|
||||
"</cert>",
|
||||
"",
|
||||
])
|
||||
|
||||
# Add client key
|
||||
config_lines.extend([
|
||||
"<key>",
|
||||
profile.client_key.strip(),
|
||||
"</key>",
|
||||
"",
|
||||
])
|
||||
|
||||
# Add TLS-Auth key if enabled
|
||||
if server.tls_auth_enabled and server.ta_key:
|
||||
config_lines.extend([
|
||||
"key-direction 1",
|
||||
"<tls-auth>",
|
||||
server.ta_key.strip(),
|
||||
"</tls-auth>",
|
||||
])
|
||||
|
||||
return "\n".join(config_lines)
|
||||
|
||||
def provision_profile(self, profile: VPNProfile) -> str:
|
||||
"""Mark profile as provisioned and return config."""
|
||||
config = self.generate_client_config(profile)
|
||||
|
||||
profile.status = VPNProfileStatus.PROVISIONED
|
||||
profile.provisioned_at = datetime.utcnow()
|
||||
profile.gateway.is_provisioned = True
|
||||
|
||||
self.db.commit()
|
||||
|
||||
return config
|
||||
|
||||
def set_priority(self, profile: VPNProfile, new_priority: int):
|
||||
"""Set profile priority and reorder others if needed."""
|
||||
gateway_id = profile.gateway_id
|
||||
|
||||
# Get all profiles for this gateway ordered by priority
|
||||
profiles = self.db.query(VPNProfile).filter(
|
||||
VPNProfile.gateway_id == gateway_id,
|
||||
VPNProfile.id != profile.id
|
||||
).order_by(VPNProfile.priority).all()
|
||||
|
||||
# Update priorities
|
||||
current_priority = 1
|
||||
for p in profiles:
|
||||
if current_priority == new_priority:
|
||||
current_priority += 1
|
||||
p.priority = current_priority
|
||||
current_priority += 1
|
||||
|
||||
profile.priority = new_priority
|
||||
self.db.commit()
|
||||
|
||||
def get_profiles_for_gateway(self, gateway_id: int) -> list[VPNProfile]:
|
||||
"""Get all VPN profiles for a gateway, ordered by priority."""
|
||||
return self.db.query(VPNProfile).filter(
|
||||
VPNProfile.gateway_id == gateway_id
|
||||
).order_by(VPNProfile.priority).all()
|
||||
|
||||
def get_active_profiles_for_gateway(self, gateway_id: int) -> list[VPNProfile]:
|
||||
"""Get active VPN profiles for a gateway."""
|
||||
return self.db.query(VPNProfile).filter(
|
||||
VPNProfile.gateway_id == gateway_id,
|
||||
VPNProfile.is_active == True,
|
||||
VPNProfile.status.in_([VPNProfileStatus.ACTIVE, VPNProfileStatus.PROVISIONED])
|
||||
).order_by(VPNProfile.priority).all()
|
||||
|
||||
def revoke_profile(self, profile: VPNProfile, reason: str = "unspecified"):
|
||||
"""Revoke a VPN profile's certificate."""
|
||||
if profile.client_cert:
|
||||
self.cert_service.revoke_certificate(
|
||||
ca=profile.certificate_authority,
|
||||
cert_pem=profile.client_cert,
|
||||
reason=reason
|
||||
)
|
||||
|
||||
profile.status = VPNProfileStatus.REVOKED
|
||||
profile.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
def renew_profile(self, profile: VPNProfile):
|
||||
"""Renew the certificate for a profile."""
|
||||
# Revoke old certificate
|
||||
if profile.client_cert:
|
||||
self.cert_service.revoke_certificate(
|
||||
ca=profile.certificate_authority,
|
||||
cert_pem=profile.client_cert,
|
||||
reason="superseded"
|
||||
)
|
||||
|
||||
# Generate new certificate
|
||||
self._generate_client_cert(profile)
|
||||
|
||||
# Reset provisioned status
|
||||
profile.status = VPNProfileStatus.ACTIVE
|
||||
profile.provisioned_at = None
|
||||
self.db.commit()
|
||||
|
||||
def get_profile_by_id(self, profile_id: int) -> Optional[VPNProfile]:
|
||||
"""Get a VPN profile by ID."""
|
||||
return self.db.query(VPNProfile).filter(
|
||||
VPNProfile.id == profile_id
|
||||
).first()
|
||||
|
||||
def delete_profile(self, profile: VPNProfile):
|
||||
"""Delete a VPN profile."""
|
||||
# Revoke certificate first
|
||||
if profile.client_cert and profile.status != VPNProfileStatus.REVOKED:
|
||||
self.cert_service.revoke_certificate(
|
||||
ca=profile.certificate_authority,
|
||||
cert_pem=profile.client_cert,
|
||||
reason="cessationOfOperation"
|
||||
)
|
||||
|
||||
self.db.delete(profile)
|
||||
self.db.commit()
|
||||
|
||||
def generate_all_configs_for_gateway(self, gateway_id: int) -> list[tuple[str, str]]:
|
||||
"""Generate configs for all active profiles of a gateway.
|
||||
|
||||
Returns list of tuples: (filename, config_content)
|
||||
"""
|
||||
profiles = self.get_active_profiles_for_gateway(gateway_id)
|
||||
configs = []
|
||||
|
||||
for profile in profiles:
|
||||
try:
|
||||
config = self.generate_client_config(profile)
|
||||
filename = f"{profile.gateway.name}-{profile.name}.ovpn"
|
||||
filename = filename.lower().replace(' ', '-')
|
||||
configs.append((filename, config))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return configs
|
||||
@@ -0,0 +1,344 @@
|
||||
"""VPN Server management service."""
|
||||
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models.vpn_server import VPNServer, VPNServerStatus, VPNProtocol, VPNCipher, VPNAuth, VPNCompression
|
||||
from ..models.certificate_authority import CertificateAuthority
|
||||
from .certificate_service import CertificateService
|
||||
from ..config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class VPNServerService:
|
||||
"""Service for managing VPN server instances."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.cert_service = CertificateService(db)
|
||||
|
||||
def create_server(
|
||||
self,
|
||||
name: str,
|
||||
hostname: str,
|
||||
ca_id: int,
|
||||
port: int = 1194,
|
||||
protocol: VPNProtocol = VPNProtocol.UDP,
|
||||
vpn_network: str = "10.8.0.0",
|
||||
vpn_netmask: str = "255.255.255.0",
|
||||
cipher: VPNCipher = VPNCipher.AES_256_GCM,
|
||||
auth: VPNAuth = VPNAuth.SHA256,
|
||||
tls_version_min: str = "1.2",
|
||||
compression: VPNCompression = VPNCompression.NONE,
|
||||
max_clients: int = 100,
|
||||
keepalive_interval: int = 10,
|
||||
keepalive_timeout: int = 60,
|
||||
management_port: int = 7505,
|
||||
is_primary: bool = False,
|
||||
tenant_id: Optional[int] = None,
|
||||
description: Optional[str] = None
|
||||
) -> VPNServer:
|
||||
"""Create a new VPN server instance."""
|
||||
|
||||
# Get CA
|
||||
ca = self.db.query(CertificateAuthority).filter(
|
||||
CertificateAuthority.id == ca_id
|
||||
).first()
|
||||
|
||||
if not ca:
|
||||
raise ValueError(f"CA with id {ca_id} not found")
|
||||
|
||||
if not ca.is_ready:
|
||||
raise ValueError("CA is not ready (DH parameters may still be generating)")
|
||||
|
||||
# Generate container name
|
||||
existing_count = self.db.query(VPNServer).count()
|
||||
container_name = f"mguard-openvpn-{existing_count + 1}"
|
||||
|
||||
# If setting as primary, unset other primaries
|
||||
if is_primary:
|
||||
self.db.query(VPNServer).filter(
|
||||
VPNServer.tenant_id == tenant_id,
|
||||
VPNServer.is_primary == True
|
||||
).update({"is_primary": False})
|
||||
|
||||
# Create server record
|
||||
server = VPNServer(
|
||||
name=name,
|
||||
description=description,
|
||||
hostname=hostname,
|
||||
port=port,
|
||||
protocol=protocol,
|
||||
vpn_network=vpn_network,
|
||||
vpn_netmask=vpn_netmask,
|
||||
cipher=cipher,
|
||||
auth=auth,
|
||||
tls_version_min=tls_version_min,
|
||||
compression=compression,
|
||||
max_clients=max_clients,
|
||||
keepalive_interval=keepalive_interval,
|
||||
keepalive_timeout=keepalive_timeout,
|
||||
management_port=management_port,
|
||||
docker_container_name=container_name,
|
||||
ca_id=ca_id,
|
||||
tenant_id=tenant_id,
|
||||
is_primary=is_primary,
|
||||
status=VPNServerStatus.PENDING
|
||||
)
|
||||
|
||||
self.db.add(server)
|
||||
self.db.commit()
|
||||
self.db.refresh(server)
|
||||
|
||||
# Generate server certificate
|
||||
self._generate_server_cert(server)
|
||||
|
||||
return server
|
||||
|
||||
def _generate_server_cert(self, server: VPNServer):
|
||||
"""Generate server certificate and TA key."""
|
||||
ca = server.certificate_authority
|
||||
|
||||
# Generate server certificate
|
||||
cert_data = self.cert_service.generate_server_certificate(
|
||||
ca=ca,
|
||||
common_name=f"{server.name}-server"
|
||||
)
|
||||
|
||||
server.server_cert = cert_data["cert"]
|
||||
server.server_key = cert_data["key"]
|
||||
|
||||
# Generate TLS-Auth key
|
||||
server.ta_key = self.cert_service.generate_ta_key()
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def generate_server_config(self, server: VPNServer) -> str:
|
||||
"""Generate OpenVPN server configuration file."""
|
||||
if not server.certificate_authority:
|
||||
raise ValueError("Server has no CA assigned")
|
||||
|
||||
config_lines = [
|
||||
"# OpenVPN Server Configuration",
|
||||
f"# Server: {server.name}",
|
||||
f"# Generated: {datetime.utcnow().isoformat()}",
|
||||
"",
|
||||
"# Basic settings",
|
||||
f"port {server.port}",
|
||||
f"proto {server.protocol.value}",
|
||||
"dev tun",
|
||||
"",
|
||||
"# Certificates",
|
||||
"ca /etc/openvpn/ca.crt",
|
||||
"cert /etc/openvpn/server.crt",
|
||||
"key /etc/openvpn/server.key",
|
||||
"dh /etc/openvpn/dh.pem",
|
||||
"crl-verify /etc/openvpn/crl.pem",
|
||||
"",
|
||||
"# TLS Auth",
|
||||
]
|
||||
|
||||
if server.tls_auth_enabled and server.ta_key:
|
||||
config_lines.append("tls-auth /etc/openvpn/ta.key 0")
|
||||
|
||||
config_lines.extend([
|
||||
f"tls-version-min {server.tls_version_min}",
|
||||
"",
|
||||
"# Network",
|
||||
f"server {server.vpn_network} {server.vpn_netmask}",
|
||||
"topology subnet",
|
||||
"",
|
||||
"# Routing",
|
||||
'push "redirect-gateway def1 bypass-dhcp"',
|
||||
'push "dhcp-option DNS 8.8.8.8"',
|
||||
'push "dhcp-option DNS 8.8.4.4"',
|
||||
"",
|
||||
"# Security",
|
||||
f"cipher {server.cipher.value}",
|
||||
f"auth {server.auth.value}",
|
||||
"",
|
||||
"# Performance",
|
||||
f"keepalive {server.keepalive_interval} {server.keepalive_timeout}",
|
||||
f"max-clients {server.max_clients}",
|
||||
])
|
||||
|
||||
# Compression handling
|
||||
# Note: OpenVPN 2.5+ deprecates compression due to VORACLE attack
|
||||
if server.compression != VPNCompression.NONE:
|
||||
config_lines.append(f"compress {server.compression.value}")
|
||||
config_lines.append("allow-compression yes")
|
||||
# If compression is NONE, don't add any directive - OpenVPN defaults to no compression
|
||||
|
||||
config_lines.extend([
|
||||
"",
|
||||
"# Persistence",
|
||||
"persist-key",
|
||||
"persist-tun",
|
||||
"",
|
||||
"# Logging",
|
||||
f"status /var/log/openvpn/status-{server.id}.log",
|
||||
f"log-append /var/log/openvpn/server-{server.id}.log",
|
||||
"verb 3",
|
||||
"",
|
||||
"# Management interface",
|
||||
f"management 0.0.0.0 {server.management_port}",
|
||||
"",
|
||||
"# User/Group (for Linux)",
|
||||
"user nobody",
|
||||
"group nogroup",
|
||||
"",
|
||||
"# Client config directory",
|
||||
"client-config-dir /etc/openvpn/ccd",
|
||||
"",
|
||||
"# Scripts",
|
||||
"script-security 2",
|
||||
"client-connect /etc/openvpn/scripts/client-connect.sh",
|
||||
"client-disconnect /etc/openvpn/scripts/client-disconnect.sh",
|
||||
])
|
||||
|
||||
return "\n".join(config_lines)
|
||||
|
||||
def get_connected_clients(self, server: VPNServer) -> list[dict]:
|
||||
"""Get list of currently connected VPN clients."""
|
||||
try:
|
||||
response = self._send_management_command(server, "status")
|
||||
clients = []
|
||||
|
||||
if "ERROR" in response:
|
||||
return clients
|
||||
|
||||
# Parse status output
|
||||
in_client_list = False
|
||||
for line in response.split('\n'):
|
||||
if line.startswith("ROUTING TABLE"):
|
||||
in_client_list = False
|
||||
elif line.startswith("Common Name"):
|
||||
in_client_list = True
|
||||
continue
|
||||
elif in_client_list and ',' in line:
|
||||
parts = line.split(',')
|
||||
if len(parts) >= 5:
|
||||
clients.append({
|
||||
"common_name": parts[0],
|
||||
"real_address": parts[1],
|
||||
"bytes_received": int(parts[2]) if parts[2].isdigit() else 0,
|
||||
"bytes_sent": int(parts[3]) if parts[3].isdigit() else 0,
|
||||
"connected_since": parts[4]
|
||||
})
|
||||
|
||||
# Update connected count
|
||||
server.connected_clients = len(clients)
|
||||
server.last_status_check = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
return clients
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def disconnect_client(self, server: VPNServer, common_name: str) -> bool:
|
||||
"""Disconnect a specific VPN client."""
|
||||
response = self._send_management_command(server, f"kill {common_name}")
|
||||
return "SUCCESS" in response
|
||||
|
||||
def get_server_status(self, server: VPNServer) -> dict:
|
||||
"""Get detailed server status."""
|
||||
try:
|
||||
# Try to connect to management interface
|
||||
response = self._send_management_command(server, "state")
|
||||
|
||||
if "ERROR" in response:
|
||||
server.status = VPNServerStatus.ERROR
|
||||
self.db.commit()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": response
|
||||
}
|
||||
|
||||
# Parse state
|
||||
for line in response.split('\n'):
|
||||
if 'CONNECTED' in line:
|
||||
server.status = VPNServerStatus.RUNNING
|
||||
break
|
||||
else:
|
||||
server.status = VPNServerStatus.STOPPED
|
||||
|
||||
server.last_status_check = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
clients = self.get_connected_clients(server)
|
||||
|
||||
return {
|
||||
"status": server.status.value,
|
||||
"connected_clients": len(clients),
|
||||
"clients": clients,
|
||||
"last_check": server.last_status_check.isoformat() if server.last_status_check else None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unknown",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
def _send_management_command(self, server: VPNServer, command: str) -> str:
|
||||
"""Send command to OpenVPN management interface."""
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(5)
|
||||
|
||||
# Connect to management interface
|
||||
# OpenVPN runs in host network mode, so we connect via host.docker.internal
|
||||
# This resolves to the Docker host from within containers
|
||||
host = "host.docker.internal"
|
||||
sock.connect((host, server.management_port))
|
||||
|
||||
# Read welcome message
|
||||
sock.recv(1024)
|
||||
|
||||
# Send command
|
||||
sock.send(f"{command}\n".encode())
|
||||
|
||||
# Read response
|
||||
response = b""
|
||||
while True:
|
||||
data = sock.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
response += data
|
||||
if b"END" in data or b"SUCCESS" in data or b"ERROR" in data:
|
||||
break
|
||||
|
||||
sock.close()
|
||||
return response.decode()
|
||||
except Exception as e:
|
||||
return f"ERROR: {str(e)}"
|
||||
|
||||
def update_all_server_status(self):
|
||||
"""Update status for all active servers."""
|
||||
servers = self.db.query(VPNServer).filter(
|
||||
VPNServer.is_active == True
|
||||
).all()
|
||||
|
||||
for server in servers:
|
||||
self.get_server_status(server)
|
||||
|
||||
def get_server_by_id(self, server_id: int) -> Optional[VPNServer]:
|
||||
"""Get a VPN server by ID."""
|
||||
return self.db.query(VPNServer).filter(
|
||||
VPNServer.id == server_id
|
||||
).first()
|
||||
|
||||
def get_servers_for_tenant(self, tenant_id: Optional[int] = None) -> list[VPNServer]:
|
||||
"""Get all VPN servers for a tenant."""
|
||||
query = self.db.query(VPNServer)
|
||||
|
||||
if tenant_id:
|
||||
query = query.filter(
|
||||
(VPNServer.tenant_id == tenant_id) | (VPNServer.tenant_id == None)
|
||||
)
|
||||
else:
|
||||
query = query.filter(VPNServer.tenant_id == None)
|
||||
|
||||
return query.order_by(VPNServer.name).all()
|
||||
@@ -0,0 +1,95 @@
|
||||
"""VPN management service for OpenVPN operations.
|
||||
|
||||
This service provides basic OpenVPN management interface communication.
|
||||
For PKI/certificate management, use CertificateService.
|
||||
For VPN server management, use VPNServerService.
|
||||
"""
|
||||
|
||||
import socket
|
||||
from ..config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class VPNService:
|
||||
"""Service for OpenVPN management interface operations."""
|
||||
|
||||
def __init__(self, host: str = None, port: int = None):
|
||||
self.management_host = host or settings.openvpn_management_host
|
||||
self.management_port = port or settings.openvpn_management_port
|
||||
|
||||
def _send_management_command(self, command: str) -> str:
|
||||
"""Send command to OpenVPN management interface."""
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(5)
|
||||
sock.connect((self.management_host, self.management_port))
|
||||
|
||||
# Read welcome message
|
||||
sock.recv(1024)
|
||||
|
||||
# Send command
|
||||
sock.send(f"{command}\n".encode())
|
||||
|
||||
# Read response
|
||||
response = b""
|
||||
while True:
|
||||
data = sock.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
response += data
|
||||
if b"END" in data or b"SUCCESS" in data or b"ERROR" in data:
|
||||
break
|
||||
|
||||
sock.close()
|
||||
return response.decode()
|
||||
except Exception as e:
|
||||
return f"ERROR: {str(e)}"
|
||||
|
||||
def get_connected_clients(self) -> list[dict]:
|
||||
"""Get list of currently connected VPN clients."""
|
||||
response = self._send_management_command("status")
|
||||
clients = []
|
||||
|
||||
if "ERROR" in response:
|
||||
return clients
|
||||
|
||||
# Parse status output
|
||||
in_client_list = False
|
||||
for line in response.split('\n'):
|
||||
if line.startswith("ROUTING TABLE"):
|
||||
in_client_list = False
|
||||
elif line.startswith("Common Name"):
|
||||
in_client_list = True
|
||||
continue
|
||||
elif in_client_list and ',' in line:
|
||||
parts = line.split(',')
|
||||
if len(parts) >= 5:
|
||||
clients.append({
|
||||
"common_name": parts[0],
|
||||
"real_address": parts[1],
|
||||
"bytes_received": int(parts[2]) if parts[2].isdigit() else 0,
|
||||
"bytes_sent": int(parts[3]) if parts[3].isdigit() else 0,
|
||||
"connected_since": parts[4]
|
||||
})
|
||||
|
||||
return clients
|
||||
|
||||
def disconnect_client(self, common_name: str) -> bool:
|
||||
"""Disconnect a specific VPN client."""
|
||||
response = self._send_management_command(f"kill {common_name}")
|
||||
return "SUCCESS" in response
|
||||
|
||||
def get_server_status(self) -> dict:
|
||||
"""Get OpenVPN server status."""
|
||||
response = self._send_management_command("state")
|
||||
|
||||
if "ERROR" in response:
|
||||
return {"status": "error", "message": response}
|
||||
|
||||
# Parse state
|
||||
for line in response.split('\n'):
|
||||
if 'CONNECTED' in line:
|
||||
return {"status": "running", "message": "Server is running"}
|
||||
|
||||
return {"status": "unknown", "message": response}
|
||||
@@ -0,0 +1,182 @@
|
||||
"""VPN Sync Service for synchronizing VPN connection state with database."""
|
||||
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from ..models.vpn_server import VPNServer
|
||||
from ..models.vpn_profile import VPNProfile
|
||||
from ..models.vpn_connection_log import VPNConnectionLog
|
||||
from ..models.gateway import Gateway
|
||||
from .vpn_server_service import VPNServerService
|
||||
|
||||
|
||||
class VPNSyncService:
|
||||
"""Service to sync VPN connections with gateway status and logs."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.vpn_service = VPNServerService(db)
|
||||
|
||||
def sync_all_connections(self) -> dict:
|
||||
"""
|
||||
Sync all VPN connections across all servers.
|
||||
Returns summary of changes.
|
||||
"""
|
||||
result = {
|
||||
"servers_checked": 0,
|
||||
"clients_found": 0,
|
||||
"gateways_online": 0,
|
||||
"gateways_offline": 0,
|
||||
"new_connections": 0,
|
||||
"closed_connections": 0,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# Get all active VPN servers
|
||||
servers = self.db.query(VPNServer).filter(VPNServer.is_active == True).all()
|
||||
|
||||
# Collect all currently connected CNs
|
||||
connected_cns = set()
|
||||
|
||||
for server in servers:
|
||||
result["servers_checked"] += 1
|
||||
try:
|
||||
clients = self.vpn_service.get_connected_clients(server)
|
||||
result["clients_found"] += len(clients)
|
||||
|
||||
for client in clients:
|
||||
cn = client.get("common_name")
|
||||
if cn:
|
||||
connected_cns.add(cn)
|
||||
self._handle_connected_client(server, client, result)
|
||||
|
||||
except Exception as e:
|
||||
result["errors"].append(f"Server {server.name}: {str(e)}")
|
||||
|
||||
# Mark disconnected profiles/gateways
|
||||
self._handle_disconnected_profiles(connected_cns, result)
|
||||
|
||||
self.db.commit()
|
||||
return result
|
||||
|
||||
def _handle_connected_client(self, server: VPNServer, client: dict, result: dict):
|
||||
"""Handle a connected VPN client - update profile, gateway, and log."""
|
||||
cn = client.get("common_name")
|
||||
real_address = client.get("real_address")
|
||||
bytes_rx = client.get("bytes_received", 0)
|
||||
bytes_tx = client.get("bytes_sent", 0)
|
||||
connected_since = client.get("connected_since")
|
||||
|
||||
# Find profile by CN
|
||||
profile = self.db.query(VPNProfile).filter(
|
||||
VPNProfile.cert_cn == cn,
|
||||
VPNProfile.vpn_server_id == server.id
|
||||
).first()
|
||||
|
||||
if not profile:
|
||||
# Try finding by CN across all servers (in case of migration)
|
||||
profile = self.db.query(VPNProfile).filter(
|
||||
VPNProfile.cert_cn == cn
|
||||
).first()
|
||||
|
||||
if not profile:
|
||||
return # Unknown client, skip
|
||||
|
||||
gateway = profile.gateway
|
||||
|
||||
# Update gateway status
|
||||
if not gateway.is_online:
|
||||
gateway.is_online = True
|
||||
gateway.last_seen = datetime.utcnow()
|
||||
result["gateways_online"] += 1
|
||||
|
||||
# Update profile last connection
|
||||
profile.last_connection = datetime.utcnow()
|
||||
|
||||
# Check for existing active connection log
|
||||
active_log = self.db.query(VPNConnectionLog).filter(
|
||||
VPNConnectionLog.vpn_profile_id == profile.id,
|
||||
VPNConnectionLog.disconnected_at.is_(None)
|
||||
).first()
|
||||
|
||||
if not active_log:
|
||||
# Create new connection log
|
||||
log = VPNConnectionLog(
|
||||
vpn_profile_id=profile.id,
|
||||
vpn_server_id=server.id,
|
||||
gateway_id=gateway.id,
|
||||
common_name=cn,
|
||||
real_address=real_address,
|
||||
connected_at=datetime.utcnow(),
|
||||
bytes_received=bytes_rx,
|
||||
bytes_sent=bytes_tx
|
||||
)
|
||||
self.db.add(log)
|
||||
result["new_connections"] += 1
|
||||
else:
|
||||
# Update traffic stats
|
||||
active_log.bytes_received = bytes_rx
|
||||
active_log.bytes_sent = bytes_tx
|
||||
active_log.real_address = real_address
|
||||
|
||||
def _handle_disconnected_profiles(self, connected_cns: set, result: dict):
|
||||
"""Mark profiles as disconnected if their CN is no longer connected."""
|
||||
# Find all active connection logs
|
||||
active_logs = self.db.query(VPNConnectionLog).filter(
|
||||
VPNConnectionLog.disconnected_at.is_(None)
|
||||
).all()
|
||||
|
||||
for log in active_logs:
|
||||
if log.common_name not in connected_cns:
|
||||
# Mark as disconnected
|
||||
log.disconnected_at = datetime.utcnow()
|
||||
result["closed_connections"] += 1
|
||||
|
||||
# Update gateway status
|
||||
gateway = log.gateway
|
||||
|
||||
# Check if gateway has any other active connections
|
||||
other_active = self.db.query(VPNConnectionLog).filter(
|
||||
VPNConnectionLog.gateway_id == gateway.id,
|
||||
VPNConnectionLog.id != log.id,
|
||||
VPNConnectionLog.disconnected_at.is_(None)
|
||||
).first()
|
||||
|
||||
if not other_active:
|
||||
gateway.is_online = False
|
||||
result["gateways_offline"] += 1
|
||||
|
||||
def get_profile_connection_logs(
|
||||
self,
|
||||
profile_id: int,
|
||||
limit: int = 50
|
||||
) -> list[VPNConnectionLog]:
|
||||
"""Get connection logs for a specific profile."""
|
||||
return self.db.query(VPNConnectionLog).filter(
|
||||
VPNConnectionLog.vpn_profile_id == profile_id
|
||||
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
|
||||
|
||||
def get_gateway_connection_logs(
|
||||
self,
|
||||
gateway_id: int,
|
||||
limit: int = 50
|
||||
) -> list[VPNConnectionLog]:
|
||||
"""Get connection logs for a gateway (all profiles)."""
|
||||
return self.db.query(VPNConnectionLog).filter(
|
||||
VPNConnectionLog.gateway_id == gateway_id
|
||||
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
|
||||
|
||||
def get_server_connection_logs(
|
||||
self,
|
||||
server_id: int,
|
||||
limit: int = 50
|
||||
) -> list[VPNConnectionLog]:
|
||||
"""Get connection logs for a VPN server."""
|
||||
return self.db.query(VPNConnectionLog).filter(
|
||||
VPNConnectionLog.vpn_server_id == server_id
|
||||
).order_by(VPNConnectionLog.connected_at.desc()).limit(limit).all()
|
||||
|
||||
def get_active_connections(self) -> list[VPNConnectionLog]:
|
||||
"""Get all currently active VPN connections."""
|
||||
return self.db.query(VPNConnectionLog).filter(
|
||||
VPNConnectionLog.disconnected_at.is_(None)
|
||||
).order_by(VPNConnectionLog.connected_at.desc()).all()
|
||||
Reference in New Issue
Block a user