openvpn-endpoint-server/server/app/api/connections.py

237 lines
7.7 KiB
Python

"""Connection management API routes."""
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status, Response
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ..database import get_db
from ..models.endpoint import Endpoint
from ..models.gateway import Gateway
from ..models.user import User, UserRole
from ..models.access import UserGatewayAccess, ConnectionLog
from ..services.vpn_service import VPNService
from ..services.firewall_service import FirewallService
from .deps import get_current_user
router = APIRouter()
class ConnectRequest(BaseModel):
"""Request to establish connection to endpoint."""
gateway_id: int
endpoint_id: int
class ConnectResponse(BaseModel):
"""Response with connection details."""
success: bool
message: str
vpn_config: str | None = None
target_ip: str | None = None
target_port: int | None = None
connection_id: int | None = None
class DisconnectRequest(BaseModel):
"""Request to disconnect from endpoint."""
connection_id: int
def user_has_gateway_access(db: Session, user: User, gateway_id: int) -> bool:
"""Check if user has access to gateway."""
if user.role in (UserRole.SUPER_ADMIN, UserRole.ADMIN):
gateway = db.query(Gateway).filter(Gateway.id == gateway_id).first()
if user.role == UserRole.SUPER_ADMIN:
return gateway is not None
return gateway and gateway.tenant_id == user.tenant_id
access = db.query(UserGatewayAccess).filter(
UserGatewayAccess.user_id == user.id,
UserGatewayAccess.gateway_id == gateway_id
).first()
return access is not None
@router.post("/connect", response_model=ConnectResponse)
def connect_to_endpoint(
request: ConnectRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Request connection to a specific endpoint through a gateway."""
# Check gateway access
if not user_has_gateway_access(db, current_user, request.gateway_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="No access to this gateway"
)
# Get gateway
gateway = db.query(Gateway).filter(Gateway.id == request.gateway_id).first()
if not gateway:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Gateway not found"
)
if not gateway.is_online:
return ConnectResponse(
success=False,
message="Gateway is offline"
)
# Get endpoint
endpoint = db.query(Endpoint).filter(
Endpoint.id == request.endpoint_id,
Endpoint.gateway_id == request.gateway_id
).first()
if not endpoint:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Endpoint not found"
)
# NOTE: Dynamic VPN config generation has been replaced by VPN profiles.
# Gateways should have pre-provisioned VPN profiles.
# This endpoint now just logs the connection intent.
# Log connection
connection = ConnectionLog(
user_id=current_user.id,
gateway_id=gateway.id,
endpoint_id=endpoint.id,
client_ip=None # Will be updated when VPN connects
)
db.add(connection)
db.commit()
db.refresh(connection)
return ConnectResponse(
success=True,
message="Connection logged. Use the gateway's VPN profile configuration to connect.",
vpn_config=None, # VPN config is now obtained through gateway VPN profiles
target_ip=endpoint.internal_ip,
target_port=endpoint.port,
connection_id=connection.id
)
@router.post("/disconnect")
def disconnect_from_endpoint(
request: DisconnectRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Disconnect from endpoint."""
connection = db.query(ConnectionLog).filter(
ConnectionLog.id == request.connection_id,
ConnectionLog.user_id == current_user.id
).first()
if not connection:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Connection not found"
)
if connection.disconnected_at:
return {"message": "Already disconnected"}
# Revoke firewall rules if VPN IP is known
if connection.vpn_ip:
endpoint = db.query(Endpoint).filter(Endpoint.id == connection.endpoint_id).first()
gateway = db.query(Gateway).filter(Gateway.id == connection.gateway_id).first()
if endpoint and gateway:
firewall = FirewallService()
firewall.revoke_connection(
client_vpn_ip=connection.vpn_ip,
gateway_vpn_ip=gateway.vpn_ip,
target_ip=endpoint.internal_ip,
target_port=endpoint.port,
protocol=endpoint.protocol.value
)
# Disconnect VPN client
vpn_service = VPNService()
client_cn = f"client-{current_user.username}-{current_user.id}"
vpn_service.disconnect_client(client_cn)
# Update log
connection.disconnected_at = datetime.utcnow()
db.commit()
return {"message": "Disconnected successfully"}
@router.get("/active")
def list_active_connections(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""List active connections for current user."""
connections = db.query(ConnectionLog).filter(
ConnectionLog.user_id == current_user.id,
ConnectionLog.disconnected_at.is_(None)
).all()
result = []
for conn in connections:
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
result.append({
"connection_id": conn.id,
"gateway_name": gateway.name if gateway else None,
"endpoint_name": endpoint.name if endpoint else None,
"endpoint_address": endpoint.address if endpoint else None,
"connected_at": conn.connected_at.isoformat(),
"vpn_ip": conn.vpn_ip
})
return result
@router.get("/history")
def get_connection_history(
skip: int = 0,
limit: int = 50,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get connection history for current user."""
query = db.query(ConnectionLog).filter(ConnectionLog.user_id == current_user.id)
# Admins can see all connections in their tenant
if current_user.is_admin:
if current_user.role == UserRole.SUPER_ADMIN:
query = db.query(ConnectionLog)
else:
# Filter by tenant's gateways
query = db.query(ConnectionLog).join(
Gateway,
Gateway.id == ConnectionLog.gateway_id
).filter(Gateway.tenant_id == current_user.tenant_id)
connections = query.order_by(ConnectionLog.connected_at.desc()).offset(skip).limit(limit).all()
result = []
for conn in connections:
user = db.query(User).filter(User.id == conn.user_id).first()
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
result.append({
"connection_id": conn.id,
"username": user.username if user else None,
"gateway_name": gateway.name if gateway else None,
"endpoint_name": endpoint.name if endpoint else None,
"connected_at": conn.connected_at.isoformat(),
"disconnected_at": conn.disconnected_at.isoformat() if conn.disconnected_at else None,
"duration_seconds": conn.duration_seconds,
"client_ip": conn.client_ip
})
return result