"""Certificate management service for PKI operations.""" import os import subprocess import threading from datetime import datetime, timedelta from pathlib import Path from typing import Optional from cryptography import x509 from cryptography.x509.oid import NameOID, ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.backends import default_backend from sqlalchemy.orm import Session from ..models.certificate_authority import CertificateAuthority, CAStatus from ..models.vpn_profile import VPNProfile, VPNProfileStatus from ..config import get_settings settings = get_settings() class CertificateService: """Service for managing certificates and CAs.""" def __init__(self, db: Session): self.db = db def create_ca( self, name: str, key_size: int = 4096, validity_days: int = 3650, organization: str = "mGuard VPN", country: str = "DE", state: str = "NRW", city: str = "Dortmund", tenant_id: Optional[int] = None, created_by_id: Optional[int] = None, is_default: bool = False ) -> CertificateAuthority: """Create a new Certificate Authority.""" # Generate CA private key private_key = rsa.generate_private_key( public_exponent=65537, key_size=key_size, backend=default_backend() ) # Build CA certificate subject = issuer = x509.Name([ x509.NameAttribute(NameOID.COUNTRY_NAME, country), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state), x509.NameAttribute(NameOID.LOCALITY_NAME, city), x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization), x509.NameAttribute(NameOID.COMMON_NAME, f"{name} CA"), ]) valid_from = datetime.utcnow() valid_until = valid_from + timedelta(days=validity_days) ca_cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(issuer) .public_key(private_key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(valid_from) .not_valid_after(valid_until) .add_extension( x509.BasicConstraints(ca=True, path_length=0), critical=True, ) .add_extension( x509.KeyUsage( digital_signature=True, key_cert_sign=True, crl_sign=True, key_encipherment=False, content_commitment=False, data_encipherment=False, key_agreement=False, encipher_only=False, decipher_only=False, ), critical=True, ) .add_extension( x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()), critical=False, ) .sign(private_key, hashes.SHA256(), default_backend()) ) # Serialize to PEM ca_cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8') ca_key_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption() ).decode('utf-8') # If this is default, unset other defaults for this tenant if is_default: self.db.query(CertificateAuthority).filter( CertificateAuthority.tenant_id == tenant_id, CertificateAuthority.is_default == True ).update({"is_default": False}) # Create CA record ca = CertificateAuthority( name=name, tenant_id=tenant_id, ca_cert=ca_cert_pem, ca_key=ca_key_pem, key_size=key_size, valid_from=valid_from, valid_until=valid_until, is_default=is_default, status=CAStatus.PENDING, # Will be ACTIVE after DH generation created_by_id=created_by_id, dh_generating=True ) self.db.add(ca) self.db.commit() self.db.refresh(ca) # Start DH parameter generation in background self._generate_dh_async(ca.id, key_size) return ca def _generate_dh_async(self, ca_id: int, key_size: int): """Generate DH parameters in background thread.""" def generate(): try: # Use openssl for DH generation (faster than pure Python) result = subprocess.run( ["openssl", "dhparam", "-out", "-", str(key_size)], capture_output=True, text=True, timeout=3600 # 1 hour max ) if result.returncode == 0: dh_pem = result.stdout # Update CA in database (new session needed for thread) from ..database import SessionLocal db = SessionLocal() try: ca = db.query(CertificateAuthority).filter( CertificateAuthority.id == ca_id ).first() if ca: ca.dh_params = dh_pem ca.dh_generating = False ca.status = CAStatus.ACTIVE db.commit() finally: db.close() except Exception as e: # Log error and mark CA as failed from ..database import SessionLocal db = SessionLocal() try: ca = db.query(CertificateAuthority).filter( CertificateAuthority.id == ca_id ).first() if ca: ca.dh_generating = False ca.description = f"DH generation failed: {str(e)}" db.commit() finally: db.close() thread = threading.Thread(target=generate, daemon=True) thread.start() def import_ca( self, name: str, ca_cert_pem: str, ca_key_pem: str, dh_params_pem: Optional[str] = None, tenant_id: Optional[int] = None, created_by_id: Optional[int] = None ) -> CertificateAuthority: """Import an existing CA from PEM files.""" # Parse certificate to extract metadata ca_cert = x509.load_pem_x509_certificate( ca_cert_pem.encode('utf-8'), default_backend() ) # Validate it's a CA certificate try: basic_constraints = ca_cert.extensions.get_extension_for_oid( ExtensionOID.BASIC_CONSTRAINTS ) if not basic_constraints.value.ca: raise ValueError("Certificate is not a CA certificate") except x509.extensions.ExtensionNotFound: raise ValueError("Certificate missing BasicConstraints extension") # Get key size from private key private_key = serialization.load_pem_private_key( ca_key_pem.encode('utf-8'), password=None, backend=default_backend() ) key_size = private_key.key_size ca = CertificateAuthority( name=name, tenant_id=tenant_id, ca_cert=ca_cert_pem, ca_key=ca_key_pem, dh_params=dh_params_pem, key_size=key_size, valid_from=ca_cert.not_valid_before_utc, valid_until=ca_cert.not_valid_after_utc, status=CAStatus.ACTIVE if dh_params_pem else CAStatus.PENDING, created_by_id=created_by_id, dh_generating=dh_params_pem is None ) self.db.add(ca) self.db.commit() self.db.refresh(ca) # Generate DH if not provided if not dh_params_pem: self._generate_dh_async(ca.id, key_size) return ca def generate_server_certificate( self, ca: CertificateAuthority, common_name: str, validity_days: int = 825 ) -> dict: """Generate a server certificate from the CA.""" if not ca.is_ready: raise ValueError("CA is not ready for issuing certificates") # Load CA key and cert ca_key = serialization.load_pem_private_key( ca.ca_key.encode('utf-8'), password=None, backend=default_backend() ) ca_cert = x509.load_pem_x509_certificate( ca.ca_cert.encode('utf-8'), default_backend() ) # Generate server private key server_key = rsa.generate_private_key( public_exponent=65537, key_size=ca.key_size, backend=default_backend() ) # Build server certificate valid_from = datetime.utcnow() valid_until = valid_from + timedelta(days=validity_days) subject = x509.Name([ x509.NameAttribute(NameOID.COMMON_NAME, common_name), ]) server_cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(ca_cert.subject) .public_key(server_key.public_key()) .serial_number(ca.next_serial) .not_valid_before(valid_from) .not_valid_after(valid_until) .add_extension( x509.BasicConstraints(ca=False, path_length=None), critical=True, ) .add_extension( x509.KeyUsage( digital_signature=True, key_encipherment=True, key_cert_sign=False, crl_sign=False, content_commitment=False, data_encipherment=False, key_agreement=False, encipher_only=False, decipher_only=False, ), critical=True, ) .add_extension( x509.ExtendedKeyUsage([ x509.oid.ExtendedKeyUsageOID.SERVER_AUTH, ]), critical=False, ) .add_extension( x509.SubjectAlternativeName([ x509.DNSName(common_name), ]), critical=False, ) .sign(ca_key, hashes.SHA256(), default_backend()) ) # Increment serial ca.next_serial += 1 self.db.commit() # Serialize to PEM cert_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8') key_pem = server_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption() ).decode('utf-8') return { "cert": cert_pem, "key": key_pem, "valid_from": valid_from, "valid_until": valid_until, "serial": ca.next_serial - 1 } def generate_client_certificate( self, ca: CertificateAuthority, common_name: str, validity_days: int = 365 ) -> dict: """Generate a client certificate from the CA.""" if not ca.is_ready: raise ValueError("CA is not ready for issuing certificates") # Load CA key and cert ca_key = serialization.load_pem_private_key( ca.ca_key.encode('utf-8'), password=None, backend=default_backend() ) ca_cert = x509.load_pem_x509_certificate( ca.ca_cert.encode('utf-8'), default_backend() ) # Generate client private key client_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, # Client keys can be smaller backend=default_backend() ) # Build client certificate valid_from = datetime.utcnow() valid_until = valid_from + timedelta(days=validity_days) subject = x509.Name([ x509.NameAttribute(NameOID.COMMON_NAME, common_name), ]) client_cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(ca_cert.subject) .public_key(client_key.public_key()) .serial_number(ca.next_serial) .not_valid_before(valid_from) .not_valid_after(valid_until) .add_extension( x509.BasicConstraints(ca=False, path_length=None), critical=True, ) .add_extension( x509.KeyUsage( digital_signature=True, key_encipherment=True, key_cert_sign=False, crl_sign=False, content_commitment=False, data_encipherment=False, key_agreement=False, encipher_only=False, decipher_only=False, ), critical=True, ) .add_extension( x509.ExtendedKeyUsage([ x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH, ]), critical=False, ) .sign(ca_key, hashes.SHA256(), default_backend()) ) # Increment serial ca.next_serial += 1 self.db.commit() # Serialize to PEM cert_pem = client_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8') key_pem = client_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption() ).decode('utf-8') return { "cert": cert_pem, "key": key_pem, "valid_from": valid_from, "valid_until": valid_until, "serial": ca.next_serial - 1 } def generate_ta_key(self) -> str: """Generate TLS-Auth key using OpenVPN.""" try: result = subprocess.run( ["openvpn", "--genkey", "secret", "/dev/stdout"], capture_output=True, text=True, timeout=30 ) if result.returncode == 0: return result.stdout else: raise RuntimeError(f"Failed to generate TA key: {result.stderr}") except FileNotFoundError: # OpenVPN not installed, generate a simple key import secrets key_data = secrets.token_hex(256) return f"""# # 2048 bit OpenVPN static key # -----BEGIN OpenVPN Static key V1----- {key_data[:64]} {key_data[64:128]} {key_data[128:192]} {key_data[192:256]} {key_data[256:320]} {key_data[320:384]} {key_data[384:448]} {key_data[448:512]} -----END OpenVPN Static key V1----- """ def revoke_certificate( self, ca: CertificateAuthority, cert_pem: str, reason: str = "unspecified" ) -> bool: """Revoke a certificate and update CRL.""" # Parse certificate to get serial number cert = x509.load_pem_x509_certificate( cert_pem.encode('utf-8'), default_backend() ) # Load CA key ca_key = serialization.load_pem_private_key( ca.ca_key.encode('utf-8'), password=None, backend=default_backend() ) # Load existing CRL or create new one revoked_certs = [] if ca.crl: try: existing_crl = x509.load_pem_x509_crl( ca.crl.encode('utf-8'), default_backend() ) for revoked in existing_crl: revoked_certs.append(revoked) except Exception: pass # Add new revoked certificate revoked_cert = ( x509.RevokedCertificateBuilder() .serial_number(cert.serial_number) .revocation_date(datetime.utcnow()) .build() ) revoked_certs.append(revoked_cert) # Load CA cert ca_cert = x509.load_pem_x509_certificate( ca.ca_cert.encode('utf-8'), default_backend() ) # Build new CRL crl_builder = ( x509.CertificateRevocationListBuilder() .issuer_name(ca_cert.subject) .last_update(datetime.utcnow()) .next_update(datetime.utcnow() + timedelta(days=180)) ) for revoked in revoked_certs: crl_builder = crl_builder.add_revoked_certificate(revoked) crl = crl_builder.sign(ca_key, hashes.SHA256(), default_backend()) # Update CA ca.crl = crl.public_bytes(serialization.Encoding.PEM).decode('utf-8') ca.crl_updated_at = datetime.utcnow() self.db.commit() return True def get_crl(self, ca: CertificateAuthority) -> str: """Get the CRL for a CA, generating if needed.""" if not ca.crl: # Generate empty CRL ca_key = serialization.load_pem_private_key( ca.ca_key.encode('utf-8'), password=None, backend=default_backend() ) ca_cert = x509.load_pem_x509_certificate( ca.ca_cert.encode('utf-8'), default_backend() ) crl = ( x509.CertificateRevocationListBuilder() .issuer_name(ca_cert.subject) .last_update(datetime.utcnow()) .next_update(datetime.utcnow() + timedelta(days=180)) .sign(ca_key, hashes.SHA256(), default_backend()) ) ca.crl = crl.public_bytes(serialization.Encoding.PEM).decode('utf-8') ca.crl_updated_at = datetime.utcnow() self.db.commit() return ca.crl def get_expiring_certificates( self, ca: CertificateAuthority, days: int = 30 ) -> list[VPNProfile]: """Get profiles with certificates expiring within given days.""" threshold = datetime.utcnow() + timedelta(days=days) return self.db.query(VPNProfile).filter( VPNProfile.ca_id == ca.id, VPNProfile.valid_until <= threshold, VPNProfile.status == VPNProfileStatus.ACTIVE ).all() def get_default_ca(self, tenant_id: Optional[int] = None) -> Optional[CertificateAuthority]: """Get the default CA for a tenant (or global).""" ca = self.db.query(CertificateAuthority).filter( CertificateAuthority.tenant_id == tenant_id, CertificateAuthority.is_default == True, CertificateAuthority.status == CAStatus.ACTIVE ).first() if not ca and tenant_id: # Fall back to global CA ca = self.db.query(CertificateAuthority).filter( CertificateAuthority.tenant_id == None, CertificateAuthority.is_default == True, CertificateAuthority.status == CAStatus.ACTIVE ).first() return ca