openvpn-endpoint-server/server/app/api/internal.py

586 lines
18 KiB
Python

"""Internal API endpoints for container-to-container communication.
These endpoints are used by OpenVPN containers to fetch their configuration.
They should only be accessible from within the Docker network.
"""
import os
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, Response, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel
from ..database import get_db
from ..models.vpn_server import VPNServer, VPNServerStatus
from ..models.vpn_profile import VPNProfile
from ..models.client_vpn_profile import ClientVPNProfile
from ..models.gateway import Gateway
from ..models.access import ConnectionLog
from ..models.endpoint import Endpoint
from ..services.vpn_server_service import VPNServerService
from ..services.vpn_sync_service import VPNSyncService
from ..services.certificate_service import CertificateService
from ..services.firewall_service import FirewallService
class ClientConnectedRequest(BaseModel):
"""Request when VPN client connects."""
common_name: str
real_ip: str
vpn_ip: str
class ClientDisconnectedRequest(BaseModel):
"""Request when VPN client disconnects."""
common_name: str
bytes_received: int = 0
bytes_sent: int = 0
duration: int = 0
# Log directory (shared volume with OpenVPN container)
LOG_DIR = Path("/var/log/openvpn")
router = APIRouter(prefix="/api/internal", tags=["internal"])
@router.get("/health")
async def health_check():
"""Health check endpoint for container startup."""
return {"status": "healthy"}
@router.get("/vpn-servers/active")
async def get_active_servers(db: Session = Depends(get_db)):
"""Get list of all active and ready VPN servers.
Used by OpenVPN container to discover which servers to start.
"""
servers = db.query(VPNServer).filter(
VPNServer.is_active == True
).all()
return [
{
"id": s.id,
"name": s.name,
"port": s.port,
"protocol": s.protocol.value,
"management_port": s.management_port,
"vpn_network": s.vpn_network,
"vpn_netmask": s.vpn_netmask,
"is_ready": s.is_ready,
"has_ca": s.certificate_authority is not None and s.certificate_authority.is_ready,
"has_cert": s.server_cert is not None and s.server_key is not None
}
for s in servers
]
@router.get("/vpn-servers/{server_id}/config")
async def get_server_config(
server_id: int,
db: Session = Depends(get_db)
):
"""Get OpenVPN server configuration file."""
service = VPNServerService(db)
server = service.get_server_by_id(server_id)
if not server:
raise HTTPException(status_code=404, detail="Server not found")
if not server.server_cert or not server.server_key:
raise HTTPException(status_code=400, detail="Server certificate not generated")
config = service.generate_server_config(server)
# Update server status
server.status = VPNServerStatus.STARTING
db.commit()
return Response(content=config, media_type="text/plain")
@router.get("/vpn-servers/{server_id}/ca")
async def get_server_ca(
server_id: int,
db: Session = Depends(get_db)
):
"""Get CA certificate for a VPN server."""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
if not server.certificate_authority or not server.certificate_authority.ca_cert:
raise HTTPException(status_code=400, detail="CA not available")
return Response(
content=server.certificate_authority.ca_cert,
media_type="application/x-pem-file"
)
@router.get("/vpn-servers/{server_id}/cert")
async def get_server_cert(
server_id: int,
db: Session = Depends(get_db)
):
"""Get server certificate."""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
if not server.server_cert:
raise HTTPException(status_code=400, detail="Server certificate not generated")
return Response(
content=server.server_cert,
media_type="application/x-pem-file"
)
@router.get("/vpn-servers/{server_id}/key")
async def get_server_key(
server_id: int,
db: Session = Depends(get_db)
):
"""Get server private key."""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
if not server.server_key:
raise HTTPException(status_code=400, detail="Server key not generated")
return Response(
content=server.server_key,
media_type="application/x-pem-file"
)
@router.get("/vpn-servers/{server_id}/dh")
async def get_server_dh(
server_id: int,
db: Session = Depends(get_db)
):
"""Get DH parameters."""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
if not server.certificate_authority or not server.certificate_authority.dh_params:
raise HTTPException(status_code=400, detail="DH parameters not available")
return Response(
content=server.certificate_authority.dh_params,
media_type="application/x-pem-file"
)
@router.get("/vpn-servers/{server_id}/ta")
async def get_server_ta(
server_id: int,
db: Session = Depends(get_db)
):
"""Get TLS-Auth key."""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
if not server.ta_key:
raise HTTPException(status_code=400, detail="TA key not available")
return Response(
content=server.ta_key,
media_type="text/plain"
)
@router.get("/vpn-servers/{server_id}/crl")
async def get_server_crl(
server_id: int,
db: Session = Depends(get_db)
):
"""Get Certificate Revocation List."""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
if not server.certificate_authority:
raise HTTPException(status_code=400, detail="CA not available")
cert_service = CertificateService(db)
crl = cert_service.get_crl(server.certificate_authority)
return Response(
content=crl,
media_type="application/x-pem-file"
)
@router.post("/vpn-servers/{server_id}/started")
async def notify_server_started(
server_id: int,
db: Session = Depends(get_db)
):
"""Notify that server has started successfully."""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
server.status = VPNServerStatus.RUNNING
db.commit()
return {"status": "ok"}
@router.post("/vpn-servers/{server_id}/stopped")
async def notify_server_stopped(
server_id: int,
db: Session = Depends(get_db)
):
"""Notify that server has stopped."""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
server.status = VPNServerStatus.STOPPED
db.commit()
return {"status": "ok"}
@router.get("/vpn-servers/{server_id}/logs")
async def get_server_logs(
server_id: int,
lines: int = Query(default=100, le=1000),
db: Session = Depends(get_db)
):
"""Get OpenVPN server log file.
Args:
server_id: VPN server ID
lines: Number of lines to return (max 1000)
"""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
log_file = LOG_DIR / f"server-{server_id}.log"
if not log_file.exists():
return {"lines": [], "message": "Log file not found"}
try:
# Read last N lines
with open(log_file, 'r') as f:
all_lines = f.readlines()
log_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
return {
"server_id": server_id,
"server_name": server.name,
"total_lines": len(all_lines),
"returned_lines": len(log_lines),
"lines": [line.rstrip() for line in log_lines]
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error reading log: {str(e)}")
@router.get("/vpn-servers/{server_id}/logs/raw")
async def get_server_logs_raw(
server_id: int,
lines: int = Query(default=100, le=5000),
db: Session = Depends(get_db)
):
"""Get OpenVPN server log file as plain text."""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
log_file = LOG_DIR / f"server-{server_id}.log"
if not log_file.exists():
return Response(content="Log file not found", media_type="text/plain")
try:
with open(log_file, 'r') as f:
all_lines = f.readlines()
log_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
return Response(
content="".join(log_lines),
media_type="text/plain"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error reading log: {str(e)}")
@router.get("/logs/supervisord")
async def get_supervisord_logs(
lines: int = Query(default=100, le=1000)
):
"""Get supervisord log file."""
log_file = LOG_DIR / "supervisord.log"
if not log_file.exists():
return {"lines": [], "message": "Supervisord log not found"}
try:
with open(log_file, 'r') as f:
all_lines = f.readlines()
log_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
return {
"total_lines": len(all_lines),
"returned_lines": len(log_lines),
"lines": [line.rstrip() for line in log_lines]
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error reading log: {str(e)}")
@router.get("/debug/sync")
async def debug_sync(db: Session = Depends(get_db)):
"""Debug endpoint to check VPN sync status."""
sync_service = VPNSyncService(db)
vpn_service = VPNServerService(db)
# Get all profiles with their CNs
profiles = db.query(VPNProfile).all()
profile_cns = [{"id": p.id, "name": p.name, "cert_cn": p.cert_cn, "gateway_id": p.gateway_id} for p in profiles]
# Get connected clients from all servers
servers = db.query(VPNServer).filter(VPNServer.is_active == True).all()
connected_clients = []
for server in servers:
try:
clients = vpn_service.get_connected_clients(server)
for client in clients:
client['server_id'] = server.id
client['server_name'] = server.name
connected_clients.append(client)
except Exception as e:
connected_clients.append({"server_id": server.id, "error": str(e)})
# Run sync and get result
sync_result = sync_service.sync_all_connections()
return {
"profiles": profile_cns,
"connected_clients": connected_clients,
"sync_result": sync_result
}
@router.get("/vpn-servers/{server_id}/ccd")
async def get_ccd_files(
server_id: int,
db: Session = Depends(get_db)
):
"""Get all CCD (Client Config Directory) files for a VPN server.
CCD files contain iroute directives that tell OpenVPN which client
handles traffic for which subnet.
"""
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail="Server not found")
service = VPNServerService(db)
ccd_files = service.get_all_ccd_files(server)
return {
"server_id": server_id,
"count": len(ccd_files),
"files": ccd_files
}
@router.get("/vpn-servers/{server_id}/ccd/{common_name}")
async def get_ccd_file(
server_id: int,
common_name: str,
db: Session = Depends(get_db)
):
"""Get a single CCD file for a specific client."""
from ..models.vpn_profile import VPNProfile
profile = db.query(VPNProfile).filter(
VPNProfile.vpn_server_id == server_id,
VPNProfile.cert_cn == common_name
).first()
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
service = VPNServerService(db)
content = service.generate_ccd_file(profile)
if not content:
raise HTTPException(status_code=404, detail="No CCD content for this profile")
return Response(content=content, media_type="text/plain")
@router.post("/vpn-servers/{server_id}/client-connected")
async def client_connected(
server_id: int,
request: ClientConnectedRequest,
db: Session = Depends(get_db)
):
"""Handle VPN client connection event.
Called by OpenVPN client-connect script when a client connects.
Sets up firewall rules for pending connections.
"""
# Determine if this is a gateway or a user client
cn = request.common_name
vpn_ip = request.vpn_ip
# Check if it's a gateway profile
gateway_profile = db.query(VPNProfile).filter(
VPNProfile.cert_cn == cn
).first()
if gateway_profile:
# Update gateway status
gateway = gateway_profile.gateway
gateway.is_online = True
gateway.vpn_ip = vpn_ip
gateway_profile.vpn_ip = vpn_ip
gateway_profile.last_connection = db.func.now()
db.commit()
return {"status": "ok", "type": "gateway", "gateway_id": gateway.id}
# Check if it's a client profile (user/technician)
client_profile = db.query(ClientVPNProfile).filter(
ClientVPNProfile.cert_cn == cn
).first()
if client_profile:
# Update client profile with VPN IP
client_profile.vpn_ip = vpn_ip
client_profile.last_connection = db.func.now()
db.commit()
# Find pending connections for this user and set firewall rules
pending_connections = db.query(ConnectionLog).filter(
ConnectionLog.user_id == client_profile.user_id,
ConnectionLog.disconnected_at.is_(None),
ConnectionLog.vpn_ip.is_(None) # Not yet processed
).all()
firewall = FirewallService()
rules_created = 0
for conn in pending_connections:
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
if endpoint and gateway and gateway.vpn_ip:
# Set firewall rule to allow client to reach endpoint via gateway
success = firewall.allow_connection(
client_vpn_ip=vpn_ip,
gateway_vpn_ip=gateway.vpn_ip,
target_ip=endpoint.internal_ip,
target_port=endpoint.port,
protocol=endpoint.protocol.value
)
if success:
conn.vpn_ip = vpn_ip
rules_created += 1
db.commit()
return {
"status": "ok",
"type": "client",
"user_id": client_profile.user_id,
"rules_created": rules_created
}
return {"status": "unknown_client", "common_name": cn}
@router.post("/vpn-servers/{server_id}/client-disconnected")
async def client_disconnected(
server_id: int,
request: ClientDisconnectedRequest,
db: Session = Depends(get_db)
):
"""Handle VPN client disconnection event.
Called by OpenVPN client-disconnect script when a client disconnects.
Cleans up firewall rules.
"""
cn = request.common_name
# Check if it's a gateway profile
gateway_profile = db.query(VPNProfile).filter(
VPNProfile.cert_cn == cn
).first()
if gateway_profile:
# Mark gateway as offline
gateway = gateway_profile.gateway
gateway.is_online = False
db.commit()
return {"status": "ok", "type": "gateway", "gateway_id": gateway.id}
# Check if it's a client profile
client_profile = db.query(ClientVPNProfile).filter(
ClientVPNProfile.cert_cn == cn
).first()
if client_profile:
vpn_ip = client_profile.vpn_ip
# Find active connections for this user and clean up firewall rules
if vpn_ip:
active_connections = db.query(ConnectionLog).filter(
ConnectionLog.user_id == client_profile.user_id,
ConnectionLog.disconnected_at.is_(None),
ConnectionLog.vpn_ip == vpn_ip
).all()
firewall = FirewallService()
for conn in active_connections:
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
if endpoint and gateway:
# Remove firewall rule
firewall.revoke_connection(
client_vpn_ip=vpn_ip,
gateway_vpn_ip=gateway.vpn_ip or "",
target_ip=endpoint.internal_ip,
target_port=endpoint.port,
protocol=endpoint.protocol.value
)
# Mark connection as disconnected
conn.disconnected_at = db.func.now()
# Clear VPN IP from profile
client_profile.vpn_ip = None
db.commit()
return {"status": "ok", "type": "client", "user_id": client_profile.user_id}
return {"status": "unknown_client", "common_name": cn}