"""Network diagrams API endpoints.""" import logging from datetime import datetime, timezone from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy import select, or_ from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_db from app.api.deps import get_current_active_user from app.models.user import User from app.models.device_type import DeviceType from app.models.network_diagram import NetworkDiagram from app.schemas.network_diagram import ( NetworkDiagramCreate, NetworkDiagramUpdate, NetworkDiagramResponse, NetworkDiagramListItem, AIGenerateRequest, AIGenerateResponse, DiagramImportRequest, DiagramImportResponse, DiagramExportResponse, DiagramNode, DiagramEdge, ) from app.services import network_diagram_ai_service logger = logging.getLogger(__name__) router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"]) async def _get_diagram_or_404( diagram_id: UUID, team_id: UUID, db: AsyncSession, ) -> NetworkDiagram: diagram = await db.get(NetworkDiagram, diagram_id) if not diagram or diagram.team_id != team_id or diagram.is_archived: raise HTTPException(status_code=404, detail="Diagram not found") return diagram def _diagram_to_response(diagram: NetworkDiagram) -> NetworkDiagramResponse: return NetworkDiagramResponse.model_validate(diagram) def _diagram_to_list_item(diagram: NetworkDiagram) -> NetworkDiagramListItem: nodes = diagram.nodes if isinstance(diagram.nodes, list) else [] return NetworkDiagramListItem( id=diagram.id, name=diagram.name, client_name=diagram.client_name, description=diagram.description, node_count=len(nodes), created_by=diagram.created_by, created_at=diagram.created_at, updated_at=diagram.updated_at, ) @router.get("/clients", response_model=list[str]) async def list_client_names( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> list[str]: stmt = ( select(NetworkDiagram.client_name) .where( NetworkDiagram.team_id == current_user.team_id, NetworkDiagram.is_archived.is_(False), NetworkDiagram.client_name.isnot(None), NetworkDiagram.client_name != "", ) .distinct() .order_by(NetworkDiagram.client_name) ) result = await db.execute(stmt) return [row[0] for row in result.all()] @router.get("/", response_model=list[NetworkDiagramListItem]) async def list_diagrams( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], client_name: str | None = Query(default=None), search: str | None = Query(default=None), ) -> list[NetworkDiagramListItem]: stmt = ( select(NetworkDiagram) .where( NetworkDiagram.team_id == current_user.team_id, NetworkDiagram.is_archived.is_(False), ) .order_by(NetworkDiagram.updated_at.desc()) ) if client_name: stmt = stmt.where(NetworkDiagram.client_name == client_name) if search: search_filter = f"%{search}%" stmt = stmt.where( or_( NetworkDiagram.name.ilike(search_filter), NetworkDiagram.client_name.ilike(search_filter), ) ) result = await db.execute(stmt) rows = result.scalars().all() return [_diagram_to_list_item(r) for r in rows] @router.post("/", response_model=NetworkDiagramResponse, status_code=201) async def create_diagram( data: NetworkDiagramCreate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: diagram = NetworkDiagram( team_id=current_user.team_id, name=data.name, client_name=data.client_name, asset_name=data.asset_name, description=data.description, nodes=[n.model_dump() for n in data.nodes], edges=[e.model_dump() for e in data.edges], created_by=current_user.id, ) db.add(diagram) await db.commit() await db.refresh(diagram) return _diagram_to_response(diagram) @router.get("/{diagram_id}", response_model=NetworkDiagramResponse) async def get_diagram( diagram_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) return _diagram_to_response(diagram) @router.put("/{diagram_id}", response_model=NetworkDiagramResponse) async def update_diagram( diagram_id: UUID, data: NetworkDiagramUpdate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) update_data = data.model_dump(exclude_unset=True) if "nodes" in update_data and update_data["nodes"] is not None: update_data["nodes"] = [n.model_dump() if hasattr(n, "model_dump") else n for n in update_data["nodes"]] if "edges" in update_data and update_data["edges"] is not None: update_data["edges"] = [e.model_dump() if hasattr(e, "model_dump") else e for e in update_data["edges"]] for field, value in update_data.items(): setattr(diagram, field, value) diagram.updated_at = datetime.now(timezone.utc) await db.commit() await db.refresh(diagram) return _diagram_to_response(diagram) @router.delete("/{diagram_id}", status_code=204) async def archive_diagram( diagram_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> None: diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) diagram.is_archived = True diagram.updated_at = datetime.now(timezone.utc) await db.commit() @router.post("/{diagram_id}/duplicate", response_model=NetworkDiagramResponse, status_code=201) async def duplicate_diagram( diagram_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: source = await _get_diagram_or_404(diagram_id, current_user.team_id, db) copy = NetworkDiagram( team_id=current_user.team_id, name=f"Copy of {source.name}", client_name=source.client_name, asset_name=source.asset_name, description=source.description, nodes=source.nodes, edges=source.edges, created_by=current_user.id, ) db.add(copy) await db.commit() await db.refresh(copy) return _diagram_to_response(copy) @router.get("/{diagram_id}/export", response_model=DiagramExportResponse) async def export_diagram( diagram_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> DiagramExportResponse: diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) nodes = [DiagramNode(**n) for n in (diagram.nodes or [])] edges = [DiagramEdge(**e) for e in (diagram.edges or [])] return DiagramExportResponse( schemaVersion=1, name=diagram.name, client_name=diagram.client_name, description=diagram.description, nodes=nodes, edges=edges, exportedAt=datetime.now(timezone.utc).isoformat(), ) @router.post("/import", response_model=DiagramImportResponse, status_code=201) async def import_diagram( data: DiagramImportRequest, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> DiagramImportResponse: available_stmt = select(DeviceType.slug).where( or_( DeviceType.is_system.is_(True), DeviceType.team_id == current_user.team_id, ) ) result = await db.execute(available_stmt) available_slugs = {row[0] for row in result.all()} warnings: list[str] = [] for node in data.nodes: if node.type not in available_slugs: warnings.append(f"Unknown device type '{node.type}' — will render with default icon") diagram = NetworkDiagram( team_id=current_user.team_id, name=data.name, client_name=data.client_name, description=data.description, nodes=[n.model_dump() for n in data.nodes], edges=[e.model_dump() for e in data.edges], created_by=current_user.id, ) db.add(diagram) await db.commit() await db.refresh(diagram) return DiagramImportResponse( diagram=_diagram_to_response(diagram), warnings=warnings, ) @router.post("/ai-generate", response_model=AIGenerateResponse) async def ai_generate_diagram( data: AIGenerateRequest, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> AIGenerateResponse: stmt = select(DeviceType.slug).where( or_( DeviceType.is_system.is_(True), DeviceType.team_id == current_user.team_id, ) ) result = await db.execute(stmt) available_slugs = [row[0] for row in result.all()] existing_node_ids: list[str] | None = None if data.mode == "merge" and data.existingBounds: existing_node_ids = [] try: return await network_diagram_ai_service.generate_diagram( request=data, available_slugs=available_slugs, existing_node_ids=existing_node_ids, ) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) except Exception: logger.exception("AI diagram generation failed") raise HTTPException(status_code=500, detail="Diagram generation failed")