237 lines
7.7 KiB
Python
237 lines
7.7 KiB
Python
"""Connection management API routes."""
|
|
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
from ..database import get_db
|
|
from ..models.endpoint import Endpoint
|
|
from ..models.gateway import Gateway
|
|
from ..models.user import User, UserRole
|
|
from ..models.access import UserGatewayAccess, ConnectionLog
|
|
from ..services.vpn_service import VPNService
|
|
from ..services.firewall_service import FirewallService
|
|
from .deps import get_current_user
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class ConnectRequest(BaseModel):
|
|
"""Request to establish connection to endpoint."""
|
|
gateway_id: int
|
|
endpoint_id: int
|
|
|
|
|
|
class ConnectResponse(BaseModel):
|
|
"""Response with connection details."""
|
|
success: bool
|
|
message: str
|
|
vpn_config: str | None = None
|
|
target_ip: str | None = None
|
|
target_port: int | None = None
|
|
connection_id: int | None = None
|
|
|
|
|
|
class DisconnectRequest(BaseModel):
|
|
"""Request to disconnect from endpoint."""
|
|
connection_id: int
|
|
|
|
|
|
def user_has_gateway_access(db: Session, user: User, gateway_id: int) -> bool:
|
|
"""Check if user has access to gateway."""
|
|
if user.role in (UserRole.SUPER_ADMIN, UserRole.ADMIN):
|
|
gateway = db.query(Gateway).filter(Gateway.id == gateway_id).first()
|
|
if user.role == UserRole.SUPER_ADMIN:
|
|
return gateway is not None
|
|
return gateway and gateway.tenant_id == user.tenant_id
|
|
|
|
access = db.query(UserGatewayAccess).filter(
|
|
UserGatewayAccess.user_id == user.id,
|
|
UserGatewayAccess.gateway_id == gateway_id
|
|
).first()
|
|
return access is not None
|
|
|
|
|
|
@router.post("/connect", response_model=ConnectResponse)
|
|
def connect_to_endpoint(
|
|
request: ConnectRequest,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Request connection to a specific endpoint through a gateway."""
|
|
# Check gateway access
|
|
if not user_has_gateway_access(db, current_user, request.gateway_id):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="No access to this gateway"
|
|
)
|
|
|
|
# Get gateway
|
|
gateway = db.query(Gateway).filter(Gateway.id == request.gateway_id).first()
|
|
if not gateway:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Gateway not found"
|
|
)
|
|
|
|
if not gateway.is_online:
|
|
return ConnectResponse(
|
|
success=False,
|
|
message="Gateway is offline"
|
|
)
|
|
|
|
# Get endpoint
|
|
endpoint = db.query(Endpoint).filter(
|
|
Endpoint.id == request.endpoint_id,
|
|
Endpoint.gateway_id == request.gateway_id
|
|
).first()
|
|
|
|
if not endpoint:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Endpoint not found"
|
|
)
|
|
|
|
# NOTE: Dynamic VPN config generation has been replaced by VPN profiles.
|
|
# Gateways should have pre-provisioned VPN profiles.
|
|
# This endpoint now just logs the connection intent.
|
|
|
|
# Log connection
|
|
connection = ConnectionLog(
|
|
user_id=current_user.id,
|
|
gateway_id=gateway.id,
|
|
endpoint_id=endpoint.id,
|
|
client_ip=None # Will be updated when VPN connects
|
|
)
|
|
db.add(connection)
|
|
db.commit()
|
|
db.refresh(connection)
|
|
|
|
return ConnectResponse(
|
|
success=True,
|
|
message="Connection logged. Use the gateway's VPN profile configuration to connect.",
|
|
vpn_config=None, # VPN config is now obtained through gateway VPN profiles
|
|
target_ip=endpoint.internal_ip,
|
|
target_port=endpoint.port,
|
|
connection_id=connection.id
|
|
)
|
|
|
|
|
|
@router.post("/disconnect")
|
|
def disconnect_from_endpoint(
|
|
request: DisconnectRequest,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Disconnect from endpoint."""
|
|
connection = db.query(ConnectionLog).filter(
|
|
ConnectionLog.id == request.connection_id,
|
|
ConnectionLog.user_id == current_user.id
|
|
).first()
|
|
|
|
if not connection:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Connection not found"
|
|
)
|
|
|
|
if connection.disconnected_at:
|
|
return {"message": "Already disconnected"}
|
|
|
|
# Revoke firewall rules if VPN IP is known
|
|
if connection.vpn_ip:
|
|
endpoint = db.query(Endpoint).filter(Endpoint.id == connection.endpoint_id).first()
|
|
gateway = db.query(Gateway).filter(Gateway.id == connection.gateway_id).first()
|
|
|
|
if endpoint and gateway:
|
|
firewall = FirewallService()
|
|
firewall.revoke_connection(
|
|
client_vpn_ip=connection.vpn_ip,
|
|
gateway_vpn_ip=gateway.vpn_ip,
|
|
target_ip=endpoint.internal_ip,
|
|
target_port=endpoint.port,
|
|
protocol=endpoint.protocol.value
|
|
)
|
|
|
|
# Disconnect VPN client
|
|
vpn_service = VPNService()
|
|
client_cn = f"client-{current_user.username}-{current_user.id}"
|
|
vpn_service.disconnect_client(client_cn)
|
|
|
|
# Update log
|
|
connection.disconnected_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
return {"message": "Disconnected successfully"}
|
|
|
|
|
|
@router.get("/active")
|
|
def list_active_connections(
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""List active connections for current user."""
|
|
connections = db.query(ConnectionLog).filter(
|
|
ConnectionLog.user_id == current_user.id,
|
|
ConnectionLog.disconnected_at.is_(None)
|
|
).all()
|
|
|
|
result = []
|
|
for conn in connections:
|
|
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
|
|
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
|
|
|
|
result.append({
|
|
"connection_id": conn.id,
|
|
"gateway_name": gateway.name if gateway else None,
|
|
"endpoint_name": endpoint.name if endpoint else None,
|
|
"endpoint_address": endpoint.address if endpoint else None,
|
|
"connected_at": conn.connected_at.isoformat(),
|
|
"vpn_ip": conn.vpn_ip
|
|
})
|
|
|
|
return result
|
|
|
|
|
|
@router.get("/history")
|
|
def get_connection_history(
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Get connection history for current user."""
|
|
query = db.query(ConnectionLog).filter(ConnectionLog.user_id == current_user.id)
|
|
|
|
# Admins can see all connections in their tenant
|
|
if current_user.is_admin:
|
|
if current_user.role == UserRole.SUPER_ADMIN:
|
|
query = db.query(ConnectionLog)
|
|
else:
|
|
# Filter by tenant's gateways
|
|
query = db.query(ConnectionLog).join(
|
|
Gateway,
|
|
Gateway.id == ConnectionLog.gateway_id
|
|
).filter(Gateway.tenant_id == current_user.tenant_id)
|
|
|
|
connections = query.order_by(ConnectionLog.connected_at.desc()).offset(skip).limit(limit).all()
|
|
|
|
result = []
|
|
for conn in connections:
|
|
user = db.query(User).filter(User.id == conn.user_id).first()
|
|
gateway = db.query(Gateway).filter(Gateway.id == conn.gateway_id).first()
|
|
endpoint = db.query(Endpoint).filter(Endpoint.id == conn.endpoint_id).first()
|
|
|
|
result.append({
|
|
"connection_id": conn.id,
|
|
"username": user.username if user else None,
|
|
"gateway_name": gateway.name if gateway else None,
|
|
"endpoint_name": endpoint.name if endpoint else None,
|
|
"connected_at": conn.connected_at.isoformat(),
|
|
"disconnected_at": conn.disconnected_at.isoformat() if conn.disconnected_at else None,
|
|
"duration_seconds": conn.duration_seconds,
|
|
"client_ip": conn.client_ip
|
|
})
|
|
|
|
return result
|