589 lines
19 KiB
Python
589 lines
19 KiB
Python
"""Certificate management service for PKI operations."""
|
|
|
|
import os
|
|
import subprocess
|
|
import threading
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from cryptography import x509
|
|
from cryptography.x509.oid import NameOID, ExtensionOID
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from cryptography.hazmat.backends import default_backend
|
|
from sqlalchemy.orm import Session
|
|
|
|
from ..models.certificate_authority import CertificateAuthority, CAStatus
|
|
from ..models.vpn_profile import VPNProfile, VPNProfileStatus
|
|
from ..config import get_settings
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
class CertificateService:
|
|
"""Service for managing certificates and CAs."""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def create_ca(
|
|
self,
|
|
name: str,
|
|
key_size: int = 4096,
|
|
validity_days: int = 3650,
|
|
organization: str = "mGuard VPN",
|
|
country: str = "DE",
|
|
state: str = "NRW",
|
|
city: str = "Dortmund",
|
|
tenant_id: Optional[int] = None,
|
|
created_by_id: Optional[int] = None,
|
|
is_default: bool = False
|
|
) -> CertificateAuthority:
|
|
"""Create a new Certificate Authority."""
|
|
|
|
# Generate CA private key
|
|
private_key = rsa.generate_private_key(
|
|
public_exponent=65537,
|
|
key_size=key_size,
|
|
backend=default_backend()
|
|
)
|
|
|
|
# Build CA certificate
|
|
subject = issuer = x509.Name([
|
|
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
|
|
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state),
|
|
x509.NameAttribute(NameOID.LOCALITY_NAME, city),
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization),
|
|
x509.NameAttribute(NameOID.COMMON_NAME, f"{name} CA"),
|
|
])
|
|
|
|
valid_from = datetime.utcnow()
|
|
valid_until = valid_from + timedelta(days=validity_days)
|
|
|
|
ca_cert = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(issuer)
|
|
.public_key(private_key.public_key())
|
|
.serial_number(x509.random_serial_number())
|
|
.not_valid_before(valid_from)
|
|
.not_valid_after(valid_until)
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=True, path_length=0),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_cert_sign=True,
|
|
crl_sign=True,
|
|
key_encipherment=False,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()),
|
|
critical=False,
|
|
)
|
|
.sign(private_key, hashes.SHA256(), default_backend())
|
|
)
|
|
|
|
# Serialize to PEM
|
|
ca_cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
|
|
ca_key_pem = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
encryption_algorithm=serialization.NoEncryption()
|
|
).decode('utf-8')
|
|
|
|
# If this is default, unset other defaults for this tenant
|
|
if is_default:
|
|
self.db.query(CertificateAuthority).filter(
|
|
CertificateAuthority.tenant_id == tenant_id,
|
|
CertificateAuthority.is_default == True
|
|
).update({"is_default": False})
|
|
|
|
# Create CA record
|
|
ca = CertificateAuthority(
|
|
name=name,
|
|
tenant_id=tenant_id,
|
|
ca_cert=ca_cert_pem,
|
|
ca_key=ca_key_pem,
|
|
key_size=key_size,
|
|
valid_from=valid_from,
|
|
valid_until=valid_until,
|
|
is_default=is_default,
|
|
status=CAStatus.PENDING, # Will be ACTIVE after DH generation
|
|
created_by_id=created_by_id,
|
|
dh_generating=True
|
|
)
|
|
|
|
self.db.add(ca)
|
|
self.db.commit()
|
|
self.db.refresh(ca)
|
|
|
|
# Start DH parameter generation in background
|
|
self._generate_dh_async(ca.id, key_size)
|
|
|
|
return ca
|
|
|
|
def _generate_dh_async(self, ca_id: int, key_size: int):
|
|
"""Generate DH parameters in background thread."""
|
|
def generate():
|
|
try:
|
|
# Use openssl for DH generation (faster than pure Python)
|
|
result = subprocess.run(
|
|
["openssl", "dhparam", "-out", "-", str(key_size)],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=3600 # 1 hour max
|
|
)
|
|
|
|
if result.returncode == 0:
|
|
dh_pem = result.stdout
|
|
# Update CA in database (new session needed for thread)
|
|
from ..database import SessionLocal
|
|
db = SessionLocal()
|
|
try:
|
|
ca = db.query(CertificateAuthority).filter(
|
|
CertificateAuthority.id == ca_id
|
|
).first()
|
|
if ca:
|
|
ca.dh_params = dh_pem
|
|
ca.dh_generating = False
|
|
ca.status = CAStatus.ACTIVE
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
# Log error and mark CA as failed
|
|
from ..database import SessionLocal
|
|
db = SessionLocal()
|
|
try:
|
|
ca = db.query(CertificateAuthority).filter(
|
|
CertificateAuthority.id == ca_id
|
|
).first()
|
|
if ca:
|
|
ca.dh_generating = False
|
|
ca.description = f"DH generation failed: {str(e)}"
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
thread = threading.Thread(target=generate, daemon=True)
|
|
thread.start()
|
|
|
|
def import_ca(
|
|
self,
|
|
name: str,
|
|
ca_cert_pem: str,
|
|
ca_key_pem: str,
|
|
dh_params_pem: Optional[str] = None,
|
|
tenant_id: Optional[int] = None,
|
|
created_by_id: Optional[int] = None
|
|
) -> CertificateAuthority:
|
|
"""Import an existing CA from PEM files."""
|
|
|
|
# Parse certificate to extract metadata
|
|
ca_cert = x509.load_pem_x509_certificate(
|
|
ca_cert_pem.encode('utf-8'),
|
|
default_backend()
|
|
)
|
|
|
|
# Validate it's a CA certificate
|
|
try:
|
|
basic_constraints = ca_cert.extensions.get_extension_for_oid(
|
|
ExtensionOID.BASIC_CONSTRAINTS
|
|
)
|
|
if not basic_constraints.value.ca:
|
|
raise ValueError("Certificate is not a CA certificate")
|
|
except x509.extensions.ExtensionNotFound:
|
|
raise ValueError("Certificate missing BasicConstraints extension")
|
|
|
|
# Get key size from private key
|
|
private_key = serialization.load_pem_private_key(
|
|
ca_key_pem.encode('utf-8'),
|
|
password=None,
|
|
backend=default_backend()
|
|
)
|
|
key_size = private_key.key_size
|
|
|
|
ca = CertificateAuthority(
|
|
name=name,
|
|
tenant_id=tenant_id,
|
|
ca_cert=ca_cert_pem,
|
|
ca_key=ca_key_pem,
|
|
dh_params=dh_params_pem,
|
|
key_size=key_size,
|
|
valid_from=ca_cert.not_valid_before_utc,
|
|
valid_until=ca_cert.not_valid_after_utc,
|
|
status=CAStatus.ACTIVE if dh_params_pem else CAStatus.PENDING,
|
|
created_by_id=created_by_id,
|
|
dh_generating=dh_params_pem is None
|
|
)
|
|
|
|
self.db.add(ca)
|
|
self.db.commit()
|
|
self.db.refresh(ca)
|
|
|
|
# Generate DH if not provided
|
|
if not dh_params_pem:
|
|
self._generate_dh_async(ca.id, key_size)
|
|
|
|
return ca
|
|
|
|
def generate_server_certificate(
|
|
self,
|
|
ca: CertificateAuthority,
|
|
common_name: str,
|
|
validity_days: int = 825
|
|
) -> dict:
|
|
"""Generate a server certificate from the CA."""
|
|
|
|
if not ca.is_ready:
|
|
raise ValueError("CA is not ready for issuing certificates")
|
|
|
|
# Load CA key and cert
|
|
ca_key = serialization.load_pem_private_key(
|
|
ca.ca_key.encode('utf-8'),
|
|
password=None,
|
|
backend=default_backend()
|
|
)
|
|
ca_cert = x509.load_pem_x509_certificate(
|
|
ca.ca_cert.encode('utf-8'),
|
|
default_backend()
|
|
)
|
|
|
|
# Generate server private key
|
|
server_key = rsa.generate_private_key(
|
|
public_exponent=65537,
|
|
key_size=ca.key_size,
|
|
backend=default_backend()
|
|
)
|
|
|
|
# Build server certificate
|
|
valid_from = datetime.utcnow()
|
|
valid_until = valid_from + timedelta(days=validity_days)
|
|
|
|
subject = x509.Name([
|
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
|
])
|
|
|
|
server_cert = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(ca_cert.subject)
|
|
.public_key(server_key.public_key())
|
|
.serial_number(ca.next_serial)
|
|
.not_valid_before(valid_from)
|
|
.not_valid_after(valid_until)
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=False, path_length=None),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_encipherment=True,
|
|
key_cert_sign=False,
|
|
crl_sign=False,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.ExtendedKeyUsage([
|
|
x509.oid.ExtendedKeyUsageOID.SERVER_AUTH,
|
|
]),
|
|
critical=False,
|
|
)
|
|
.add_extension(
|
|
x509.SubjectAlternativeName([
|
|
x509.DNSName(common_name),
|
|
]),
|
|
critical=False,
|
|
)
|
|
.sign(ca_key, hashes.SHA256(), default_backend())
|
|
)
|
|
|
|
# Increment serial
|
|
ca.next_serial += 1
|
|
self.db.commit()
|
|
|
|
# Serialize to PEM
|
|
cert_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
|
|
key_pem = server_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
encryption_algorithm=serialization.NoEncryption()
|
|
).decode('utf-8')
|
|
|
|
return {
|
|
"cert": cert_pem,
|
|
"key": key_pem,
|
|
"valid_from": valid_from,
|
|
"valid_until": valid_until,
|
|
"serial": ca.next_serial - 1
|
|
}
|
|
|
|
def generate_client_certificate(
|
|
self,
|
|
ca: CertificateAuthority,
|
|
common_name: str,
|
|
validity_days: int = 365
|
|
) -> dict:
|
|
"""Generate a client certificate from the CA."""
|
|
|
|
if not ca.is_ready:
|
|
raise ValueError("CA is not ready for issuing certificates")
|
|
|
|
# Load CA key and cert
|
|
ca_key = serialization.load_pem_private_key(
|
|
ca.ca_key.encode('utf-8'),
|
|
password=None,
|
|
backend=default_backend()
|
|
)
|
|
ca_cert = x509.load_pem_x509_certificate(
|
|
ca.ca_cert.encode('utf-8'),
|
|
default_backend()
|
|
)
|
|
|
|
# Generate client private key
|
|
client_key = rsa.generate_private_key(
|
|
public_exponent=65537,
|
|
key_size=2048, # Client keys can be smaller
|
|
backend=default_backend()
|
|
)
|
|
|
|
# Build client certificate
|
|
valid_from = datetime.utcnow()
|
|
valid_until = valid_from + timedelta(days=validity_days)
|
|
|
|
subject = x509.Name([
|
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
|
])
|
|
|
|
client_cert = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(ca_cert.subject)
|
|
.public_key(client_key.public_key())
|
|
.serial_number(ca.next_serial)
|
|
.not_valid_before(valid_from)
|
|
.not_valid_after(valid_until)
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=False, path_length=None),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_encipherment=True,
|
|
key_cert_sign=False,
|
|
crl_sign=False,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.ExtendedKeyUsage([
|
|
x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH,
|
|
]),
|
|
critical=False,
|
|
)
|
|
.sign(ca_key, hashes.SHA256(), default_backend())
|
|
)
|
|
|
|
# Increment serial
|
|
ca.next_serial += 1
|
|
self.db.commit()
|
|
|
|
# Serialize to PEM
|
|
cert_pem = client_cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
|
|
key_pem = client_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
encryption_algorithm=serialization.NoEncryption()
|
|
).decode('utf-8')
|
|
|
|
return {
|
|
"cert": cert_pem,
|
|
"key": key_pem,
|
|
"valid_from": valid_from,
|
|
"valid_until": valid_until,
|
|
"serial": ca.next_serial - 1
|
|
}
|
|
|
|
def generate_ta_key(self) -> str:
|
|
"""Generate TLS-Auth key using OpenVPN."""
|
|
try:
|
|
result = subprocess.run(
|
|
["openvpn", "--genkey", "secret", "/dev/stdout"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30
|
|
)
|
|
if result.returncode == 0:
|
|
return result.stdout
|
|
else:
|
|
raise RuntimeError(f"Failed to generate TA key: {result.stderr}")
|
|
except FileNotFoundError:
|
|
# OpenVPN not installed, generate a simple key
|
|
import secrets
|
|
key_data = secrets.token_hex(256)
|
|
return f"""#
|
|
# 2048 bit OpenVPN static key
|
|
#
|
|
-----BEGIN OpenVPN Static key V1-----
|
|
{key_data[:64]}
|
|
{key_data[64:128]}
|
|
{key_data[128:192]}
|
|
{key_data[192:256]}
|
|
{key_data[256:320]}
|
|
{key_data[320:384]}
|
|
{key_data[384:448]}
|
|
{key_data[448:512]}
|
|
-----END OpenVPN Static key V1-----
|
|
"""
|
|
|
|
def revoke_certificate(
|
|
self,
|
|
ca: CertificateAuthority,
|
|
cert_pem: str,
|
|
reason: str = "unspecified"
|
|
) -> bool:
|
|
"""Revoke a certificate and update CRL."""
|
|
# Parse certificate to get serial number
|
|
cert = x509.load_pem_x509_certificate(
|
|
cert_pem.encode('utf-8'),
|
|
default_backend()
|
|
)
|
|
|
|
# Load CA key
|
|
ca_key = serialization.load_pem_private_key(
|
|
ca.ca_key.encode('utf-8'),
|
|
password=None,
|
|
backend=default_backend()
|
|
)
|
|
|
|
# Load existing CRL or create new one
|
|
revoked_certs = []
|
|
if ca.crl:
|
|
try:
|
|
existing_crl = x509.load_pem_x509_crl(
|
|
ca.crl.encode('utf-8'),
|
|
default_backend()
|
|
)
|
|
for revoked in existing_crl:
|
|
revoked_certs.append(revoked)
|
|
except Exception:
|
|
pass
|
|
|
|
# Add new revoked certificate
|
|
revoked_cert = (
|
|
x509.RevokedCertificateBuilder()
|
|
.serial_number(cert.serial_number)
|
|
.revocation_date(datetime.utcnow())
|
|
.build()
|
|
)
|
|
revoked_certs.append(revoked_cert)
|
|
|
|
# Load CA cert
|
|
ca_cert = x509.load_pem_x509_certificate(
|
|
ca.ca_cert.encode('utf-8'),
|
|
default_backend()
|
|
)
|
|
|
|
# Build new CRL
|
|
crl_builder = (
|
|
x509.CertificateRevocationListBuilder()
|
|
.issuer_name(ca_cert.subject)
|
|
.last_update(datetime.utcnow())
|
|
.next_update(datetime.utcnow() + timedelta(days=180))
|
|
)
|
|
|
|
for revoked in revoked_certs:
|
|
crl_builder = crl_builder.add_revoked_certificate(revoked)
|
|
|
|
crl = crl_builder.sign(ca_key, hashes.SHA256(), default_backend())
|
|
|
|
# Update CA
|
|
ca.crl = crl.public_bytes(serialization.Encoding.PEM).decode('utf-8')
|
|
ca.crl_updated_at = datetime.utcnow()
|
|
self.db.commit()
|
|
|
|
return True
|
|
|
|
def get_crl(self, ca: CertificateAuthority) -> str:
|
|
"""Get the CRL for a CA, generating if needed."""
|
|
if not ca.crl:
|
|
# Generate empty CRL
|
|
ca_key = serialization.load_pem_private_key(
|
|
ca.ca_key.encode('utf-8'),
|
|
password=None,
|
|
backend=default_backend()
|
|
)
|
|
ca_cert = x509.load_pem_x509_certificate(
|
|
ca.ca_cert.encode('utf-8'),
|
|
default_backend()
|
|
)
|
|
|
|
crl = (
|
|
x509.CertificateRevocationListBuilder()
|
|
.issuer_name(ca_cert.subject)
|
|
.last_update(datetime.utcnow())
|
|
.next_update(datetime.utcnow() + timedelta(days=180))
|
|
.sign(ca_key, hashes.SHA256(), default_backend())
|
|
)
|
|
|
|
ca.crl = crl.public_bytes(serialization.Encoding.PEM).decode('utf-8')
|
|
ca.crl_updated_at = datetime.utcnow()
|
|
self.db.commit()
|
|
|
|
return ca.crl
|
|
|
|
def get_expiring_certificates(
|
|
self,
|
|
ca: CertificateAuthority,
|
|
days: int = 30
|
|
) -> list[VPNProfile]:
|
|
"""Get profiles with certificates expiring within given days."""
|
|
threshold = datetime.utcnow() + timedelta(days=days)
|
|
|
|
return self.db.query(VPNProfile).filter(
|
|
VPNProfile.ca_id == ca.id,
|
|
VPNProfile.valid_until <= threshold,
|
|
VPNProfile.status == VPNProfileStatus.ACTIVE
|
|
).all()
|
|
|
|
def get_default_ca(self, tenant_id: Optional[int] = None) -> Optional[CertificateAuthority]:
|
|
"""Get the default CA for a tenant (or global)."""
|
|
ca = self.db.query(CertificateAuthority).filter(
|
|
CertificateAuthority.tenant_id == tenant_id,
|
|
CertificateAuthority.is_default == True,
|
|
CertificateAuthority.status == CAStatus.ACTIVE
|
|
).first()
|
|
|
|
if not ca and tenant_id:
|
|
# Fall back to global CA
|
|
ca = self.db.query(CertificateAuthority).filter(
|
|
CertificateAuthority.tenant_id == None,
|
|
CertificateAuthority.is_default == True,
|
|
CertificateAuthority.status == CAStatus.ACTIVE
|
|
).first()
|
|
|
|
return ca
|