readme refreshed
This commit is contained in:
@@ -11,6 +11,7 @@ from ..models.user import User, UserRole
|
||||
from ..models.access import UserGatewayAccess, ConnectionLog
|
||||
from ..services.vpn_service import VPNService
|
||||
from ..services.firewall_service import FirewallService
|
||||
from ..services.client_vpn_profile_service import ClientVPNProfileService
|
||||
from .deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
@@ -92,9 +93,15 @@ def connect_to_endpoint(
|
||||
detail="Endpoint not found"
|
||||
)
|
||||
|
||||
# NOTE: Dynamic VPN config generation has been replaced by VPN profiles.
|
||||
# Gateways should have pre-provisioned VPN profiles.
|
||||
# This endpoint now just logs the connection intent.
|
||||
# Get or create VPN profile for user
|
||||
client_profile_service = ClientVPNProfileService(db)
|
||||
vpn_config = client_profile_service.get_vpn_config_for_user(current_user)
|
||||
|
||||
if not vpn_config:
|
||||
return ConnectResponse(
|
||||
success=False,
|
||||
message="No VPN server available. Please contact administrator."
|
||||
)
|
||||
|
||||
# Log connection
|
||||
connection = ConnectionLog(
|
||||
@@ -107,10 +114,14 @@ def connect_to_endpoint(
|
||||
db.commit()
|
||||
db.refresh(connection)
|
||||
|
||||
# Set up firewall rules for this connection
|
||||
# The client's VPN IP will be determined after VPN connects
|
||||
# For now, we'll configure firewall rules when the client-connect script runs
|
||||
|
||||
return ConnectResponse(
|
||||
success=True,
|
||||
message="Connection logged. Use the gateway's VPN profile configuration to connect.",
|
||||
vpn_config=None, # VPN config is now obtained through gateway VPN profiles
|
||||
message="VPN configuration ready. Connect to access endpoint.",
|
||||
vpn_config=vpn_config,
|
||||
target_ip=endpoint.internal_ip,
|
||||
target_port=endpoint.port,
|
||||
connection_id=connection.id
|
||||
|
||||
@@ -9,12 +9,34 @@ from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..database import get_db
|
||||
from ..models.vpn_server import VPNServer, VPNServerStatus
|
||||
from ..models.vpn_profile import VPNProfile
|
||||
from ..models.client_vpn_profile import ClientVPNProfile
|
||||
from ..models.gateway import Gateway
|
||||
from ..models.access import ConnectionLog
|
||||
from ..models.endpoint import Endpoint
|
||||
from ..services.vpn_server_service import VPNServerService
|
||||
from ..services.vpn_sync_service import VPNSyncService
|
||||
from ..services.certificate_service import CertificateService
|
||||
from ..services.firewall_service import FirewallService
|
||||
|
||||
|
||||
class ClientConnectedRequest(BaseModel):
|
||||
"""Request when VPN client connects."""
|
||||
common_name: str
|
||||
real_ip: str
|
||||
vpn_ip: str
|
||||
|
||||
|
||||
class ClientDisconnectedRequest(BaseModel):
|
||||
"""Request when VPN client disconnects."""
|
||||
common_name: str
|
||||
bytes_received: int = 0
|
||||
bytes_sent: int = 0
|
||||
duration: int = 0
|
||||
|
||||
# Log directory (shared volume with OpenVPN container)
|
||||
LOG_DIR = Path("/var/log/openvpn")
|
||||
@@ -361,3 +383,203 @@ async def debug_sync(db: Session = Depends(get_db)):
|
||||
"connected_clients": connected_clients,
|
||||
"sync_result": sync_result
|
||||
}
|
||||
|
||||
|
||||
@router.get("/vpn-servers/{server_id}/ccd")
|
||||
async def get_ccd_files(
|
||||
server_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get all CCD (Client Config Directory) files for a VPN server.
|
||||
|
||||
CCD files contain iroute directives that tell OpenVPN which client
|
||||
handles traffic for which subnet.
|
||||
"""
|
||||
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
||||
|
||||
if not server:
|
||||
raise HTTPException(status_code=404, detail="Server not found")
|
||||
|
||||
service = VPNServerService(db)
|
||||
ccd_files = service.get_all_ccd_files(server)
|
||||
|
||||
return {
|
||||
"server_id": server_id,
|
||||
"count": len(ccd_files),
|
||||
"files": ccd_files
|
||||
}
|
||||
|
||||
|
||||
@router.get("/vpn-servers/{server_id}/ccd/{common_name}")
|
||||
async def get_ccd_file(
|
||||
server_id: int,
|
||||
common_name: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get a single CCD file for a specific client."""
|
||||
from ..models.vpn_profile import VPNProfile
|
||||
|
||||
profile = db.query(VPNProfile).filter(
|
||||
VPNProfile.vpn_server_id == server_id,
|
||||
VPNProfile.cert_cn == common_name
|
||||
).first()
|
||||
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
service = VPNServerService(db)
|
||||
content = service.generate_ccd_file(profile)
|
||||
|
||||
if not content:
|
||||
raise HTTPException(status_code=404, detail="No CCD content for this profile")
|
||||
|
||||
return Response(content=content, media_type="text/plain")
|
||||
|
||||
|
||||
@router.post("/vpn-servers/{server_id}/client-connected")
|
||||
async def client_connected(
|
||||
server_id: int,
|
||||
request: ClientConnectedRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Handle VPN client connection event.
|
||||
|
||||
Called by OpenVPN client-connect script when a client connects.
|
||||
Sets up firewall rules for pending connections.
|
||||
"""
|
||||
# Determine if this is a gateway or a user client
|
||||
cn = request.common_name
|
||||
vpn_ip = request.vpn_ip
|
||||
|
||||
# Check if it's a gateway profile
|
||||
gateway_profile = db.query(VPNProfile).filter(
|
||||
VPNProfile.cert_cn == cn
|
||||
).first()
|
||||
|
||||
if gateway_profile:
|
||||
# Update gateway status
|
||||
gateway = gateway_profile.gateway
|
||||
gateway.is_online = True
|
||||
gateway.vpn_ip = vpn_ip
|
||||
gateway_profile.vpn_ip = vpn_ip
|
||||
gateway_profile.last_connection = db.func.now()
|
||||
db.commit()
|
||||
return {"status": "ok", "type": "gateway", "gateway_id": gateway.id}
|
||||
|
||||
# Check if it's a client profile (user/technician)
|
||||
client_profile = db.query(ClientVPNProfile).filter(
|
||||
ClientVPNProfile.cert_cn == cn
|
||||
).first()
|
||||
|
||||
if client_profile:
|
||||
# Update client profile with VPN IP
|
||||
client_profile.vpn_ip = vpn_ip
|
||||
client_profile.last_connection = db.func.now()
|
||||
db.commit()
|
||||
|
||||
# Find pending connections for this user and set firewall rules
|
||||
pending_connections = db.query(ConnectionLog).filter(
|
||||
ConnectionLog.user_id == client_profile.user_id,
|
||||
ConnectionLog.disconnected_at.is_(None),
|
||||
ConnectionLog.vpn_ip.is_(None) # Not yet processed
|
||||
).all()
|
||||
|
||||
firewall = FirewallService()
|
||||
rules_created = 0
|
||||
|
||||
for conn in pending_connections:
|
||||
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
|
||||
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
|
||||
|
||||
if endpoint and gateway and gateway.vpn_ip:
|
||||
# Set firewall rule to allow client to reach endpoint via gateway
|
||||
success = firewall.allow_connection(
|
||||
client_vpn_ip=vpn_ip,
|
||||
gateway_vpn_ip=gateway.vpn_ip,
|
||||
target_ip=endpoint.internal_ip,
|
||||
target_port=endpoint.port,
|
||||
protocol=endpoint.protocol.value
|
||||
)
|
||||
|
||||
if success:
|
||||
conn.vpn_ip = vpn_ip
|
||||
rules_created += 1
|
||||
|
||||
db.commit()
|
||||
return {
|
||||
"status": "ok",
|
||||
"type": "client",
|
||||
"user_id": client_profile.user_id,
|
||||
"rules_created": rules_created
|
||||
}
|
||||
|
||||
return {"status": "unknown_client", "common_name": cn}
|
||||
|
||||
|
||||
@router.post("/vpn-servers/{server_id}/client-disconnected")
|
||||
async def client_disconnected(
|
||||
server_id: int,
|
||||
request: ClientDisconnectedRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Handle VPN client disconnection event.
|
||||
|
||||
Called by OpenVPN client-disconnect script when a client disconnects.
|
||||
Cleans up firewall rules.
|
||||
"""
|
||||
cn = request.common_name
|
||||
|
||||
# Check if it's a gateway profile
|
||||
gateway_profile = db.query(VPNProfile).filter(
|
||||
VPNProfile.cert_cn == cn
|
||||
).first()
|
||||
|
||||
if gateway_profile:
|
||||
# Mark gateway as offline
|
||||
gateway = gateway_profile.gateway
|
||||
gateway.is_online = False
|
||||
db.commit()
|
||||
return {"status": "ok", "type": "gateway", "gateway_id": gateway.id}
|
||||
|
||||
# Check if it's a client profile
|
||||
client_profile = db.query(ClientVPNProfile).filter(
|
||||
ClientVPNProfile.cert_cn == cn
|
||||
).first()
|
||||
|
||||
if client_profile:
|
||||
vpn_ip = client_profile.vpn_ip
|
||||
|
||||
# Find active connections for this user and clean up firewall rules
|
||||
if vpn_ip:
|
||||
active_connections = db.query(ConnectionLog).filter(
|
||||
ConnectionLog.user_id == client_profile.user_id,
|
||||
ConnectionLog.disconnected_at.is_(None),
|
||||
ConnectionLog.vpn_ip == vpn_ip
|
||||
).all()
|
||||
|
||||
firewall = FirewallService()
|
||||
|
||||
for conn in active_connections:
|
||||
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
|
||||
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
|
||||
|
||||
if endpoint and gateway:
|
||||
# Remove firewall rule
|
||||
firewall.revoke_connection(
|
||||
client_vpn_ip=vpn_ip,
|
||||
gateway_vpn_ip=gateway.vpn_ip or "",
|
||||
target_ip=endpoint.internal_ip,
|
||||
target_port=endpoint.port,
|
||||
protocol=endpoint.protocol.value
|
||||
)
|
||||
|
||||
# Mark connection as disconnected
|
||||
conn.disconnected_at = db.func.now()
|
||||
|
||||
# Clear VPN IP from profile
|
||||
client_profile.vpn_ip = None
|
||||
db.commit()
|
||||
|
||||
return {"status": "ok", "type": "client", "user_id": client_profile.user_id}
|
||||
|
||||
return {"status": "unknown_client", "common_name": cn}
|
||||
|
||||
Reference in New Issue
Block a user