diff --git a/backend/app/api/endpoints/network_diagrams.py b/backend/app/api/endpoints/network_diagrams.py new file mode 100644 index 00000000..a54478d2 --- /dev/null +++ b/backend/app/api/endpoints/network_diagrams.py @@ -0,0 +1,297 @@ +"""Network diagrams API endpoints.""" +import logging +from datetime import datetime, timezone +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import select, or_ +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.api.deps import get_current_active_user +from app.models.user import User +from app.models.device_type import DeviceType +from app.models.network_diagram import NetworkDiagram +from app.schemas.network_diagram import ( + NetworkDiagramCreate, + NetworkDiagramUpdate, + NetworkDiagramResponse, + NetworkDiagramListItem, + AIGenerateRequest, + AIGenerateResponse, + DiagramImportRequest, + DiagramImportResponse, + DiagramExportResponse, + DiagramNode, + DiagramEdge, +) +from app.services import network_diagram_ai_service + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"]) + + +async def _get_diagram_or_404( + diagram_id: UUID, + team_id: UUID, + db: AsyncSession, +) -> NetworkDiagram: + diagram = await db.get(NetworkDiagram, diagram_id) + if not diagram or diagram.team_id != team_id or diagram.is_archived: + raise HTTPException(status_code=404, detail="Diagram not found") + return diagram + + +def _diagram_to_response(diagram: NetworkDiagram) -> NetworkDiagramResponse: + return NetworkDiagramResponse.model_validate(diagram) + + +def _diagram_to_list_item(diagram: NetworkDiagram) -> NetworkDiagramListItem: + nodes = diagram.nodes if isinstance(diagram.nodes, list) else [] + return NetworkDiagramListItem( + id=diagram.id, + name=diagram.name, + client_name=diagram.client_name, + description=diagram.description, + node_count=len(nodes), + created_by=diagram.created_by, + created_at=diagram.created_at, + updated_at=diagram.updated_at, + ) + + +@router.get("/clients", response_model=list[str]) +async def list_client_names( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> list[str]: + stmt = ( + select(NetworkDiagram.client_name) + .where( + NetworkDiagram.team_id == current_user.team_id, + NetworkDiagram.is_archived.is_(False), + NetworkDiagram.client_name.isnot(None), + NetworkDiagram.client_name != "", + ) + .distinct() + .order_by(NetworkDiagram.client_name) + ) + result = await db.execute(stmt) + return [row[0] for row in result.all()] + + +@router.get("/", response_model=list[NetworkDiagramListItem]) +async def list_diagrams( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], + client_name: str | None = Query(default=None), + search: str | None = Query(default=None), +) -> list[NetworkDiagramListItem]: + stmt = ( + select(NetworkDiagram) + .where( + NetworkDiagram.team_id == current_user.team_id, + NetworkDiagram.is_archived.is_(False), + ) + .order_by(NetworkDiagram.updated_at.desc()) + ) + + if client_name: + stmt = stmt.where(NetworkDiagram.client_name == client_name) + + if search: + search_filter = f"%{search}%" + stmt = stmt.where( + or_( + NetworkDiagram.name.ilike(search_filter), + NetworkDiagram.client_name.ilike(search_filter), + ) + ) + + result = await db.execute(stmt) + rows = result.scalars().all() + return [_diagram_to_list_item(r) for r in rows] + + +@router.post("/", response_model=NetworkDiagramResponse, status_code=201) +async def create_diagram( + data: NetworkDiagramCreate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> NetworkDiagramResponse: + diagram = NetworkDiagram( + team_id=current_user.team_id, + name=data.name, + client_name=data.client_name, + asset_name=data.asset_name, + description=data.description, + nodes=[n.model_dump() for n in data.nodes], + edges=[e.model_dump() for e in data.edges], + created_by=current_user.id, + ) + db.add(diagram) + await db.commit() + await db.refresh(diagram) + return _diagram_to_response(diagram) + + +@router.get("/{diagram_id}", response_model=NetworkDiagramResponse) +async def get_diagram( + diagram_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> NetworkDiagramResponse: + diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) + return _diagram_to_response(diagram) + + +@router.put("/{diagram_id}", response_model=NetworkDiagramResponse) +async def update_diagram( + diagram_id: UUID, + data: NetworkDiagramUpdate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> NetworkDiagramResponse: + diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) + + update_data = data.model_dump(exclude_unset=True) + if "nodes" in update_data and update_data["nodes"] is not None: + update_data["nodes"] = [n.model_dump() if hasattr(n, "model_dump") else n for n in update_data["nodes"]] + if "edges" in update_data and update_data["edges"] is not None: + update_data["edges"] = [e.model_dump() if hasattr(e, "model_dump") else e for e in update_data["edges"]] + + for field, value in update_data.items(): + setattr(diagram, field, value) + + diagram.updated_at = datetime.now(timezone.utc) + await db.commit() + await db.refresh(diagram) + return _diagram_to_response(diagram) + + +@router.delete("/{diagram_id}", status_code=204) +async def archive_diagram( + diagram_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> None: + diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) + diagram.is_archived = True + diagram.updated_at = datetime.now(timezone.utc) + await db.commit() + + +@router.post("/{diagram_id}/duplicate", response_model=NetworkDiagramResponse, status_code=201) +async def duplicate_diagram( + diagram_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> NetworkDiagramResponse: + source = await _get_diagram_or_404(diagram_id, current_user.team_id, db) + copy = NetworkDiagram( + team_id=current_user.team_id, + name=f"Copy of {source.name}", + client_name=source.client_name, + asset_name=source.asset_name, + description=source.description, + nodes=source.nodes, + edges=source.edges, + created_by=current_user.id, + ) + db.add(copy) + await db.commit() + await db.refresh(copy) + return _diagram_to_response(copy) + + +@router.get("/{diagram_id}/export", response_model=DiagramExportResponse) +async def export_diagram( + diagram_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> DiagramExportResponse: + diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) + nodes = [DiagramNode(**n) for n in (diagram.nodes or [])] + edges = [DiagramEdge(**e) for e in (diagram.edges or [])] + return DiagramExportResponse( + schemaVersion=1, + name=diagram.name, + client_name=diagram.client_name, + description=diagram.description, + nodes=nodes, + edges=edges, + exportedAt=datetime.now(timezone.utc).isoformat(), + ) + + +@router.post("/import", response_model=DiagramImportResponse, status_code=201) +async def import_diagram( + data: DiagramImportRequest, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> DiagramImportResponse: + available_stmt = select(DeviceType.slug).where( + or_( + DeviceType.is_system.is_(True), + DeviceType.team_id == current_user.team_id, + ) + ) + result = await db.execute(available_stmt) + available_slugs = {row[0] for row in result.all()} + + warnings: list[str] = [] + for node in data.nodes: + if node.type not in available_slugs: + warnings.append(f"Unknown device type '{node.type}' — will render with default icon") + + diagram = NetworkDiagram( + team_id=current_user.team_id, + name=data.name, + client_name=data.client_name, + description=data.description, + nodes=[n.model_dump() for n in data.nodes], + edges=[e.model_dump() for e in data.edges], + created_by=current_user.id, + ) + db.add(diagram) + await db.commit() + await db.refresh(diagram) + + return DiagramImportResponse( + diagram=_diagram_to_response(diagram), + warnings=warnings, + ) + + +@router.post("/ai-generate", response_model=AIGenerateResponse) +async def ai_generate_diagram( + data: AIGenerateRequest, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> AIGenerateResponse: + stmt = select(DeviceType.slug).where( + or_( + DeviceType.is_system.is_(True), + DeviceType.team_id == current_user.team_id, + ) + ) + result = await db.execute(stmt) + available_slugs = [row[0] for row in result.all()] + + existing_node_ids: list[str] | None = None + if data.mode == "merge" and data.existingBounds: + existing_node_ids = [] + + try: + return await network_diagram_ai_service.generate_diagram( + request=data, + available_slugs=available_slugs, + existing_node_ids=existing_node_ids, + ) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + except Exception: + logger.exception("AI diagram generation failed") + raise HTTPException(status_code=500, detail="Diagram generation failed") diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 7c6367e0..6dd46eff 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -34,6 +34,7 @@ from app.api.endpoints import session_branches from app.api.endpoints import session_handoffs from app.api.endpoints import session_resolutions from app.api.endpoints import device_types +from app.api.endpoints import network_diagrams api_router = APIRouter() @@ -80,6 +81,7 @@ api_router.include_router(integrations.router) api_router.include_router(onboarding.router) api_router.include_router(branding.router) api_router.include_router(supporting_data.router) +api_router.include_router(network_diagrams.router) # Must be before ai_sessions to avoid /{diagram_id} conflict api_router.include_router(session_handoffs.queue_router) # Must be before ai_sessions to avoid /{session_id} conflict api_router.include_router(session_resolutions.router) # Must be before ai_sessions to avoid /{session_id} conflict api_router.include_router(ai_sessions.router)