- Colorize: semantic category colors for all device types (network=blue, security=orange, compute=emerald, endpoint=amber, storage=violet, cloud=cyan, infra=steel); better icons (Router, ShieldAlert, Boxes, Package, Gauge, PlugZap, Video, Radio); MiniMap uses category colors - Onboard: centered AI generate prompt on empty canvas with 5 MSP-specific example chips, ⌘↵ shortcut, spinner; AIAssistPanel only shown with nodes - Arrange: properties panel — status badge grid at top, fields grouped into Network (IP/Subnet/VLAN) and Hardware (Hostname/Vendor/Model/Role) sections - Delight: segmented topology color bar on listing cards; backend returns category_counts via single extra query on list endpoint - Harden: real PNG export via html-to-image + getNodesBounds/getViewportForBounds - Polish: ChevronDown replaces unicode ▾, click-outside for client filter, consistent spinner in empty prompt Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
333 lines
12 KiB
Python
333 lines
12 KiB
Python
"""Network diagrams API endpoints."""
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Annotated
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlalchemy import select, or_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.database import get_db
|
|
from app.api.deps import get_current_active_user
|
|
from app.models.user import User
|
|
from app.models.device_type import DeviceType
|
|
from app.models.network_diagram import NetworkDiagram
|
|
from app.schemas.network_diagram import (
|
|
NetworkDiagramCreate,
|
|
NetworkDiagramUpdate,
|
|
NetworkDiagramResponse,
|
|
NetworkDiagramListItem,
|
|
AIGenerateRequest,
|
|
AIGenerateResponse,
|
|
DiagramImportRequest,
|
|
DiagramImportResponse,
|
|
DiagramExportResponse,
|
|
DiagramNode,
|
|
DiagramEdge,
|
|
)
|
|
from app.services import network_diagram_ai_service
|
|
|
|
# Maps system device-type slugs to their category — mirrors frontend deviceRegistry.ts
|
|
_SLUG_CATEGORY: dict[str, str] = {
|
|
"router": "network", "switch": "network", "access-point": "network", "load-balancer": "network",
|
|
"firewall": "security", "badge-reader": "security",
|
|
"server": "compute", "vm": "compute", "container": "compute",
|
|
"nas": "storage", "san": "storage", "cloud-storage": "storage",
|
|
"cloud": "cloud", "aws": "cloud", "azure": "cloud", "gcp": "cloud", "isp": "cloud",
|
|
"workstation": "endpoint", "laptop": "endpoint", "tablet": "endpoint",
|
|
"phone": "endpoint", "printer": "endpoint",
|
|
"ups": "infrastructure", "pdu": "infrastructure", "rack": "infrastructure",
|
|
"patch-panel": "infrastructure", "camera": "infrastructure",
|
|
"nvr": "infrastructure", "iot": "infrastructure",
|
|
}
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"])
|
|
|
|
|
|
async def _get_diagram_or_404(
|
|
diagram_id: UUID,
|
|
team_id: UUID,
|
|
db: AsyncSession,
|
|
) -> NetworkDiagram:
|
|
diagram = await db.get(NetworkDiagram, diagram_id)
|
|
if not diagram or diagram.team_id != team_id or diagram.is_archived:
|
|
raise HTTPException(status_code=404, detail="Diagram not found")
|
|
return diagram
|
|
|
|
|
|
def _diagram_to_response(diagram: NetworkDiagram) -> NetworkDiagramResponse:
|
|
return NetworkDiagramResponse.model_validate(diagram)
|
|
|
|
|
|
def _diagram_to_list_item(
|
|
diagram: NetworkDiagram,
|
|
custom_slug_category: dict[str, str] | None = None,
|
|
) -> NetworkDiagramListItem:
|
|
nodes = diagram.nodes if isinstance(diagram.nodes, list) else []
|
|
slug_to_cat = {**_SLUG_CATEGORY, **(custom_slug_category or {})}
|
|
|
|
category_counts: dict[str, int] = {}
|
|
for node in nodes:
|
|
slug = node.get("type", "") if isinstance(node, dict) else ""
|
|
cat = slug_to_cat.get(slug, "other")
|
|
category_counts[cat] = category_counts.get(cat, 0) + 1
|
|
|
|
return NetworkDiagramListItem(
|
|
id=diagram.id,
|
|
name=diagram.name,
|
|
client_name=diagram.client_name,
|
|
description=diagram.description,
|
|
node_count=len(nodes),
|
|
category_counts=category_counts,
|
|
created_by=diagram.created_by,
|
|
created_at=diagram.created_at,
|
|
updated_at=diagram.updated_at,
|
|
)
|
|
|
|
|
|
async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]:
|
|
stmt = select(DeviceType.slug).where(
|
|
or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return {row[0] for row in result.all()}
|
|
|
|
|
|
@router.get("/clients", response_model=list[str])
|
|
async def list_client_names(
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> list[str]:
|
|
stmt = (
|
|
select(NetworkDiagram.client_name)
|
|
.where(
|
|
NetworkDiagram.team_id == current_user.team_id,
|
|
NetworkDiagram.is_archived.is_(False),
|
|
NetworkDiagram.client_name.isnot(None),
|
|
NetworkDiagram.client_name != "",
|
|
)
|
|
.distinct()
|
|
.order_by(NetworkDiagram.client_name)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return [row[0] for row in result.all()]
|
|
|
|
|
|
@router.get("/", response_model=list[NetworkDiagramListItem])
|
|
async def list_diagrams(
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
client_name: str | None = Query(default=None),
|
|
search: str | None = Query(default=None),
|
|
) -> list[NetworkDiagramListItem]:
|
|
stmt = (
|
|
select(NetworkDiagram)
|
|
.where(
|
|
NetworkDiagram.team_id == current_user.team_id,
|
|
NetworkDiagram.is_archived.is_(False),
|
|
)
|
|
.order_by(NetworkDiagram.updated_at.desc())
|
|
)
|
|
|
|
if client_name:
|
|
stmt = stmt.where(NetworkDiagram.client_name == client_name)
|
|
|
|
if search:
|
|
escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
|
search_filter = f"%{escaped}%"
|
|
stmt = stmt.where(
|
|
or_(
|
|
NetworkDiagram.name.ilike(search_filter),
|
|
NetworkDiagram.client_name.ilike(search_filter),
|
|
)
|
|
)
|
|
|
|
# Single query for custom device types so category_counts is accurate
|
|
dt_stmt = select(DeviceType.slug, DeviceType.category).where(
|
|
DeviceType.is_system.is_(False),
|
|
DeviceType.team_id == current_user.team_id,
|
|
)
|
|
dt_result = await db.execute(dt_stmt)
|
|
custom_slug_category = {row[0]: row[1] for row in dt_result.all()}
|
|
|
|
result = await db.execute(stmt)
|
|
rows = result.scalars().all()
|
|
return [_diagram_to_list_item(r, custom_slug_category) for r in rows]
|
|
|
|
|
|
@router.post("/", response_model=NetworkDiagramResponse, status_code=201)
|
|
async def create_diagram(
|
|
data: NetworkDiagramCreate,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> NetworkDiagramResponse:
|
|
if current_user.team_id is None:
|
|
raise HTTPException(
|
|
status_code=422,
|
|
detail="Network Diagrams require a team account. Assign your account to a team first.",
|
|
)
|
|
diagram = NetworkDiagram(
|
|
team_id=current_user.team_id,
|
|
name=data.name,
|
|
client_name=data.client_name,
|
|
asset_name=data.asset_name,
|
|
description=data.description,
|
|
nodes=[n.model_dump() for n in data.nodes],
|
|
edges=[e.model_dump() for e in data.edges],
|
|
created_by=current_user.id,
|
|
)
|
|
db.add(diagram)
|
|
await db.commit()
|
|
await db.refresh(diagram)
|
|
return _diagram_to_response(diagram)
|
|
|
|
|
|
@router.get("/{diagram_id}", response_model=NetworkDiagramResponse)
|
|
async def get_diagram(
|
|
diagram_id: UUID,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> NetworkDiagramResponse:
|
|
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
|
|
return _diagram_to_response(diagram)
|
|
|
|
|
|
@router.put("/{diagram_id}", response_model=NetworkDiagramResponse)
|
|
async def update_diagram(
|
|
diagram_id: UUID,
|
|
data: NetworkDiagramUpdate,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> NetworkDiagramResponse:
|
|
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
|
|
|
|
update_data = data.model_dump(exclude_unset=True)
|
|
if "nodes" in update_data and update_data["nodes"] is not None:
|
|
update_data["nodes"] = [n.model_dump() if hasattr(n, "model_dump") else n for n in update_data["nodes"]]
|
|
if "edges" in update_data and update_data["edges"] is not None:
|
|
update_data["edges"] = [e.model_dump() if hasattr(e, "model_dump") else e for e in update_data["edges"]]
|
|
|
|
for field, value in update_data.items():
|
|
setattr(diagram, field, value)
|
|
|
|
diagram.updated_at = datetime.now(timezone.utc)
|
|
await db.commit()
|
|
await db.refresh(diagram)
|
|
return _diagram_to_response(diagram)
|
|
|
|
|
|
@router.delete("/{diagram_id}", status_code=204)
|
|
async def archive_diagram(
|
|
diagram_id: UUID,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> None:
|
|
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
|
|
diagram.is_archived = True
|
|
diagram.updated_at = datetime.now(timezone.utc)
|
|
await db.commit()
|
|
|
|
|
|
@router.post("/{diagram_id}/duplicate", response_model=NetworkDiagramResponse, status_code=201)
|
|
async def duplicate_diagram(
|
|
diagram_id: UUID,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> NetworkDiagramResponse:
|
|
source = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
|
|
copy = NetworkDiagram(
|
|
team_id=current_user.team_id,
|
|
name=f"Copy of {source.name}",
|
|
client_name=source.client_name,
|
|
asset_name=source.asset_name,
|
|
description=source.description,
|
|
nodes=source.nodes,
|
|
edges=source.edges,
|
|
created_by=current_user.id,
|
|
)
|
|
db.add(copy)
|
|
await db.commit()
|
|
await db.refresh(copy)
|
|
return _diagram_to_response(copy)
|
|
|
|
|
|
@router.get("/{diagram_id}/export", response_model=DiagramExportResponse)
|
|
async def export_diagram(
|
|
diagram_id: UUID,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> DiagramExportResponse:
|
|
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
|
|
nodes = [DiagramNode(**n) for n in (diagram.nodes or [])]
|
|
edges = [DiagramEdge(**e) for e in (diagram.edges or [])]
|
|
return DiagramExportResponse(
|
|
schemaVersion=1,
|
|
name=diagram.name,
|
|
client_name=diagram.client_name,
|
|
description=diagram.description,
|
|
nodes=nodes,
|
|
edges=edges,
|
|
exportedAt=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
|
|
|
|
@router.post("/import", response_model=DiagramImportResponse, status_code=201)
|
|
async def import_diagram(
|
|
data: DiagramImportRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> DiagramImportResponse:
|
|
available_slugs = await _get_available_slugs(current_user.team_id, db)
|
|
|
|
warnings: list[str] = []
|
|
for node in data.nodes:
|
|
if node.type not in available_slugs:
|
|
warnings.append(f"Unknown device type '{node.type}' — will render with default icon")
|
|
|
|
diagram = NetworkDiagram(
|
|
team_id=current_user.team_id,
|
|
name=data.name,
|
|
client_name=data.client_name,
|
|
description=data.description,
|
|
nodes=[n.model_dump() for n in data.nodes],
|
|
edges=[e.model_dump() for e in data.edges],
|
|
created_by=current_user.id,
|
|
)
|
|
db.add(diagram)
|
|
await db.commit()
|
|
await db.refresh(diagram)
|
|
|
|
return DiagramImportResponse(
|
|
diagram=_diagram_to_response(diagram),
|
|
warnings=warnings,
|
|
)
|
|
|
|
|
|
@router.post("/ai-generate", response_model=AIGenerateResponse)
|
|
async def ai_generate_diagram(
|
|
data: AIGenerateRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> AIGenerateResponse:
|
|
available_slugs_set = await _get_available_slugs(current_user.team_id, db)
|
|
available_slugs = list(available_slugs_set)
|
|
|
|
existing_node_ids: list[str] | None = None
|
|
if data.mode == "merge" and data.existingBounds:
|
|
existing_node_ids = []
|
|
|
|
try:
|
|
return await network_diagram_ai_service.generate_diagram(
|
|
request=data,
|
|
available_slugs=available_slugs,
|
|
existing_node_ids=existing_node_ids,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=422, detail=str(e))
|
|
except Exception:
|
|
logger.exception("AI diagram generation failed")
|
|
raise HTTPException(status_code=500, detail="Diagram generation failed")
|