"""Internal API endpoints for container-to-container communication. These endpoints are used by OpenVPN containers to fetch their configuration. They should only be accessible from within the Docker network. """ import os from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Response, Query from sqlalchemy.orm import Session from pydantic import BaseModel from ..database import get_db from ..models.vpn_server import VPNServer, VPNServerStatus from ..models.vpn_profile import VPNProfile from ..models.client_vpn_profile import ClientVPNProfile from ..models.gateway import Gateway from ..models.access import ConnectionLog from ..models.endpoint import Endpoint from ..services.vpn_server_service import VPNServerService from ..services.vpn_sync_service import VPNSyncService from ..services.certificate_service import CertificateService from ..services.firewall_service import FirewallService class ClientConnectedRequest(BaseModel): """Request when VPN client connects.""" common_name: str real_ip: str vpn_ip: str class ClientDisconnectedRequest(BaseModel): """Request when VPN client disconnects.""" common_name: str bytes_received: int = 0 bytes_sent: int = 0 duration: int = 0 # Log directory (shared volume with OpenVPN container) LOG_DIR = Path("/var/log/openvpn") router = APIRouter(prefix="/api/internal", tags=["internal"]) @router.get("/health") async def health_check(): """Health check endpoint for container startup.""" return {"status": "healthy"} @router.get("/vpn-servers/active") async def get_active_servers(db: Session = Depends(get_db)): """Get list of all active and ready VPN servers. Used by OpenVPN container to discover which servers to start. """ servers = db.query(VPNServer).filter( VPNServer.is_active == True ).all() return [ { "id": s.id, "name": s.name, "port": s.port, "protocol": s.protocol.value, "management_port": s.management_port, "vpn_network": s.vpn_network, "vpn_netmask": s.vpn_netmask, "is_ready": s.is_ready, "has_ca": s.certificate_authority is not None and s.certificate_authority.is_ready, "has_cert": s.server_cert is not None and s.server_key is not None } for s in servers ] @router.get("/vpn-servers/{server_id}/config") async def get_server_config( server_id: int, db: Session = Depends(get_db) ): """Get OpenVPN server configuration file.""" service = VPNServerService(db) server = service.get_server_by_id(server_id) if not server: raise HTTPException(status_code=404, detail="Server not found") if not server.server_cert or not server.server_key: raise HTTPException(status_code=400, detail="Server certificate not generated") config = service.generate_server_config(server) # Update server status server.status = VPNServerStatus.STARTING db.commit() return Response(content=config, media_type="text/plain") @router.get("/vpn-servers/{server_id}/ca") async def get_server_ca( server_id: int, db: Session = Depends(get_db) ): """Get CA certificate for a VPN server.""" server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") if not server.certificate_authority or not server.certificate_authority.ca_cert: raise HTTPException(status_code=400, detail="CA not available") return Response( content=server.certificate_authority.ca_cert, media_type="application/x-pem-file" ) @router.get("/vpn-servers/{server_id}/cert") async def get_server_cert( server_id: int, db: Session = Depends(get_db) ): """Get server certificate.""" server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") if not server.server_cert: raise HTTPException(status_code=400, detail="Server certificate not generated") return Response( content=server.server_cert, media_type="application/x-pem-file" ) @router.get("/vpn-servers/{server_id}/key") async def get_server_key( server_id: int, db: Session = Depends(get_db) ): """Get server private key.""" server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") if not server.server_key: raise HTTPException(status_code=400, detail="Server key not generated") return Response( content=server.server_key, media_type="application/x-pem-file" ) @router.get("/vpn-servers/{server_id}/dh") async def get_server_dh( server_id: int, db: Session = Depends(get_db) ): """Get DH parameters.""" server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") if not server.certificate_authority or not server.certificate_authority.dh_params: raise HTTPException(status_code=400, detail="DH parameters not available") return Response( content=server.certificate_authority.dh_params, media_type="application/x-pem-file" ) @router.get("/vpn-servers/{server_id}/ta") async def get_server_ta( server_id: int, db: Session = Depends(get_db) ): """Get TLS-Auth key.""" server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") if not server.ta_key: raise HTTPException(status_code=400, detail="TA key not available") return Response( content=server.ta_key, media_type="text/plain" ) @router.get("/vpn-servers/{server_id}/crl") async def get_server_crl( server_id: int, db: Session = Depends(get_db) ): """Get Certificate Revocation List.""" server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") if not server.certificate_authority: raise HTTPException(status_code=400, detail="CA not available") cert_service = CertificateService(db) crl = cert_service.get_crl(server.certificate_authority) return Response( content=crl, media_type="application/x-pem-file" ) @router.post("/vpn-servers/{server_id}/started") async def notify_server_started( server_id: int, db: Session = Depends(get_db) ): """Notify that server has started successfully.""" server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") server.status = VPNServerStatus.RUNNING db.commit() return {"status": "ok"} @router.post("/vpn-servers/{server_id}/stopped") async def notify_server_stopped( server_id: int, db: Session = Depends(get_db) ): """Notify that server has stopped.""" server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") server.status = VPNServerStatus.STOPPED db.commit() return {"status": "ok"} @router.get("/vpn-servers/{server_id}/logs") async def get_server_logs( server_id: int, lines: int = Query(default=100, le=1000), db: Session = Depends(get_db) ): """Get OpenVPN server log file. Args: server_id: VPN server ID lines: Number of lines to return (max 1000) """ server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") log_file = LOG_DIR / f"server-{server_id}.log" if not log_file.exists(): return {"lines": [], "message": "Log file not found"} try: # Read last N lines with open(log_file, 'r') as f: all_lines = f.readlines() log_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines return { "server_id": server_id, "server_name": server.name, "total_lines": len(all_lines), "returned_lines": len(log_lines), "lines": [line.rstrip() for line in log_lines] } except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading log: {str(e)}") @router.get("/vpn-servers/{server_id}/logs/raw") async def get_server_logs_raw( server_id: int, lines: int = Query(default=100, le=5000), db: Session = Depends(get_db) ): """Get OpenVPN server log file as plain text.""" server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") log_file = LOG_DIR / f"server-{server_id}.log" if not log_file.exists(): return Response(content="Log file not found", media_type="text/plain") try: with open(log_file, 'r') as f: all_lines = f.readlines() log_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines return Response( content="".join(log_lines), media_type="text/plain" ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading log: {str(e)}") @router.get("/logs/supervisord") async def get_supervisord_logs( lines: int = Query(default=100, le=1000) ): """Get supervisord log file.""" log_file = LOG_DIR / "supervisord.log" if not log_file.exists(): return {"lines": [], "message": "Supervisord log not found"} try: with open(log_file, 'r') as f: all_lines = f.readlines() log_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines return { "total_lines": len(all_lines), "returned_lines": len(log_lines), "lines": [line.rstrip() for line in log_lines] } except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading log: {str(e)}") @router.get("/debug/sync") async def debug_sync(db: Session = Depends(get_db)): """Debug endpoint to check VPN sync status.""" sync_service = VPNSyncService(db) vpn_service = VPNServerService(db) # Get all profiles with their CNs profiles = db.query(VPNProfile).all() profile_cns = [{"id": p.id, "name": p.name, "cert_cn": p.cert_cn, "gateway_id": p.gateway_id} for p in profiles] # Get connected clients from all servers servers = db.query(VPNServer).filter(VPNServer.is_active == True).all() connected_clients = [] for server in servers: try: clients = vpn_service.get_connected_clients(server) for client in clients: client['server_id'] = server.id client['server_name'] = server.name connected_clients.append(client) except Exception as e: connected_clients.append({"server_id": server.id, "error": str(e)}) # Run sync and get result sync_result = sync_service.sync_all_connections() return { "profiles": profile_cns, "connected_clients": connected_clients, "sync_result": sync_result } @router.get("/vpn-servers/{server_id}/ccd") async def get_ccd_files( server_id: int, db: Session = Depends(get_db) ): """Get all CCD (Client Config Directory) files for a VPN server. CCD files contain iroute directives that tell OpenVPN which client handles traffic for which subnet. """ server = db.query(VPNServer).filter(VPNServer.id == server_id).first() if not server: raise HTTPException(status_code=404, detail="Server not found") service = VPNServerService(db) ccd_files = service.get_all_ccd_files(server) return { "server_id": server_id, "count": len(ccd_files), "files": ccd_files } @router.get("/vpn-servers/{server_id}/ccd/{common_name}") async def get_ccd_file( server_id: int, common_name: str, db: Session = Depends(get_db) ): """Get a single CCD file for a specific client.""" from ..models.vpn_profile import VPNProfile profile = db.query(VPNProfile).filter( VPNProfile.vpn_server_id == server_id, VPNProfile.cert_cn == common_name ).first() if not profile: raise HTTPException(status_code=404, detail="Profile not found") service = VPNServerService(db) content = service.generate_ccd_file(profile) if not content: raise HTTPException(status_code=404, detail="No CCD content for this profile") return Response(content=content, media_type="text/plain") @router.post("/vpn-servers/{server_id}/client-connected") async def client_connected( server_id: int, request: ClientConnectedRequest, db: Session = Depends(get_db) ): """Handle VPN client connection event. Called by OpenVPN client-connect script when a client connects. Sets up firewall rules for pending connections. """ # Determine if this is a gateway or a user client cn = request.common_name vpn_ip = request.vpn_ip # Check if it's a gateway profile gateway_profile = db.query(VPNProfile).filter( VPNProfile.cert_cn == cn ).first() if gateway_profile: # Update gateway status gateway = gateway_profile.gateway gateway.is_online = True gateway.vpn_ip = vpn_ip gateway_profile.vpn_ip = vpn_ip gateway_profile.last_connection = db.func.now() db.commit() return {"status": "ok", "type": "gateway", "gateway_id": gateway.id} # Check if it's a client profile (user/technician) client_profile = db.query(ClientVPNProfile).filter( ClientVPNProfile.cert_cn == cn ).first() if client_profile: # Update client profile with VPN IP client_profile.vpn_ip = vpn_ip client_profile.last_connection = db.func.now() db.commit() # Find pending connections for this user and set firewall rules pending_connections = db.query(ConnectionLog).filter( ConnectionLog.user_id == client_profile.user_id, ConnectionLog.disconnected_at.is_(None), ConnectionLog.vpn_ip.is_(None) # Not yet processed ).all() firewall = FirewallService() rules_created = 0 for conn in pending_connections: endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first() gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first() if endpoint and gateway and gateway.vpn_ip: # Set firewall rule to allow client to reach endpoint via gateway success = firewall.allow_connection( client_vpn_ip=vpn_ip, gateway_vpn_ip=gateway.vpn_ip, target_ip=endpoint.internal_ip, target_port=endpoint.port, protocol=endpoint.protocol.value ) if success: conn.vpn_ip = vpn_ip rules_created += 1 db.commit() return { "status": "ok", "type": "client", "user_id": client_profile.user_id, "rules_created": rules_created } return {"status": "unknown_client", "common_name": cn} @router.post("/vpn-servers/{server_id}/client-disconnected") async def client_disconnected( server_id: int, request: ClientDisconnectedRequest, db: Session = Depends(get_db) ): """Handle VPN client disconnection event. Called by OpenVPN client-disconnect script when a client disconnects. Cleans up firewall rules. """ cn = request.common_name # Check if it's a gateway profile gateway_profile = db.query(VPNProfile).filter( VPNProfile.cert_cn == cn ).first() if gateway_profile: # Mark gateway as offline gateway = gateway_profile.gateway gateway.is_online = False db.commit() return {"status": "ok", "type": "gateway", "gateway_id": gateway.id} # Check if it's a client profile client_profile = db.query(ClientVPNProfile).filter( ClientVPNProfile.cert_cn == cn ).first() if client_profile: vpn_ip = client_profile.vpn_ip # Find active connections for this user and clean up firewall rules if vpn_ip: active_connections = db.query(ConnectionLog).filter( ConnectionLog.user_id == client_profile.user_id, ConnectionLog.disconnected_at.is_(None), ConnectionLog.vpn_ip == vpn_ip ).all() firewall = FirewallService() for conn in active_connections: endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first() gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first() if endpoint and gateway: # Remove firewall rule firewall.revoke_connection( client_vpn_ip=vpn_ip, gateway_vpn_ip=gateway.vpn_ip or "", target_ip=endpoint.internal_ip, target_port=endpoint.port, protocol=endpoint.protocol.value ) # Mark connection as disconnected conn.disconnected_at = db.func.now() # Clear VPN IP from profile client_profile.vpn_ip = None db.commit() return {"status": "ok", "type": "client", "user_id": client_profile.user_id} return {"status": "unknown_client", "common_name": cn}