"""Connection management API routes.""" from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, status, Response from pydantic import BaseModel from sqlalchemy.orm import Session from ..database import get_db from ..models.endpoint import Endpoint from ..models.gateway import Gateway from ..models.user import User, UserRole from ..models.access import UserGatewayAccess, ConnectionLog from ..services.vpn_service import VPNService from ..services.firewall_service import FirewallService from ..services.client_vpn_profile_service import ClientVPNProfileService from .deps import get_current_user router = APIRouter() class ConnectRequest(BaseModel): """Request to establish connection to endpoint.""" gateway_id: int endpoint_id: int class ConnectResponse(BaseModel): """Response with connection details.""" success: bool message: str vpn_config: str | None = None target_ip: str | None = None target_port: int | None = None connection_id: int | None = None class DisconnectRequest(BaseModel): """Request to disconnect from endpoint.""" connection_id: int def user_has_gateway_access(db: Session, user: User, gateway_id: int) -> bool: """Check if user has access to gateway.""" if user.role in (UserRole.SUPER_ADMIN, UserRole.ADMIN): gateway = db.query(Gateway).filter(Gateway.id == gateway_id).first() if user.role == UserRole.SUPER_ADMIN: return gateway is not None return gateway and gateway.tenant_id == user.tenant_id access = db.query(UserGatewayAccess).filter( UserGatewayAccess.user_id == user.id, UserGatewayAccess.gateway_id == gateway_id ).first() return access is not None @router.post("/connect", response_model=ConnectResponse) def connect_to_endpoint( request: ConnectRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """Request connection to a specific endpoint through a gateway.""" # Check gateway access if not user_has_gateway_access(db, current_user, request.gateway_id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="No access to this gateway" ) # Get gateway gateway = db.query(Gateway).filter(Gateway.id == request.gateway_id).first() if not gateway: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found" ) if not gateway.is_online: return ConnectResponse( success=False, message="Gateway is offline" ) # Get endpoint endpoint = db.query(Endpoint).filter( Endpoint.id == request.endpoint_id, Endpoint.gateway_id == request.gateway_id ).first() if not endpoint: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Endpoint not found" ) # Get or create VPN profile for user client_profile_service = ClientVPNProfileService(db) vpn_config = client_profile_service.get_vpn_config_for_user(current_user) if not vpn_config: return ConnectResponse( success=False, message="No VPN server available. Please contact administrator." ) # Log connection connection = ConnectionLog( user_id=current_user.id, gateway_id=gateway.id, endpoint_id=endpoint.id, client_ip=None # Will be updated when VPN connects ) db.add(connection) db.commit() db.refresh(connection) # Set up firewall rules for this connection # The client's VPN IP will be determined after VPN connects # For now, we'll configure firewall rules when the client-connect script runs return ConnectResponse( success=True, message="VPN configuration ready. Connect to access endpoint.", vpn_config=vpn_config, target_ip=endpoint.internal_ip, target_port=endpoint.port, connection_id=connection.id ) @router.post("/disconnect") def disconnect_from_endpoint( request: DisconnectRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """Disconnect from endpoint.""" connection = db.query(ConnectionLog).filter( ConnectionLog.id == request.connection_id, ConnectionLog.user_id == current_user.id ).first() if not connection: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Connection not found" ) if connection.disconnected_at: return {"message": "Already disconnected"} # Revoke firewall rules if VPN IP is known if connection.vpn_ip: endpoint = db.query(Endpoint).filter(Endpoint.id == connection.endpoint_id).first() gateway = db.query(Gateway).filter(Gateway.id == connection.gateway_id).first() if endpoint and gateway: firewall = FirewallService() firewall.revoke_connection( client_vpn_ip=connection.vpn_ip, gateway_vpn_ip=gateway.vpn_ip, target_ip=endpoint.internal_ip, target_port=endpoint.port, protocol=endpoint.protocol.value ) # Disconnect VPN client vpn_service = VPNService() client_cn = f"client-{current_user.username}-{current_user.id}" vpn_service.disconnect_client(client_cn) # Update log connection.disconnected_at = datetime.utcnow() db.commit() return {"message": "Disconnected successfully"} @router.get("/active") def list_active_connections( db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """List active connections for current user.""" connections = db.query(ConnectionLog).filter( ConnectionLog.user_id == current_user.id, ConnectionLog.disconnected_at.is_(None) ).all() result = [] for conn in connections: gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first() endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first() result.append({ "connection_id": conn.id, "gateway_name": gateway.name if gateway else None, "endpoint_name": endpoint.name if endpoint else None, "endpoint_address": endpoint.address if endpoint else None, "connected_at": conn.connected_at.isoformat(), "vpn_ip": conn.vpn_ip }) return result @router.get("/history") def get_connection_history( skip: int = 0, limit: int = 50, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """Get connection history for current user.""" query = db.query(ConnectionLog).filter(ConnectionLog.user_id == current_user.id) # Admins can see all connections in their tenant if current_user.is_admin: if current_user.role == UserRole.SUPER_ADMIN: query = db.query(ConnectionLog) else: # Filter by tenant's gateways query = db.query(ConnectionLog).join( Gateway, Gateway.id == ConnectionLog.gateway_id ).filter(Gateway.tenant_id == current_user.tenant_id) connections = query.order_by(ConnectionLog.connected_at.desc()).offset(skip).limit(limit).all() result = [] for conn in connections: user = db.query(User).filter(User.id == conn.user_id).first() gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first() endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first() result.append({ "connection_id": conn.id, "username": user.username if user else None, "gateway_name": gateway.name if gateway else None, "endpoint_name": endpoint.name if endpoint else None, "connected_at": conn.connected_at.isoformat(), "disconnected_at": conn.disconnected_at.isoformat() if conn.disconnected_at else None, "duration_seconds": conn.duration_seconds, "client_ip": conn.client_ip }) return result