"""Network diagrams API endpoints.""" import base64 import logging from datetime import datetime, timezone from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel from sqlalchemy import select, or_ from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_db from app.api.deps import get_current_active_user from app.models.user import User from app.models.device_type import DeviceType from app.models.network_diagram import NetworkDiagram from app.core.service_account import PLATFORM_ACCOUNT_ID from app.schemas.network_diagram import ( NetworkDiagramCreate, NetworkDiagramUpdate, NetworkDiagramResponse, NetworkDiagramListItem, AIGenerateRequest, AIGenerateResponse, DiagramImportRequest, DiagramImportResponse, DiagramExportResponse, DiagramNode, DiagramEdge, ) from app.services import network_diagram_ai_service, storage_service # Maps system device-type slugs to their category — mirrors frontend deviceRegistry.ts _SLUG_CATEGORY: dict[str, str] = { "router": "network", "switch": "network", "access-point": "network", "load-balancer": "network", "firewall": "security", "badge-reader": "security", "server": "compute", "vm": "compute", "container": "compute", "nas": "storage", "san": "storage", "cloud-storage": "storage", "cloud": "cloud", "aws": "cloud", "azure": "cloud", "gcp": "cloud", "isp": "cloud", "workstation": "endpoint", "laptop": "endpoint", "tablet": "endpoint", "phone": "endpoint", "printer": "endpoint", "ups": "infrastructure", "pdu": "infrastructure", "rack": "infrastructure", "patch-panel": "infrastructure", "camera": "infrastructure", "nvr": "infrastructure", "iot": "infrastructure", } logger = logging.getLogger(__name__) router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"]) async def _get_diagram_or_404( diagram_id: UUID, account_id: UUID, db: AsyncSession, ) -> NetworkDiagram: diagram = await db.get(NetworkDiagram, diagram_id) if not diagram or diagram.account_id != account_id or diagram.is_archived: raise HTTPException(status_code=404, detail="Diagram not found") return diagram def _diagram_to_response(diagram: NetworkDiagram) -> NetworkDiagramResponse: return NetworkDiagramResponse.model_validate(diagram) def _diagram_to_list_item( diagram: NetworkDiagram, custom_slug_category: dict[str, str] | None = None, ) -> NetworkDiagramListItem: nodes = diagram.nodes if isinstance(diagram.nodes, list) else [] slug_to_cat = {**_SLUG_CATEGORY, **(custom_slug_category or {})} category_counts: dict[str, int] = {} for node in nodes: slug = node.get("type", "") if isinstance(node, dict) else "" cat = slug_to_cat.get(slug, "other") category_counts[cat] = category_counts.get(cat, 0) + 1 return NetworkDiagramListItem( id=diagram.id, name=diagram.name, client_name=diagram.client_name, description=diagram.description, node_count=len(nodes), category_counts=category_counts, thumbnail_url=diagram.thumbnail_url, created_by=diagram.created_by, created_at=diagram.created_at, updated_at=diagram.updated_at, ) async def _get_available_slugs(account_id: UUID, db: AsyncSession) -> set[str]: stmt = select(DeviceType.slug).where( or_( DeviceType.account_id == PLATFORM_ACCOUNT_ID, DeviceType.account_id == account_id, ) ) result = await db.execute(stmt) return {row[0] for row in result.all()} @router.get("/clients", response_model=list[str]) async def list_client_names( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> list[str]: stmt = ( select(NetworkDiagram.client_name) .where( NetworkDiagram.account_id == current_user.account_id, NetworkDiagram.is_archived.is_(False), NetworkDiagram.client_name.isnot(None), NetworkDiagram.client_name != "", ) .distinct() .order_by(NetworkDiagram.client_name) ) result = await db.execute(stmt) return [row[0] for row in result.all()] @router.get("/", response_model=list[NetworkDiagramListItem]) async def list_diagrams( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], client_name: str | None = Query(default=None), search: str | None = Query(default=None), ) -> list[NetworkDiagramListItem]: stmt = ( select(NetworkDiagram) .where( NetworkDiagram.account_id == current_user.account_id, NetworkDiagram.is_archived.is_(False), ) .order_by(NetworkDiagram.updated_at.desc()) ) if client_name: stmt = stmt.where(NetworkDiagram.client_name == client_name) if search: escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") search_filter = f"%{escaped}%" stmt = stmt.where( or_( NetworkDiagram.name.ilike(search_filter), NetworkDiagram.client_name.ilike(search_filter), ) ) # Single query for custom device types so category_counts is accurate dt_stmt = select(DeviceType.slug, DeviceType.category).where( DeviceType.is_system.is_(False), DeviceType.account_id == current_user.account_id, ) dt_result = await db.execute(dt_stmt) custom_slug_category = {row[0]: row[1] for row in dt_result.all()} result = await db.execute(stmt) rows = result.scalars().all() return [_diagram_to_list_item(r, custom_slug_category) for r in rows] @router.post("/", response_model=NetworkDiagramResponse, status_code=201) async def create_diagram( data: NetworkDiagramCreate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: diagram = NetworkDiagram( account_id=current_user.account_id, name=data.name, client_name=data.client_name, asset_name=data.asset_name, description=data.description, nodes=[n.model_dump() for n in data.nodes], edges=[e.model_dump() for e in data.edges], created_by=current_user.id, ) db.add(diagram) await db.commit() await db.refresh(diagram) return _diagram_to_response(diagram) @router.get("/{diagram_id}", response_model=NetworkDiagramResponse) async def get_diagram( diagram_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) return _diagram_to_response(diagram) @router.put("/{diagram_id}", response_model=NetworkDiagramResponse) async def update_diagram( diagram_id: UUID, data: NetworkDiagramUpdate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) update_data = data.model_dump(exclude_unset=True) if "nodes" in update_data and update_data["nodes"] is not None: update_data["nodes"] = [n.model_dump() if hasattr(n, "model_dump") else n for n in update_data["nodes"]] if "edges" in update_data and update_data["edges"] is not None: update_data["edges"] = [e.model_dump() if hasattr(e, "model_dump") else e for e in update_data["edges"]] for field, value in update_data.items(): setattr(diagram, field, value) diagram.updated_at = datetime.now(timezone.utc) await db.commit() await db.refresh(diagram) return _diagram_to_response(diagram) @router.delete("/{diagram_id}", status_code=204) async def archive_diagram( diagram_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> None: diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) diagram.is_archived = True diagram.updated_at = datetime.now(timezone.utc) await db.commit() @router.post("/{diagram_id}/duplicate", response_model=NetworkDiagramResponse, status_code=201) async def duplicate_diagram( diagram_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: source = await _get_diagram_or_404(diagram_id, current_user.account_id, db) copy = NetworkDiagram( account_id=current_user.account_id, name=f"Copy of {source.name}", client_name=source.client_name, asset_name=source.asset_name, description=source.description, nodes=source.nodes, edges=source.edges, created_by=current_user.id, ) db.add(copy) await db.commit() await db.refresh(copy) return _diagram_to_response(copy) @router.get("/{diagram_id}/export", response_model=DiagramExportResponse) async def export_diagram( diagram_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> DiagramExportResponse: diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) nodes = [DiagramNode(**n) for n in (diagram.nodes or [])] edges = [DiagramEdge(**e) for e in (diagram.edges or [])] return DiagramExportResponse( schemaVersion=1, name=diagram.name, client_name=diagram.client_name, description=diagram.description, nodes=nodes, edges=edges, exportedAt=datetime.now(timezone.utc).isoformat(), ) @router.post("/import", response_model=DiagramImportResponse, status_code=201) async def import_diagram( data: DiagramImportRequest, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> DiagramImportResponse: available_slugs = await _get_available_slugs(current_user.account_id, db) warnings: list[str] = [] for node in data.nodes: if node.type not in available_slugs: warnings.append(f"Unknown device type '{node.type}' — will render with default icon") diagram = NetworkDiagram( account_id=current_user.account_id, name=data.name, client_name=data.client_name, description=data.description, nodes=[n.model_dump() for n in data.nodes], edges=[e.model_dump() for e in data.edges], created_by=current_user.id, ) db.add(diagram) await db.commit() await db.refresh(diagram) return DiagramImportResponse( diagram=_diagram_to_response(diagram), warnings=warnings, ) class ThumbnailUploadRequest(BaseModel): data_url: str # base64 PNG data URL: "data:image/png;base64,..." @router.post("/{diagram_id}/thumbnail", status_code=204) async def upload_thumbnail( diagram_id: UUID, body: ThumbnailUploadRequest, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> None: diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) try: header, encoded = body.data_url.split(",", 1) except ValueError: raise HTTPException(status_code=422, detail="Invalid data URL format") image_bytes = base64.b64decode(encoded) storage_key = await storage_service.upload_file( file_data=image_bytes, filename=f"thumbnail-{diagram_id}.png", content_type="image/png", account_id=str(current_user.account_id), ) presigned_url = storage_service.get_presigned_url(storage_key) diagram.thumbnail_url = presigned_url await db.commit() @router.post("/ai-generate", response_model=AIGenerateResponse) async def ai_generate_diagram( data: AIGenerateRequest, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> AIGenerateResponse: available_slugs_set = await _get_available_slugs(current_user.account_id, db) available_slugs = list(available_slugs_set) existing_node_ids: list[str] | None = None if data.mode == "merge" and data.existingBounds: existing_node_ids = [] try: return await network_diagram_ai_service.generate_diagram( request=data, available_slugs=available_slugs, existing_node_ids=existing_node_ids, ) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) except Exception: logger.exception("AI diagram generation failed") raise HTTPException(status_code=500, detail="Diagram generation failed")