586 lines
18 KiB
Python
586 lines
18 KiB
Python
"""Internal API endpoints for container-to-container communication.
|
|
|
|
These endpoints are used by OpenVPN containers to fetch their configuration.
|
|
They should only be accessible from within the Docker network.
|
|
"""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from fastapi import APIRouter, Depends, HTTPException, Response, Query
|
|
from sqlalchemy.orm import Session
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from ..database import get_db
|
|
from ..models.vpn_server import VPNServer, VPNServerStatus
|
|
from ..models.vpn_profile import VPNProfile
|
|
from ..models.client_vpn_profile import ClientVPNProfile
|
|
from ..models.gateway import Gateway
|
|
from ..models.access import ConnectionLog
|
|
from ..models.endpoint import Endpoint
|
|
from ..services.vpn_server_service import VPNServerService
|
|
from ..services.vpn_sync_service import VPNSyncService
|
|
from ..services.certificate_service import CertificateService
|
|
from ..services.firewall_service import FirewallService
|
|
|
|
|
|
class ClientConnectedRequest(BaseModel):
|
|
"""Request when VPN client connects."""
|
|
common_name: str
|
|
real_ip: str
|
|
vpn_ip: str
|
|
|
|
|
|
class ClientDisconnectedRequest(BaseModel):
|
|
"""Request when VPN client disconnects."""
|
|
common_name: str
|
|
bytes_received: int = 0
|
|
bytes_sent: int = 0
|
|
duration: int = 0
|
|
|
|
# Log directory (shared volume with OpenVPN container)
|
|
LOG_DIR = Path("/var/log/openvpn")
|
|
|
|
router = APIRouter(prefix="/api/internal", tags=["internal"])
|
|
|
|
|
|
@router.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint for container startup."""
|
|
return {"status": "healthy"}
|
|
|
|
|
|
@router.get("/vpn-servers/active")
|
|
async def get_active_servers(db: Session = Depends(get_db)):
|
|
"""Get list of all active and ready VPN servers.
|
|
|
|
Used by OpenVPN container to discover which servers to start.
|
|
"""
|
|
servers = db.query(VPNServer).filter(
|
|
VPNServer.is_active == True
|
|
).all()
|
|
|
|
return [
|
|
{
|
|
"id": s.id,
|
|
"name": s.name,
|
|
"port": s.port,
|
|
"protocol": s.protocol.value,
|
|
"management_port": s.management_port,
|
|
"vpn_network": s.vpn_network,
|
|
"vpn_netmask": s.vpn_netmask,
|
|
"is_ready": s.is_ready,
|
|
"has_ca": s.certificate_authority is not None and s.certificate_authority.is_ready,
|
|
"has_cert": s.server_cert is not None and s.server_key is not None
|
|
}
|
|
for s in servers
|
|
]
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/config")
|
|
async def get_server_config(
|
|
server_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get OpenVPN server configuration file."""
|
|
service = VPNServerService(db)
|
|
server = service.get_server_by_id(server_id)
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
if not server.server_cert or not server.server_key:
|
|
raise HTTPException(status_code=400, detail="Server certificate not generated")
|
|
|
|
config = service.generate_server_config(server)
|
|
|
|
# Update server status
|
|
server.status = VPNServerStatus.STARTING
|
|
db.commit()
|
|
|
|
return Response(content=config, media_type="text/plain")
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/ca")
|
|
async def get_server_ca(
|
|
server_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get CA certificate for a VPN server."""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
if not server.certificate_authority or not server.certificate_authority.ca_cert:
|
|
raise HTTPException(status_code=400, detail="CA not available")
|
|
|
|
return Response(
|
|
content=server.certificate_authority.ca_cert,
|
|
media_type="application/x-pem-file"
|
|
)
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/cert")
|
|
async def get_server_cert(
|
|
server_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get server certificate."""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
if not server.server_cert:
|
|
raise HTTPException(status_code=400, detail="Server certificate not generated")
|
|
|
|
return Response(
|
|
content=server.server_cert,
|
|
media_type="application/x-pem-file"
|
|
)
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/key")
|
|
async def get_server_key(
|
|
server_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get server private key."""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
if not server.server_key:
|
|
raise HTTPException(status_code=400, detail="Server key not generated")
|
|
|
|
return Response(
|
|
content=server.server_key,
|
|
media_type="application/x-pem-file"
|
|
)
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/dh")
|
|
async def get_server_dh(
|
|
server_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get DH parameters."""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
if not server.certificate_authority or not server.certificate_authority.dh_params:
|
|
raise HTTPException(status_code=400, detail="DH parameters not available")
|
|
|
|
return Response(
|
|
content=server.certificate_authority.dh_params,
|
|
media_type="application/x-pem-file"
|
|
)
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/ta")
|
|
async def get_server_ta(
|
|
server_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get TLS-Auth key."""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
if not server.ta_key:
|
|
raise HTTPException(status_code=400, detail="TA key not available")
|
|
|
|
return Response(
|
|
content=server.ta_key,
|
|
media_type="text/plain"
|
|
)
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/crl")
|
|
async def get_server_crl(
|
|
server_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get Certificate Revocation List."""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
if not server.certificate_authority:
|
|
raise HTTPException(status_code=400, detail="CA not available")
|
|
|
|
cert_service = CertificateService(db)
|
|
crl = cert_service.get_crl(server.certificate_authority)
|
|
|
|
return Response(
|
|
content=crl,
|
|
media_type="application/x-pem-file"
|
|
)
|
|
|
|
|
|
@router.post("/vpn-servers/{server_id}/started")
|
|
async def notify_server_started(
|
|
server_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Notify that server has started successfully."""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
server.status = VPNServerStatus.RUNNING
|
|
db.commit()
|
|
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.post("/vpn-servers/{server_id}/stopped")
|
|
async def notify_server_stopped(
|
|
server_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Notify that server has stopped."""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
server.status = VPNServerStatus.STOPPED
|
|
db.commit()
|
|
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/logs")
|
|
async def get_server_logs(
|
|
server_id: int,
|
|
lines: int = Query(default=100, le=1000),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get OpenVPN server log file.
|
|
|
|
Args:
|
|
server_id: VPN server ID
|
|
lines: Number of lines to return (max 1000)
|
|
"""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
log_file = LOG_DIR / f"server-{server_id}.log"
|
|
|
|
if not log_file.exists():
|
|
return {"lines": [], "message": "Log file not found"}
|
|
|
|
try:
|
|
# Read last N lines
|
|
with open(log_file, 'r') as f:
|
|
all_lines = f.readlines()
|
|
log_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
|
|
|
|
return {
|
|
"server_id": server_id,
|
|
"server_name": server.name,
|
|
"total_lines": len(all_lines),
|
|
"returned_lines": len(log_lines),
|
|
"lines": [line.rstrip() for line in log_lines]
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error reading log: {str(e)}")
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/logs/raw")
|
|
async def get_server_logs_raw(
|
|
server_id: int,
|
|
lines: int = Query(default=100, le=5000),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get OpenVPN server log file as plain text."""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
log_file = LOG_DIR / f"server-{server_id}.log"
|
|
|
|
if not log_file.exists():
|
|
return Response(content="Log file not found", media_type="text/plain")
|
|
|
|
try:
|
|
with open(log_file, 'r') as f:
|
|
all_lines = f.readlines()
|
|
log_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
|
|
|
|
return Response(
|
|
content="".join(log_lines),
|
|
media_type="text/plain"
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error reading log: {str(e)}")
|
|
|
|
|
|
@router.get("/logs/supervisord")
|
|
async def get_supervisord_logs(
|
|
lines: int = Query(default=100, le=1000)
|
|
):
|
|
"""Get supervisord log file."""
|
|
log_file = LOG_DIR / "supervisord.log"
|
|
|
|
if not log_file.exists():
|
|
return {"lines": [], "message": "Supervisord log not found"}
|
|
|
|
try:
|
|
with open(log_file, 'r') as f:
|
|
all_lines = f.readlines()
|
|
log_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
|
|
|
|
return {
|
|
"total_lines": len(all_lines),
|
|
"returned_lines": len(log_lines),
|
|
"lines": [line.rstrip() for line in log_lines]
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error reading log: {str(e)}")
|
|
|
|
|
|
@router.get("/debug/sync")
|
|
async def debug_sync(db: Session = Depends(get_db)):
|
|
"""Debug endpoint to check VPN sync status."""
|
|
sync_service = VPNSyncService(db)
|
|
vpn_service = VPNServerService(db)
|
|
|
|
# Get all profiles with their CNs
|
|
profiles = db.query(VPNProfile).all()
|
|
profile_cns = [{"id": p.id, "name": p.name, "cert_cn": p.cert_cn, "gateway_id": p.gateway_id} for p in profiles]
|
|
|
|
# Get connected clients from all servers
|
|
servers = db.query(VPNServer).filter(VPNServer.is_active == True).all()
|
|
connected_clients = []
|
|
|
|
for server in servers:
|
|
try:
|
|
clients = vpn_service.get_connected_clients(server)
|
|
for client in clients:
|
|
client['server_id'] = server.id
|
|
client['server_name'] = server.name
|
|
connected_clients.append(client)
|
|
except Exception as e:
|
|
connected_clients.append({"server_id": server.id, "error": str(e)})
|
|
|
|
# Run sync and get result
|
|
sync_result = sync_service.sync_all_connections()
|
|
|
|
return {
|
|
"profiles": profile_cns,
|
|
"connected_clients": connected_clients,
|
|
"sync_result": sync_result
|
|
}
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/ccd")
|
|
async def get_ccd_files(
|
|
server_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get all CCD (Client Config Directory) files for a VPN server.
|
|
|
|
CCD files contain iroute directives that tell OpenVPN which client
|
|
handles traffic for which subnet.
|
|
"""
|
|
server = db.query(VPNServer).filter(VPNServer.id == server_id).first()
|
|
|
|
if not server:
|
|
raise HTTPException(status_code=404, detail="Server not found")
|
|
|
|
service = VPNServerService(db)
|
|
ccd_files = service.get_all_ccd_files(server)
|
|
|
|
return {
|
|
"server_id": server_id,
|
|
"count": len(ccd_files),
|
|
"files": ccd_files
|
|
}
|
|
|
|
|
|
@router.get("/vpn-servers/{server_id}/ccd/{common_name}")
|
|
async def get_ccd_file(
|
|
server_id: int,
|
|
common_name: str,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get a single CCD file for a specific client."""
|
|
from ..models.vpn_profile import VPNProfile
|
|
|
|
profile = db.query(VPNProfile).filter(
|
|
VPNProfile.vpn_server_id == server_id,
|
|
VPNProfile.cert_cn == common_name
|
|
).first()
|
|
|
|
if not profile:
|
|
raise HTTPException(status_code=404, detail="Profile not found")
|
|
|
|
service = VPNServerService(db)
|
|
content = service.generate_ccd_file(profile)
|
|
|
|
if not content:
|
|
raise HTTPException(status_code=404, detail="No CCD content for this profile")
|
|
|
|
return Response(content=content, media_type="text/plain")
|
|
|
|
|
|
@router.post("/vpn-servers/{server_id}/client-connected")
|
|
async def client_connected(
|
|
server_id: int,
|
|
request: ClientConnectedRequest,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Handle VPN client connection event.
|
|
|
|
Called by OpenVPN client-connect script when a client connects.
|
|
Sets up firewall rules for pending connections.
|
|
"""
|
|
# Determine if this is a gateway or a user client
|
|
cn = request.common_name
|
|
vpn_ip = request.vpn_ip
|
|
|
|
# Check if it's a gateway profile
|
|
gateway_profile = db.query(VPNProfile).filter(
|
|
VPNProfile.cert_cn == cn
|
|
).first()
|
|
|
|
if gateway_profile:
|
|
# Update gateway status
|
|
gateway = gateway_profile.gateway
|
|
gateway.is_online = True
|
|
gateway.vpn_ip = vpn_ip
|
|
gateway_profile.vpn_ip = vpn_ip
|
|
gateway_profile.last_connection = db.func.now()
|
|
db.commit()
|
|
return {"status": "ok", "type": "gateway", "gateway_id": gateway.id}
|
|
|
|
# Check if it's a client profile (user/technician)
|
|
client_profile = db.query(ClientVPNProfile).filter(
|
|
ClientVPNProfile.cert_cn == cn
|
|
).first()
|
|
|
|
if client_profile:
|
|
# Update client profile with VPN IP
|
|
client_profile.vpn_ip = vpn_ip
|
|
client_profile.last_connection = db.func.now()
|
|
db.commit()
|
|
|
|
# Find pending connections for this user and set firewall rules
|
|
pending_connections = db.query(ConnectionLog).filter(
|
|
ConnectionLog.user_id == client_profile.user_id,
|
|
ConnectionLog.disconnected_at.is_(None),
|
|
ConnectionLog.vpn_ip.is_(None) # Not yet processed
|
|
).all()
|
|
|
|
firewall = FirewallService()
|
|
rules_created = 0
|
|
|
|
for conn in pending_connections:
|
|
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
|
|
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
|
|
|
|
if endpoint and gateway and gateway.vpn_ip:
|
|
# Set firewall rule to allow client to reach endpoint via gateway
|
|
success = firewall.allow_connection(
|
|
client_vpn_ip=vpn_ip,
|
|
gateway_vpn_ip=gateway.vpn_ip,
|
|
target_ip=endpoint.internal_ip,
|
|
target_port=endpoint.port,
|
|
protocol=endpoint.protocol.value
|
|
)
|
|
|
|
if success:
|
|
conn.vpn_ip = vpn_ip
|
|
rules_created += 1
|
|
|
|
db.commit()
|
|
return {
|
|
"status": "ok",
|
|
"type": "client",
|
|
"user_id": client_profile.user_id,
|
|
"rules_created": rules_created
|
|
}
|
|
|
|
return {"status": "unknown_client", "common_name": cn}
|
|
|
|
|
|
@router.post("/vpn-servers/{server_id}/client-disconnected")
|
|
async def client_disconnected(
|
|
server_id: int,
|
|
request: ClientDisconnectedRequest,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Handle VPN client disconnection event.
|
|
|
|
Called by OpenVPN client-disconnect script when a client disconnects.
|
|
Cleans up firewall rules.
|
|
"""
|
|
cn = request.common_name
|
|
|
|
# Check if it's a gateway profile
|
|
gateway_profile = db.query(VPNProfile).filter(
|
|
VPNProfile.cert_cn == cn
|
|
).first()
|
|
|
|
if gateway_profile:
|
|
# Mark gateway as offline
|
|
gateway = gateway_profile.gateway
|
|
gateway.is_online = False
|
|
db.commit()
|
|
return {"status": "ok", "type": "gateway", "gateway_id": gateway.id}
|
|
|
|
# Check if it's a client profile
|
|
client_profile = db.query(ClientVPNProfile).filter(
|
|
ClientVPNProfile.cert_cn == cn
|
|
).first()
|
|
|
|
if client_profile:
|
|
vpn_ip = client_profile.vpn_ip
|
|
|
|
# Find active connections for this user and clean up firewall rules
|
|
if vpn_ip:
|
|
active_connections = db.query(ConnectionLog).filter(
|
|
ConnectionLog.user_id == client_profile.user_id,
|
|
ConnectionLog.disconnected_at.is_(None),
|
|
ConnectionLog.vpn_ip == vpn_ip
|
|
).all()
|
|
|
|
firewall = FirewallService()
|
|
|
|
for conn in active_connections:
|
|
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
|
|
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
|
|
|
|
if endpoint and gateway:
|
|
# Remove firewall rule
|
|
firewall.revoke_connection(
|
|
client_vpn_ip=vpn_ip,
|
|
gateway_vpn_ip=gateway.vpn_ip or "",
|
|
target_ip=endpoint.internal_ip,
|
|
target_port=endpoint.port,
|
|
protocol=endpoint.protocol.value
|
|
)
|
|
|
|
# Mark connection as disconnected
|
|
conn.disconnected_at = db.func.now()
|
|
|
|
# Clear VPN IP from profile
|
|
client_profile.vpn_ip = None
|
|
db.commit()
|
|
|
|
return {"status": "ok", "type": "client", "user_id": client_profile.user_id}
|
|
|
|
return {"status": "unknown_client", "common_name": cn}
|