diff --git a/backend/alembic/versions/073_add_device_types_table.py b/backend/alembic/versions/073_add_device_types_table.py new file mode 100644 index 00000000..70b61e5d --- /dev/null +++ b/backend/alembic/versions/073_add_device_types_table.py @@ -0,0 +1,132 @@ +"""Add account-scoped device_types table with platform seed data. + +Revision ID: 073 +Revises: b3c7e9f2a1d8 +Create Date: 2026-04-12 +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID +import uuid + + +revision = "073" +down_revision = "b3c7e9f2a1d8" +branch_labels = None +depends_on = None + +_PLATFORM_UUID = "00000000-0000-0000-0000-000000000001" +_CURRENT_ACCOUNT = ( + "COALESCE(" + "NULLIF(current_setting('app.current_account_id', TRUE), ''), " + "'00000000-0000-0000-0000-000000000000'" + ")::uuid" +) + +SYSTEM_DEVICE_TYPES = [ + ("router", "Router", "network", 0), + ("switch", "Switch", "network", 1), + ("firewall", "Firewall", "network", 2), + ("access-point", "Access Point", "network", 3), + ("load-balancer", "Load Balancer", "network", 4), + ("server", "Server", "compute", 0), + ("workstation", "Workstation", "compute", 1), + ("vm", "Virtual Machine", "compute", 2), + ("container", "Container", "compute", 3), + ("nas", "NAS", "storage", 0), + ("san", "SAN", "storage", 1), + ("cloud-storage", "Cloud Storage", "storage", 2), + ("cloud", "Cloud", "cloud", 0), + ("aws", "AWS", "cloud", 1), + ("azure", "Azure", "cloud", 2), + ("gcp", "Google Cloud", "cloud", 3), + ("printer", "Printer", "endpoint", 0), + ("phone", "Phone", "endpoint", 1), + ("iot", "IoT Device", "endpoint", 2), + ("camera", "Camera", "endpoint", 3), + ("tablet", "Tablet", "endpoint", 4), + ("laptop", "Laptop", "endpoint", 5), + ("ups", "UPS", "infrastructure", 0), + ("pdu", "PDU", "infrastructure", 1), + ("rack", "Rack", "infrastructure", 2), + ("patch-panel", "Patch Panel", "infrastructure", 3), + ("nvr", "NVR", "security", 0), + ("badge-reader", "Badge Reader", "security", 1), +] + + +def upgrade() -> None: + op.create_table( + "device_types", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column("slug", sa.String(50), nullable=False), + sa.Column("label", sa.String(100), nullable=False), + sa.Column("category", sa.String(50), nullable=False), + sa.Column("is_system", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False), + sa.Column("sort_order", sa.Integer(), nullable=False, server_default=sa.text("0")), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")), + ) + + op.create_unique_constraint("uq_device_types_slug_account", "device_types", ["slug", "account_id"]) + op.create_index("ix_device_types_account_id", "device_types", ["account_id"]) + + device_types_table = sa.table( + "device_types", + sa.column("id", UUID(as_uuid=True)), + sa.column("slug", sa.String), + sa.column("label", sa.String), + sa.column("category", sa.String), + sa.column("is_system", sa.Boolean), + sa.column("account_id", UUID(as_uuid=True)), + sa.column("sort_order", sa.Integer), + ) + + op.bulk_insert(device_types_table, [ + { + "id": uuid.uuid4(), + "slug": slug, + "label": label, + "category": category, + "is_system": True, + "account_id": uuid.UUID(_PLATFORM_UUID), + "sort_order": sort_order, + } + for slug, label, category, sort_order in SYSTEM_DEVICE_TYPES + ]) + + op.execute("ALTER TABLE device_types ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE device_types FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY device_types_select ON device_types + FOR SELECT + USING ( + account_id = {_CURRENT_ACCOUNT} + OR account_id = '{_PLATFORM_UUID}'::uuid + ) + """) + op.execute(f""" + CREATE POLICY device_types_insert ON device_types + FOR INSERT + WITH CHECK (account_id = {_CURRENT_ACCOUNT}) + """) + op.execute(f""" + CREATE POLICY device_types_update ON device_types + FOR UPDATE + USING (account_id = {_CURRENT_ACCOUNT}) + WITH CHECK (account_id = {_CURRENT_ACCOUNT}) + """) + op.execute(f""" + CREATE POLICY device_types_delete ON device_types + FOR DELETE + USING (account_id = {_CURRENT_ACCOUNT}) + """) + + +def downgrade() -> None: + op.execute("DROP POLICY IF EXISTS device_types_delete ON device_types") + op.execute("DROP POLICY IF EXISTS device_types_update ON device_types") + op.execute("DROP POLICY IF EXISTS device_types_insert ON device_types") + op.execute("DROP POLICY IF EXISTS device_types_select ON device_types") + op.execute("ALTER TABLE device_types DISABLE ROW LEVEL SECURITY") + op.drop_table("device_types") diff --git a/backend/alembic/versions/074_add_network_diagrams_table.py b/backend/alembic/versions/074_add_network_diagrams_table.py new file mode 100644 index 00000000..1cd7397e --- /dev/null +++ b/backend/alembic/versions/074_add_network_diagrams_table.py @@ -0,0 +1,57 @@ +"""Add network_diagrams table. + +Revision ID: 074 +Revises: 073 +Create Date: 2026-04-12 +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID, JSONB + + +revision = "074" +down_revision = "073" +branch_labels = None +depends_on = None + +_CURRENT_ACCOUNT = ( + "COALESCE(" + "NULLIF(current_setting('app.current_account_id', TRUE), ''), " + "'00000000-0000-0000-0000-000000000000'" + ")::uuid" +) + + +def upgrade() -> None: + op.create_table( + "network_diagrams", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("client_name", sa.String(255), nullable=True), + sa.Column("asset_name", sa.String(255), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("nodes", JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")), + sa.Column("edges", JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")), + sa.Column("thumbnail_url", sa.Text(), nullable=True), + sa.Column("is_archived", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.Column("created_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")), + ) + + op.create_index("ix_network_diagrams_account_id", "network_diagrams", ["account_id"]) + op.create_index("idx_network_diagrams_account_client", "network_diagrams", ["account_id", "client_name"]) + op.execute("ALTER TABLE network_diagrams ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE network_diagrams FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON network_diagrams + USING (account_id = {_CURRENT_ACCOUNT}) + WITH CHECK (account_id = {_CURRENT_ACCOUNT}) + """) + + +def downgrade() -> None: + op.execute("DROP POLICY IF EXISTS tenant_isolation ON network_diagrams") + op.execute("ALTER TABLE network_diagrams DISABLE ROW LEVEL SECURITY") + op.drop_table("network_diagrams") diff --git a/backend/app/api/endpoints/device_types.py b/backend/app/api/endpoints/device_types.py new file mode 100644 index 00000000..330c0e9f --- /dev/null +++ b/backend/app/api/endpoints/device_types.py @@ -0,0 +1,120 @@ +"""Device types API endpoints.""" +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select, or_ +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.api.deps import get_current_active_user +from app.models.user import User +from app.models.device_type import DeviceType +from app.schemas.device_type import ( + DeviceTypeCreate, + DeviceTypeUpdate, + DeviceTypeResponse, +) +from app.core.service_account import PLATFORM_ACCOUNT_ID + +router = APIRouter(prefix="/device-types", tags=["device-types"]) + + +@router.get("/", response_model=list[DeviceTypeResponse]) +async def list_device_types( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> list[DeviceTypeResponse]: + stmt = ( + select(DeviceType) + .where( + or_( + DeviceType.account_id == PLATFORM_ACCOUNT_ID, + DeviceType.account_id == current_user.account_id, + ) + ) + .order_by(DeviceType.category, DeviceType.sort_order, DeviceType.label) + ) + result = await db.execute(stmt) + rows = result.scalars().all() + return [DeviceTypeResponse.model_validate(r) for r in rows] + + +@router.post("/", response_model=DeviceTypeResponse, status_code=201) +async def create_device_type( + data: DeviceTypeCreate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> DeviceTypeResponse: + existing = await db.execute( + select(DeviceType).where( + DeviceType.slug == data.slug, + DeviceType.account_id == current_user.account_id, + ) + ) + if existing.scalar_one_or_none(): + raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your account") + + system_existing = await db.execute( + select(DeviceType).where( + DeviceType.slug == data.slug, + DeviceType.account_id == PLATFORM_ACCOUNT_ID, + ) + ) + if system_existing.scalar_one_or_none(): + raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' conflicts with a system type") + + device_type = DeviceType( + slug=data.slug, + label=data.label, + category=data.category, + is_system=False, + account_id=current_user.account_id, + sort_order=data.sort_order, + ) + db.add(device_type) + await db.commit() + await db.refresh(device_type) + return DeviceTypeResponse.model_validate(device_type) + + +@router.put("/{device_type_id}", response_model=DeviceTypeResponse) +async def update_device_type( + device_type_id: UUID, + data: DeviceTypeUpdate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> DeviceTypeResponse: + device_type = await db.get(DeviceType, device_type_id) + if not device_type: + raise HTTPException(status_code=404, detail="Device type not found") + if device_type.is_system: + raise HTTPException(status_code=403, detail="Cannot modify system device types") + if device_type.account_id != current_user.account_id: + raise HTTPException(status_code=404, detail="Device type not found") + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(device_type, field, value) + + await db.commit() + await db.refresh(device_type) + return DeviceTypeResponse.model_validate(device_type) + + +@router.delete("/{device_type_id}", status_code=204) +async def delete_device_type( + device_type_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> None: + device_type = await db.get(DeviceType, device_type_id) + if not device_type: + raise HTTPException(status_code=404, detail="Device type not found") + if device_type.is_system: + raise HTTPException(status_code=403, detail="Cannot delete system device types") + if device_type.account_id != current_user.account_id: + raise HTTPException(status_code=404, detail="Device type not found") + + await db.delete(device_type) + await db.commit() diff --git a/backend/app/api/endpoints/network_diagrams.py b/backend/app/api/endpoints/network_diagrams.py new file mode 100644 index 00000000..e00ecf7a --- /dev/null +++ b/backend/app/api/endpoints/network_diagrams.py @@ -0,0 +1,331 @@ +"""Network diagrams API endpoints.""" +import logging +from datetime import datetime, timezone +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import select, or_ +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.api.deps import get_current_active_user +from app.models.user import User +from app.models.device_type import DeviceType +from app.models.network_diagram import NetworkDiagram +from app.core.service_account import PLATFORM_ACCOUNT_ID +from app.schemas.network_diagram import ( + NetworkDiagramCreate, + NetworkDiagramUpdate, + NetworkDiagramResponse, + NetworkDiagramListItem, + AIGenerateRequest, + AIGenerateResponse, + DiagramImportRequest, + DiagramImportResponse, + DiagramExportResponse, + DiagramNode, + DiagramEdge, +) +from app.services import network_diagram_ai_service + +# Maps system device-type slugs to their category — mirrors frontend deviceRegistry.ts +_SLUG_CATEGORY: dict[str, str] = { + "router": "network", "switch": "network", "access-point": "network", "load-balancer": "network", + "firewall": "security", "badge-reader": "security", + "server": "compute", "vm": "compute", "container": "compute", + "nas": "storage", "san": "storage", "cloud-storage": "storage", + "cloud": "cloud", "aws": "cloud", "azure": "cloud", "gcp": "cloud", "isp": "cloud", + "workstation": "endpoint", "laptop": "endpoint", "tablet": "endpoint", + "phone": "endpoint", "printer": "endpoint", + "ups": "infrastructure", "pdu": "infrastructure", "rack": "infrastructure", + "patch-panel": "infrastructure", "camera": "infrastructure", + "nvr": "infrastructure", "iot": "infrastructure", +} + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"]) + + +async def _get_diagram_or_404( + diagram_id: UUID, + account_id: UUID, + db: AsyncSession, +) -> NetworkDiagram: + diagram = await db.get(NetworkDiagram, diagram_id) + if not diagram or diagram.account_id != account_id or diagram.is_archived: + raise HTTPException(status_code=404, detail="Diagram not found") + return diagram + + +def _diagram_to_response(diagram: NetworkDiagram) -> NetworkDiagramResponse: + return NetworkDiagramResponse.model_validate(diagram) + + +def _diagram_to_list_item( + diagram: NetworkDiagram, + custom_slug_category: dict[str, str] | None = None, +) -> NetworkDiagramListItem: + nodes = diagram.nodes if isinstance(diagram.nodes, list) else [] + slug_to_cat = {**_SLUG_CATEGORY, **(custom_slug_category or {})} + + category_counts: dict[str, int] = {} + for node in nodes: + slug = node.get("type", "") if isinstance(node, dict) else "" + cat = slug_to_cat.get(slug, "other") + category_counts[cat] = category_counts.get(cat, 0) + 1 + + return NetworkDiagramListItem( + id=diagram.id, + name=diagram.name, + client_name=diagram.client_name, + description=diagram.description, + node_count=len(nodes), + category_counts=category_counts, + created_by=diagram.created_by, + created_at=diagram.created_at, + updated_at=diagram.updated_at, + ) + + +async def _get_available_slugs(account_id: UUID, db: AsyncSession) -> set[str]: + stmt = select(DeviceType.slug).where( + or_( + DeviceType.account_id == PLATFORM_ACCOUNT_ID, + DeviceType.account_id == account_id, + ) + ) + result = await db.execute(stmt) + return {row[0] for row in result.all()} + + +@router.get("/clients", response_model=list[str]) +async def list_client_names( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> list[str]: + stmt = ( + select(NetworkDiagram.client_name) + .where( + NetworkDiagram.account_id == current_user.account_id, + NetworkDiagram.is_archived.is_(False), + NetworkDiagram.client_name.isnot(None), + NetworkDiagram.client_name != "", + ) + .distinct() + .order_by(NetworkDiagram.client_name) + ) + result = await db.execute(stmt) + return [row[0] for row in result.all()] + + +@router.get("/", response_model=list[NetworkDiagramListItem]) +async def list_diagrams( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], + client_name: str | None = Query(default=None), + search: str | None = Query(default=None), +) -> list[NetworkDiagramListItem]: + stmt = ( + select(NetworkDiagram) + .where( + NetworkDiagram.account_id == current_user.account_id, + NetworkDiagram.is_archived.is_(False), + ) + .order_by(NetworkDiagram.updated_at.desc()) + ) + + if client_name: + stmt = stmt.where(NetworkDiagram.client_name == client_name) + + if search: + escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + search_filter = f"%{escaped}%" + stmt = stmt.where( + or_( + NetworkDiagram.name.ilike(search_filter), + NetworkDiagram.client_name.ilike(search_filter), + ) + ) + + # Single query for custom device types so category_counts is accurate + dt_stmt = select(DeviceType.slug, DeviceType.category).where( + DeviceType.is_system.is_(False), + DeviceType.account_id == current_user.account_id, + ) + dt_result = await db.execute(dt_stmt) + custom_slug_category = {row[0]: row[1] for row in dt_result.all()} + + result = await db.execute(stmt) + rows = result.scalars().all() + return [_diagram_to_list_item(r, custom_slug_category) for r in rows] + + +@router.post("/", response_model=NetworkDiagramResponse, status_code=201) +async def create_diagram( + data: NetworkDiagramCreate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> NetworkDiagramResponse: + diagram = NetworkDiagram( + account_id=current_user.account_id, + name=data.name, + client_name=data.client_name, + asset_name=data.asset_name, + description=data.description, + nodes=[n.model_dump() for n in data.nodes], + edges=[e.model_dump() for e in data.edges], + created_by=current_user.id, + ) + db.add(diagram) + await db.commit() + await db.refresh(diagram) + return _diagram_to_response(diagram) + + +@router.get("/{diagram_id}", response_model=NetworkDiagramResponse) +async def get_diagram( + diagram_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> NetworkDiagramResponse: + diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) + return _diagram_to_response(diagram) + + +@router.put("/{diagram_id}", response_model=NetworkDiagramResponse) +async def update_diagram( + diagram_id: UUID, + data: NetworkDiagramUpdate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> NetworkDiagramResponse: + diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) + + update_data = data.model_dump(exclude_unset=True) + if "nodes" in update_data and update_data["nodes"] is not None: + update_data["nodes"] = [n.model_dump() if hasattr(n, "model_dump") else n for n in update_data["nodes"]] + if "edges" in update_data and update_data["edges"] is not None: + update_data["edges"] = [e.model_dump() if hasattr(e, "model_dump") else e for e in update_data["edges"]] + + for field, value in update_data.items(): + setattr(diagram, field, value) + + diagram.updated_at = datetime.now(timezone.utc) + await db.commit() + await db.refresh(diagram) + return _diagram_to_response(diagram) + + +@router.delete("/{diagram_id}", status_code=204) +async def archive_diagram( + diagram_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> None: + diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) + diagram.is_archived = True + diagram.updated_at = datetime.now(timezone.utc) + await db.commit() + + +@router.post("/{diagram_id}/duplicate", response_model=NetworkDiagramResponse, status_code=201) +async def duplicate_diagram( + diagram_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> NetworkDiagramResponse: + source = await _get_diagram_or_404(diagram_id, current_user.account_id, db) + copy = NetworkDiagram( + account_id=current_user.account_id, + name=f"Copy of {source.name}", + client_name=source.client_name, + asset_name=source.asset_name, + description=source.description, + nodes=source.nodes, + edges=source.edges, + created_by=current_user.id, + ) + db.add(copy) + await db.commit() + await db.refresh(copy) + return _diagram_to_response(copy) + + +@router.get("/{diagram_id}/export", response_model=DiagramExportResponse) +async def export_diagram( + diagram_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> DiagramExportResponse: + diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) + nodes = [DiagramNode(**n) for n in (diagram.nodes or [])] + edges = [DiagramEdge(**e) for e in (diagram.edges or [])] + return DiagramExportResponse( + schemaVersion=1, + name=diagram.name, + client_name=diagram.client_name, + description=diagram.description, + nodes=nodes, + edges=edges, + exportedAt=datetime.now(timezone.utc).isoformat(), + ) + + +@router.post("/import", response_model=DiagramImportResponse, status_code=201) +async def import_diagram( + data: DiagramImportRequest, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> DiagramImportResponse: + available_slugs = await _get_available_slugs(current_user.account_id, db) + + warnings: list[str] = [] + for node in data.nodes: + if node.type not in available_slugs: + warnings.append(f"Unknown device type '{node.type}' — will render with default icon") + + diagram = NetworkDiagram( + account_id=current_user.account_id, + name=data.name, + client_name=data.client_name, + description=data.description, + nodes=[n.model_dump() for n in data.nodes], + edges=[e.model_dump() for e in data.edges], + created_by=current_user.id, + ) + db.add(diagram) + await db.commit() + await db.refresh(diagram) + + return DiagramImportResponse( + diagram=_diagram_to_response(diagram), + warnings=warnings, + ) + + +@router.post("/ai-generate", response_model=AIGenerateResponse) +async def ai_generate_diagram( + data: AIGenerateRequest, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> AIGenerateResponse: + available_slugs_set = await _get_available_slugs(current_user.account_id, db) + available_slugs = list(available_slugs_set) + + existing_node_ids: list[str] | None = None + if data.mode == "merge" and data.existingBounds: + existing_node_ids = [] + + try: + return await network_diagram_ai_service.generate_diagram( + request=data, + available_slugs=available_slugs, + existing_node_ids=existing_node_ids, + ) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + except Exception: + logger.exception("AI diagram generation failed") + raise HTTPException(status_code=500, detail="Diagram generation failed") diff --git a/backend/app/api/router.py b/backend/app/api/router.py index ed32ba58..349f5969 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -24,6 +24,7 @@ from app.api.endpoints import ( branding, categories, copilot, + device_types, feedback, flow_proposals, flowpilot_analytics, @@ -32,6 +33,7 @@ from app.api.endpoints import ( invite, kb_accelerator, maintenance_schedules, + network_diagrams, notifications, onboarding, public_templates, @@ -93,7 +95,6 @@ api_router.include_router(admin_settings.router) api_router.include_router(admin_categories.router) api_router.include_router(admin_survey.router) api_router.include_router(admin_gallery.router) - # --------------------------------------------------------------------------- # User-facing endpoints — tenant context required # --------------------------------------------------------------------------- @@ -130,6 +131,7 @@ api_router.include_router(integrations.router, dependencies=_tenant_deps) api_router.include_router(onboarding.router, dependencies=_tenant_deps) api_router.include_router(branding.router, dependencies=_tenant_deps) api_router.include_router(supporting_data.router, dependencies=_tenant_deps) +api_router.include_router(network_diagrams.router, dependencies=_tenant_deps) # session_handoffs queue router must come before ai_sessions to avoid conflict api_router.include_router(session_handoffs.queue_router, dependencies=_tenant_deps) api_router.include_router(session_resolutions.router, dependencies=_tenant_deps) @@ -142,3 +144,4 @@ api_router.include_router(script_builder.router, dependencies=_tenant_deps) api_router.include_router(beta_feedback.router, dependencies=_tenant_deps) api_router.include_router(session_branches.router, dependencies=_tenant_deps) api_router.include_router(session_handoffs.router, dependencies=_tenant_deps) +api_router.include_router(device_types.router, dependencies=_tenant_deps) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 5d31b789..90db6f83 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -128,6 +128,7 @@ class Settings(BaseSettings): "variable_inference": "fast", "kb_convert": "standard", "script_build": "standard", + "network_diagram_generate": "standard", } def get_model_for_action(self, action_type: str) -> str: diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 0441624f..5346c6ec 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -56,6 +56,8 @@ from .session_handoff import SessionHandoff from .session_resolution_output import SessionResolutionOutput from .template_tree import TemplateTree from .platform_step import PlatformStep +from .device_type import DeviceType +from .network_diagram import NetworkDiagram __all__ = [ "User", @@ -126,4 +128,6 @@ __all__ = [ "SessionResolutionOutput", "TemplateTree", "PlatformStep", + "DeviceType", + "NetworkDiagram", ] diff --git a/backend/app/models/device_type.py b/backend/app/models/device_type.py new file mode 100644 index 00000000..d2f9c756 --- /dev/null +++ b/backend/app/models/device_type.py @@ -0,0 +1,47 @@ +"""Device type model for network diagrams.""" +import uuid +from datetime import datetime, timezone + +from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import UUID + +from app.core.database import Base + + +class DeviceType(Base): + """A device type for network diagram nodes (platform or account-custom).""" + __tablename__ = "device_types" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + slug: Mapped[str] = mapped_column( + String(50), nullable=False, + comment="Unique identifier used in diagram node data", + ) + label: Mapped[str] = mapped_column( + String(100), nullable=False, + comment="Display name", + ) + category: Mapped[str] = mapped_column( + String(50), nullable=False, + comment="network, compute, storage, cloud, endpoint, infrastructure, security", + ) + is_system: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False, + comment="True for built-in types that cannot be deleted", + ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + comment="Platform account for system types, tenant account for custom types", + ) + sort_order: Mapped[int] = mapped_column( + Integer, nullable=False, default=0, + comment="Display order within category", + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) diff --git a/backend/app/models/network_diagram.py b/backend/app/models/network_diagram.py new file mode 100644 index 00000000..347216da --- /dev/null +++ b/backend/app/models/network_diagram.py @@ -0,0 +1,53 @@ +"""Network diagram model.""" +import uuid +from datetime import datetime, timezone +from typing import Any, TYPE_CHECKING + +from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID, JSONB + +from app.core.database import Base + +if TYPE_CHECKING: + from app.models.user import User + + +class NetworkDiagram(Base): + """A network topology diagram scoped to one account.""" + __tablename__ = "network_diagrams" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + name: Mapped[str] = mapped_column(String(255), nullable=False) + client_name: Mapped[str | None] = mapped_column(String(255), nullable=True) + asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'") + edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'") + thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True) + is_archived: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False, + ) + created_by: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id"), + nullable=True, + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + + creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by]) diff --git a/backend/app/schemas/device_type.py b/backend/app/schemas/device_type.py new file mode 100644 index 00000000..665fd8a2 --- /dev/null +++ b/backend/app/schemas/device_type.py @@ -0,0 +1,37 @@ +"""Pydantic schemas for device types.""" +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, Field + + +class DeviceTypeCreate(BaseModel): + slug: str = Field(min_length=1, max_length=50, pattern=r"^[a-z0-9\-]+$") + label: str = Field(min_length=1, max_length=100) + category: str = Field( + min_length=1, max_length=50, + pattern=r"^(network|compute|storage|cloud|endpoint|infrastructure|security)$", + ) + sort_order: int = Field(default=0, ge=0) + + +class DeviceTypeUpdate(BaseModel): + label: str | None = Field(default=None, min_length=1, max_length=100) + category: str | None = Field( + default=None, min_length=1, max_length=50, + pattern=r"^(network|compute|storage|cloud|endpoint|infrastructure|security)$", + ) + sort_order: int | None = Field(default=None, ge=0) + + +class DeviceTypeResponse(BaseModel): + id: UUID + slug: str + label: str + category: str + is_system: bool + account_id: UUID + sort_order: int + created_at: datetime + + model_config = {"from_attributes": True} diff --git a/backend/app/schemas/network_diagram.py b/backend/app/schemas/network_diagram.py new file mode 100644 index 00000000..e31d5283 --- /dev/null +++ b/backend/app/schemas/network_diagram.py @@ -0,0 +1,136 @@ +"""Pydantic schemas for network diagrams.""" +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, Field + + +class Position(BaseModel): + x: float + y: float + + +class DeviceProperties(BaseModel): + hostname: str | None = None + ip: str | None = None + subnet: str | None = None + vendor: str | None = None + model: str | None = None + role: str | None = None + vlan: str | None = None + notes: str | None = None + status: str = Field(default="unknown", pattern=r"^(unknown|online|offline|degraded)$") + + +class DiagramNode(BaseModel): + id: str + type: str + label: str + position: Position + properties: DeviceProperties = Field(default_factory=DeviceProperties) + + +class DiagramEdge(BaseModel): + id: str + source: str + target: str + label: str | None = None + connectionType: str = "ethernet" + speed: str | None = None + notes: str | None = None + routing: str | None = None + + +class NetworkDiagramCreate(BaseModel): + name: str = Field(min_length=1, max_length=255) + client_name: str | None = None + asset_name: str | None = None + description: str | None = None + nodes: list[DiagramNode] = Field(default_factory=list) + edges: list[DiagramEdge] = Field(default_factory=list) + + +class NetworkDiagramUpdate(BaseModel): + name: str | None = Field(default=None, min_length=1, max_length=255) + client_name: str | None = None + asset_name: str | None = None + description: str | None = None + nodes: list[DiagramNode] | None = None + edges: list[DiagramEdge] | None = None + + +class NetworkDiagramResponse(BaseModel): + id: UUID + account_id: UUID + name: str + client_name: str | None = None + asset_name: str | None = None + description: str | None = None + nodes: list[DiagramNode] = Field(default_factory=list) + edges: list[DiagramEdge] = Field(default_factory=list) + thumbnail_url: str | None = None + is_archived: bool = False + created_by: UUID | None = None + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class NetworkDiagramListItem(BaseModel): + id: UUID + name: str + client_name: str | None = None + description: str | None = None + node_count: int = 0 + category_counts: dict[str, int] = Field(default_factory=dict) + created_by: UUID | None = None + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class ExistingBounds(BaseModel): + minX: float + maxX: float + minY: float + maxY: float + + +class AIGenerateRequest(BaseModel): + description: str = Field(min_length=1, max_length=5000) + client_name: str | None = None + mode: str = Field(default="replace", pattern=r"^(replace|merge)$") + existingBounds: ExistingBounds | None = None + + +class AIGenerateResponse(BaseModel): + nodes: list[DiagramNode] + edges: list[DiagramEdge] + suggestedName: str | None = None + notes: str | None = None + + +class DiagramImportRequest(BaseModel): + schemaVersion: int = Field(ge=1, le=1) + name: str = Field(min_length=1, max_length=255) + client_name: str | None = None + description: str | None = None + nodes: list[DiagramNode] = Field(default_factory=list) + edges: list[DiagramEdge] = Field(default_factory=list) + + +class DiagramImportResponse(BaseModel): + diagram: NetworkDiagramResponse + warnings: list[str] = Field(default_factory=list) + + +class DiagramExportResponse(BaseModel): + schemaVersion: int = 1 + name: str + client_name: str | None = None + description: str | None = None + nodes: list[DiagramNode] + edges: list[DiagramEdge] + exportedAt: str diff --git a/backend/app/services/network_diagram_ai_service.py b/backend/app/services/network_diagram_ai_service.py new file mode 100644 index 00000000..5ac0df9c --- /dev/null +++ b/backend/app/services/network_diagram_ai_service.py @@ -0,0 +1,151 @@ +"""AI service for generating network diagrams from natural language.""" +import json +import logging + +from app.core.ai_provider import get_ai_provider +from app.core.config import settings +from app.schemas.network_diagram import ( + AIGenerateRequest, + AIGenerateResponse, + DiagramNode, + DiagramEdge, + DeviceProperties, + Position, +) + +logger = logging.getLogger(__name__) + +SYSTEM_PROMPT_TEMPLATE = """You are a network diagram generator for MSP engineers. +Given a plain English description of a network, you must return ONLY valid JSON with no markdown, no explanation, no preamble. + +Return this exact structure: +{{ + "nodes": [ + {{ + "id": "unique-string", + "type": "device-type-slug", + "label": "device label", + "position": {{ "x": number, "y": number }}, + "properties": {{ + "hostname": "string or null", + "ip": "string or null", + "subnet": "string or null", + "vendor": "string or null", + "model": "string or null", + "role": "string or null", + "vlan": "string or null", + "notes": "string or null", + "status": "unknown" + }} + }} + ], + "edges": [ + {{ + "id": "unique-string", + "source": "node-id", + "target": "node-id", + "label": "connection label or null", + "connectionType": "ethernet|fiber|wifi|vpn|vlan|wan", + "speed": "string or null", + "notes": "string or null" + }} + ], + "suggestedName": "short descriptive diagram name", + "notes": "any important assumptions or missing info, or null" +}} + +Available device type slugs: {available_slugs} + +Position nodes thoughtfully in a logical network topology layout. +Use x/y coordinates between 0 and 1200 for x, 0 and 800 for y. +Place WAN/internet at top, core network in middle, endpoints at bottom. +{merge_instructions}""" + +MERGE_INSTRUCTIONS = """ +IMPORTANT: You are ADDING devices to an existing diagram. Do NOT replace existing devices. +The existing diagram occupies this bounding box: minX={minX}, maxX={maxX}, minY={minY}, maxY={maxY}. +Place all new nodes OUTSIDE this bounding box — below (y > {maxY} + 100) or to the right (x > {maxX} + 100). +You may create edges that connect new nodes to existing nodes if the description implies a connection. +Use these existing node IDs for connections: {existing_node_ids}""" + + +async def generate_diagram( + request: AIGenerateRequest, + available_slugs: list[str], + existing_node_ids: list[str] | None = None, +) -> AIGenerateResponse: + merge_instructions = "" + if request.mode == "merge" and request.existingBounds: + b = request.existingBounds + merge_instructions = MERGE_INSTRUCTIONS.format( + minX=b.minX, maxX=b.maxX, minY=b.minY, maxY=b.maxY, + existing_node_ids=", ".join(existing_node_ids or []), + ) + + system_prompt = SYSTEM_PROMPT_TEMPLATE.format( + available_slugs=", ".join(available_slugs), + merge_instructions=merge_instructions, + ) + + model = settings.get_model_for_action("network_diagram_generate") + provider = get_ai_provider(model) + + messages = [{"role": "user", "content": request.description}] + + response_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=system_prompt, + messages=messages, + max_tokens=4096, + ) + + logger.info( + "Network diagram AI generation: input_tokens=%d, output_tokens=%d", + input_tokens, output_tokens, + ) + + try: + data = json.loads(response_text) + except json.JSONDecodeError as e: + logger.error("Failed to parse AI response as JSON: %s", e) + raise ValueError("AI generated an invalid response, please try again") + + try: + nodes = [] + for raw_node in data.get("nodes", []): + node_type = raw_node.get("type", "server") + if node_type not in available_slugs: + logger.warning("Unknown device type '%s', falling back to 'server'", node_type) + node_type = "server" + + nodes.append(DiagramNode( + id=raw_node["id"], + type=node_type, + label=raw_node.get("label", node_type), + position=Position(**raw_node.get("position", {"x": 0, "y": 0})), + properties=DeviceProperties(**{ + k: v for k, v in raw_node.get("properties", {}).items() + if k in DeviceProperties.model_fields + }), + )) + + edges = [] + for raw_edge in data.get("edges", []): + edges.append(DiagramEdge( + id=raw_edge["id"], + source=raw_edge["source"], + target=raw_edge["target"], + label=raw_edge.get("label"), + connectionType=raw_edge.get("connectionType", "ethernet"), + speed=raw_edge.get("speed"), + notes=raw_edge.get("notes"), + )) + except KeyError as e: + logger.warning("AI response missing required field: %s", e) + raise ValueError(f"AI generated incomplete data (missing {e}), please try again") + + return AIGenerateResponse( + nodes=nodes, + edges=edges, + suggestedName=data.get("suggestedName"), + notes=data.get("notes"), + ) diff --git a/backend/tests/test_network_diagrams.py b/backend/tests/test_network_diagrams.py new file mode 100644 index 00000000..6b6cb93c --- /dev/null +++ b/backend/tests/test_network_diagrams.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import uuid + +import pytest +from sqlalchemy import select + +from app.models.device_type import DeviceType +from app.models.user import User +from app.core.service_account import PLATFORM_ACCOUNT_ID + + +async def _login_headers(client, email: str, password: str) -> dict[str, str]: + response = await client.post( + "/api/v1/auth/login/json", + json={"email": email, "password": password}, + ) + assert response.status_code == 200 + token = response.json()["access_token"] + return {"Authorization": f"Bearer {token}"} + + +@pytest.mark.asyncio +async def test_device_types_include_platform_and_account_custom(client, test_db, auth_headers, test_user): + result = await test_db.execute(select(User).where(User.email == test_user["email"])) + user = result.scalar_one() + + test_db.add( + DeviceType( + id=uuid.uuid4(), + slug="platform-router", + label="Platform Router", + category="network", + is_system=True, + account_id=PLATFORM_ACCOUNT_ID, + sort_order=0, + ) + ) + await test_db.commit() + + create_response = await client.post( + "/api/v1/device-types/", + json={ + "slug": "tenant-appliance", + "label": "Tenant Appliance", + "category": "network", + "sort_order": 3, + }, + headers=auth_headers, + ) + assert create_response.status_code == 201 + assert create_response.json()["account_id"] == str(user.account_id) + + list_response = await client.get("/api/v1/device-types/", headers=auth_headers) + assert list_response.status_code == 200 + payload = list_response.json() + slugs = {item["slug"] for item in payload} + + assert "platform-router" in slugs + assert "tenant-appliance" in slugs + + +@pytest.mark.asyncio +async def test_network_diagrams_are_account_scoped(client, test_db, auth_headers, test_user): + other_user = { + "email": "other-network@example.com", + "password": "TestPassword123!", + "name": "Other Network User", + } + register_response = await client.post("/api/v1/auth/register", json=other_user) + assert register_response.status_code in (200, 201) + other_headers = await _login_headers(client, other_user["email"], other_user["password"]) + + owner_result = await test_db.execute(select(User).where(User.email == test_user["email"])) + owner = owner_result.scalar_one() + + create_response = await client.post( + "/api/v1/network-diagrams/", + json={ + "name": "HQ Core", + "client_name": "Acme", + "description": "Primary topology", + "nodes": [], + "edges": [], + }, + headers=auth_headers, + ) + assert create_response.status_code == 201 + diagram = create_response.json() + assert diagram["account_id"] == str(owner.account_id) + + own_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=auth_headers) + assert own_get.status_code == 200 + + other_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=other_headers) + assert other_get.status_code == 404 diff --git a/frontend/package-lock.json b/frontend/package-lock.json index cd02c6f0..cb2eb520 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -23,6 +23,7 @@ "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "date-fns": "^4.1.0", + "html-to-image": "^1.11.13", "immer": "^11.1.3", "lucide-react": "^0.563.0", "monaco-editor": "^0.55.1", @@ -5331,6 +5332,12 @@ "dev": true, "license": "MIT" }, + "node_modules/html-to-image": { + "version": "1.11.13", + "resolved": "https://registry.npmjs.org/html-to-image/-/html-to-image-1.11.13.tgz", + "integrity": "sha512-cuOPoI7WApyhBElTTb9oqsawRvZ0rHhaHwghRLlTuffoD1B2aDemlCruLeZrUIIdvG7gs9xeELEPm6PhuASqrg==", + "license": "MIT" + }, "node_modules/html-url-attributes": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/html-url-attributes/-/html-url-attributes-3.0.1.tgz", diff --git a/frontend/package.json b/frontend/package.json index d93a3831..f03d09a0 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -36,6 +36,7 @@ "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "date-fns": "^4.1.0", + "html-to-image": "^1.11.13", "immer": "^11.1.3", "lucide-react": "^0.563.0", "monaco-editor": "^0.55.1", diff --git a/frontend/public/images/hero_001.jpg b/frontend/public/images/hero_001.jpg new file mode 100644 index 00000000..e2c292a6 Binary files /dev/null and b/frontend/public/images/hero_001.jpg differ diff --git a/frontend/src/api/deviceTypes.ts b/frontend/src/api/deviceTypes.ts new file mode 100644 index 00000000..1425584a --- /dev/null +++ b/frontend/src/api/deviceTypes.ts @@ -0,0 +1,23 @@ +import apiClient from './client' +import type { DeviceTypeResponse, DeviceTypeCreate } from '@/types' + +export const deviceTypesApi = { + async list(): Promise { + const response = await apiClient.get('/device-types/') + return response.data + }, + + async create(data: DeviceTypeCreate): Promise { + const response = await apiClient.post('/device-types/', data) + return response.data + }, + + async update(id: string, data: Partial): Promise { + const response = await apiClient.put(`/device-types/${id}`, data) + return response.data + }, + + async remove(id: string): Promise { + await apiClient.delete(`/device-types/${id}`) + }, +} diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index b362e193..9ec66c9b 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -35,3 +35,5 @@ export { betaFeedbackApi } from './betaFeedback' export { branchesApi } from './branches' export { handoffsApi } from './handoffs' export { resolutionsApi } from './resolutions' +export { deviceTypesApi } from './deviceTypes' +export { networkDiagramsApi } from './networkDiagrams' diff --git a/frontend/src/api/networkDiagrams.ts b/frontend/src/api/networkDiagrams.ts new file mode 100644 index 00000000..c074fb00 --- /dev/null +++ b/frontend/src/api/networkDiagrams.ts @@ -0,0 +1,63 @@ +import apiClient from './client' +import type { + NetworkDiagramResponse, + NetworkDiagramListItem, + NetworkDiagramCreate, + NetworkDiagramUpdate, + AIGenerateRequest, + AIGenerateResponse, + DiagramImportData, + DiagramImportResponse, + DiagramExportResponse, +} from '@/types' + +export const networkDiagramsApi = { + async list(params?: { client_name?: string; search?: string }): Promise { + const response = await apiClient.get('/network-diagrams/', { params }) + return response.data + }, + + async get(id: string): Promise { + const response = await apiClient.get(`/network-diagrams/${id}`) + return response.data + }, + + async create(data: NetworkDiagramCreate): Promise { + const response = await apiClient.post('/network-diagrams/', data) + return response.data + }, + + async update(id: string, data: NetworkDiagramUpdate): Promise { + const response = await apiClient.put(`/network-diagrams/${id}`, data) + return response.data + }, + + async archive(id: string): Promise { + await apiClient.delete(`/network-diagrams/${id}`) + }, + + async duplicate(id: string): Promise { + const response = await apiClient.post(`/network-diagrams/${id}/duplicate`) + return response.data + }, + + async exportJson(id: string): Promise { + const response = await apiClient.get(`/network-diagrams/${id}/export`) + return response.data + }, + + async importJson(data: DiagramImportData): Promise { + const response = await apiClient.post('/network-diagrams/import', data) + return response.data + }, + + async aiGenerate(data: AIGenerateRequest): Promise { + const response = await apiClient.post('/network-diagrams/ai-generate', data) + return response.data + }, + + async listClients(): Promise { + const response = await apiClient.get('/network-diagrams/clients') + return response.data + }, +} diff --git a/frontend/src/components/layout/Sidebar.tsx b/frontend/src/components/layout/Sidebar.tsx index 506438a5..b8e2ae82 100644 --- a/frontend/src/components/layout/Sidebar.tsx +++ b/frontend/src/components/layout/Sidebar.tsx @@ -5,7 +5,7 @@ import { LayoutGrid, Clock, AlertTriangle, GitBranch, Code2, Wand2, ListChecks, Download, BarChart3, Settings, Pin, PinOff, - History, FileText, + History, FileText, Network, } from 'lucide-react' import { cn } from '@/lib/utils' import { useUserPreferencesStore } from '@/store/userPreferencesStore' @@ -86,10 +86,11 @@ export function Sidebar() { { href: '/trees', icon: GitBranch, label: 'Flows', shortLabel: 'Flows', badge: stats?.tree_counts.total || undefined, - matchPaths: ['/trees', '/flows', '/my-trees', '/step-library', '/review-queue'], + matchPaths: ['/trees', '/flows', '/my-trees', '/step-library', '/review-queue', '/network-diagrams'], children: [ { href: '/trees', label: 'Flow Library', count: stats?.tree_counts.total || undefined }, { href: '/trees?type=procedural', label: 'Projects', count: stats?.tree_counts.procedural || undefined }, + { href: '/network-diagrams', label: 'Network Maps' }, { href: '/step-library', label: 'Solutions Library' }, { href: '/review-queue', label: 'Review Queue' }, ], @@ -134,6 +135,7 @@ export function Sidebar() { { href: '/trees?type=procedural', label: 'Projects', count: stats?.tree_counts.procedural || undefined }, ], }, + { href: '/network-diagrams', icon: Network, label: 'Network Maps', shortLabel: 'NetMap', matchPaths: ['/network-diagrams'] }, { href: '/scripts', icon: Code2, label: 'Scripts', shortLabel: 'Scripts' }, { href: '/script-builder', icon: Wand2, label: 'Script Builder', shortLabel: 'Builder' }, { href: '/review-queue', icon: ListChecks, label: 'Review Queue', shortLabel: 'Review' }, diff --git a/frontend/src/components/network/CanvasEmptyPrompt.tsx b/frontend/src/components/network/CanvasEmptyPrompt.tsx new file mode 100644 index 00000000..cbf6a4f5 --- /dev/null +++ b/frontend/src/components/network/CanvasEmptyPrompt.tsx @@ -0,0 +1,232 @@ +import { useState, useCallback, useEffect } from 'react' +import { Sparkles, ArrowRight, PencilRuler, Wand2, X } from 'lucide-react' +import { networkDiagramsApi } from '@/api' +import type { AIGenerateResponse } from '@/types' + +const EXAMPLE_PROMPTS = [ + 'Small office with firewall and core switch', + 'Azure hybrid cloud with VPN gateway', + 'Branch office connected to HQ via MPLS', + 'Data center with redundant core switches', + 'Remote workforce with Meraki and cloud apps', +] + +interface CanvasEmptyPromptProps { + onGenerate: (result: AIGenerateResponse, mode: 'replace' | 'merge') => void +} + +export function CanvasEmptyPrompt({ onGenerate }: CanvasEmptyPromptProps) { + const [mode, setMode] = useState<'choice' | 'ai' | 'manual'>('choice') + const [description, setDescription] = useState('') + const [loading, setLoading] = useState(false) + const [error, setError] = useState(null) + + const switchToManual = useCallback(() => { + if (loading) return + setMode('manual') + setError(null) + }, [loading]) + + const handleGenerate = useCallback(async (text?: string) => { + const desc = (text ?? description).trim() + if (!desc) return + setLoading(true) + setError(null) + try { + const result = await networkDiagramsApi.aiGenerate({ + description: desc, + mode: 'replace', + existingBounds: null, + }) + onGenerate(result, 'replace') + } catch (err: unknown) { + setError(err instanceof Error ? err.message : 'Generation failed. Please try again.') + } finally { + setLoading(false) + } + }, [description, onGenerate]) + + useEffect(() => { + if (mode === 'manual') return + + const handleKeyDown = (event: KeyboardEvent) => { + if (event.key === 'Escape') { + event.preventDefault() + switchToManual() + } + } + + window.addEventListener('keydown', handleKeyDown) + return () => window.removeEventListener('keydown', handleKeyDown) + }, [mode, switchToManual]) + + if (mode === 'manual') { + return ( +
+
+
+ +
+
+

Manual mode is on

+

+ Drag devices from the left panel onto the canvas, or reopen AI whenever you want. +

+
+ +
+
+ ) + } + + return ( +
{ + if (event.target === event.currentTarget) { + switchToManual() + } + }} + > +
+ + {mode === 'choice' ? ( + <> +
+
+ +

+ Start a network map +

+
+

+ Generate a topology with AI or start with a blank canvas and build it manually. +

+

+ Press Esc or click outside to skip AI and start dragging devices. +

+
+ +
+ + + +
+ + + ) : ( + <> +
+
+ +

+ Describe your network +

+
+

+ AI will generate the topology in seconds, or you can go back and switch to manual creation. +

+

+ Press Esc, click outside, or use the close button to build manually instead. +

+
+ +
+