readme refreshed
This commit is contained in:
@@ -11,6 +11,7 @@ from ..models.user import User, UserRole
|
||||
from ..models.access import UserGatewayAccess, ConnectionLog
|
||||
from ..services.vpn_service import VPNService
|
||||
from ..services.firewall_service import FirewallService
|
||||
from ..services.client_vpn_profile_service import ClientVPNProfileService
|
||||
from .deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
@@ -92,9 +93,15 @@ def connect_to_endpoint(
|
||||
detail="Endpoint not found"
|
||||
)
|
||||
|
||||
# NOTE: Dynamic VPN config generation has been replaced by VPN profiles.
|
||||
# Gateways should have pre-provisioned VPN profiles.
|
||||
# This endpoint now just logs the connection intent.
|
||||
# Get or create VPN profile for user
|
||||
client_profile_service = ClientVPNProfileService(db)
|
||||
vpn_config = client_profile_service.get_vpn_config_for_user(current_user)
|
||||
|
||||
if not vpn_config:
|
||||
return ConnectResponse(
|
||||
success=False,
|
||||
message="No VPN server available. Please contact administrator."
|
||||
)
|
||||
|
||||
# Log connection
|
||||
connection = ConnectionLog(
|
||||
@@ -107,10 +114,14 @@ def connect_to_endpoint(
|
||||
db.commit()
|
||||
db.refresh(connection)
|
||||
|
||||
# Set up firewall rules for this connection
|
||||
# The client's VPN IP will be determined after VPN connects
|
||||
# For now, we'll configure firewall rules when the client-connect script runs
|
||||
|
||||
return ConnectResponse(
|
||||
success=True,
|
||||
message="Connection logged. Use the gateway's VPN profile configuration to connect.",
|
||||
vpn_config=None, # VPN config is now obtained through gateway VPN profiles
|
||||
message="VPN configuration ready. Connect to access endpoint.",
|
||||
vpn_config=vpn_config,
|
||||
target_ip=endpoint.internal_ip,
|
||||
target_port=endpoint.port,
|
||||
connection_id=connection.id
|
||||
|
||||
@@ -9,12 +9,34 @@ from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..database import get_db
|
||||
from ..models.vpn_server import VPNServer, VPNServerStatus
|
||||
from ..models.vpn_profile import VPNProfile
|
||||
from ..models.client_vpn_profile import ClientVPNProfile
|
||||
from ..models.gateway import Gateway
|
||||
from ..models.access import ConnectionLog
|
||||
from ..models.endpoint import Endpoint
|
||||
from ..services.vpn_server_service import VPNServerService
|
||||
from ..services.vpn_sync_service import VPNSyncService
|
||||
from ..services.certificate_service import CertificateService
|
||||
from ..services.firewall_service import FirewallService
|
||||
|
||||
|
||||
class ClientConnectedRequest(BaseModel):
|
||||
"""Request when VPN client connects."""
|
||||
common_name: str
|
||||
real_ip: str
|
||||
vpn_ip: str
|
||||
|
||||
|
||||
class ClientDisconnectedRequest(BaseModel):
|
||||
"""Request when VPN client disconnects."""
|
||||
common_name: str
|
||||
bytes_received: int = 0
|
||||
bytes_sent: int = 0
|
||||
duration: int = 0
|
||||
|
||||
# Log directory (shared volume with OpenVPN container)
|
||||
LOG_DIR = Path("/var/log/openvpn")
|
||||
@@ -361,3 +383,203 @@ async def debug_sync(db: Session = Depends(get_db)):
|
||||
"connected_clients": connected_clients,
|
||||
"sync_result": sync_result
|
||||
}
|
||||
|
||||
|
||||
@router.get("/vpn-servers/{server_id}/ccd")
|
||||
async def get_ccd_files(
|
||||
server_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get all CCD (Client Config Directory) files for a VPN server.
|
||||
|
||||
CCD files contain iroute directives that tell OpenVPN which client
|
||||
handles traffic for which subnet.
|
||||
"""
|
||||
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
||||
|
||||
if not server:
|
||||
raise HTTPException(status_code=404, detail="Server not found")
|
||||
|
||||
service = VPNServerService(db)
|
||||
ccd_files = service.get_all_ccd_files(server)
|
||||
|
||||
return {
|
||||
"server_id": server_id,
|
||||
"count": len(ccd_files),
|
||||
"files": ccd_files
|
||||
}
|
||||
|
||||
|
||||
@router.get("/vpn-servers/{server_id}/ccd/{common_name}")
|
||||
async def get_ccd_file(
|
||||
server_id: int,
|
||||
common_name: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get a single CCD file for a specific client."""
|
||||
from ..models.vpn_profile import VPNProfile
|
||||
|
||||
profile = db.query(VPNProfile).filter(
|
||||
VPNProfile.vpn_server_id == server_id,
|
||||
VPNProfile.cert_cn == common_name
|
||||
).first()
|
||||
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
service = VPNServerService(db)
|
||||
content = service.generate_ccd_file(profile)
|
||||
|
||||
if not content:
|
||||
raise HTTPException(status_code=404, detail="No CCD content for this profile")
|
||||
|
||||
return Response(content=content, media_type="text/plain")
|
||||
|
||||
|
||||
@router.post("/vpn-servers/{server_id}/client-connected")
|
||||
async def client_connected(
|
||||
server_id: int,
|
||||
request: ClientConnectedRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Handle VPN client connection event.
|
||||
|
||||
Called by OpenVPN client-connect script when a client connects.
|
||||
Sets up firewall rules for pending connections.
|
||||
"""
|
||||
# Determine if this is a gateway or a user client
|
||||
cn = request.common_name
|
||||
vpn_ip = request.vpn_ip
|
||||
|
||||
# Check if it's a gateway profile
|
||||
gateway_profile = db.query(VPNProfile).filter(
|
||||
VPNProfile.cert_cn == cn
|
||||
).first()
|
||||
|
||||
if gateway_profile:
|
||||
# Update gateway status
|
||||
gateway = gateway_profile.gateway
|
||||
gateway.is_online = True
|
||||
gateway.vpn_ip = vpn_ip
|
||||
gateway_profile.vpn_ip = vpn_ip
|
||||
gateway_profile.last_connection = db.func.now()
|
||||
db.commit()
|
||||
return {"status": "ok", "type": "gateway", "gateway_id": gateway.id}
|
||||
|
||||
# Check if it's a client profile (user/technician)
|
||||
client_profile = db.query(ClientVPNProfile).filter(
|
||||
ClientVPNProfile.cert_cn == cn
|
||||
).first()
|
||||
|
||||
if client_profile:
|
||||
# Update client profile with VPN IP
|
||||
client_profile.vpn_ip = vpn_ip
|
||||
client_profile.last_connection = db.func.now()
|
||||
db.commit()
|
||||
|
||||
# Find pending connections for this user and set firewall rules
|
||||
pending_connections = db.query(ConnectionLog).filter(
|
||||
ConnectionLog.user_id == client_profile.user_id,
|
||||
ConnectionLog.disconnected_at.is_(None),
|
||||
ConnectionLog.vpn_ip.is_(None) # Not yet processed
|
||||
).all()
|
||||
|
||||
firewall = FirewallService()
|
||||
rules_created = 0
|
||||
|
||||
for conn in pending_connections:
|
||||
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
|
||||
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
|
||||
|
||||
if endpoint and gateway and gateway.vpn_ip:
|
||||
# Set firewall rule to allow client to reach endpoint via gateway
|
||||
success = firewall.allow_connection(
|
||||
client_vpn_ip=vpn_ip,
|
||||
gateway_vpn_ip=gateway.vpn_ip,
|
||||
target_ip=endpoint.internal_ip,
|
||||
target_port=endpoint.port,
|
||||
protocol=endpoint.protocol.value
|
||||
)
|
||||
|
||||
if success:
|
||||
conn.vpn_ip = vpn_ip
|
||||
rules_created += 1
|
||||
|
||||
db.commit()
|
||||
return {
|
||||
"status": "ok",
|
||||
"type": "client",
|
||||
"user_id": client_profile.user_id,
|
||||
"rules_created": rules_created
|
||||
}
|
||||
|
||||
return {"status": "unknown_client", "common_name": cn}
|
||||
|
||||
|
||||
@router.post("/vpn-servers/{server_id}/client-disconnected")
|
||||
async def client_disconnected(
|
||||
server_id: int,
|
||||
request: ClientDisconnectedRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Handle VPN client disconnection event.
|
||||
|
||||
Called by OpenVPN client-disconnect script when a client disconnects.
|
||||
Cleans up firewall rules.
|
||||
"""
|
||||
cn = request.common_name
|
||||
|
||||
# Check if it's a gateway profile
|
||||
gateway_profile = db.query(VPNProfile).filter(
|
||||
VPNProfile.cert_cn == cn
|
||||
).first()
|
||||
|
||||
if gateway_profile:
|
||||
# Mark gateway as offline
|
||||
gateway = gateway_profile.gateway
|
||||
gateway.is_online = False
|
||||
db.commit()
|
||||
return {"status": "ok", "type": "gateway", "gateway_id": gateway.id}
|
||||
|
||||
# Check if it's a client profile
|
||||
client_profile = db.query(ClientVPNProfile).filter(
|
||||
ClientVPNProfile.cert_cn == cn
|
||||
).first()
|
||||
|
||||
if client_profile:
|
||||
vpn_ip = client_profile.vpn_ip
|
||||
|
||||
# Find active connections for this user and clean up firewall rules
|
||||
if vpn_ip:
|
||||
active_connections = db.query(ConnectionLog).filter(
|
||||
ConnectionLog.user_id == client_profile.user_id,
|
||||
ConnectionLog.disconnected_at.is_(None),
|
||||
ConnectionLog.vpn_ip == vpn_ip
|
||||
).all()
|
||||
|
||||
firewall = FirewallService()
|
||||
|
||||
for conn in active_connections:
|
||||
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
|
||||
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
|
||||
|
||||
if endpoint and gateway:
|
||||
# Remove firewall rule
|
||||
firewall.revoke_connection(
|
||||
client_vpn_ip=vpn_ip,
|
||||
gateway_vpn_ip=gateway.vpn_ip or "",
|
||||
target_ip=endpoint.internal_ip,
|
||||
target_port=endpoint.port,
|
||||
protocol=endpoint.protocol.value
|
||||
)
|
||||
|
||||
# Mark connection as disconnected
|
||||
conn.disconnected_at = db.func.now()
|
||||
|
||||
# Clear VPN IP from profile
|
||||
client_profile.vpn_ip = None
|
||||
db.commit()
|
||||
|
||||
return {"status": "ok", "type": "client", "user_id": client_profile.user_id}
|
||||
|
||||
return {"status": "unknown_client", "common_name": cn}
|
||||
|
||||
@@ -8,6 +8,7 @@ from .access import UserGatewayAccess, UserEndpointAccess, ConnectionLog
|
||||
from .certificate_authority import CertificateAuthority, CAStatus, CAAlgorithm
|
||||
from .vpn_server import VPNServer, VPNProtocol, VPNCipher, VPNAuth, VPNCompression, VPNServerStatus
|
||||
from .vpn_profile import VPNProfile, VPNProfileStatus
|
||||
from .client_vpn_profile import ClientVPNProfile, ClientVPNProfileStatus
|
||||
from .vpn_connection_log import VPNConnectionLog
|
||||
|
||||
__all__ = [
|
||||
@@ -32,5 +33,7 @@ __all__ = [
|
||||
"VPNServerStatus",
|
||||
"VPNProfile",
|
||||
"VPNProfileStatus",
|
||||
"ClientVPNProfile",
|
||||
"ClientVPNProfileStatus",
|
||||
"VPNConnectionLog",
|
||||
]
|
||||
|
||||
@@ -69,6 +69,7 @@ class CertificateAuthority(Base):
|
||||
created_by = relationship("User", foreign_keys=[created_by_id])
|
||||
vpn_servers = relationship("VPNServer", back_populates="certificate_authority")
|
||||
vpn_profiles = relationship("VPNProfile", back_populates="certificate_authority")
|
||||
client_vpn_profiles = relationship("ClientVPNProfile", back_populates="certificate_authority")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CertificateAuthority(id={self.id}, name='{self.name}', status='{self.status}')>"
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Client VPN Profile model for user/technician VPN connections."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum as PyEnum
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Enum, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class ClientVPNProfileStatus(str, PyEnum):
|
||||
"""Client VPN Profile status."""
|
||||
PENDING = "pending" # Certificate being generated
|
||||
ACTIVE = "active" # Ready to use
|
||||
EXPIRED = "expired" # Certificate expired
|
||||
REVOKED = "revoked" # Certificate revoked
|
||||
|
||||
|
||||
class ClientVPNProfile(Base):
|
||||
"""VPN Profile for a user/technician to connect to the VPN server."""
|
||||
|
||||
__tablename__ = "client_vpn_profiles"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
vpn_server_id = Column(Integer, ForeignKey("vpn_servers.id"), nullable=False)
|
||||
ca_id = Column(Integer, ForeignKey("certificate_authorities.id"), nullable=False)
|
||||
|
||||
# Certificate data
|
||||
cert_cn = Column(String(255), nullable=False, unique=True) # Common Name
|
||||
client_cert = Column(Text, nullable=True) # Client certificate PEM
|
||||
client_key = Column(Text, nullable=True) # Client private key PEM
|
||||
|
||||
# Status
|
||||
status = Column(Enum(ClientVPNProfileStatus), default=ClientVPNProfileStatus.PENDING)
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Validity
|
||||
valid_from = Column(DateTime, nullable=True)
|
||||
valid_until = Column(DateTime, nullable=True)
|
||||
|
||||
# VPN IP assigned (updated when connected)
|
||||
vpn_ip = Column(String(15), nullable=True)
|
||||
|
||||
# Tracking
|
||||
last_connection = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="client_vpn_profile")
|
||||
vpn_server = relationship("VPNServer", back_populates="client_vpn_profiles")
|
||||
certificate_authority = relationship("CertificateAuthority", back_populates="client_vpn_profiles")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ClientVPNProfile(id={self.id}, user_id={self.user_id}, cert_cn='{self.cert_cn}')>"
|
||||
|
||||
@property
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if profile is ready for use."""
|
||||
return (
|
||||
self.status == ClientVPNProfileStatus.ACTIVE and
|
||||
self.client_cert is not None and
|
||||
self.client_key is not None
|
||||
)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if certificate is expired."""
|
||||
if self.valid_until:
|
||||
return datetime.utcnow() > self.valid_until
|
||||
return False
|
||||
@@ -47,6 +47,12 @@ class User(Base):
|
||||
primaryjoin="User.id == UserEndpointAccess.user_id"
|
||||
)
|
||||
connection_logs = relationship("ConnectionLog", back_populates="user")
|
||||
client_vpn_profile = relationship(
|
||||
"ClientVPNProfile",
|
||||
back_populates="user",
|
||||
uselist=False, # One-to-one relationship
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, username='{self.username}', role='{self.role}')>"
|
||||
|
||||
@@ -106,6 +106,7 @@ class VPNServer(Base):
|
||||
tenant = relationship("Tenant", back_populates="vpn_servers")
|
||||
certificate_authority = relationship("CertificateAuthority", back_populates="vpn_servers")
|
||||
vpn_profiles = relationship("VPNProfile", back_populates="vpn_server")
|
||||
client_vpn_profiles = relationship("ClientVPNProfile", back_populates="vpn_server")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<VPNServer(id={self.id}, name='{self.name}', {self.hostname}:{self.port}/{self.protocol.value})>"
|
||||
|
||||
@@ -6,6 +6,7 @@ from .firewall_service import FirewallService
|
||||
from .certificate_service import CertificateService
|
||||
from .vpn_server_service import VPNServerService
|
||||
from .vpn_profile_service import VPNProfileService
|
||||
from .client_vpn_profile_service import ClientVPNProfileService
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
@@ -14,4 +15,5 @@ __all__ = [
|
||||
"CertificateService",
|
||||
"VPNServerService",
|
||||
"VPNProfileService",
|
||||
"ClientVPNProfileService",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
"""Client VPN Profile management service for user/technician VPN connections."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models.client_vpn_profile import ClientVPNProfile, ClientVPNProfileStatus
|
||||
from ..models.vpn_server import VPNServer
|
||||
from ..models.user import User
|
||||
from ..models.gateway import Gateway
|
||||
from ..models.access import UserGatewayAccess
|
||||
from ..models.certificate_authority import CertificateAuthority
|
||||
from .certificate_service import CertificateService
|
||||
from ..config import get_settings
|
||||
import ipaddress
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class ClientVPNProfileService:
|
||||
"""Service for managing VPN profiles for users/technicians."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.cert_service = CertificateService(db)
|
||||
|
||||
def get_or_create_profile(self, user: User) -> Optional[ClientVPNProfile]:
|
||||
"""Get existing profile for user or create a new one."""
|
||||
# Check if user already has a profile
|
||||
if user.client_vpn_profile and user.client_vpn_profile.is_ready:
|
||||
return user.client_vpn_profile
|
||||
|
||||
# Find active VPN server (prefer primary, then any active)
|
||||
server = self.db.query(VPNServer).filter(
|
||||
VPNServer.is_active == True,
|
||||
VPNServer.is_primary == True
|
||||
).first()
|
||||
|
||||
if not server:
|
||||
server = self.db.query(VPNServer).filter(
|
||||
VPNServer.is_active == True
|
||||
).first()
|
||||
|
||||
if not server:
|
||||
return None
|
||||
|
||||
if not server.certificate_authority or not server.certificate_authority.is_ready:
|
||||
return None
|
||||
|
||||
# Create new profile
|
||||
return self.create_profile(user, server.id)
|
||||
|
||||
def create_profile(
|
||||
self,
|
||||
user: User,
|
||||
vpn_server_id: int
|
||||
) -> ClientVPNProfile:
|
||||
"""Create a new VPN profile for a user."""
|
||||
|
||||
# Get VPN server
|
||||
server = self.db.query(VPNServer).filter(
|
||||
VPNServer.id == vpn_server_id
|
||||
).first()
|
||||
|
||||
if not server:
|
||||
raise ValueError(f"VPN Server with id {vpn_server_id} not found")
|
||||
|
||||
if not server.certificate_authority.is_ready:
|
||||
raise ValueError("CA is not ready (DH parameters may still be generating)")
|
||||
|
||||
# Generate unique common name
|
||||
cert_cn = f"client-{user.username}-{user.id}"
|
||||
|
||||
# Check if profile already exists
|
||||
existing = self.db.query(ClientVPNProfile).filter(
|
||||
ClientVPNProfile.user_id == user.id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Delete old profile and create new
|
||||
self.db.delete(existing)
|
||||
self.db.commit()
|
||||
|
||||
# Create profile
|
||||
profile = ClientVPNProfile(
|
||||
user_id=user.id,
|
||||
vpn_server_id=vpn_server_id,
|
||||
ca_id=server.ca_id,
|
||||
cert_cn=cert_cn,
|
||||
status=ClientVPNProfileStatus.PENDING
|
||||
)
|
||||
|
||||
self.db.add(profile)
|
||||
self.db.commit()
|
||||
self.db.refresh(profile)
|
||||
|
||||
# Generate client certificate
|
||||
self._generate_client_cert(profile)
|
||||
|
||||
return profile
|
||||
|
||||
def _generate_client_cert(self, profile: ClientVPNProfile):
|
||||
"""Generate client certificate for profile."""
|
||||
ca = profile.certificate_authority
|
||||
|
||||
cert_data = self.cert_service.generate_client_certificate(
|
||||
ca=ca,
|
||||
common_name=profile.cert_cn
|
||||
)
|
||||
|
||||
profile.client_cert = cert_data["cert"]
|
||||
profile.client_key = cert_data["key"]
|
||||
profile.valid_from = cert_data["valid_from"]
|
||||
profile.valid_until = cert_data["valid_until"]
|
||||
profile.status = ClientVPNProfileStatus.ACTIVE
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def generate_client_config(self, profile: ClientVPNProfile, split_tunnel: bool = True) -> str:
|
||||
"""Generate OpenVPN client configuration (.ovpn) for a profile.
|
||||
|
||||
Args:
|
||||
profile: The client VPN profile
|
||||
split_tunnel: If True, only VPN traffic goes through tunnel (default).
|
||||
If False, all traffic goes through VPN (no internet).
|
||||
"""
|
||||
|
||||
if not profile.is_ready:
|
||||
raise ValueError("Profile is not ready")
|
||||
|
||||
server = profile.vpn_server
|
||||
ca = profile.certificate_authority
|
||||
|
||||
config_lines = [
|
||||
"# OpenVPN Client Configuration",
|
||||
f"# User: {profile.user.username}",
|
||||
f"# Server: {server.name}",
|
||||
f"# Generated: {datetime.utcnow().isoformat()}",
|
||||
f"# Split Tunneling: {'Enabled' if split_tunnel else 'Disabled'}",
|
||||
"",
|
||||
"client",
|
||||
"dev tun",
|
||||
f"proto {server.protocol.value}",
|
||||
f"remote {server.hostname} {server.port}",
|
||||
"",
|
||||
"resolv-retry infinite",
|
||||
"nobind",
|
||||
"persist-key",
|
||||
"persist-tun",
|
||||
"",
|
||||
"remote-cert-tls server",
|
||||
f"cipher {server.cipher.value}",
|
||||
f"auth {server.auth.value}",
|
||||
"",
|
||||
"verb 3",
|
||||
"",
|
||||
]
|
||||
|
||||
# Split Tunneling Configuration
|
||||
if split_tunnel:
|
||||
config_lines.extend([
|
||||
"# Split Tunneling - Only VPN traffic through tunnel",
|
||||
"# Ignore any redirect-gateway push from server",
|
||||
"pull-filter ignore \"redirect-gateway\"",
|
||||
"",
|
||||
])
|
||||
|
||||
# Add route for VPN network
|
||||
vpn_network = server.vpn_network
|
||||
vpn_netmask = server.vpn_netmask
|
||||
config_lines.append(f"# Route for VPN network")
|
||||
config_lines.append(f"route {vpn_network} {vpn_netmask}")
|
||||
config_lines.append("")
|
||||
|
||||
# Add routes for gateway subnets the user has access to
|
||||
routes_added = self._get_gateway_routes(profile.user)
|
||||
if routes_added:
|
||||
config_lines.append("# Routes for accessible gateway networks")
|
||||
for route in routes_added:
|
||||
config_lines.append(f"route {route['network']} {route['netmask']} # {route['name']}")
|
||||
config_lines.append("")
|
||||
|
||||
# Add CA certificate
|
||||
config_lines.extend([
|
||||
"<ca>",
|
||||
ca.ca_cert.strip(),
|
||||
"</ca>",
|
||||
"",
|
||||
])
|
||||
|
||||
# Add client certificate
|
||||
config_lines.extend([
|
||||
"<cert>",
|
||||
profile.client_cert.strip(),
|
||||
"</cert>",
|
||||
"",
|
||||
])
|
||||
|
||||
# Add client key
|
||||
config_lines.extend([
|
||||
"<key>",
|
||||
profile.client_key.strip(),
|
||||
"</key>",
|
||||
"",
|
||||
])
|
||||
|
||||
# Add TLS-Auth key if enabled
|
||||
if server.tls_auth_enabled and server.ta_key:
|
||||
config_lines.extend([
|
||||
"key-direction 1",
|
||||
"<tls-auth>",
|
||||
server.ta_key.strip(),
|
||||
"</tls-auth>",
|
||||
])
|
||||
|
||||
return "\n".join(config_lines)
|
||||
|
||||
def _get_gateway_routes(self, user: User) -> list[dict]:
|
||||
"""Get routes for all gateway subnets the user has access to."""
|
||||
routes = []
|
||||
seen_networks = set()
|
||||
|
||||
# Get gateways the user has access to
|
||||
if user.is_admin:
|
||||
# Admins have access to all gateways in their tenant
|
||||
gateways = self.db.query(Gateway).filter(
|
||||
Gateway.tenant_id == user.tenant_id
|
||||
).all()
|
||||
else:
|
||||
# Regular users - check access table
|
||||
access_entries = self.db.query(UserGatewayAccess).filter(
|
||||
UserGatewayAccess.user_id == user.id
|
||||
).all()
|
||||
gateway_ids = [a.gateway_id for a in access_entries]
|
||||
gateways = self.db.query(Gateway).filter(
|
||||
Gateway.id.in_(gateway_ids)
|
||||
).all() if gateway_ids else []
|
||||
|
||||
for gateway in gateways:
|
||||
# Add route for gateway's VPN subnet if defined
|
||||
if gateway.vpn_subnet and gateway.vpn_subnet not in seen_networks:
|
||||
try:
|
||||
network = ipaddress.ip_network(gateway.vpn_subnet, strict=False)
|
||||
routes.append({
|
||||
"network": str(network.network_address),
|
||||
"netmask": str(network.netmask),
|
||||
"name": gateway.name
|
||||
})
|
||||
seen_networks.add(gateway.vpn_subnet)
|
||||
except ValueError:
|
||||
pass # Invalid subnet, skip
|
||||
|
||||
return routes
|
||||
|
||||
def get_vpn_config_for_user(self, user: User) -> Optional[str]:
|
||||
"""Get or create VPN config for a user."""
|
||||
profile = self.get_or_create_profile(user)
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
return self.generate_client_config(profile)
|
||||
|
||||
def revoke_profile(self, profile: ClientVPNProfile, reason: str = "unspecified"):
|
||||
"""Revoke a user's VPN profile certificate."""
|
||||
if profile.client_cert:
|
||||
self.cert_service.revoke_certificate(
|
||||
ca=profile.certificate_authority,
|
||||
cert_pem=profile.client_cert,
|
||||
reason=reason
|
||||
)
|
||||
|
||||
profile.status = ClientVPNProfileStatus.REVOKED
|
||||
profile.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
def renew_profile(self, profile: ClientVPNProfile):
|
||||
"""Renew the certificate for a profile."""
|
||||
# Revoke old certificate
|
||||
if profile.client_cert:
|
||||
self.cert_service.revoke_certificate(
|
||||
ca=profile.certificate_authority,
|
||||
cert_pem=profile.client_cert,
|
||||
reason="superseded"
|
||||
)
|
||||
|
||||
# Generate new certificate
|
||||
self._generate_client_cert(profile)
|
||||
self.db.commit()
|
||||
@@ -7,8 +7,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from ..models.vpn_server import VPNServer, VPNServerStatus, VPNProtocol, VPNCipher, VPNAuth, VPNCompression
|
||||
from ..models.certificate_authority import CertificateAuthority
|
||||
from ..models.vpn_profile import VPNProfile
|
||||
from ..models.gateway import Gateway
|
||||
from .certificate_service import CertificateService
|
||||
from ..config import get_settings
|
||||
import ipaddress
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
@@ -156,6 +159,17 @@ class VPNServerService:
|
||||
'push "dhcp-option DNS 8.8.8.8"',
|
||||
'push "dhcp-option DNS 8.8.4.4"',
|
||||
"",
|
||||
])
|
||||
|
||||
# Routes for gateway subnets (required for iroute to work)
|
||||
gateway_routes = self._get_gateway_subnet_routes(server)
|
||||
if gateway_routes:
|
||||
config_lines.append("# Routes for gateway subnets")
|
||||
for route in gateway_routes:
|
||||
config_lines.append(f"route {route['network']} {route['netmask']} # {route['gateway_name']}")
|
||||
config_lines.append("")
|
||||
|
||||
config_lines.extend([
|
||||
"# Security",
|
||||
f"cipher {server.cipher.value}",
|
||||
f"auth {server.auth.value}",
|
||||
@@ -201,6 +215,81 @@ class VPNServerService:
|
||||
|
||||
return "\n".join(config_lines)
|
||||
|
||||
def _get_gateway_subnet_routes(self, server: VPNServer) -> list[dict]:
|
||||
"""Get all gateway subnets that need routes on this server."""
|
||||
routes = []
|
||||
seen_networks = set()
|
||||
|
||||
# Find all VPN profiles that use this server
|
||||
profiles = self.db.query(VPNProfile).filter(
|
||||
VPNProfile.vpn_server_id == server.id
|
||||
).all()
|
||||
|
||||
for profile in profiles:
|
||||
gateway = profile.gateway
|
||||
if gateway and gateway.vpn_subnet and gateway.vpn_subnet not in seen_networks:
|
||||
try:
|
||||
network = ipaddress.ip_network(gateway.vpn_subnet, strict=False)
|
||||
routes.append({
|
||||
"network": str(network.network_address),
|
||||
"netmask": str(network.netmask),
|
||||
"gateway_name": gateway.name
|
||||
})
|
||||
seen_networks.add(gateway.vpn_subnet)
|
||||
except ValueError:
|
||||
pass # Invalid subnet, skip
|
||||
|
||||
return routes
|
||||
|
||||
def generate_ccd_file(self, profile: VPNProfile) -> str | None:
|
||||
"""Generate CCD (Client Config Directory) file content for a gateway profile.
|
||||
|
||||
This file tells OpenVPN that traffic for the gateway's subnet
|
||||
should be routed through this client.
|
||||
"""
|
||||
gateway = profile.gateway
|
||||
if not gateway or not gateway.vpn_subnet:
|
||||
return None
|
||||
|
||||
lines = [
|
||||
f"# CCD file for gateway: {gateway.name}",
|
||||
f"# Profile: {profile.name}",
|
||||
f"# Common Name: {profile.cert_cn}",
|
||||
"",
|
||||
]
|
||||
|
||||
try:
|
||||
network = ipaddress.ip_network(gateway.vpn_subnet, strict=False)
|
||||
# iroute tells OpenVPN to route this subnet through this client
|
||||
lines.append(f"iroute {network.network_address} {network.netmask}")
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Optional: assign static IP to this gateway
|
||||
if profile.vpn_ip:
|
||||
lines.append(f"# Static IP assignment (if needed)")
|
||||
lines.append(f"# ifconfig-push {profile.vpn_ip} {profile.vpn_server.vpn_netmask}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_all_ccd_files(self, server: VPNServer) -> dict[str, str]:
|
||||
"""Get all CCD files for a VPN server.
|
||||
|
||||
Returns a dict mapping cert_cn to CCD file content.
|
||||
"""
|
||||
ccd_files = {}
|
||||
|
||||
profiles = self.db.query(VPNProfile).filter(
|
||||
VPNProfile.vpn_server_id == server.id
|
||||
).all()
|
||||
|
||||
for profile in profiles:
|
||||
content = self.generate_ccd_file(profile)
|
||||
if content:
|
||||
ccd_files[profile.cert_cn] = content
|
||||
|
||||
return ccd_files
|
||||
|
||||
def get_connected_clients(self, server: VPNServer) -> list[dict]:
|
||||
"""Get list of currently connected VPN clients."""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user