feat: extract network map builder from PR 124 #137

Merged
chihlasm merged 47 commits from feat/network-map-builder-prod into main 2026-04-13 06:38:01 +00:00
45 changed files with 4826 additions and 3 deletions

View File

@@ -0,0 +1,132 @@
"""Add account-scoped device_types table with platform seed data.
Revision ID: 073
Revises: b3c7e9f2a1d8
Create Date: 2026-04-12
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
import uuid
revision = "073"
down_revision = "b3c7e9f2a1d8"
branch_labels = None
depends_on = None
_PLATFORM_UUID = "00000000-0000-0000-0000-000000000001"
_CURRENT_ACCOUNT = (
"COALESCE("
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
"'00000000-0000-0000-0000-000000000000'"
")::uuid"
)
SYSTEM_DEVICE_TYPES = [
("router", "Router", "network", 0),
("switch", "Switch", "network", 1),
("firewall", "Firewall", "network", 2),
("access-point", "Access Point", "network", 3),
("load-balancer", "Load Balancer", "network", 4),
("server", "Server", "compute", 0),
("workstation", "Workstation", "compute", 1),
("vm", "Virtual Machine", "compute", 2),
("container", "Container", "compute", 3),
("nas", "NAS", "storage", 0),
("san", "SAN", "storage", 1),
("cloud-storage", "Cloud Storage", "storage", 2),
("cloud", "Cloud", "cloud", 0),
("aws", "AWS", "cloud", 1),
("azure", "Azure", "cloud", 2),
("gcp", "Google Cloud", "cloud", 3),
("printer", "Printer", "endpoint", 0),
("phone", "Phone", "endpoint", 1),
("iot", "IoT Device", "endpoint", 2),
("camera", "Camera", "endpoint", 3),
("tablet", "Tablet", "endpoint", 4),
("laptop", "Laptop", "endpoint", 5),
("ups", "UPS", "infrastructure", 0),
("pdu", "PDU", "infrastructure", 1),
("rack", "Rack", "infrastructure", 2),
("patch-panel", "Patch Panel", "infrastructure", 3),
("nvr", "NVR", "security", 0),
("badge-reader", "Badge Reader", "security", 1),
]
def upgrade() -> None:
op.create_table(
"device_types",
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column("slug", sa.String(50), nullable=False),
sa.Column("label", sa.String(100), nullable=False),
sa.Column("category", sa.String(50), nullable=False),
sa.Column("is_system", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False),
sa.Column("sort_order", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
)
op.create_unique_constraint("uq_device_types_slug_account", "device_types", ["slug", "account_id"])
op.create_index("ix_device_types_account_id", "device_types", ["account_id"])
device_types_table = sa.table(
"device_types",
sa.column("id", UUID(as_uuid=True)),
sa.column("slug", sa.String),
sa.column("label", sa.String),
sa.column("category", sa.String),
sa.column("is_system", sa.Boolean),
sa.column("account_id", UUID(as_uuid=True)),
sa.column("sort_order", sa.Integer),
)
op.bulk_insert(device_types_table, [
{
"id": uuid.uuid4(),
"slug": slug,
"label": label,
"category": category,
"is_system": True,
"account_id": uuid.UUID(_PLATFORM_UUID),
"sort_order": sort_order,
}
for slug, label, category, sort_order in SYSTEM_DEVICE_TYPES
])
op.execute("ALTER TABLE device_types ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE device_types FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY device_types_select ON device_types
FOR SELECT
USING (
account_id = {_CURRENT_ACCOUNT}
OR account_id = '{_PLATFORM_UUID}'::uuid
)
""")
op.execute(f"""
CREATE POLICY device_types_insert ON device_types
FOR INSERT
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
""")
op.execute(f"""
CREATE POLICY device_types_update ON device_types
FOR UPDATE
USING (account_id = {_CURRENT_ACCOUNT})
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
""")
op.execute(f"""
CREATE POLICY device_types_delete ON device_types
FOR DELETE
USING (account_id = {_CURRENT_ACCOUNT})
""")
def downgrade() -> None:
op.execute("DROP POLICY IF EXISTS device_types_delete ON device_types")
op.execute("DROP POLICY IF EXISTS device_types_update ON device_types")
op.execute("DROP POLICY IF EXISTS device_types_insert ON device_types")
op.execute("DROP POLICY IF EXISTS device_types_select ON device_types")
op.execute("ALTER TABLE device_types DISABLE ROW LEVEL SECURITY")
op.drop_table("device_types")

View File

@@ -0,0 +1,57 @@
"""Add network_diagrams table.
Revision ID: 074
Revises: 073
Create Date: 2026-04-12
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID, JSONB
revision = "074"
down_revision = "073"
branch_labels = None
depends_on = None
_CURRENT_ACCOUNT = (
"COALESCE("
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
"'00000000-0000-0000-0000-000000000000'"
")::uuid"
)
def upgrade() -> None:
op.create_table(
"network_diagrams",
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("client_name", sa.String(255), nullable=True),
sa.Column("asset_name", sa.String(255), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("nodes", JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
sa.Column("edges", JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
sa.Column("thumbnail_url", sa.Text(), nullable=True),
sa.Column("is_archived", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.Column("created_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
)
op.create_index("ix_network_diagrams_account_id", "network_diagrams", ["account_id"])
op.create_index("idx_network_diagrams_account_client", "network_diagrams", ["account_id", "client_name"])
op.execute("ALTER TABLE network_diagrams ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE network_diagrams FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON network_diagrams
USING (account_id = {_CURRENT_ACCOUNT})
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
""")
def downgrade() -> None:
op.execute("DROP POLICY IF EXISTS tenant_isolation ON network_diagrams")
op.execute("ALTER TABLE network_diagrams DISABLE ROW LEVEL SECURITY")
op.drop_table("network_diagrams")

View File

@@ -0,0 +1,120 @@
"""Device types API endpoints."""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select, or_
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.api.deps import get_current_active_user
from app.models.user import User
from app.models.device_type import DeviceType
from app.schemas.device_type import (
DeviceTypeCreate,
DeviceTypeUpdate,
DeviceTypeResponse,
)
from app.core.service_account import PLATFORM_ACCOUNT_ID
router = APIRouter(prefix="/device-types", tags=["device-types"])
@router.get("/", response_model=list[DeviceTypeResponse])
async def list_device_types(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> list[DeviceTypeResponse]:
stmt = (
select(DeviceType)
.where(
or_(
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
DeviceType.account_id == current_user.account_id,
)
)
.order_by(DeviceType.category, DeviceType.sort_order, DeviceType.label)
)
result = await db.execute(stmt)
rows = result.scalars().all()
return [DeviceTypeResponse.model_validate(r) for r in rows]
@router.post("/", response_model=DeviceTypeResponse, status_code=201)
async def create_device_type(
data: DeviceTypeCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DeviceTypeResponse:
existing = await db.execute(
select(DeviceType).where(
DeviceType.slug == data.slug,
DeviceType.account_id == current_user.account_id,
)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your account")
system_existing = await db.execute(
select(DeviceType).where(
DeviceType.slug == data.slug,
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
)
)
if system_existing.scalar_one_or_none():
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' conflicts with a system type")
device_type = DeviceType(
slug=data.slug,
label=data.label,
category=data.category,
is_system=False,
account_id=current_user.account_id,
sort_order=data.sort_order,
)
db.add(device_type)
await db.commit()
await db.refresh(device_type)
return DeviceTypeResponse.model_validate(device_type)
@router.put("/{device_type_id}", response_model=DeviceTypeResponse)
async def update_device_type(
device_type_id: UUID,
data: DeviceTypeUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DeviceTypeResponse:
device_type = await db.get(DeviceType, device_type_id)
if not device_type:
raise HTTPException(status_code=404, detail="Device type not found")
if device_type.is_system:
raise HTTPException(status_code=403, detail="Cannot modify system device types")
if device_type.account_id != current_user.account_id:
raise HTTPException(status_code=404, detail="Device type not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(device_type, field, value)
await db.commit()
await db.refresh(device_type)
return DeviceTypeResponse.model_validate(device_type)
@router.delete("/{device_type_id}", status_code=204)
async def delete_device_type(
device_type_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> None:
device_type = await db.get(DeviceType, device_type_id)
if not device_type:
raise HTTPException(status_code=404, detail="Device type not found")
if device_type.is_system:
raise HTTPException(status_code=403, detail="Cannot delete system device types")
if device_type.account_id != current_user.account_id:
raise HTTPException(status_code=404, detail="Device type not found")
await db.delete(device_type)
await db.commit()

View File

@@ -0,0 +1,331 @@
"""Network diagrams API endpoints."""
import logging
from datetime import datetime, timezone
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, or_
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.api.deps import get_current_active_user
from app.models.user import User
from app.models.device_type import DeviceType
from app.models.network_diagram import NetworkDiagram
from app.core.service_account import PLATFORM_ACCOUNT_ID
from app.schemas.network_diagram import (
NetworkDiagramCreate,
NetworkDiagramUpdate,
NetworkDiagramResponse,
NetworkDiagramListItem,
AIGenerateRequest,
AIGenerateResponse,
DiagramImportRequest,
DiagramImportResponse,
DiagramExportResponse,
DiagramNode,
DiagramEdge,
)
from app.services import network_diagram_ai_service
# Maps system device-type slugs to their category — mirrors frontend deviceRegistry.ts
_SLUG_CATEGORY: dict[str, str] = {
"router": "network", "switch": "network", "access-point": "network", "load-balancer": "network",
"firewall": "security", "badge-reader": "security",
"server": "compute", "vm": "compute", "container": "compute",
"nas": "storage", "san": "storage", "cloud-storage": "storage",
"cloud": "cloud", "aws": "cloud", "azure": "cloud", "gcp": "cloud", "isp": "cloud",
"workstation": "endpoint", "laptop": "endpoint", "tablet": "endpoint",
"phone": "endpoint", "printer": "endpoint",
"ups": "infrastructure", "pdu": "infrastructure", "rack": "infrastructure",
"patch-panel": "infrastructure", "camera": "infrastructure",
"nvr": "infrastructure", "iot": "infrastructure",
}
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"])
async def _get_diagram_or_404(
diagram_id: UUID,
account_id: UUID,
db: AsyncSession,
) -> NetworkDiagram:
diagram = await db.get(NetworkDiagram, diagram_id)
if not diagram or diagram.account_id != account_id or diagram.is_archived:
raise HTTPException(status_code=404, detail="Diagram not found")
return diagram
def _diagram_to_response(diagram: NetworkDiagram) -> NetworkDiagramResponse:
return NetworkDiagramResponse.model_validate(diagram)
def _diagram_to_list_item(
diagram: NetworkDiagram,
custom_slug_category: dict[str, str] | None = None,
) -> NetworkDiagramListItem:
nodes = diagram.nodes if isinstance(diagram.nodes, list) else []
slug_to_cat = {**_SLUG_CATEGORY, **(custom_slug_category or {})}
category_counts: dict[str, int] = {}
for node in nodes:
slug = node.get("type", "") if isinstance(node, dict) else ""
cat = slug_to_cat.get(slug, "other")
category_counts[cat] = category_counts.get(cat, 0) + 1
return NetworkDiagramListItem(
id=diagram.id,
name=diagram.name,
client_name=diagram.client_name,
description=diagram.description,
node_count=len(nodes),
category_counts=category_counts,
created_by=diagram.created_by,
created_at=diagram.created_at,
updated_at=diagram.updated_at,
)
async def _get_available_slugs(account_id: UUID, db: AsyncSession) -> set[str]:
stmt = select(DeviceType.slug).where(
or_(
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
DeviceType.account_id == account_id,
)
)
result = await db.execute(stmt)
return {row[0] for row in result.all()}
@router.get("/clients", response_model=list[str])
async def list_client_names(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> list[str]:
stmt = (
select(NetworkDiagram.client_name)
.where(
NetworkDiagram.account_id == current_user.account_id,
NetworkDiagram.is_archived.is_(False),
NetworkDiagram.client_name.isnot(None),
NetworkDiagram.client_name != "",
)
.distinct()
.order_by(NetworkDiagram.client_name)
)
result = await db.execute(stmt)
return [row[0] for row in result.all()]
@router.get("/", response_model=list[NetworkDiagramListItem])
async def list_diagrams(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
client_name: str | None = Query(default=None),
search: str | None = Query(default=None),
) -> list[NetworkDiagramListItem]:
stmt = (
select(NetworkDiagram)
.where(
NetworkDiagram.account_id == current_user.account_id,
NetworkDiagram.is_archived.is_(False),
)
.order_by(NetworkDiagram.updated_at.desc())
)
if client_name:
stmt = stmt.where(NetworkDiagram.client_name == client_name)
if search:
escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
search_filter = f"%{escaped}%"
stmt = stmt.where(
or_(
NetworkDiagram.name.ilike(search_filter),
NetworkDiagram.client_name.ilike(search_filter),
)
)
# Single query for custom device types so category_counts is accurate
dt_stmt = select(DeviceType.slug, DeviceType.category).where(
DeviceType.is_system.is_(False),
DeviceType.account_id == current_user.account_id,
)
dt_result = await db.execute(dt_stmt)
custom_slug_category = {row[0]: row[1] for row in dt_result.all()}
result = await db.execute(stmt)
rows = result.scalars().all()
return [_diagram_to_list_item(r, custom_slug_category) for r in rows]
@router.post("/", response_model=NetworkDiagramResponse, status_code=201)
async def create_diagram(
data: NetworkDiagramCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
diagram = NetworkDiagram(
account_id=current_user.account_id,
name=data.name,
client_name=data.client_name,
asset_name=data.asset_name,
description=data.description,
nodes=[n.model_dump() for n in data.nodes],
edges=[e.model_dump() for e in data.edges],
created_by=current_user.id,
)
db.add(diagram)
await db.commit()
await db.refresh(diagram)
return _diagram_to_response(diagram)
@router.get("/{diagram_id}", response_model=NetworkDiagramResponse)
async def get_diagram(
diagram_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
return _diagram_to_response(diagram)
@router.put("/{diagram_id}", response_model=NetworkDiagramResponse)
async def update_diagram(
diagram_id: UUID,
data: NetworkDiagramUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
update_data = data.model_dump(exclude_unset=True)
if "nodes" in update_data and update_data["nodes"] is not None:
update_data["nodes"] = [n.model_dump() if hasattr(n, "model_dump") else n for n in update_data["nodes"]]
if "edges" in update_data and update_data["edges"] is not None:
update_data["edges"] = [e.model_dump() if hasattr(e, "model_dump") else e for e in update_data["edges"]]
for field, value in update_data.items():
setattr(diagram, field, value)
diagram.updated_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(diagram)
return _diagram_to_response(diagram)
@router.delete("/{diagram_id}", status_code=204)
async def archive_diagram(
diagram_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> None:
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
diagram.is_archived = True
diagram.updated_at = datetime.now(timezone.utc)
await db.commit()
@router.post("/{diagram_id}/duplicate", response_model=NetworkDiagramResponse, status_code=201)
async def duplicate_diagram(
diagram_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
source = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
copy = NetworkDiagram(
account_id=current_user.account_id,
name=f"Copy of {source.name}",
client_name=source.client_name,
asset_name=source.asset_name,
description=source.description,
nodes=source.nodes,
edges=source.edges,
created_by=current_user.id,
)
db.add(copy)
await db.commit()
await db.refresh(copy)
return _diagram_to_response(copy)
@router.get("/{diagram_id}/export", response_model=DiagramExportResponse)
async def export_diagram(
diagram_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramExportResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
nodes = [DiagramNode(**n) for n in (diagram.nodes or [])]
edges = [DiagramEdge(**e) for e in (diagram.edges or [])]
return DiagramExportResponse(
schemaVersion=1,
name=diagram.name,
client_name=diagram.client_name,
description=diagram.description,
nodes=nodes,
edges=edges,
exportedAt=datetime.now(timezone.utc).isoformat(),
)
@router.post("/import", response_model=DiagramImportResponse, status_code=201)
async def import_diagram(
data: DiagramImportRequest,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramImportResponse:
available_slugs = await _get_available_slugs(current_user.account_id, db)
warnings: list[str] = []
for node in data.nodes:
if node.type not in available_slugs:
warnings.append(f"Unknown device type '{node.type}' — will render with default icon")
diagram = NetworkDiagram(
account_id=current_user.account_id,
name=data.name,
client_name=data.client_name,
description=data.description,
nodes=[n.model_dump() for n in data.nodes],
edges=[e.model_dump() for e in data.edges],
created_by=current_user.id,
)
db.add(diagram)
await db.commit()
await db.refresh(diagram)
return DiagramImportResponse(
diagram=_diagram_to_response(diagram),
warnings=warnings,
)
@router.post("/ai-generate", response_model=AIGenerateResponse)
async def ai_generate_diagram(
data: AIGenerateRequest,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> AIGenerateResponse:
available_slugs_set = await _get_available_slugs(current_user.account_id, db)
available_slugs = list(available_slugs_set)
existing_node_ids: list[str] | None = None
if data.mode == "merge" and data.existingBounds:
existing_node_ids = []
try:
return await network_diagram_ai_service.generate_diagram(
request=data,
available_slugs=available_slugs,
existing_node_ids=existing_node_ids,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception:
logger.exception("AI diagram generation failed")
raise HTTPException(status_code=500, detail="Diagram generation failed")

View File

@@ -24,6 +24,7 @@ from app.api.endpoints import (
branding,
categories,
copilot,
device_types,
feedback,
flow_proposals,
flowpilot_analytics,
@@ -32,6 +33,7 @@ from app.api.endpoints import (
invite,
kb_accelerator,
maintenance_schedules,
network_diagrams,
notifications,
onboarding,
public_templates,
@@ -93,7 +95,6 @@ api_router.include_router(admin_settings.router)
api_router.include_router(admin_categories.router)
api_router.include_router(admin_survey.router)
api_router.include_router(admin_gallery.router)
# ---------------------------------------------------------------------------
# User-facing endpoints — tenant context required
# ---------------------------------------------------------------------------
@@ -130,6 +131,7 @@ api_router.include_router(integrations.router, dependencies=_tenant_deps)
api_router.include_router(onboarding.router, dependencies=_tenant_deps)
api_router.include_router(branding.router, dependencies=_tenant_deps)
api_router.include_router(supporting_data.router, dependencies=_tenant_deps)
api_router.include_router(network_diagrams.router, dependencies=_tenant_deps)
# session_handoffs queue router must come before ai_sessions to avoid conflict
api_router.include_router(session_handoffs.queue_router, dependencies=_tenant_deps)
api_router.include_router(session_resolutions.router, dependencies=_tenant_deps)
@@ -142,3 +144,4 @@ api_router.include_router(script_builder.router, dependencies=_tenant_deps)
api_router.include_router(beta_feedback.router, dependencies=_tenant_deps)
api_router.include_router(session_branches.router, dependencies=_tenant_deps)
api_router.include_router(session_handoffs.router, dependencies=_tenant_deps)
api_router.include_router(device_types.router, dependencies=_tenant_deps)

View File

@@ -128,6 +128,7 @@ class Settings(BaseSettings):
"variable_inference": "fast",
"kb_convert": "standard",
"script_build": "standard",
"network_diagram_generate": "standard",
}
def get_model_for_action(self, action_type: str) -> str:

View File

@@ -56,6 +56,8 @@ from .session_handoff import SessionHandoff
from .session_resolution_output import SessionResolutionOutput
from .template_tree import TemplateTree
from .platform_step import PlatformStep
from .device_type import DeviceType
from .network_diagram import NetworkDiagram
__all__ = [
"User",
@@ -126,4 +128,6 @@ __all__ = [
"SessionResolutionOutput",
"TemplateTree",
"PlatformStep",
"DeviceType",
"NetworkDiagram",
]

View File

@@ -0,0 +1,47 @@
"""Device type model for network diagrams."""
import uuid
from datetime import datetime, timezone
from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
class DeviceType(Base):
"""A device type for network diagram nodes (platform or account-custom)."""
__tablename__ = "device_types"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
slug: Mapped[str] = mapped_column(
String(50), nullable=False,
comment="Unique identifier used in diagram node data",
)
label: Mapped[str] = mapped_column(
String(100), nullable=False,
comment="Display name",
)
category: Mapped[str] = mapped_column(
String(50), nullable=False,
comment="network, compute, storage, cloud, endpoint, infrastructure, security",
)
is_system: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False,
comment="True for built-in types that cannot be deleted",
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
comment="Platform account for system types, tenant account for custom types",
)
sort_order: Mapped[int] = mapped_column(
Integer, nullable=False, default=0,
comment="Display order within category",
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)

View File

@@ -0,0 +1,53 @@
"""Network diagram model."""
import uuid
from datetime import datetime, timezone
from typing import Any, TYPE_CHECKING
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.core.database import Base
if TYPE_CHECKING:
from app.models.user import User
class NetworkDiagram(Base):
"""A network topology diagram scoped to one account."""
__tablename__ = "network_diagrams"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
client_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True)
is_archived: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False,
)
created_by: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id"),
nullable=True,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by])

View File

@@ -0,0 +1,37 @@
"""Pydantic schemas for device types."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, Field
class DeviceTypeCreate(BaseModel):
slug: str = Field(min_length=1, max_length=50, pattern=r"^[a-z0-9\-]+$")
label: str = Field(min_length=1, max_length=100)
category: str = Field(
min_length=1, max_length=50,
pattern=r"^(network|compute|storage|cloud|endpoint|infrastructure|security)$",
)
sort_order: int = Field(default=0, ge=0)
class DeviceTypeUpdate(BaseModel):
label: str | None = Field(default=None, min_length=1, max_length=100)
category: str | None = Field(
default=None, min_length=1, max_length=50,
pattern=r"^(network|compute|storage|cloud|endpoint|infrastructure|security)$",
)
sort_order: int | None = Field(default=None, ge=0)
class DeviceTypeResponse(BaseModel):
id: UUID
slug: str
label: str
category: str
is_system: bool
account_id: UUID
sort_order: int
created_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,136 @@
"""Pydantic schemas for network diagrams."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, Field
class Position(BaseModel):
x: float
y: float
class DeviceProperties(BaseModel):
hostname: str | None = None
ip: str | None = None
subnet: str | None = None
vendor: str | None = None
model: str | None = None
role: str | None = None
vlan: str | None = None
notes: str | None = None
status: str = Field(default="unknown", pattern=r"^(unknown|online|offline|degraded)$")
class DiagramNode(BaseModel):
id: str
type: str
label: str
position: Position
properties: DeviceProperties = Field(default_factory=DeviceProperties)
class DiagramEdge(BaseModel):
id: str
source: str
target: str
label: str | None = None
connectionType: str = "ethernet"
speed: str | None = None
notes: str | None = None
routing: str | None = None
class NetworkDiagramCreate(BaseModel):
name: str = Field(min_length=1, max_length=255)
client_name: str | None = None
asset_name: str | None = None
description: str | None = None
nodes: list[DiagramNode] = Field(default_factory=list)
edges: list[DiagramEdge] = Field(default_factory=list)
class NetworkDiagramUpdate(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=255)
client_name: str | None = None
asset_name: str | None = None
description: str | None = None
nodes: list[DiagramNode] | None = None
edges: list[DiagramEdge] | None = None
class NetworkDiagramResponse(BaseModel):
id: UUID
account_id: UUID
name: str
client_name: str | None = None
asset_name: str | None = None
description: str | None = None
nodes: list[DiagramNode] = Field(default_factory=list)
edges: list[DiagramEdge] = Field(default_factory=list)
thumbnail_url: str | None = None
is_archived: bool = False
created_by: UUID | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class NetworkDiagramListItem(BaseModel):
id: UUID
name: str
client_name: str | None = None
description: str | None = None
node_count: int = 0
category_counts: dict[str, int] = Field(default_factory=dict)
created_by: UUID | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class ExistingBounds(BaseModel):
minX: float
maxX: float
minY: float
maxY: float
class AIGenerateRequest(BaseModel):
description: str = Field(min_length=1, max_length=5000)
client_name: str | None = None
mode: str = Field(default="replace", pattern=r"^(replace|merge)$")
existingBounds: ExistingBounds | None = None
class AIGenerateResponse(BaseModel):
nodes: list[DiagramNode]
edges: list[DiagramEdge]
suggestedName: str | None = None
notes: str | None = None
class DiagramImportRequest(BaseModel):
schemaVersion: int = Field(ge=1, le=1)
name: str = Field(min_length=1, max_length=255)
client_name: str | None = None
description: str | None = None
nodes: list[DiagramNode] = Field(default_factory=list)
edges: list[DiagramEdge] = Field(default_factory=list)
class DiagramImportResponse(BaseModel):
diagram: NetworkDiagramResponse
warnings: list[str] = Field(default_factory=list)
class DiagramExportResponse(BaseModel):
schemaVersion: int = 1
name: str
client_name: str | None = None
description: str | None = None
nodes: list[DiagramNode]
edges: list[DiagramEdge]
exportedAt: str

View File

@@ -0,0 +1,151 @@
"""AI service for generating network diagrams from natural language."""
import json
import logging
from app.core.ai_provider import get_ai_provider
from app.core.config import settings
from app.schemas.network_diagram import (
AIGenerateRequest,
AIGenerateResponse,
DiagramNode,
DiagramEdge,
DeviceProperties,
Position,
)
logger = logging.getLogger(__name__)
SYSTEM_PROMPT_TEMPLATE = """You are a network diagram generator for MSP engineers.
Given a plain English description of a network, you must return ONLY valid JSON with no markdown, no explanation, no preamble.
Return this exact structure:
{{
"nodes": [
{{
"id": "unique-string",
"type": "device-type-slug",
"label": "device label",
"position": {{ "x": number, "y": number }},
"properties": {{
"hostname": "string or null",
"ip": "string or null",
"subnet": "string or null",
"vendor": "string or null",
"model": "string or null",
"role": "string or null",
"vlan": "string or null",
"notes": "string or null",
"status": "unknown"
}}
}}
],
"edges": [
{{
"id": "unique-string",
"source": "node-id",
"target": "node-id",
"label": "connection label or null",
"connectionType": "ethernet|fiber|wifi|vpn|vlan|wan",
"speed": "string or null",
"notes": "string or null"
}}
],
"suggestedName": "short descriptive diagram name",
"notes": "any important assumptions or missing info, or null"
}}
Available device type slugs: {available_slugs}
Position nodes thoughtfully in a logical network topology layout.
Use x/y coordinates between 0 and 1200 for x, 0 and 800 for y.
Place WAN/internet at top, core network in middle, endpoints at bottom.
{merge_instructions}"""
MERGE_INSTRUCTIONS = """
IMPORTANT: You are ADDING devices to an existing diagram. Do NOT replace existing devices.
The existing diagram occupies this bounding box: minX={minX}, maxX={maxX}, minY={minY}, maxY={maxY}.
Place all new nodes OUTSIDE this bounding box — below (y > {maxY} + 100) or to the right (x > {maxX} + 100).
You may create edges that connect new nodes to existing nodes if the description implies a connection.
Use these existing node IDs for connections: {existing_node_ids}"""
async def generate_diagram(
request: AIGenerateRequest,
available_slugs: list[str],
existing_node_ids: list[str] | None = None,
) -> AIGenerateResponse:
merge_instructions = ""
if request.mode == "merge" and request.existingBounds:
b = request.existingBounds
merge_instructions = MERGE_INSTRUCTIONS.format(
minX=b.minX, maxX=b.maxX, minY=b.minY, maxY=b.maxY,
existing_node_ids=", ".join(existing_node_ids or []),
)
system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
available_slugs=", ".join(available_slugs),
merge_instructions=merge_instructions,
)
model = settings.get_model_for_action("network_diagram_generate")
provider = get_ai_provider(model)
messages = [{"role": "user", "content": request.description}]
response_text, input_tokens, output_tokens = await provider.generate_json(
system_prompt=system_prompt,
messages=messages,
max_tokens=4096,
)
logger.info(
"Network diagram AI generation: input_tokens=%d, output_tokens=%d",
input_tokens, output_tokens,
)
try:
data = json.loads(response_text)
except json.JSONDecodeError as e:
logger.error("Failed to parse AI response as JSON: %s", e)
raise ValueError("AI generated an invalid response, please try again")
try:
nodes = []
for raw_node in data.get("nodes", []):
node_type = raw_node.get("type", "server")
if node_type not in available_slugs:
logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
node_type = "server"
nodes.append(DiagramNode(
id=raw_node["id"],
type=node_type,
label=raw_node.get("label", node_type),
position=Position(**raw_node.get("position", {"x": 0, "y": 0})),
properties=DeviceProperties(**{
k: v for k, v in raw_node.get("properties", {}).items()
if k in DeviceProperties.model_fields
}),
))
edges = []
for raw_edge in data.get("edges", []):
edges.append(DiagramEdge(
id=raw_edge["id"],
source=raw_edge["source"],
target=raw_edge["target"],
label=raw_edge.get("label"),
connectionType=raw_edge.get("connectionType", "ethernet"),
speed=raw_edge.get("speed"),
notes=raw_edge.get("notes"),
))
except KeyError as e:
logger.warning("AI response missing required field: %s", e)
raise ValueError(f"AI generated incomplete data (missing {e}), please try again")
return AIGenerateResponse(
nodes=nodes,
edges=edges,
suggestedName=data.get("suggestedName"),
notes=data.get("notes"),
)

View File

@@ -0,0 +1,96 @@
from __future__ import annotations
import uuid
import pytest
from sqlalchemy import select
from app.models.device_type import DeviceType
from app.models.user import User
from app.core.service_account import PLATFORM_ACCOUNT_ID
async def _login_headers(client, email: str, password: str) -> dict[str, str]:
response = await client.post(
"/api/v1/auth/login/json",
json={"email": email, "password": password},
)
assert response.status_code == 200
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
@pytest.mark.asyncio
async def test_device_types_include_platform_and_account_custom(client, test_db, auth_headers, test_user):
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
user = result.scalar_one()
test_db.add(
DeviceType(
id=uuid.uuid4(),
slug="platform-router",
label="Platform Router",
category="network",
is_system=True,
account_id=PLATFORM_ACCOUNT_ID,
sort_order=0,
)
)
await test_db.commit()
create_response = await client.post(
"/api/v1/device-types/",
json={
"slug": "tenant-appliance",
"label": "Tenant Appliance",
"category": "network",
"sort_order": 3,
},
headers=auth_headers,
)
assert create_response.status_code == 201
assert create_response.json()["account_id"] == str(user.account_id)
list_response = await client.get("/api/v1/device-types/", headers=auth_headers)
assert list_response.status_code == 200
payload = list_response.json()
slugs = {item["slug"] for item in payload}
assert "platform-router" in slugs
assert "tenant-appliance" in slugs
@pytest.mark.asyncio
async def test_network_diagrams_are_account_scoped(client, test_db, auth_headers, test_user):
other_user = {
"email": "other-network@example.com",
"password": "TestPassword123!",
"name": "Other Network User",
}
register_response = await client.post("/api/v1/auth/register", json=other_user)
assert register_response.status_code in (200, 201)
other_headers = await _login_headers(client, other_user["email"], other_user["password"])
owner_result = await test_db.execute(select(User).where(User.email == test_user["email"]))
owner = owner_result.scalar_one()
create_response = await client.post(
"/api/v1/network-diagrams/",
json={
"name": "HQ Core",
"client_name": "Acme",
"description": "Primary topology",
"nodes": [],
"edges": [],
},
headers=auth_headers,
)
assert create_response.status_code == 201
diagram = create_response.json()
assert diagram["account_id"] == str(owner.account_id)
own_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=auth_headers)
assert own_get.status_code == 200
other_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=other_headers)
assert other_get.status_code == 404

View File

@@ -23,6 +23,7 @@
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"date-fns": "^4.1.0",
"html-to-image": "^1.11.13",
"immer": "^11.1.3",
"lucide-react": "^0.563.0",
"monaco-editor": "^0.55.1",
@@ -5331,6 +5332,12 @@
"dev": true,
"license": "MIT"
},
"node_modules/html-to-image": {
"version": "1.11.13",
"resolved": "https://registry.npmjs.org/html-to-image/-/html-to-image-1.11.13.tgz",
"integrity": "sha512-cuOPoI7WApyhBElTTb9oqsawRvZ0rHhaHwghRLlTuffoD1B2aDemlCruLeZrUIIdvG7gs9xeELEPm6PhuASqrg==",
"license": "MIT"
},
"node_modules/html-url-attributes": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/html-url-attributes/-/html-url-attributes-3.0.1.tgz",

View File

@@ -36,6 +36,7 @@
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"date-fns": "^4.1.0",
"html-to-image": "^1.11.13",
"immer": "^11.1.3",
"lucide-react": "^0.563.0",
"monaco-editor": "^0.55.1",

Binary file not shown.

After

Width:  |  Height:  |  Size: 686 KiB

View File

@@ -0,0 +1,23 @@
import apiClient from './client'
import type { DeviceTypeResponse, DeviceTypeCreate } from '@/types'
export const deviceTypesApi = {
async list(): Promise<DeviceTypeResponse[]> {
const response = await apiClient.get<DeviceTypeResponse[]>('/device-types/')
return response.data
},
async create(data: DeviceTypeCreate): Promise<DeviceTypeResponse> {
const response = await apiClient.post<DeviceTypeResponse>('/device-types/', data)
return response.data
},
async update(id: string, data: Partial<DeviceTypeCreate>): Promise<DeviceTypeResponse> {
const response = await apiClient.put<DeviceTypeResponse>(`/device-types/${id}`, data)
return response.data
},
async remove(id: string): Promise<void> {
await apiClient.delete(`/device-types/${id}`)
},
}

View File

@@ -35,3 +35,5 @@ export { betaFeedbackApi } from './betaFeedback'
export { branchesApi } from './branches'
export { handoffsApi } from './handoffs'
export { resolutionsApi } from './resolutions'
export { deviceTypesApi } from './deviceTypes'
export { networkDiagramsApi } from './networkDiagrams'

View File

@@ -0,0 +1,63 @@
import apiClient from './client'
import type {
NetworkDiagramResponse,
NetworkDiagramListItem,
NetworkDiagramCreate,
NetworkDiagramUpdate,
AIGenerateRequest,
AIGenerateResponse,
DiagramImportData,
DiagramImportResponse,
DiagramExportResponse,
} from '@/types'
export const networkDiagramsApi = {
async list(params?: { client_name?: string; search?: string }): Promise<NetworkDiagramListItem[]> {
const response = await apiClient.get<NetworkDiagramListItem[]>('/network-diagrams/', { params })
return response.data
},
async get(id: string): Promise<NetworkDiagramResponse> {
const response = await apiClient.get<NetworkDiagramResponse>(`/network-diagrams/${id}`)
return response.data
},
async create(data: NetworkDiagramCreate): Promise<NetworkDiagramResponse> {
const response = await apiClient.post<NetworkDiagramResponse>('/network-diagrams/', data)
return response.data
},
async update(id: string, data: NetworkDiagramUpdate): Promise<NetworkDiagramResponse> {
const response = await apiClient.put<NetworkDiagramResponse>(`/network-diagrams/${id}`, data)
return response.data
},
async archive(id: string): Promise<void> {
await apiClient.delete(`/network-diagrams/${id}`)
},
async duplicate(id: string): Promise<NetworkDiagramResponse> {
const response = await apiClient.post<NetworkDiagramResponse>(`/network-diagrams/${id}/duplicate`)
return response.data
},
async exportJson(id: string): Promise<DiagramExportResponse> {
const response = await apiClient.get<DiagramExportResponse>(`/network-diagrams/${id}/export`)
return response.data
},
async importJson(data: DiagramImportData): Promise<DiagramImportResponse> {
const response = await apiClient.post<DiagramImportResponse>('/network-diagrams/import', data)
return response.data
},
async aiGenerate(data: AIGenerateRequest): Promise<AIGenerateResponse> {
const response = await apiClient.post<AIGenerateResponse>('/network-diagrams/ai-generate', data)
return response.data
},
async listClients(): Promise<string[]> {
const response = await apiClient.get<string[]>('/network-diagrams/clients')
return response.data
},
}

View File

@@ -5,7 +5,7 @@ import {
LayoutGrid, Clock, AlertTriangle, GitBranch, Code2, Wand2,
ListChecks, Download, BarChart3,
Settings, Pin, PinOff,
History, FileText,
History, FileText, Network,
} from 'lucide-react'
import { cn } from '@/lib/utils'
import { useUserPreferencesStore } from '@/store/userPreferencesStore'
@@ -86,10 +86,11 @@ export function Sidebar() {
{
href: '/trees', icon: GitBranch, label: 'Flows', shortLabel: 'Flows',
badge: stats?.tree_counts.total || undefined,
matchPaths: ['/trees', '/flows', '/my-trees', '/step-library', '/review-queue'],
matchPaths: ['/trees', '/flows', '/my-trees', '/step-library', '/review-queue', '/network-diagrams'],
children: [
{ href: '/trees', label: 'Flow Library', count: stats?.tree_counts.total || undefined },
{ href: '/trees?type=procedural', label: 'Projects', count: stats?.tree_counts.procedural || undefined },
{ href: '/network-diagrams', label: 'Network Maps' },
{ href: '/step-library', label: 'Solutions Library' },
{ href: '/review-queue', label: 'Review Queue' },
],
@@ -134,6 +135,7 @@ export function Sidebar() {
{ href: '/trees?type=procedural', label: 'Projects', count: stats?.tree_counts.procedural || undefined },
],
},
{ href: '/network-diagrams', icon: Network, label: 'Network Maps', shortLabel: 'NetMap', matchPaths: ['/network-diagrams'] },
{ href: '/scripts', icon: Code2, label: 'Scripts', shortLabel: 'Scripts' },
{ href: '/script-builder', icon: Wand2, label: 'Script Builder', shortLabel: 'Builder' },
{ href: '/review-queue', icon: ListChecks, label: 'Review Queue', shortLabel: 'Review' },

View File

@@ -0,0 +1,232 @@
import { useState, useCallback, useEffect } from 'react'
import { Sparkles, ArrowRight, PencilRuler, Wand2, X } from 'lucide-react'
import { networkDiagramsApi } from '@/api'
import type { AIGenerateResponse } from '@/types'
const EXAMPLE_PROMPTS = [
'Small office with firewall and core switch',
'Azure hybrid cloud with VPN gateway',
'Branch office connected to HQ via MPLS',
'Data center with redundant core switches',
'Remote workforce with Meraki and cloud apps',
]
interface CanvasEmptyPromptProps {
onGenerate: (result: AIGenerateResponse, mode: 'replace' | 'merge') => void
}
export function CanvasEmptyPrompt({ onGenerate }: CanvasEmptyPromptProps) {
const [mode, setMode] = useState<'choice' | 'ai' | 'manual'>('choice')
const [description, setDescription] = useState('')
const [loading, setLoading] = useState(false)
const [error, setError] = useState<string | null>(null)
const switchToManual = useCallback(() => {
if (loading) return
setMode('manual')
setError(null)
}, [loading])
const handleGenerate = useCallback(async (text?: string) => {
const desc = (text ?? description).trim()
if (!desc) return
setLoading(true)
setError(null)
try {
const result = await networkDiagramsApi.aiGenerate({
description: desc,
mode: 'replace',
existingBounds: null,
})
onGenerate(result, 'replace')
} catch (err: unknown) {
setError(err instanceof Error ? err.message : 'Generation failed. Please try again.')
} finally {
setLoading(false)
}
}, [description, onGenerate])
useEffect(() => {
if (mode === 'manual') return
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
event.preventDefault()
switchToManual()
}
}
window.addEventListener('keydown', handleKeyDown)
return () => window.removeEventListener('keydown', handleKeyDown)
}, [mode, switchToManual])
if (mode === 'manual') {
return (
<div className="pointer-events-none absolute inset-x-0 bottom-6 z-10 flex justify-center px-6">
<div className="pointer-events-auto flex max-w-xl items-center gap-3 rounded-lg border border-default bg-card px-4 py-3 shadow-xl">
<div className="flex h-8 w-8 shrink-0 items-center justify-center rounded-full bg-accent/10 text-accent">
<PencilRuler size={14} />
</div>
<div className="min-w-0 flex-1">
<p className="text-sm font-medium text-heading">Manual mode is on</p>
<p className="text-xs text-muted-foreground">
Drag devices from the left panel onto the canvas, or reopen AI whenever you want.
</p>
</div>
<button
onClick={() => setMode('ai')}
className="inline-flex shrink-0 items-center gap-1 rounded-full border border-default px-3 py-1 text-xs font-medium text-primary hover:border-accent hover:text-accent"
>
<Sparkles size={12} />
Open AI Generator
</button>
</div>
</div>
)
}
return (
<div
className="pointer-events-none absolute inset-0 z-10 flex items-center justify-center bg-[rgba(10,14,20,0.42)] px-6"
onClick={event => {
if (event.target === event.currentTarget) {
switchToManual()
}
}}
>
<div className="pointer-events-auto relative w-full max-w-lg rounded-lg border border-default bg-card p-8 shadow-2xl">
<button
onClick={switchToManual}
disabled={loading}
aria-label="Close AI prompt and build manually"
className="absolute right-4 top-4 inline-flex h-8 w-8 items-center justify-center rounded-full border border-default text-muted-foreground hover:border-hover hover:text-primary disabled:opacity-40"
>
<X size={14} />
</button>
{mode === 'choice' ? (
<>
<div className="mb-6 text-center">
<div className="mb-2 flex items-center justify-center gap-2">
<Wand2 size={16} className="text-accent" />
<h2 className="font-heading text-base font-semibold text-heading">
Start a network map
</h2>
</div>
<p className="text-xs text-muted-foreground">
Generate a topology with AI or start with a blank canvas and build it manually.
</p>
<p className="mt-2 text-[11px] text-muted-foreground/80">
Press <span className="font-medium text-primary">Esc</span> or click outside to skip AI and start dragging devices.
</p>
</div>
<div className="grid gap-3 sm:grid-cols-2">
<button
onClick={() => setMode('ai')}
className="rounded-lg border border-accent/40 bg-accent/10 p-4 text-left transition-colors hover:border-accent hover:bg-accent/15"
>
<div className="mb-3 inline-flex rounded-lg bg-accent/15 p-2 text-accent">
<Sparkles size={16} />
</div>
<div className="mb-1 text-sm font-semibold text-heading">Generate with AI</div>
<p className="text-xs text-muted-foreground">
Describe the environment and let AI lay out the first version for you.
</p>
</button>
<button
onClick={switchToManual}
className="rounded-lg border border-default bg-elevated/40 p-4 text-left transition-colors hover:border-accent hover:bg-elevated/60"
>
<div className="mb-3 inline-flex rounded-lg bg-primary/10 p-2 text-primary">
<PencilRuler size={16} />
</div>
<div className="mb-1 text-sm font-semibold text-heading">Build manually</div>
<p className="text-xs text-muted-foreground">
Close this prompt and use click-and-drag from the left toolbar to place devices on the canvas.
</p>
</button>
</div>
</>
) : (
<>
<div className="mb-5 text-center">
<div className="mb-2 flex items-center justify-center gap-2">
<Sparkles size={16} className="text-accent" />
<h2 className="font-heading text-base font-semibold text-heading">
Describe your network
</h2>
</div>
<p className="text-xs text-muted-foreground">
AI will generate the topology in seconds, or you can go back and switch to manual creation.
</p>
<p className="mt-2 text-[11px] text-muted-foreground/80">
Press <span className="font-medium text-primary">Esc</span>, click outside, or use the close button to build manually instead.
</p>
</div>
<div className="relative mb-3">
<textarea
value={description}
onChange={e => setDescription(e.target.value)}
onKeyDown={e => {
if (e.key === 'Enter' && (e.metaKey || e.ctrlKey)) handleGenerate()
}}
placeholder="e.g. Small office with a firewall, core switch, 3 access points, a file server, and 20 workstations"
rows={3}
disabled={loading}
autoFocus
className="w-full resize-none rounded-lg border border-default bg-input px-4 py-3 pb-7 text-sm text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none disabled:opacity-50"
/>
<span className="pointer-events-none absolute bottom-2 right-3 text-[10px] text-muted-foreground">
to generate
</span>
</div>
<div className="mb-4 flex flex-wrap gap-1.5">
{EXAMPLE_PROMPTS.map(p => (
<button
key={p}
onClick={() => handleGenerate(p)}
disabled={loading}
className="rounded-full border border-default px-3 py-1 text-xs text-muted-foreground transition-colors hover:border-accent hover:text-accent disabled:opacity-40"
>
{p}
</button>
))}
</div>
{error && <p className="mb-3 text-xs text-red-400">{error}</p>}
<div className="flex gap-2">
<button
onClick={switchToManual}
disabled={loading}
className="flex-1 rounded-lg border border-default px-4 py-2.5 text-sm font-medium text-primary hover:border-accent hover:text-accent disabled:opacity-40"
>
Build Manually
</button>
{loading ? (
<div className="flex flex-1 items-center justify-center gap-2 py-2.5">
<div className="h-4 w-4 animate-spin rounded-full border-2 border-accent border-t-transparent" />
<span className="text-sm text-muted-foreground">Mapping your network</span>
</div>
) : (
<button
onClick={() => handleGenerate()}
disabled={!description.trim()}
className="flex flex-1 items-center justify-center gap-2 rounded-lg bg-accent px-4 py-2.5 text-sm font-medium text-white transition-opacity hover:bg-accent/90 disabled:opacity-40"
>
<Sparkles size={14} />
Generate Diagram
<ArrowRight size={14} />
</button>
)}
</div>
</>
)}
</div>
</div>
)
}

View File

@@ -0,0 +1,119 @@
import { useEffect, useRef } from 'react'
import { Copy, CopyPlus, Trash2, ClipboardPaste, BoxSelect, Maximize2, BringToFront, SendToBack } from 'lucide-react'
import { cn } from '@/lib/utils'
interface MenuAction {
label: string
icon: React.ElementType
shortcut: string
onClick: () => void
disabled?: boolean
dividerBefore?: boolean
}
interface ContextMenuProps {
position: { x: number; y: number }
actions: MenuAction[]
onClose: () => void
}
export function ContextMenu({ position, actions, onClose }: ContextMenuProps) {
const menuRef = useRef<HTMLDivElement>(null)
const clampedPosition = { ...position }
if (typeof window !== 'undefined') {
const itemCount = actions.length
const dividerCount = actions.filter(a => a.dividerBefore).length
const menuWidth = 192
const menuHeight = itemCount * 36 + dividerCount * 9 + 8
if (clampedPosition.x + menuWidth > window.innerWidth) {
clampedPosition.x = window.innerWidth - menuWidth - 8
}
if (clampedPosition.y + menuHeight > window.innerHeight) {
clampedPosition.y = window.innerHeight - menuHeight - 8
}
}
useEffect(() => {
const handleClickOutside = (e: MouseEvent) => {
if (menuRef.current && !menuRef.current.contains(e.target as HTMLElement)) {
onClose()
}
}
const handleEscape = (e: KeyboardEvent) => {
if (e.key === 'Escape') onClose()
}
const handleScroll = () => onClose()
document.addEventListener('mousedown', handleClickOutside)
document.addEventListener('keydown', handleEscape)
document.addEventListener('scroll', handleScroll, true)
return () => {
document.removeEventListener('mousedown', handleClickOutside)
document.removeEventListener('keydown', handleEscape)
document.removeEventListener('scroll', handleScroll, true)
}
}, [onClose])
return (
<div
ref={menuRef}
className="fixed z-50 w-48 rounded-lg border border-default bg-card py-1 shadow-lg"
style={{ left: clampedPosition.x, top: clampedPosition.y }}
>
{actions.map((action) => (
<div key={action.label}>
{action.dividerBefore && (
<div className="my-1 border-t border-default" />
)}
<button
onClick={() => {
action.onClick()
onClose()
}}
disabled={action.disabled}
className={cn(
'flex w-full items-center gap-2 px-3 py-2 text-xs text-primary hover:bg-elevated',
action.disabled && 'opacity-40 pointer-events-none',
)}
>
<action.icon size={14} />
<span>{action.label}</span>
<span className="ml-auto text-[10px] text-muted-foreground">{action.shortcut}</span>
</button>
</div>
))}
</div>
)
}
// eslint-disable-next-line react-refresh/only-export-components
export function getNodeMenuActions(handlers: {
onCopy: () => void
onDuplicate: () => void
onBringToFront: () => void
onSendToBack: () => void
onDelete: () => void
}): MenuAction[] {
return [
{ label: 'Copy', icon: Copy, shortcut: 'Ctrl+C', onClick: handlers.onCopy },
{ label: 'Duplicate', icon: CopyPlus, shortcut: 'Ctrl+D', onClick: handlers.onDuplicate },
{ label: 'Bring to Front', icon: BringToFront, shortcut: ']', onClick: handlers.onBringToFront, dividerBefore: true },
{ label: 'Send to Back', icon: SendToBack, shortcut: '[', onClick: handlers.onSendToBack },
{ label: 'Delete', icon: Trash2, shortcut: 'Del', onClick: handlers.onDelete, dividerBefore: true },
]
}
// eslint-disable-next-line react-refresh/only-export-components
export function getCanvasMenuActions(handlers: {
onPaste: () => void
onSelectAll: () => void
onFitView: () => void
hasClipboard: boolean
}): MenuAction[] {
return [
{ label: 'Paste', icon: ClipboardPaste, shortcut: 'Ctrl+V', onClick: handlers.onPaste, disabled: !handlers.hasClipboard },
{ label: 'Select All', icon: BoxSelect, shortcut: 'Ctrl+A', onClick: handlers.onSelectAll },
{ label: 'Fit View', icon: Maximize2, shortcut: '⌘⇧F', onClick: handlers.onFitView },
]
}

View File

@@ -0,0 +1,167 @@
import { useState, useCallback, useRef, useEffect } from 'react'
import { useNavigate } from 'react-router-dom'
import { ChevronLeft, Save, Download, FileJson, Image, FileText } from 'lucide-react'
interface DiagramHeaderProps {
name: string
clientName: string | null
isDirty: boolean
isSaving: boolean
lastSavedAt: Date | null
diagramId: string | null
onNameChange: (name: string) => void
onSave: () => void
onExportPng: () => void
onExportPdf: () => void
onExportJson: () => void
}
export function DiagramHeader({
name,
clientName,
isDirty,
isSaving,
lastSavedAt,
diagramId,
onNameChange,
onSave,
onExportPng,
onExportPdf,
onExportJson,
}: DiagramHeaderProps) {
const navigate = useNavigate()
const [editing, setEditing] = useState(false)
const [editValue, setEditValue] = useState(name)
const [showExportMenu, setShowExportMenu] = useState(false)
const inputRef = useRef<HTMLInputElement>(null)
const exportMenuRef = useRef<HTMLDivElement>(null)
useEffect(() => {
if (editing && inputRef.current) {
inputRef.current.focus()
inputRef.current.select()
}
}, [editing])
useEffect(() => {
setEditValue(name)
}, [name])
useEffect(() => {
if (!showExportMenu) return
const handleClick = (e: MouseEvent) => {
if (exportMenuRef.current && !exportMenuRef.current.contains(e.target as HTMLElement)) {
setShowExportMenu(false)
}
}
document.addEventListener('mousedown', handleClick)
return () => document.removeEventListener('mousedown', handleClick)
}, [showExportMenu])
const handleConfirmName = useCallback(() => {
setEditing(false)
if (editValue.trim() && editValue !== name) {
onNameChange(editValue.trim())
} else {
setEditValue(name)
}
}, [editValue, name, onNameChange])
const formatLastSaved = () => {
if (!lastSavedAt) return null
// eslint-disable-next-line react-hooks/purity
const diff = Date.now() - lastSavedAt.getTime()
if (diff < 60_000) return 'Saved just now'
const mins = Math.floor(diff / 60_000)
return `Saved ${mins}m ago`
}
return (
<div className="flex h-14 items-center gap-3 border-b border-default bg-card px-4">
<button
onClick={() => navigate('/network-diagrams')}
className="flex items-center gap-1 text-xs text-muted-foreground hover:text-primary"
>
<ChevronLeft size={16} />
Network Maps
</button>
<div className="mx-2 h-5 w-px bg-border-default" />
{editing ? (
<input
ref={inputRef}
value={editValue}
onChange={e => setEditValue(e.target.value)}
onBlur={handleConfirmName}
onKeyDown={e => { if (e.key === 'Enter') handleConfirmName(); if (e.key === 'Escape') { setEditing(false); setEditValue(name) } }}
className="rounded border border-accent bg-input px-2 py-1 text-sm font-heading font-semibold text-heading focus:outline-none"
/>
) : (
<button
onClick={() => setEditing(true)}
className="text-sm font-heading font-semibold text-heading hover:text-accent"
>
{name || 'Untitled Diagram'}
</button>
)}
{clientName && (
<span className="rounded-full bg-elevated px-2 py-0.5 text-[10px] text-muted-foreground">
{clientName}
</span>
)}
<div className="flex-1" />
{isDirty && !isSaving ? (
<span className="text-[10px] text-amber-400">Unsaved changes</span>
) : lastSavedAt ? (
<span className="text-[10px] text-muted-foreground">{formatLastSaved()}</span>
) : null}
<button
onClick={onSave}
disabled={isSaving}
className="flex items-center gap-1.5 rounded bg-accent px-3 py-1.5 text-xs font-medium text-white hover:bg-accent/90 disabled:opacity-50"
>
<Save size={14} />
{isSaving ? 'Saving...' : 'Save'}
</button>
<div className="relative" ref={exportMenuRef}>
<button
onClick={() => setShowExportMenu(prev => !prev)}
className="flex items-center gap-1.5 rounded border border-default px-3 py-1.5 text-xs text-primary hover:border-hover"
>
<Download size={14} />
Export
</button>
{showExportMenu && (
<div className="absolute right-0 top-full z-50 mt-1 w-40 rounded border border-default bg-card py-1 shadow-lg">
<button
onClick={() => { onExportPng(); setShowExportMenu(false) }}
className="flex w-full items-center gap-2 px-3 py-1.5 text-xs text-primary hover:bg-elevated"
>
<Image size={12} /> Export PNG
</button>
<button
onClick={() => { onExportPdf(); setShowExportMenu(false) }}
className="flex w-full items-center gap-2 px-3 py-1.5 text-xs text-primary hover:bg-elevated"
>
<FileText size={12} /> Export PDF
</button>
{diagramId && (
<button
onClick={() => { onExportJson(); setShowExportMenu(false) }}
className="flex w-full items-center gap-2 px-3 py-1.5 text-xs text-primary hover:bg-elevated"
>
<FileJson size={12} /> Export JSON
</button>
)}
</div>
)}
</div>
</div>
)
}

View File

@@ -0,0 +1,119 @@
import { useCallback } from 'react'
import {
ReactFlow,
Background,
Controls,
MiniMap,
BackgroundVariant,
type OnConnect,
type OnNodesChange,
type OnEdgesChange,
type Node,
type Edge,
} from '@xyflow/react'
import { nodeTypes } from './nodes/nodeTypes'
import { edgeTypes } from './edges/edgeTypes'
import { getDeviceRenderConfig } from './nodes/deviceRegistry'
import type { DeviceNodeData } from './nodes/DeviceNode'
interface NetworkCanvasProps {
nodes: Node[]
edges: Edge[]
onNodesChange: OnNodesChange
onEdgesChange: OnEdgesChange
onConnect: OnConnect
onNodeSelect: (nodeId: string | null) => void
onEdgeSelect: (edgeId: string | null) => void
onDrop: (event: React.DragEvent) => void
onDragOver: (event: React.DragEvent) => void
onDragLeave?: (event: React.DragEvent) => void
isDragOver?: boolean
onNodeContextMenu?: (event: React.MouseEvent, node: Node) => void
onPaneContextMenu?: (event: MouseEvent | React.MouseEvent) => void
onPaneClick?: () => void
}
export function NetworkCanvas({
nodes,
edges,
onNodesChange,
onEdgesChange,
onConnect,
onNodeSelect,
onEdgeSelect,
onDrop,
onDragOver,
onDragLeave,
isDragOver,
onNodeContextMenu,
onPaneContextMenu,
onPaneClick: onPaneClickProp,
}: NetworkCanvasProps) {
const handleSelectionChange = useCallback(({ nodes: selectedNodes, edges: selectedEdges }: { nodes: Node[]; edges: Edge[] }) => {
if (selectedNodes.length === 1) {
onNodeSelect(selectedNodes[0].id)
onEdgeSelect(null)
} else if (selectedEdges.length === 1) {
onEdgeSelect(selectedEdges[0].id)
onNodeSelect(null)
} else {
onNodeSelect(null)
onEdgeSelect(null)
}
}, [onNodeSelect, onEdgeSelect])
const handlePaneClick = useCallback(() => {
onNodeSelect(null)
onEdgeSelect(null)
onPaneClickProp?.()
}, [onNodeSelect, onEdgeSelect, onPaneClickProp])
const getNodeColor = useCallback((node: Node) => {
if (node.type === 'group') return 'var(--color-bg-elevated)'
const data = node.data as unknown as DeviceNodeData
return getDeviceRenderConfig(data?.deviceType || '', data?.category).color
}, [])
return (
<div className="relative h-full w-full" onDragLeave={onDragLeave}>
<ReactFlow
nodes={nodes}
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onConnect={onConnect}
onSelectionChange={handleSelectionChange}
onPaneClick={handlePaneClick}
onDrop={onDrop}
onDragOver={onDragOver}
onNodeContextMenu={onNodeContextMenu}
onPaneContextMenu={onPaneContextMenu}
nodeTypes={nodeTypes}
edgeTypes={edgeTypes}
defaultEdgeOptions={{ type: 'connection' }}
deleteKeyCode={['Backspace', 'Delete']}
multiSelectionKeyCode="Shift"
snapToGrid={true}
snapGrid={[20, 20]}
fitView
className="bg-page"
>
<Background variant={BackgroundVariant.Dots} color="var(--color-border-default)" gap={20} size={1} />
<Controls className="!border-default !bg-card [&>button]:!border-default [&>button]:!bg-card [&>button]:!fill-text-primary" />
<MiniMap
nodeColor={getNodeColor}
maskColor="rgba(0,0,0,0.5)"
className="!border-default !bg-card"
position="bottom-right"
/>
</ReactFlow>
{isDragOver && (
<div className="pointer-events-none absolute inset-2 z-10 flex items-center justify-center rounded-lg border-2 border-dashed border-accent/30">
<span className="rounded-md bg-card/80 px-3 py-1.5 text-sm text-muted-foreground">
Drop to add
</span>
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,71 @@
import { memo } from 'react'
import { BaseEdge, EdgeLabelRenderer, getStraightPath, getBezierPath, getSmoothStepPath, type EdgeProps } from '@xyflow/react'
interface ConnectionEdgeData {
connectionType?: string
routing?: string | null
speed?: string | null
notes?: string | null
[key: string]: unknown
}
const CONNECTION_STYLES: Record<string, { stroke: string; strokeDasharray?: string; strokeWidth: number }> = {
ethernet: { stroke: '#60a5fa', strokeWidth: 2 },
fiber: { stroke: '#34d399', strokeWidth: 3 },
wifi: { stroke: '#a78bfa', strokeDasharray: '3,3', strokeWidth: 2 },
vpn: { stroke: '#eab308', strokeDasharray: '8,4', strokeWidth: 2 },
vlan: { stroke: '#848b9b', strokeWidth: 2 },
wan: { stroke: '#f87171', strokeDasharray: '12,4', strokeWidth: 2 },
}
const DEFAULT_STYLE = { stroke: '#848b9b', strokeWidth: 2 }
function getEdgePath(routing: string | null | undefined, props: EdgeProps) {
const base = {
sourceX: props.sourceX,
sourceY: props.sourceY,
sourcePosition: props.sourcePosition,
targetX: props.targetX,
targetY: props.targetY,
targetPosition: props.targetPosition,
}
if (routing === 'curved') return getBezierPath(base)
if (routing === 'step') return getSmoothStepPath(base)
return getStraightPath(base)
}
function ConnectionEdgeComponent(props: EdgeProps) {
const edgeData = props.data as ConnectionEdgeData | undefined
const connectionType = edgeData?.connectionType || 'ethernet'
const style = CONNECTION_STYLES[connectionType] || DEFAULT_STYLE
const [edgePath, labelX, labelY] = getEdgePath(edgeData?.routing, props)
return (
<>
<BaseEdge
path={edgePath}
style={{
...style,
...(props.selected ? { stroke: '#60a5fa', strokeWidth: style.strokeWidth + 1 } : {}),
}}
/>
{props.label && (
<EdgeLabelRenderer>
<div
className="nodrag nopan rounded bg-page px-1.5 py-0.5 text-[10px] text-muted-foreground"
style={{
position: 'absolute',
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
pointerEvents: 'all',
}}
>
{props.label}
</div>
</EdgeLabelRenderer>
)}
</>
)
}
export const ConnectionEdge = memo(ConnectionEdgeComponent)

View File

@@ -0,0 +1,7 @@
import { ConnectionEdge } from './ConnectionEdge'
import { AnimatedSvgEdge } from '../ui/animated-svg-edge'
export const edgeTypes = {
connection: ConnectionEdge,
animated: AnimatedSvgEdge,
}

View File

@@ -0,0 +1,252 @@
import { useCallback, useEffect, useRef } from 'react'
import { useReactFlow, type Node, type Edge } from '@xyflow/react'
interface ClipboardData {
nodes: Array<{
type: string
data: Record<string, unknown>
style?: React.CSSProperties
relativePosition: { x: number; y: number }
}>
edges: Array<{
sourceIndex: number
targetIndex: number
type?: string
data?: Record<string, unknown>
label?: string
}>
}
function generateId(prefix: string): string {
return `${prefix}-${Date.now()}-${Math.random().toString(36).slice(2, 7)}`
}
function isInputFocused(): boolean {
const tag = document.activeElement?.tagName
return tag === 'INPUT' || tag === 'TEXTAREA' || tag === 'SELECT'
}
export function useCanvasShortcuts({
nodes: _nodes, // eslint-disable-line @typescript-eslint/no-unused-vars
edges,
setNodes,
setEdges,
setIsDirty,
canvasRef,
}: {
nodes: Node[]
edges: Edge[]
setNodes: React.Dispatch<React.SetStateAction<Node[]>>
setEdges: React.Dispatch<React.SetStateAction<Edge[]>>
setIsDirty: (dirty: boolean) => void
canvasRef: React.RefObject<HTMLDivElement | null>
}) {
const { getNodes, fitView, screenToFlowPosition, setNodes: rfSetNodes } = useReactFlow()
const clipboardRef = useRef<ClipboardData | null>(null)
const getSelectedNodes = useCallback((): Node[] => {
return getNodes().filter(n => n.selected)
}, [getNodes])
const copyNodes = useCallback(() => {
const selected = getSelectedNodes()
if (selected.length === 0) return
const centroid = {
x: selected.reduce((sum, n) => sum + n.position.x, 0) / selected.length,
y: selected.reduce((sum, n) => sum + n.position.y, 0) / selected.length,
}
const selectedIds = new Set(selected.map(n => n.id))
const clipNodes = selected.map(n => ({
type: n.type || 'device',
data: structuredClone(n.data),
style: n.style ? { ...n.style } : undefined,
relativePosition: {
x: n.position.x - centroid.x,
y: n.position.y - centroid.y,
},
}))
const selectedList = selected.map(n => n.id)
const clipEdges = edges
.filter(e => selectedIds.has(e.source) && selectedIds.has(e.target))
.map(e => ({
sourceIndex: selectedList.indexOf(e.source),
targetIndex: selectedList.indexOf(e.target),
type: e.type,
data: e.data ? structuredClone(e.data) as Record<string, unknown> : undefined,
label: typeof e.label === 'string' ? e.label : undefined,
}))
clipboardRef.current = { nodes: clipNodes, edges: clipEdges }
}, [getSelectedNodes, edges])
const pasteNodes = useCallback(() => {
const clipboard = clipboardRef.current
if (!clipboard || clipboard.nodes.length === 0) return
const canvasEl = canvasRef.current
if (!canvasEl) return
const rect = canvasEl.getBoundingClientRect()
const center = screenToFlowPosition({
x: rect.left + rect.width / 2,
y: rect.top + rect.height / 2,
})
const newNodeIds: string[] = []
const newNodes: Node[] = clipboard.nodes.map(cn => {
const prefix = cn.type === 'group' ? 'group' : 'device'
const id = generateId(prefix)
newNodeIds.push(id)
return {
id,
type: cn.type,
position: {
x: center.x + cn.relativePosition.x,
y: center.y + cn.relativePosition.y,
},
data: structuredClone(cn.data) as Record<string, unknown>,
style: cn.style ? { ...cn.style } : undefined,
selected: true,
}
})
const newEdges: Edge[] = clipboard.edges.map(ce => ({
id: generateId('edge'),
source: newNodeIds[ce.sourceIndex],
target: newNodeIds[ce.targetIndex],
type: ce.type,
data: ce.data ? structuredClone(ce.data) as Record<string, unknown> : undefined,
label: ce.label,
}))
setNodes(nds => [
...nds.map(n => ({ ...n, selected: false })),
...newNodes,
])
setEdges(eds => [...eds, ...newEdges])
setIsDirty(true)
}, [canvasRef, screenToFlowPosition, setNodes, setEdges, setIsDirty])
const duplicateNodes = useCallback(() => {
const selected = getSelectedNodes()
if (selected.length === 0) return
const selectedIds = new Set(selected.map(n => n.id))
const idMap = new Map<string, string>()
const newNodes: Node[] = selected.map(n => {
const prefix = n.type === 'group' ? 'group' : 'device'
const newId = generateId(prefix)
idMap.set(n.id, newId)
return {
id: newId,
type: n.type,
position: { x: n.position.x + 30, y: n.position.y + 30 },
data: structuredClone(n.data) as Record<string, unknown>,
style: n.style ? { ...n.style } : undefined,
selected: true,
}
})
const newEdges: Edge[] = edges
.filter(e => selectedIds.has(e.source) && selectedIds.has(e.target))
.map(e => ({
id: generateId('edge'),
source: idMap.get(e.source)!,
target: idMap.get(e.target)!,
type: e.type,
data: e.data ? structuredClone(e.data) as Record<string, unknown> : undefined,
label: e.label,
}))
setNodes(nds => [
...nds.map(n => ({ ...n, selected: false })),
...newNodes,
])
setEdges(eds => [...eds, ...newEdges])
setIsDirty(true)
}, [getSelectedNodes, edges, setNodes, setEdges, setIsDirty])
const selectAll = useCallback(() => {
rfSetNodes(nds => nds.map(n => ({ ...n, selected: true })))
}, [rfSetNodes])
const deleteSelected = useCallback(() => {
const selected = getSelectedNodes()
if (selected.length === 0) return
const selectedIds = new Set(selected.map(n => n.id))
setNodes(nds => nds.filter(n => !selectedIds.has(n.id)))
setEdges(eds => eds.filter(e => !selectedIds.has(e.source) && !selectedIds.has(e.target)))
setIsDirty(true)
}, [getSelectedNodes, setNodes, setEdges, setIsDirty])
const bringSelectedToFront = useCallback(() => {
const selected = getSelectedNodes()
if (!selected.length) return
const selectedIds = new Set(selected.map(n => n.id))
setNodes(nds => {
const maxZ = Math.max(0, ...nds.map(n => n.zIndex ?? 0))
return nds.map(n => selectedIds.has(n.id) ? { ...n, zIndex: maxZ + 1 } : n)
})
setIsDirty(true)
}, [getSelectedNodes, setNodes, setIsDirty])
const sendSelectedToBack = useCallback(() => {
const selected = getSelectedNodes()
if (!selected.length) return
const selectedIds = new Set(selected.map(n => n.id))
setNodes(nds => {
const minZ = Math.min(0, ...nds.map(n => n.zIndex ?? 0))
return nds.map(n => selectedIds.has(n.id) ? { ...n, zIndex: minZ - 1 } : n)
})
setIsDirty(true)
}, [getSelectedNodes, setNodes, setIsDirty])
useEffect(() => {
const handleKeyDown = (e: KeyboardEvent) => {
if (isInputFocused()) return
const ctrl = e.ctrlKey || e.metaKey
if (ctrl && e.key === 'c') {
e.preventDefault()
copyNodes()
} else if (ctrl && e.key === 'v') {
e.preventDefault()
pasteNodes()
} else if (ctrl && e.key === 'd') {
e.preventDefault()
duplicateNodes()
} else if (ctrl && e.key === 'a') {
e.preventDefault()
selectAll()
} else if (ctrl && e.shiftKey && (e.key === 'f' || e.key === 'F')) {
e.preventDefault()
fitView({ padding: 0.2 })
} else if (e.key === ']' && !ctrl) {
e.preventDefault()
bringSelectedToFront()
} else if (e.key === '[' && !ctrl) {
e.preventDefault()
sendSelectedToBack()
}
}
document.addEventListener('keydown', handleKeyDown)
return () => document.removeEventListener('keydown', handleKeyDown)
}, [copyNodes, pasteNodes, duplicateNodes, selectAll, fitView, bringSelectedToFront, sendSelectedToBack])
return {
copyNodes,
pasteNodes,
duplicateNodes,
selectAll,
deleteSelected,
bringSelectedToFront,
sendSelectedToBack,
hasClipboard: () => clipboardRef.current !== null && clipboardRef.current.nodes.length > 0,
}
}

View File

@@ -0,0 +1,106 @@
import { memo } from 'react'
import { Position, NodeResizer, type NodeProps } from '@xyflow/react'
import { BaseNode, BaseNodeHeader, BaseNodeHeaderTitle, BaseNodeContent } from '../ui/base-node'
import { BaseHandle } from '../ui/base-handle'
import { NodeStatusIndicator, type NodeStatus } from '../ui/node-status-indicator'
import { NodeTooltip, NodeTooltipTrigger, NodeTooltipContent } from '../ui/node-tooltip'
import { getDeviceRenderConfig } from './deviceRegistry'
import type { DeviceProperties } from '@/types'
export interface DeviceNodeData {
label: string
deviceType: string
category?: string
properties: DeviceProperties
[key: string]: unknown
}
function TooltipRow({ label, value }: { label: string; value: string | null | undefined }) {
if (!value) return null
return (
<div className="flex gap-2">
<span className="text-[10px] uppercase tracking-wider text-muted-foreground">{label}</span>
<span className="text-xs font-mono text-primary">{value}</span>
</div>
)
}
const NODE_DEFAULT = 120 // default square side in px
const NODE_MIN = 80 // minimum square side in px
const NODE_MAX = 280 // maximum square side in px
function DeviceNodeComponent({ data, selected, width, height }: NodeProps) {
const nodeData = data as unknown as DeviceNodeData
const { icon: Icon, color } = getDeviceRenderConfig(nodeData.deviceType, nodeData.category)
const status = (nodeData.properties?.status || 'unknown') as NodeStatus
const ip = nodeData.properties?.ip
const props = nodeData.properties || {}
// Use the shorter dimension so content never overflows a non-square node
const size = Math.min(width ?? NODE_DEFAULT, height ?? NODE_DEFAULT)
const scale = size / NODE_DEFAULT
// Icon: 28px at default, clamped to [14, 72]
const iconPx = Math.round(Math.max(14, Math.min(72, scale * 28)))
// Label font: 11px at default, clamped to [9, 20]
const labelPx = Math.max(9, Math.min(20, Math.round(scale * 11)))
// IP font: 9px at default, clamped to [8, 16]
const ipPx = Math.max(8, Math.min(16, Math.round(scale * 9)))
const hasTooltipContent = props.hostname || props.ip || props.vendor || props.model || props.role || props.notes
return (
<>
<NodeResizer
isVisible={selected}
minWidth={NODE_MIN}
minHeight={NODE_MIN}
maxWidth={NODE_MAX}
maxHeight={NODE_MAX}
keepAspectRatio
lineStyle={{ borderColor: 'var(--color-accent)', borderWidth: 1 }}
handleStyle={{ width: 8, height: 8, borderColor: 'var(--color-accent)', background: 'var(--color-card)' }}
/>
<NodeStatusIndicator status={status}>
<NodeTooltip>
<NodeTooltipTrigger>
<BaseNode className="w-full h-full group flex flex-col items-center justify-center">
<BaseNodeHeader className="flex-col gap-1 items-center py-2 px-2">
<Icon size={iconPx} style={{ color }} />
<BaseNodeHeaderTitle className="text-center leading-tight" style={{ fontSize: labelPx }}>
{nodeData.label}
</BaseNodeHeaderTitle>
</BaseNodeHeader>
{ip && (
<BaseNodeContent className="items-center pt-0 pb-1">
<span className="font-mono text-muted-foreground" style={{ fontSize: ipPx }}>{ip}</span>
</BaseNodeContent>
)}
<BaseHandle type="target" position={Position.Top} />
<BaseHandle type="source" position={Position.Bottom} />
<BaseHandle type="target" position={Position.Left} id="left" />
<BaseHandle type="source" position={Position.Right} id="right" />
</BaseNode>
</NodeTooltipTrigger>
{hasTooltipContent && (
<NodeTooltipContent position={Position.Top}>
<div className="flex flex-col gap-1 min-w-[140px]">
<TooltipRow label="Host" value={props.hostname} />
<TooltipRow label="IP" value={props.ip} />
{(props.vendor || props.model) && (
<TooltipRow label="HW" value={[props.vendor, props.model].filter(Boolean).join(' ')} />
)}
<TooltipRow label="Role" value={props.role} />
{props.notes && (
<TooltipRow label="Notes" value={props.notes.length > 100 ? props.notes.slice(0, 100) + '...' : props.notes} />
)}
</div>
</NodeTooltipContent>
)}
</NodeTooltip>
</NodeStatusIndicator>
</>
)
}
export const DeviceNode = memo(DeviceNodeComponent)

View File

@@ -0,0 +1,113 @@
import type { LucideIcon } from 'lucide-react'
import {
Router, Network, BrickWallFire, Wifi, Server, Monitor, Boxes, Package, Cloud,
Printer, Smartphone, HardDrive, Gauge, Database, CloudCog,
Cpu, Tablet, Laptop, BatteryCharging, RectangleVertical,
Cable, Camera, KeyRound, Globe, Video, PlugZap, Radio,
} from 'lucide-react'
export interface DeviceRenderConfig {
icon: LucideIcon
color: string
}
// Category-semantic color palette — each color carries meaning:
// Network (blue) — backbone connectivity layer
// Security (orange) — critical/protective elements
// Compute (emerald)— running workloads and VMs
// Endpoint (amber) — user-facing devices
// Storage (violet) — data at rest
// Cloud (cyan) — external/internet-connected
// Infra (steel) — physical/passive hardware
export const NETWORK_COLOR = '#60a5fa'
export const SECURITY_COLOR = '#f87171'
export const COMPUTE_COLOR = '#34d399'
export const ENDPOINT_COLOR = '#fbbf24'
export const STORAGE_COLOR = '#a78bfa'
export const CLOUD_COLOR = '#67e8f9'
export const INFRA_COLOR = '#94a3b8'
const SYSTEM_DEVICE_ICONS: Record<string, DeviceRenderConfig> = {
// Network layer
'router': { icon: Router, color: NETWORK_COLOR },
'switch': { icon: Network, color: NETWORK_COLOR },
'access-point': { icon: Wifi, color: NETWORK_COLOR },
'load-balancer': { icon: Gauge, color: NETWORK_COLOR },
// Security
'firewall': { icon: BrickWallFire, color: SECURITY_COLOR },
'badge-reader': { icon: KeyRound, color: SECURITY_COLOR },
// Compute
'server': { icon: Server, color: COMPUTE_COLOR },
'vm': { icon: Boxes, color: COMPUTE_COLOR },
'container': { icon: Package, color: COMPUTE_COLOR },
// Storage
'nas': { icon: Database, color: STORAGE_COLOR },
'san': { icon: HardDrive, color: STORAGE_COLOR },
'cloud-storage': { icon: CloudCog, color: STORAGE_COLOR },
// Cloud / Internet
'cloud': { icon: Cloud, color: CLOUD_COLOR },
'aws': { icon: Cloud, color: CLOUD_COLOR },
'azure': { icon: Cloud, color: CLOUD_COLOR },
'gcp': { icon: Cloud, color: CLOUD_COLOR },
'isp': { icon: Globe, color: CLOUD_COLOR },
// Endpoints
'workstation': { icon: Monitor, color: ENDPOINT_COLOR },
'laptop': { icon: Laptop, color: ENDPOINT_COLOR },
'tablet': { icon: Tablet, color: ENDPOINT_COLOR },
'phone': { icon: Smartphone, color: ENDPOINT_COLOR },
'printer': { icon: Printer, color: ENDPOINT_COLOR },
// Infrastructure / physical
'ups': { icon: BatteryCharging, color: INFRA_COLOR },
'pdu': { icon: PlugZap, color: INFRA_COLOR },
'rack': { icon: RectangleVertical, color: INFRA_COLOR },
'patch-panel': { icon: Cable, color: INFRA_COLOR },
'camera': { icon: Camera, color: INFRA_COLOR },
'nvr': { icon: Video, color: INFRA_COLOR },
'iot': { icon: Radio, color: INFRA_COLOR },
}
const CATEGORY_DEFAULTS: Record<string, DeviceRenderConfig> = {
'network': { icon: Router, color: NETWORK_COLOR },
'compute': { icon: Server, color: COMPUTE_COLOR },
'storage': { icon: Database, color: STORAGE_COLOR },
'cloud': { icon: Cloud, color: CLOUD_COLOR },
'endpoint': { icon: Monitor, color: ENDPOINT_COLOR },
'infrastructure': { icon: PlugZap, color: INFRA_COLOR },
'security': { icon: BrickWallFire, color: SECURITY_COLOR },
}
const FALLBACK: DeviceRenderConfig = { icon: Cpu, color: INFRA_COLOR }
export function getDeviceRenderConfig(slug: string, category?: string): DeviceRenderConfig {
if (SYSTEM_DEVICE_ICONS[slug]) return SYSTEM_DEVICE_ICONS[slug]
if (category && CATEGORY_DEFAULTS[category]) return CATEGORY_DEFAULTS[category]
return FALLBACK
}
export const CATEGORY_LABELS: Record<string, string> = {
'network': 'Network',
'compute': 'Compute',
'storage': 'Storage',
'cloud': 'Cloud',
'endpoint': 'Endpoints',
'infrastructure': 'Infrastructure',
'security': 'Security',
}
export const CATEGORY_COLORS: Record<string, string> = {
'network': NETWORK_COLOR,
'compute': COMPUTE_COLOR,
'storage': STORAGE_COLOR,
'cloud': CLOUD_COLOR,
'endpoint': ENDPOINT_COLOR,
'infrastructure': INFRA_COLOR,
'security': SECURITY_COLOR,
}
export const CATEGORY_ORDER = ['network', 'compute', 'storage', 'cloud', 'endpoint', 'infrastructure', 'security']

View File

@@ -0,0 +1,7 @@
import { DeviceNode } from './DeviceNode'
import { GroupNode } from '../ui/labeled-group-node'
export const nodeTypes = {
device: DeviceNode,
group: GroupNode,
}

View File

@@ -0,0 +1,168 @@
import { useState, useCallback } from 'react'
import { Sparkles, ChevronUp, ChevronDown, AlertTriangle } from 'lucide-react'
import { cn } from '@/lib/utils'
import { networkDiagramsApi } from '@/api'
import type { AIGenerateResponse } from '@/types'
interface AIAssistPanelProps {
onGenerate: (result: AIGenerateResponse, mode: 'replace' | 'merge') => void
getExistingBounds: () => { minX: number; maxX: number; minY: number; maxY: number } | null
hasNodes: boolean
}
export function AIAssistPanel({ onGenerate, getExistingBounds, hasNodes }: AIAssistPanelProps) {
const [expanded, setExpanded] = useState(false)
const [description, setDescription] = useState('')
const [mode, setMode] = useState<'replace' | 'merge'>('replace')
const [loading, setLoading] = useState(false)
const [error, setError] = useState<string | null>(null)
const [replaceConfirm, setReplaceConfirm] = useState(false)
const handleGenerate = useCallback(async () => {
if (!description.trim()) return
setLoading(true)
setError(null)
setReplaceConfirm(false)
try {
const result = await networkDiagramsApi.aiGenerate({
description: description.trim(),
mode,
existingBounds: mode === 'merge' ? getExistingBounds() : null,
})
onGenerate(result, mode)
setDescription('')
setExpanded(false)
} catch (err: unknown) {
const msg = err instanceof Error ? err.message : 'Generation failed. Please try again.'
setError(msg)
} finally {
setLoading(false)
}
}, [description, mode, onGenerate, getExistingBounds])
// Reset confirm state when mode changes or panel collapses
const handleModeChange = (newMode: 'replace' | 'merge') => {
setMode(newMode)
setReplaceConfirm(false)
}
const needsReplaceConfirm = mode === 'replace' && hasNodes
if (!expanded) {
return (
<div className="border-t border-default bg-card">
<button
onClick={() => setExpanded(true)}
className="flex w-full items-center justify-center gap-2 px-4 py-2 text-xs text-muted-foreground hover:text-primary"
>
<Sparkles size={14} />
AI Generate
<ChevronUp size={14} />
</button>
</div>
)
}
return (
<div className="border-t border-default bg-card">
<div className="flex items-center justify-between border-b border-default px-4 py-2">
<div className="flex items-center gap-2 text-xs font-medium text-heading">
<Sparkles size={14} />
AI Generate
</div>
<button
onClick={() => { setExpanded(false); setReplaceConfirm(false) }}
className="text-muted-foreground hover:text-primary"
>
<ChevronDown size={14} />
</button>
</div>
<div className="flex flex-col gap-3 p-4">
<div className="flex gap-2">
<button
onClick={() => handleModeChange('replace')}
className={cn(
'rounded px-3 py-1 text-xs font-medium transition-colors',
mode === 'replace'
? 'bg-accent text-white'
: 'border border-default text-muted-foreground hover:text-primary',
)}
>
Generate New
</button>
<button
onClick={() => handleModeChange('merge')}
className={cn(
'rounded px-3 py-1 text-xs font-medium transition-colors',
mode === 'merge'
? 'bg-accent text-white'
: 'border border-default text-muted-foreground hover:text-primary',
)}
>
Add to Diagram
</button>
</div>
{needsReplaceConfirm && (
<div className="flex items-start gap-2 rounded border border-yellow-500/30 bg-yellow-500/5 px-3 py-2">
<AlertTriangle size={14} className="mt-0.5 shrink-0 text-yellow-400" />
<p className="text-[11px] text-yellow-400">
This will replace your current diagram. Save first if needed.
</p>
</div>
)}
<textarea
value={description}
onChange={e => setDescription(e.target.value)}
placeholder="Describe the network you want to create... e.g. 'Small office with a firewall, core switch, 3 access points, and a file server'"
rows={3}
disabled={loading}
className="w-full resize-none rounded border border-default bg-input px-3 py-2 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none disabled:opacity-50"
/>
{error && <p className="text-[11px] text-red-400">{error}</p>}
{loading ? (
<div className="flex items-center justify-center gap-2 py-2">
<div className="h-4 w-4 animate-spin rounded-full border-2 border-accent border-t-transparent" />
<span className="text-xs text-muted-foreground">Generating your network diagram</span>
</div>
) : needsReplaceConfirm && !replaceConfirm ? (
<button
onClick={() => setReplaceConfirm(true)}
disabled={!description.trim()}
className="rounded border border-yellow-500/40 bg-yellow-500/10 px-4 py-2 text-xs font-medium text-yellow-400 hover:bg-yellow-500/20 disabled:opacity-50"
>
Replace Diagram
</button>
) : needsReplaceConfirm && replaceConfirm ? (
<div className="flex gap-2">
<button
onClick={() => setReplaceConfirm(false)}
className="flex-1 rounded border border-default px-3 py-2 text-xs text-primary hover:bg-elevated"
>
Cancel
</button>
<button
onClick={handleGenerate}
disabled={!description.trim()}
className="flex-1 rounded bg-red-500/20 px-3 py-2 text-xs font-medium text-red-400 hover:bg-red-500/30 disabled:opacity-50"
>
Yes, Replace
</button>
</div>
) : (
<button
onClick={handleGenerate}
disabled={!description.trim()}
className="rounded bg-accent px-4 py-2 text-xs font-medium text-white hover:bg-accent/90 disabled:opacity-50"
>
Generate
</button>
)}
</div>
</div>
)
}

View File

@@ -0,0 +1,227 @@
import { useState, useMemo, useCallback } from 'react'
import { Search, Plus, ChevronDown, ChevronRight, X, LayoutGrid, GripVertical, Globe } from 'lucide-react'
import { getDeviceRenderConfig, CATEGORY_LABELS, CATEGORY_ORDER } from '../nodes/deviceRegistry'
import type { DeviceTypeResponse, DeviceTypeCreate } from '@/types'
import { deviceTypesApi } from '@/api'
interface DeviceToolbarProps {
deviceTypes: DeviceTypeResponse[]
onDeviceTypesChange: () => void
}
export function DeviceToolbar({ deviceTypes, onDeviceTypesChange }: DeviceToolbarProps) {
const [search, setSearch] = useState('')
const [collapsedCategories, setCollapsedCategories] = useState<Set<string>>(new Set())
const [showAddForm, setShowAddForm] = useState(false)
const [newType, setNewType] = useState<DeviceTypeCreate>({ slug: '', label: '', category: 'network' })
const [addError, setAddError] = useState<string | null>(null)
const [addLoading, setAddLoading] = useState(false)
const filteredByCategory = useMemo(() => {
const lower = search.toLowerCase()
const filtered = search
? deviceTypes.filter(dt => dt.label.toLowerCase().includes(lower) || dt.slug.toLowerCase().includes(lower))
: deviceTypes
const grouped: Record<string, DeviceTypeResponse[]> = {}
for (const dt of filtered) {
if (!grouped[dt.category]) grouped[dt.category] = []
grouped[dt.category].push(dt)
}
return grouped
}, [deviceTypes, search])
const toggleCategory = useCallback((cat: string) => {
setCollapsedCategories(prev => {
const next = new Set(prev)
if (next.has(cat)) next.delete(cat)
else next.add(cat)
return next
})
}, [])
const handleDragStart = useCallback((e: React.DragEvent, deviceType: DeviceTypeResponse) => {
e.dataTransfer.setData('application/reactflow-device', JSON.stringify({
slug: deviceType.slug,
label: deviceType.label,
category: deviceType.category,
}))
e.dataTransfer.effectAllowed = 'move'
}, [])
const handleAddType = useCallback(async () => {
if (!newType.slug || !newType.label) {
setAddError('Slug and label are required')
return
}
setAddLoading(true)
setAddError(null)
try {
await deviceTypesApi.create(newType)
setNewType({ slug: '', label: '', category: 'network' })
setShowAddForm(false)
onDeviceTypesChange()
} catch (err: unknown) {
const msg = err instanceof Error ? err.message : 'Failed to create device type'
setAddError(msg)
} finally {
setAddLoading(false)
}
}, [newType, onDeviceTypesChange])
return (
<div className="flex h-full w-[200px] flex-col border-r border-default bg-sidebar">
<div className="relative p-2">
<Search size={14} className="absolute left-4 top-1/2 -translate-y-1/2 text-muted-foreground" />
<input
type="text"
placeholder="Search devices..."
value={search}
onChange={e => setSearch(e.target.value)}
className="w-full rounded-md border border-default bg-input pl-8 pr-2 py-1.5 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none"
/>
</div>
<div className="flex-1 overflow-y-auto px-2 pb-2">
{CATEGORY_ORDER.map(cat => {
const items = filteredByCategory[cat] || []
const isCloud = cat === 'cloud'
const ispMatchesSearch = !search || 'isp'.includes(search.toLowerCase()) || 'internet service provider'.includes(search.toLowerCase())
const showIsp = isCloud && ispMatchesSearch
if (!items.length && !showIsp) return null
const collapsed = collapsedCategories.has(cat)
const totalCount = items.length + (showIsp ? 1 : 0)
return (
<div key={cat} className="mb-1">
<button
onClick={() => toggleCategory(cat)}
className="flex w-full items-center gap-1 rounded px-1 py-1 text-[10px] font-semibold uppercase tracking-wider text-muted-foreground hover:text-primary"
>
{collapsed ? <ChevronRight size={12} /> : <ChevronDown size={12} />}
{CATEGORY_LABELS[cat] || cat}
<span className="ml-auto text-[10px] font-normal">{totalCount}</span>
</button>
{!collapsed && (
<div className="flex flex-col gap-0.5">
{items.map(dt => {
const { icon: Icon, color } = getDeviceRenderConfig(dt.slug, dt.category)
return (
<div
key={dt.id}
draggable
onDragStart={e => handleDragStart(e, dt)}
className="flex cursor-grab items-center gap-2 rounded px-2 py-1.5 text-xs text-primary hover:bg-elevated active:cursor-grabbing active:scale-[0.98] transition-transform"
>
<GripVertical size={12} className="shrink-0 text-muted-foreground/50" />
<Icon size={14} style={{ color }} />
<span>{dt.label}</span>
</div>
)
})}
{showIsp && (
<div
draggable
onDragStart={e => {
e.dataTransfer.setData('application/reactflow-device', JSON.stringify({
slug: 'isp',
label: 'ISP',
category: 'cloud',
}))
e.dataTransfer.effectAllowed = 'move'
}}
className="flex cursor-grab items-center gap-2 rounded px-2 py-1.5 text-xs text-primary hover:bg-elevated active:cursor-grabbing active:scale-[0.98] transition-transform"
>
<GripVertical size={12} className="shrink-0 text-muted-foreground/50" />
<Globe size={14} style={{ color: 'var(--color-accent)' }} />
<span>ISP</span>
</div>
)}
</div>
)}
</div>
)
})}
</div>
{/* Grouping section */}
<div className="mb-1 mt-2 border-t border-default pt-2">
<div className="flex items-center gap-1 px-1 py-1 text-[10px] font-semibold uppercase tracking-wider text-muted-foreground">
Grouping
</div>
<div className="flex flex-col gap-0.5">
{[
{ slug: 'subnet', label: 'Subnet' },
{ slug: 'vlan', label: 'VLAN' },
{ slug: 'site', label: 'Site' },
{ slug: 'dmz', label: 'DMZ' },
].map(item => (
<div
key={item.slug}
draggable
onDragStart={e => {
e.dataTransfer.setData('application/reactflow-group', JSON.stringify(item))
e.dataTransfer.effectAllowed = 'move'
}}
className="flex cursor-grab items-center gap-2 rounded px-2 py-1.5 text-xs text-primary hover:bg-elevated active:cursor-grabbing active:scale-[0.98] transition-transform"
>
<GripVertical size={12} className="shrink-0 text-muted-foreground/50" />
<LayoutGrid size={14} className="text-muted-foreground" />
<span>{item.label}</span>
</div>
))}
</div>
</div>
<div className="border-t border-default p-2">
{!showAddForm ? (
<button
onClick={() => setShowAddForm(true)}
className="flex w-full items-center justify-center gap-1 rounded border border-default px-2 py-1.5 text-xs text-muted-foreground hover:border-hover hover:text-primary"
>
<Plus size={12} />
Custom Type
</button>
) : (
<div className="flex flex-col gap-1.5">
<div className="flex items-center justify-between">
<span className="text-[10px] font-semibold uppercase tracking-wider text-muted-foreground">New Type</span>
<button onClick={() => { setShowAddForm(false); setAddError(null) }} className="text-muted-foreground hover:text-primary">
<X size={12} />
</button>
</div>
<input
placeholder="slug (e.g. pacs-server)"
value={newType.slug}
onChange={e => setNewType(prev => ({ ...prev, slug: e.target.value.toLowerCase().replace(/[^a-z0-9-]/g, '') }))}
className="rounded border border-default bg-input px-2 py-1 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none"
/>
<input
placeholder="Label (e.g. PACS Server)"
value={newType.label}
onChange={e => setNewType(prev => ({ ...prev, label: e.target.value }))}
className="rounded border border-default bg-input px-2 py-1 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none"
/>
<select
value={newType.category}
onChange={e => setNewType(prev => ({ ...prev, category: e.target.value }))}
className="rounded border border-default bg-input px-2 py-1 text-xs text-primary focus:border-accent focus:outline-none"
>
{CATEGORY_ORDER.map(c => (
<option key={c} value={c}>{CATEGORY_LABELS[c]}</option>
))}
</select>
{addError && <p className="text-[10px] text-red-400">{addError}</p>}
<button
onClick={handleAddType}
disabled={addLoading}
className="rounded bg-accent px-2 py-1 text-xs font-medium text-white hover:bg-accent/90 disabled:opacity-50"
>
{addLoading ? 'Adding...' : 'Add Type'}
</button>
</div>
)}
</div>
</div>
)
}

View File

@@ -0,0 +1,412 @@
import { useCallback, useState, useEffect } from 'react'
import { Trash2, Minus, Spline, GitBranch, BringToFront, SendToBack } from 'lucide-react'
import { cn } from '@/lib/utils'
import type { DeviceProperties, DiagramEdge } from '@/types'
import type { Node, Edge } from '@xyflow/react'
import type { DeviceNodeData } from '../nodes/DeviceNode'
interface PropertiesPanelProps {
selectedNode: Node | null
selectedEdge: Edge | null
onNodeUpdate: (nodeId: string, data: Partial<DeviceNodeData>) => void
onEdgeUpdate: (edgeId: string, data: Partial<DiagramEdge>) => void
onEdgeTypeChange: (edgeId: string, edgeType: string) => void
onBringToFront: (nodeId: string) => void
onSendToBack: (nodeId: string) => void
onDeleteNode: (nodeId: string) => void
onDeleteEdge: (edgeId: string) => void
}
type NodeStatus = 'online' | 'offline' | 'degraded' | 'unknown'
const STATUS_CONFIG: Record<NodeStatus, { color: string; label: string }> = {
online: { color: '#34d399', label: 'Online' },
offline: { color: '#f87171', label: 'Offline' },
degraded: { color: '#fbbf24', label: 'Degraded' },
unknown: { color: '#94a3b8', label: 'Unknown' },
}
const STATUS_OPTIONS = Object.keys(STATUS_CONFIG) as NodeStatus[]
const CONNECTION_TYPE_OPTIONS = ['ethernet', 'fiber', 'wifi', 'vpn', 'vlan', 'wan'] as const
function FieldLabel({ children }: { children: React.ReactNode }) {
return (
<label className="text-[10px] font-semibold uppercase tracking-wider text-muted-foreground">
{children}
</label>
)
}
function FieldInput({ value, onChange, placeholder, mono }: {
value: string
onChange: (val: string) => void
placeholder?: string
mono?: boolean
}) {
return (
<input
type="text"
value={value}
onChange={e => onChange(e.target.value)}
placeholder={placeholder}
className={cn(
'w-full rounded border border-default bg-input px-2 py-1.5 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none',
mono && 'font-mono',
)}
/>
)
}
function SectionDivider({ label }: { label: string }) {
return (
<div className="flex items-center gap-2 pt-1">
<span className="whitespace-nowrap text-[10px] font-semibold uppercase tracking-wider text-muted-foreground">
{label}
</span>
<div className="flex-1 border-t border-default" />
</div>
)
}
export function PropertiesPanel({
selectedNode,
selectedEdge,
onNodeUpdate,
onEdgeUpdate,
onEdgeTypeChange,
onBringToFront,
onSendToBack,
onDeleteNode,
onDeleteEdge,
}: PropertiesPanelProps) {
const [deleteConfirm, setDeleteConfirm] = useState(false)
// Reset confirm state whenever the selection changes
// eslint-disable-next-line react-hooks/set-state-in-effect
useEffect(() => { setDeleteConfirm(false) }, [selectedNode?.id, selectedEdge?.id])
const handlePropertyChange = useCallback((field: keyof DeviceProperties, value: string) => {
if (!selectedNode) return
const nodeData = selectedNode.data as unknown as DeviceNodeData
onNodeUpdate(selectedNode.id, {
properties: { ...nodeData.properties, [field]: value },
} as Partial<DeviceNodeData>)
}, [selectedNode, onNodeUpdate])
const handleLabelChange = useCallback((value: string) => {
if (!selectedNode) return
onNodeUpdate(selectedNode.id, { label: value } as Partial<DeviceNodeData>)
}, [selectedNode, onNodeUpdate])
if (!selectedNode && !selectedEdge) {
return (
<div className="flex h-full w-[260px] flex-col items-center justify-center border-l border-default bg-sidebar px-4">
<p className="text-center text-xs text-muted-foreground">
Select a device or connection to edit its properties
</p>
<p className="mt-1 text-center text-[10px] text-muted-foreground/60">
Hover a device to preview its info
</p>
</div>
)
}
if (selectedEdge) {
const edgeData = (selectedEdge.data || {}) as Record<string, unknown>
const connectionType = (edgeData.connectionType as string) || 'ethernet'
const isCustomType = !CONNECTION_TYPE_OPTIONS.includes(connectionType as typeof CONNECTION_TYPE_OPTIONS[number])
return (
<div className="flex h-full w-[260px] flex-col border-l border-default bg-sidebar">
<div className="border-b border-default px-3 py-2">
<h3 className="text-xs font-semibold text-heading">Connection</h3>
</div>
<div className="flex flex-1 flex-col gap-3 overflow-y-auto p-3">
<div className="flex flex-col gap-1">
<FieldLabel>Label</FieldLabel>
<FieldInput
value={(selectedEdge.label as string) || ''}
onChange={val => onEdgeUpdate(selectedEdge.id, { label: val || null })}
placeholder="Connection label"
/>
</div>
<div className="flex flex-col gap-1">
<FieldLabel>Type</FieldLabel>
<select
value={isCustomType ? '__custom__' : connectionType}
onChange={e => {
const val = e.target.value
if (val !== '__custom__') {
onEdgeUpdate(selectedEdge.id, { connectionType: val })
}
}}
className="w-full rounded border border-default bg-input px-2 py-1.5 text-xs text-primary focus:border-accent focus:outline-none"
>
{CONNECTION_TYPE_OPTIONS.map(opt => (
<option key={opt} value={opt}>{opt.charAt(0).toUpperCase() + opt.slice(1)}</option>
))}
<option value="__custom__">Custom</option>
</select>
{isCustomType && (
<FieldInput
value={connectionType}
onChange={val => onEdgeUpdate(selectedEdge.id, { connectionType: val })}
placeholder="Custom type name"
/>
)}
</div>
<div className="flex flex-col gap-1">
<FieldLabel>Speed</FieldLabel>
<FieldInput
value={(edgeData.speed as string) || ''}
onChange={val => onEdgeUpdate(selectedEdge.id, { speed: val || null })}
placeholder="e.g. 1 Gbps"
/>
</div>
<div className="flex flex-col gap-1">
<FieldLabel>Notes</FieldLabel>
<FieldInput
value={(edgeData.notes as string) || ''}
onChange={val => onEdgeUpdate(selectedEdge.id, { notes: val || null })}
placeholder="Port info, cable type…"
mono
/>
</div>
<div className="flex flex-col gap-1">
<FieldLabel>Line Style</FieldLabel>
<div className="flex gap-1">
{([
{ value: null, icon: Minus, label: 'Straight' },
{ value: 'curved', icon: Spline, label: 'Curved' },
{ value: 'step', icon: GitBranch, label: 'Step' },
] as const).map(({ value, icon: Icon, label }) => {
const routing = (edgeData.routing as string | null | undefined) ?? null
const active = routing === value
return (
<button
key={label}
title={label}
onClick={() => onEdgeUpdate(selectedEdge.id, { routing: value })}
className={cn(
'flex flex-1 items-center justify-center gap-1 rounded border py-1.5 text-[10px] transition-colors',
active
? 'border-accent bg-accent/10 text-accent'
: 'border-default text-muted-foreground hover:border-hover hover:text-primary',
)}
>
<Icon size={12} />
{label}
</button>
)
})}
</div>
</div>
<div className="flex items-center justify-between">
<FieldLabel>Show Traffic</FieldLabel>
<button
onClick={() => {
const newType = selectedEdge.type === 'animated' ? 'connection' : 'animated'
onEdgeTypeChange(selectedEdge.id, newType)
}}
className={cn(
'relative h-5 w-9 rounded-full transition-colors',
selectedEdge.type === 'animated' ? 'bg-accent' : 'bg-elevated',
)}
>
<span
className={cn(
'absolute top-0.5 left-0.5 h-4 w-4 rounded-full bg-white transition-transform',
selectedEdge.type === 'animated' && 'translate-x-4',
)}
/>
</button>
</div>
</div>
<div className="border-t border-default p-3">
{deleteConfirm ? (
<div className="flex flex-col gap-1.5">
<p className="text-center text-[10px] text-muted-foreground">Delete this connection?</p>
<div className="flex gap-1.5">
<button
onClick={() => setDeleteConfirm(false)}
className="flex-1 rounded border border-default px-2 py-1.5 text-xs text-primary hover:bg-elevated"
>
Cancel
</button>
<button
onClick={() => onDeleteEdge(selectedEdge.id)}
className="flex-1 rounded bg-red-500/20 px-2 py-1.5 text-xs font-medium text-red-400 hover:bg-red-500/30"
>
Delete
</button>
</div>
</div>
) : (
<button
onClick={() => setDeleteConfirm(true)}
className="flex w-full items-center justify-center gap-1.5 rounded border border-red-500/30 px-2 py-1.5 text-xs text-red-400 hover:bg-red-500/10"
>
<Trash2 size={12} />
Delete Connection
</button>
)}
</div>
</div>
)
}
const nodeData = selectedNode!.data as unknown as DeviceNodeData
const props = nodeData.properties || {} as DeviceProperties
const currentStatus = (props.status || 'unknown') as NodeStatus
return (
<div className="flex h-full w-[260px] flex-col border-l border-default bg-sidebar">
<div className="border-b border-default px-3 py-2">
<h3 className="text-xs font-semibold text-heading">Device Properties</h3>
</div>
<div className="flex flex-1 flex-col gap-3 overflow-y-auto p-3">
{/* Identity */}
<div className="flex flex-col gap-1">
<FieldLabel>Name</FieldLabel>
<FieldInput value={nodeData.label} onChange={handleLabelChange} placeholder="Device name" />
</div>
{/* Layering */}
<div className="flex flex-col gap-1">
<FieldLabel>Layer</FieldLabel>
<div className="flex gap-1.5">
<button
onClick={() => onBringToFront(selectedNode!.id)}
title="Bring to Front ]"
className="flex flex-1 items-center justify-center gap-1.5 rounded border border-default px-2 py-1.5 text-[10px] text-muted-foreground hover:border-hover hover:text-primary"
>
<BringToFront size={12} />
Bring Front
</button>
<button
onClick={() => onSendToBack(selectedNode!.id)}
title="Send to Back ["
className="flex flex-1 items-center justify-center gap-1.5 rounded border border-default px-2 py-1.5 text-[10px] text-muted-foreground hover:border-hover hover:text-primary"
>
<SendToBack size={12} />
Send Back
</button>
</div>
</div>
{/* Status badge grid */}
<div className="flex flex-col gap-1.5">
<FieldLabel>Status</FieldLabel>
<div className="grid grid-cols-2 gap-1">
{STATUS_OPTIONS.map(opt => {
const { color, label } = STATUS_CONFIG[opt]
const active = currentStatus === opt
return (
<button
key={opt}
onClick={() => handlePropertyChange('status', opt)}
className={cn(
'flex items-center justify-center gap-1.5 rounded border py-1.5 text-[10px] font-medium transition-colors',
active
? 'border-transparent text-white'
: 'border-default text-muted-foreground hover:border-hover hover:text-primary',
)}
style={active ? { backgroundColor: color } : undefined}
>
<span
className="h-1.5 w-1.5 shrink-0 rounded-full"
style={{ backgroundColor: active ? 'rgba(255,255,255,0.8)' : color }}
/>
{label}
</button>
)
})}
</div>
</div>
{/* Network section */}
<SectionDivider label="Network" />
<div className="flex flex-col gap-2">
<div className="flex flex-col gap-1">
<FieldLabel>IP Address</FieldLabel>
<FieldInput value={props.ip || ''} onChange={v => handlePropertyChange('ip', v)} placeholder="e.g. 10.0.0.1" mono />
</div>
<div className="flex flex-col gap-1">
<FieldLabel>Subnet</FieldLabel>
<FieldInput value={props.subnet || ''} onChange={v => handlePropertyChange('subnet', v)} placeholder="e.g. 10.0.0.0/24" mono />
</div>
<div className="flex flex-col gap-1">
<FieldLabel>VLAN</FieldLabel>
<FieldInput value={props.vlan || ''} onChange={v => handlePropertyChange('vlan', v)} placeholder="e.g. 10" mono />
</div>
</div>
{/* Hardware section */}
<SectionDivider label="Hardware" />
<div className="flex flex-col gap-2">
<div className="flex flex-col gap-1">
<FieldLabel>Hostname</FieldLabel>
<FieldInput value={props.hostname || ''} onChange={v => handlePropertyChange('hostname', v)} placeholder="e.g. core-rtr-01" mono />
</div>
<div className="flex flex-col gap-1">
<FieldLabel>Vendor</FieldLabel>
<FieldInput value={props.vendor || ''} onChange={v => handlePropertyChange('vendor', v)} placeholder="e.g. Cisco" />
</div>
<div className="flex flex-col gap-1">
<FieldLabel>Model</FieldLabel>
<FieldInput value={props.model || ''} onChange={v => handlePropertyChange('model', v)} placeholder="e.g. ISR 4331" />
</div>
<div className="flex flex-col gap-1">
<FieldLabel>Role</FieldLabel>
<FieldInput value={props.role || ''} onChange={v => handlePropertyChange('role', v)} placeholder="e.g. Core gateway" />
</div>
</div>
{/* Notes */}
<SectionDivider label="Notes" />
<div className="flex flex-col gap-1">
<textarea
value={props.notes || ''}
onChange={e => handlePropertyChange('notes', e.target.value)}
placeholder="Additional notes…"
rows={3}
className="w-full resize-none rounded border border-default bg-input px-2 py-1.5 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none"
/>
</div>
</div>
<div className="border-t border-default p-3">
{deleteConfirm ? (
<div className="flex flex-col gap-1.5">
<p className="text-center text-[10px] text-muted-foreground">Delete this device?</p>
<div className="flex gap-1.5">
<button
onClick={() => setDeleteConfirm(false)}
className="flex-1 rounded border border-default px-2 py-1.5 text-xs text-primary hover:bg-elevated"
>
Cancel
</button>
<button
onClick={() => onDeleteNode(selectedNode!.id)}
className="flex-1 rounded bg-red-500/20 px-2 py-1.5 text-xs font-medium text-red-400 hover:bg-red-500/30"
>
Delete
</button>
</div>
</div>
) : (
<button
onClick={() => setDeleteConfirm(true)}
className="flex w-full items-center justify-center gap-1.5 rounded border border-red-500/30 px-2 py-1.5 text-xs text-red-400 hover:bg-red-500/10"
>
<Trash2 size={12} />
Delete Device
</button>
)}
</div>
</div>
)
}

View File

@@ -0,0 +1,131 @@
import { memo } from 'react'
import {
BaseEdge,
getSmoothStepPath,
getStraightPath,
getBezierPath,
type EdgeProps,
} from '@xyflow/react'
interface AnimatedEdgeData {
connectionType?: string
duration?: number
direction?: 'forward' | 'reverse' | 'alternate' | 'alternate-reverse'
path?: 'bezier' | 'smoothstep' | 'step' | 'straight'
repeat?: number | 'indefinite'
shape?: 'circle' | 'package'
speed?: string | null
notes?: string | null
[key: string]: unknown
}
const CONNECTION_COLORS: Record<string, string> = {
ethernet: '#60a5fa',
fiber: '#34d399',
wifi: '#a78bfa',
vpn: '#eab308',
vlan: '#848b9b',
wan: '#f87171',
}
const DEFAULT_COLOR = '#848b9b'
function getPath(
props: EdgeProps,
pathType: string,
): [string, number, number] {
const params = {
sourceX: props.sourceX,
sourceY: props.sourceY,
sourcePosition: props.sourcePosition,
targetX: props.targetX,
targetY: props.targetY,
targetPosition: props.targetPosition,
}
switch (pathType) {
case 'bezier': {
const [path, labelX, labelY] = getBezierPath(params)
return [path, labelX, labelY]
}
case 'straight': {
const [path, labelX, labelY] = getStraightPath(params)
return [path, labelX, labelY]
}
default: {
const [path, labelX, labelY] = getSmoothStepPath(params)
return [path, labelX, labelY]
}
}
}
function getAnimateMotionProps(data: AnimatedEdgeData) {
const duration = data.duration ?? 2
const direction = data.direction ?? 'forward'
const repeat = data.repeat ?? 'indefinite'
const keyPoints: Record<string, string> = {
forward: '0;1',
reverse: '1;0',
alternate: '0;1',
'alternate-reverse': '1;0',
}
return {
dur: `${duration}s`,
repeatCount: String(repeat),
keyPoints: keyPoints[direction] || '0;1',
keyTimes: '0;1',
}
}
function AnimatedSvgEdgeComponent(props: EdgeProps) {
const data = (props.data || {}) as AnimatedEdgeData
const connectionType = data.connectionType || 'ethernet'
const color = CONNECTION_COLORS[connectionType] || DEFAULT_COLOR
const pathType = data.path ?? 'smoothstep'
const shape = data.shape ?? 'circle'
const [edgePath] = getPath(props, pathType)
const motionProps = getAnimateMotionProps(data)
return (
<>
<BaseEdge
path={edgePath}
style={{
stroke: color,
strokeWidth: props.selected ? 3 : 2,
...(connectionType === 'wifi' || connectionType === 'wan' || connectionType === 'vpn'
? { strokeDasharray: connectionType === 'wifi' ? '3,3' : '8,4' }
: {}),
}}
/>
<circle r={0} fill={color}>
<animateMotion
path={edgePath}
calcMode="linear"
{...motionProps}
/>
<animate
attributeName="r"
values="0;3;3;3;0"
keyTimes="0;0.05;0.5;0.95;1"
dur={motionProps.dur}
repeatCount={motionProps.repeatCount}
/>
</circle>
{shape === 'package' && (
<rect x={-4} y={-4} width={8} height={8} rx={2} fill={color} opacity={0.8}>
<animateMotion
path={edgePath}
calcMode="linear"
{...motionProps}
/>
</rect>
)}
</>
)
}
export const AnimatedSvgEdge = memo(AnimatedSvgEdgeComponent)

View File

@@ -0,0 +1,20 @@
import type { ComponentProps } from 'react'
import { Handle, type HandleProps } from '@xyflow/react'
import { cn } from '@/lib/utils'
export type BaseHandleProps = HandleProps
export function BaseHandle({ className, children, ...props }: ComponentProps<typeof Handle>) {
return (
<Handle
{...props}
className={cn(
'h-[10px] w-[10px] rounded-full border border-default bg-elevated transition-opacity',
'opacity-0 group-hover:opacity-100',
className,
)}
>
{children}
</Handle>
)
}

View File

@@ -0,0 +1,56 @@
import type { ComponentProps } from 'react'
import { cn } from '@/lib/utils'
export function BaseNode({ className, ...props }: ComponentProps<'div'>) {
return (
<div
className={cn(
'bg-card text-heading relative rounded-lg border border-default',
'transition-colors hover:border-hover',
'in-[.selected]:border-accent',
className,
)}
tabIndex={0}
{...props}
/>
)
}
export function BaseNodeHeader({ className, ...props }: ComponentProps<'header'>) {
return (
<header
{...props}
className={cn('flex flex-row items-center gap-2 px-3 py-2', className)}
/>
)
}
export function BaseNodeHeaderTitle({ className, ...props }: ComponentProps<'h3'>) {
return (
<h3
data-slot="base-node-title"
className={cn('select-none flex-1 text-xs font-semibold text-heading', className)}
{...props}
/>
)
}
export function BaseNodeContent({ className, ...props }: ComponentProps<'div'>) {
return (
<div
data-slot="base-node-content"
className={cn('flex flex-col gap-y-1 px-3 pb-2', className)}
{...props}
/>
)
}
export function BaseNodeFooter({ className, ...props }: ComponentProps<'div'>) {
return (
<div
data-slot="base-node-footer"
className={cn('flex flex-col items-center gap-y-1 border-t border-default px-3 pt-1.5 pb-2', className)}
{...props}
/>
)
}

View File

@@ -0,0 +1,68 @@
import type { ReactNode, ComponentProps } from 'react'
import { Panel, NodeResizer, type NodeProps, type PanelPosition } from '@xyflow/react'
import { BaseNode } from './base-node'
import { cn } from '@/lib/utils'
export type GroupNodeLabelProps = ComponentProps<'div'>
export function GroupNodeLabel({ children, className, ...props }: GroupNodeLabelProps) {
return (
<div className="h-full w-full" {...props}>
<div className={cn('bg-card text-muted-foreground w-fit p-2 text-[10px] font-semibold uppercase tracking-wider', className)}>
{children}
</div>
</div>
)
}
export interface GroupNodeData {
label?: string
groupType?: string
[key: string]: unknown
}
export type GroupNodeProps = Partial<NodeProps> & {
label?: ReactNode
position?: PanelPosition
}
function getLabelClassName(position?: PanelPosition): string {
switch (position) {
case 'top-left': return 'rounded-br-sm'
case 'top-center': return 'rounded-b-sm'
case 'top-right': return 'rounded-bl-sm'
case 'bottom-left': return 'rounded-tr-sm'
case 'bottom-right': return 'rounded-tl-sm'
case 'bottom-center': return 'rounded-t-sm'
default: return 'rounded-br-sm'
}
}
export function GroupNode({ data, selected }: NodeProps) {
const nodeData = data as unknown as GroupNodeData
const label = nodeData.label || 'Group'
return (
<>
<NodeResizer
isVisible={selected}
minWidth={150}
minHeight={100}
lineStyle={{ borderColor: 'var(--color-accent)', borderWidth: 1 }}
handleStyle={{ width: 8, height: 8, borderColor: 'var(--color-accent)', background: 'var(--color-card)' }}
/>
<BaseNode
className={cn(
'h-full w-full min-h-[100px] min-w-[150px] overflow-hidden rounded-lg bg-elevated/30 border-default/50',
selected && 'border-accent',
)}
>
<Panel className="m-0 p-0" position="top-left">
<GroupNodeLabel className={getLabelClassName('top-left')}>
{label}
</GroupNodeLabel>
</Panel>
</BaseNode>
</>
)
}

View File

@@ -0,0 +1,39 @@
import type { ComponentProps } from 'react'
import { type HandleProps, Position } from '@xyflow/react'
import { cn } from '@/lib/utils'
import { BaseHandle } from './base-handle'
const flexDirections: Record<string, string> = {
[Position.Top]: 'flex-col',
[Position.Right]: 'flex-row-reverse justify-end',
[Position.Bottom]: 'flex-col-reverse justify-end',
[Position.Left]: 'flex-row',
}
export function LabeledHandle({
className,
labelClassName,
handleClassName,
title,
position,
...props
}: HandleProps &
ComponentProps<'div'> & {
title: string
handleClassName?: string
labelClassName?: string
}) {
const { ref, ...handleProps } = props
return (
<div
title={title}
className={cn('relative flex items-center', flexDirections[position], className)}
ref={ref}
>
<BaseHandle position={position} className={handleClassName} {...handleProps} />
<label className={cn('text-muted-foreground text-[10px] font-mono px-1.5', labelClassName)}>
{title}
</label>
</div>
)
}

View File

@@ -0,0 +1,43 @@
import type { ReactNode } from 'react'
import { cn } from '@/lib/utils'
export type NodeStatus = 'online' | 'offline' | 'degraded' | 'unknown'
const STATUS_BORDER_COLORS: Record<NodeStatus, string> = {
online: 'border-emerald-400',
offline: 'border-red-400',
degraded: 'border-yellow-400',
unknown: '',
}
const STATUS_GLOW: Record<NodeStatus, string> = {
online: 'shadow-[0_0_8px_rgba(52,211,153,0.3)]',
offline: 'shadow-[0_0_8px_rgba(248,113,113,0.3)]',
degraded: 'shadow-[0_0_8px_rgba(250,204,21,0.3)]',
unknown: '',
}
interface NodeStatusIndicatorProps {
status?: NodeStatus
children: ReactNode
className?: string
}
export function NodeStatusIndicator({ status = 'unknown', children, className }: NodeStatusIndicatorProps) {
if (status === 'unknown') {
return <>{children}</>
}
return (
<div
className={cn(
'rounded-lg border-2 transition-colors',
STATUS_BORDER_COLORS[status],
STATUS_GLOW[status],
className,
)}
>
{children}
</div>
)
}

View File

@@ -0,0 +1,77 @@
import { createContext, useContext, useState, useCallback, type ReactNode, type ComponentProps } from 'react'
import { NodeToolbar, type NodeToolbarProps } from '@xyflow/react'
import { cn } from '@/lib/utils'
interface NodeTooltipContextValue {
visible: boolean
show: () => void
hide: () => void
}
const NodeTooltipContext = createContext<NodeTooltipContextValue>({
visible: false,
show: () => {},
hide: () => {},
})
export function NodeTooltip({ children, ...props }: ComponentProps<'div'>) {
const [visible, setVisible] = useState(false)
const show = useCallback(() => setVisible(true), [])
const hide = useCallback(() => setVisible(false), [])
return (
<NodeTooltipContext.Provider value={{ visible, show, hide }}>
<div {...props}>{children}</div>
</NodeTooltipContext.Provider>
)
}
export function NodeTooltipTrigger({
children,
onMouseEnter,
onMouseLeave,
...props
}: ComponentProps<'div'>) {
const { show, hide } = useContext(NodeTooltipContext)
return (
<div
onMouseEnter={(e) => {
show()
onMouseEnter?.(e)
}}
onMouseLeave={(e) => {
hide()
onMouseLeave?.(e)
}}
{...props}
>
{children}
</div>
)
}
export function NodeTooltipContent({
className,
position,
children,
...props
}: Omit<NodeToolbarProps, 'children'> & { children: ReactNode }) {
const { visible } = useContext(NodeTooltipContext)
if (!visible) return null
return (
<NodeToolbar
position={position}
className={cn(
'rounded-lg border border-default bg-elevated px-3 py-2',
'pointer-events-none',
className,
)}
{...props}
>
{children}
</NodeToolbar>
)
}

View File

@@ -0,0 +1,660 @@
import { useState, useCallback, useEffect, useRef } from 'react'
import { useParams, useNavigate } from 'react-router-dom'
import {
ReactFlowProvider,
useNodesState,
useEdgesState,
addEdge,
useReactFlow,
getNodesBounds,
getViewportForBounds,
type Node,
type Edge,
type Connection,
} from '@xyflow/react'
import '@xyflow/react/dist/style.css'
import { NetworkCanvas } from '@/components/network/NetworkCanvas'
import { ContextMenu, getNodeMenuActions, getCanvasMenuActions } from '@/components/network/ContextMenu'
import { useCanvasShortcuts } from '@/components/network/hooks/useCanvasShortcuts'
import { DiagramHeader } from '@/components/network/DiagramHeader'
import { DeviceToolbar } from '@/components/network/panels/DeviceToolbar'
import { PropertiesPanel } from '@/components/network/panels/PropertiesPanel'
import { AIAssistPanel } from '@/components/network/panels/AIAssistPanel'
import { CanvasEmptyPrompt } from '@/components/network/CanvasEmptyPrompt'
import { networkDiagramsApi, deviceTypesApi } from '@/api'
import { toast } from '@/lib/toast'
import type { DeviceTypeResponse, DeviceProperties, AIGenerateResponse, DiagramEdge, DiagramNode } from '@/types'
import type { DeviceNodeData } from '@/components/network/nodes/DeviceNode'
type ContextMenuState = {
type: 'node' | 'canvas'
position: { x: number; y: number }
nodeId?: string
} | null
function DiagramEditorInner() {
const { id } = useParams<{ id: string }>()
const navigate = useNavigate()
const { getNodes, fitView, screenToFlowPosition } = useReactFlow()
const [diagramId, setDiagramId] = useState<string | null>(id || null)
const [name, setName] = useState('Untitled Diagram')
const [clientName, setClientName] = useState<string | null>(null)
const [assetName, setAssetName] = useState<string | null>(null)
const [description, setDescription] = useState<string | null>(null)
const [nodes, setNodes, onNodesChange] = useNodesState<Node>([])
const [edges, setEdges, onEdgesChange] = useEdgesState<Edge>([])
const [selectedNodeId, setSelectedNodeId] = useState<string | null>(null)
const [selectedEdgeId, setSelectedEdgeId] = useState<string | null>(null)
const selectedNode = selectedNodeId ? nodes.find(n => n.id === selectedNodeId) ?? null : null
const selectedEdge = selectedEdgeId ? edges.find(e => e.id === selectedEdgeId) ?? null : null
const [isDirty, setIsDirty] = useState(false)
const [isSaving, setIsSaving] = useState(false)
const [lastSavedAt, setLastSavedAt] = useState<Date | null>(null)
const isDirtyRef = useRef(false)
const diagramIdRef = useRef<string | null>(id || null)
const [deviceTypes, setDeviceTypes] = useState<DeviceTypeResponse[]>([])
const [loading, setLoading] = useState(!!id)
const [isDragOver, setIsDragOver] = useState(false)
const canvasRef = useRef<HTMLDivElement | null>(null)
const [contextMenu, setContextMenu] = useState<ContextMenuState>(null)
const [pendingDeleteNodeId, setPendingDeleteNodeId] = useState<string | null>(null)
useEffect(() => { isDirtyRef.current = isDirty }, [isDirty])
useEffect(() => { diagramIdRef.current = diagramId }, [diagramId])
const {
copyNodes,
pasteNodes,
duplicateNodes,
selectAll,
deleteSelected,
hasClipboard,
} = useCanvasShortcuts({
nodes,
edges,
setNodes,
setEdges,
setIsDirty: (v: boolean) => setIsDirty(v),
canvasRef,
})
const handleNodesChange: typeof onNodesChange = useCallback((changes) => {
onNodesChange(changes)
const hasRealChange = changes.some(c => c.type !== 'select')
if (hasRealChange) setIsDirty(true)
}, [onNodesChange])
const handleEdgesChange: typeof onEdgesChange = useCallback((changes) => {
onEdgesChange(changes)
const hasRealChange = changes.some(c => c.type !== 'select')
if (hasRealChange) setIsDirty(true)
}, [onEdgesChange])
const loadDeviceTypes = useCallback(async () => {
try {
const types = await deviceTypesApi.list()
setDeviceTypes(types)
} catch { /* ignore */ }
}, [])
useEffect(() => { loadDeviceTypes() }, [loadDeviceTypes])
useEffect(() => {
if (!id) return
let cancelled = false
;(async () => {
try {
const diagram = await networkDiagramsApi.get(id)
if (cancelled) return
setName(diagram.name)
setClientName(diagram.client_name)
setAssetName(diagram.asset_name)
setDescription(diagram.description)
setNodes(
diagram.nodes.map(n => {
if (n.nodeType === 'group') {
return {
id: n.id,
type: 'group',
position: n.position,
style: n.style || { width: 300, height: 200 },
data: {
label: n.label,
groupType: n.type,
},
}
}
return {
id: n.id,
type: 'device',
position: n.position,
style: n.style || { width: 120, height: 120 },
data: {
label: n.label,
deviceType: n.type,
properties: n.properties,
} satisfies DeviceNodeData,
}
})
)
setEdges(
diagram.edges.map(e => ({
id: e.id,
source: e.source,
target: e.target,
type: 'connection',
label: e.label || undefined,
data: {
connectionType: e.connectionType,
speed: e.speed,
notes: e.notes,
routing: e.routing ?? null,
},
}))
)
setLastSavedAt(new Date(diagram.updated_at))
} catch {
toast.error('Failed to load diagram')
navigate('/network-diagrams')
} finally {
if (!cancelled) setLoading(false)
}
})()
return () => { cancelled = true }
}, [id, navigate, setNodes, setEdges])
const serializeNodes = useCallback((): DiagramNode[] => {
return getNodes().map(n => {
if (n.type === 'group') {
const data = n.data as Record<string, unknown>
const width = typeof n.style?.width === 'number' ? n.style.width : (n.measured?.width ?? 300)
const height = typeof n.style?.height === 'number' ? n.style.height : (n.measured?.height ?? 200)
return {
id: n.id,
type: (data.groupType as string) || 'subnet',
label: (data.label as string) || 'Group',
position: n.position,
properties: {} as DeviceProperties,
nodeType: 'group',
style: { width, height },
}
}
const data = n.data as unknown as DeviceNodeData
const dw = typeof n.style?.width === 'number' ? n.style.width : (n.measured?.width ?? 120)
const dh = typeof n.style?.height === 'number' ? n.style.height : (n.measured?.height ?? 120)
return {
id: n.id,
type: data.deviceType,
label: data.label,
position: n.position,
properties: data.properties,
style: { width: dw, height: dh },
}
})
}, [getNodes])
const serializeEdges = useCallback((): DiagramEdge[] => {
return edges.map(e => {
const d = (e.data as Record<string, unknown>) || {}
return {
id: e.id,
source: e.source,
target: e.target,
label: (e.label as string) || null,
connectionType: d.connectionType as string || 'ethernet',
speed: d.speed as string || null,
notes: d.notes as string || null,
routing: d.routing as string || null,
}
})
}, [edges])
const handleSave = useCallback(async () => {
setIsSaving(true)
try {
const payload = {
name,
client_name: clientName,
asset_name: assetName,
description,
nodes: serializeNodes(),
edges: serializeEdges(),
}
if (diagramIdRef.current) {
await networkDiagramsApi.update(diagramIdRef.current, payload)
} else {
const created = await networkDiagramsApi.create(payload)
setDiagramId(created.id)
navigate(`/network-diagrams/${created.id}`, { replace: true })
}
setIsDirty(false)
setLastSavedAt(new Date())
} catch {
toast.error('Failed to save diagram')
} finally {
setIsSaving(false)
}
}, [name, clientName, assetName, description, serializeNodes, serializeEdges, navigate])
useEffect(() => {
const interval = setInterval(() => {
if (isDirtyRef.current && diagramIdRef.current) {
handleSave()
}
}, 30_000)
return () => clearInterval(interval)
}, [handleSave])
const onConnect = useCallback((connection: Connection) => {
setEdges(eds => addEdge({
...connection,
type: 'connection',
data: { connectionType: 'ethernet' },
}, eds))
setIsDirty(true)
}, [setEdges])
const onDragOver = useCallback((event: React.DragEvent) => {
event.preventDefault()
event.dataTransfer.dropEffect = 'move'
setIsDragOver(true)
}, [])
const onDragLeave = useCallback((event: React.DragEvent) => {
const relatedTarget = event.relatedTarget as HTMLElement | null
if (relatedTarget && (event.currentTarget as HTMLElement).contains(relatedTarget)) return
setIsDragOver(false)
}, [])
const handleNodeContextMenu = useCallback((event: React.MouseEvent, node: Node) => {
event.preventDefault()
const currentNodes = getNodes()
const isInSelection = currentNodes.find(n => n.id === node.id)?.selected
if (!isInSelection) {
setNodes(nds => nds.map(n => ({ ...n, selected: n.id === node.id })))
setSelectedNodeId(node.id)
setSelectedEdgeId(null)
}
setContextMenu({
type: 'node',
position: { x: event.clientX, y: event.clientY },
nodeId: node.id,
})
}, [getNodes, setNodes, setSelectedNodeId, setSelectedEdgeId])
const handlePaneContextMenu = useCallback((event: MouseEvent | React.MouseEvent) => {
event.preventDefault()
setContextMenu({
type: 'canvas',
position: { x: event.clientX, y: event.clientY },
})
}, [])
const closeContextMenu = useCallback(() => {
setContextMenu(null)
}, [])
const onDrop = useCallback((event: React.DragEvent) => {
event.preventDefault()
setIsDragOver(false)
const position = screenToFlowPosition({ x: event.clientX, y: event.clientY })
// Handle device drops
const deviceRaw = event.dataTransfer.getData('application/reactflow-device')
if (deviceRaw) {
const { slug, label, category } = JSON.parse(deviceRaw) as { slug: string; label: string; category: string }
const newNode: Node = {
id: `device-${Date.now()}-${Math.random().toString(36).slice(2, 7)}`,
type: 'device',
position,
style: { width: 120, height: 120 },
data: {
label,
deviceType: slug,
category,
properties: {
hostname: null,
ip: null,
subnet: null,
vendor: null,
model: null,
role: null,
vlan: null,
notes: null,
status: 'unknown',
} satisfies DeviceProperties,
} satisfies DeviceNodeData,
}
setNodes(nds => [...nds, newNode])
setIsDirty(true)
return
}
// Handle group drops
const groupRaw = event.dataTransfer.getData('application/reactflow-group')
if (groupRaw) {
const { slug, label } = JSON.parse(groupRaw) as { slug: string; label: string }
const newNode: Node = {
id: `group-${Date.now()}-${Math.random().toString(36).slice(2, 7)}`,
type: 'group',
position,
style: { width: 300, height: 200 },
data: {
label,
groupType: slug,
},
}
setNodes(nds => [...nds, newNode])
setIsDirty(true)
}
}, [setNodes, screenToFlowPosition])
const handleNodeUpdate = useCallback((nodeId: string, updates: Partial<DeviceNodeData>) => {
setNodes(nds => nds.map(n => {
if (n.id !== nodeId) return n
return { ...n, data: { ...n.data, ...updates } }
}))
setIsDirty(true)
}, [setNodes])
const handleEdgeUpdate = useCallback((edgeId: string, updates: Partial<DiagramEdge>) => {
setEdges(eds => eds.map(e => {
if (e.id !== edgeId) return e
return {
...e,
label: updates.label !== undefined ? (updates.label || undefined) : e.label,
data: {
...(e.data || {}),
...(updates.connectionType !== undefined ? { connectionType: updates.connectionType } : {}),
...(updates.speed !== undefined ? { speed: updates.speed } : {}),
...(updates.notes !== undefined ? { notes: updates.notes } : {}),
...(updates.routing !== undefined ? { routing: updates.routing } : {}),
},
}
}))
setIsDirty(true)
}, [setEdges])
const handleEdgeTypeChange = useCallback((edgeId: string, edgeType: string) => {
setEdges(eds => eds.map(e => {
if (e.id !== edgeId) return e
return { ...e, type: edgeType }
}))
setIsDirty(true)
}, [setEdges])
const handleDeleteNode = useCallback((nodeId: string) => {
setNodes(nds => nds.filter(n => n.id !== nodeId))
setEdges(eds => eds.filter(e => e.source !== nodeId && e.target !== nodeId))
setSelectedNodeId(null)
setIsDirty(true)
}, [setNodes, setEdges])
const handleDeleteEdge = useCallback((edgeId: string) => {
setEdges(eds => eds.filter(e => e.id !== edgeId))
setSelectedEdgeId(null)
setIsDirty(true)
}, [setEdges])
const handleBringToFront = useCallback((nodeId: string) => {
setNodes(nds => {
const maxZ = Math.max(0, ...nds.map(n => n.zIndex ?? 0))
return nds.map(n => n.id === nodeId ? { ...n, zIndex: maxZ + 1 } : n)
})
setIsDirty(true)
}, [setNodes])
const handleSendToBack = useCallback((nodeId: string) => {
setNodes(nds => {
const minZ = Math.min(0, ...nds.map(n => n.zIndex ?? 0))
return nds.map(n => n.id === nodeId ? { ...n, zIndex: minZ - 1 } : n)
})
setIsDirty(true)
}, [setNodes])
const handleAIGenerate = useCallback((result: AIGenerateResponse, mode: 'replace' | 'merge') => {
const newNodes: Node[] = result.nodes.map(n => ({
id: n.id,
type: 'device',
position: n.position,
data: {
label: n.label,
deviceType: n.type,
properties: n.properties,
} satisfies DeviceNodeData,
}))
const newEdges: Edge[] = result.edges.map(e => ({
id: e.id,
source: e.source,
target: e.target,
type: 'connection',
label: e.label || undefined,
data: { connectionType: e.connectionType, speed: e.speed, notes: e.notes },
}))
if (mode === 'replace') {
setNodes(newNodes)
setEdges(newEdges)
} else {
setNodes(nds => [...nds, ...newNodes])
setEdges(eds => [...eds, ...newEdges])
}
if (result.suggestedName && !diagramId) {
setName(result.suggestedName)
toast.success(`Generated: ${result.suggestedName}`)
} else {
toast.success('Diagram generated')
}
if (result.notes) {
toast.info(result.notes)
}
setIsDirty(true)
setTimeout(() => fitView({ padding: 0.2 }), 100)
}, [setNodes, setEdges, diagramId, fitView])
const getExistingBounds = useCallback(() => {
const currentNodes = getNodes()
if (currentNodes.length === 0) return null
let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity
for (const n of currentNodes) {
minX = Math.min(minX, n.position.x)
maxX = Math.max(maxX, n.position.x + 120)
minY = Math.min(minY, n.position.y)
maxY = Math.max(maxY, n.position.y + 80)
}
return { minX, maxX, minY, maxY }
}, [getNodes])
const handleExportPng = useCallback(async () => {
if (nodes.length === 0) {
toast.warning('Add some devices to the diagram before exporting')
return
}
try {
const { toPng } = await import('html-to-image')
const IMAGE_WIDTH = 1920
const IMAGE_HEIGHT = 1080
const bounds = getNodesBounds(nodes)
const viewport = getViewportForBounds(bounds, IMAGE_WIDTH, IMAGE_HEIGHT, 0.5, 2, 0.15)
const flowEl = document.querySelector('.react-flow__viewport') as HTMLElement | null
if (!flowEl) {
toast.error('Could not find canvas to export')
return
}
const dataUrl = await toPng(flowEl, {
backgroundColor: '#16181f',
width: IMAGE_WIDTH,
height: IMAGE_HEIGHT,
style: {
width: `${IMAGE_WIDTH}px`,
height: `${IMAGE_HEIGHT}px`,
transform: `translate(${viewport.x}px, ${viewport.y}px) scale(${viewport.zoom})`,
transformOrigin: 'top left',
},
})
const a = document.createElement('a')
a.download = `${name.replace(/[^a-zA-Z0-9-_ ]/g, '') || 'diagram'}.png`
a.href = dataUrl
a.click()
} catch {
toast.error('PNG export failed — try Print > Save as PDF instead')
}
}, [nodes, name])
const handleExportPdf = useCallback(() => {
window.print()
}, [])
const handleExportJson = useCallback(async () => {
if (!diagramId) return
try {
const data = await networkDiagramsApi.exportJson(diagramId)
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `${name.replace(/[^a-zA-Z0-9-_ ]/g, '')}.json`
a.click()
URL.revokeObjectURL(url)
} catch {
toast.error('Failed to export diagram')
}
}, [diagramId, name])
if (loading) {
return (
<div className="flex h-full items-center justify-center">
<div className="h-6 w-6 animate-spin rounded-full border-2 border-accent border-t-transparent" />
</div>
)
}
return (
<div className="flex h-full flex-col">
<DiagramHeader
name={name}
clientName={clientName}
isDirty={isDirty}
isSaving={isSaving}
lastSavedAt={lastSavedAt}
diagramId={diagramId}
onNameChange={n => { setName(n); setIsDirty(true) }}
onSave={handleSave}
onExportPng={handleExportPng}
onExportPdf={handleExportPdf}
onExportJson={handleExportJson}
/>
<div className="flex flex-1 min-h-0">
<DeviceToolbar deviceTypes={deviceTypes} onDeviceTypesChange={loadDeviceTypes} />
<div className="flex flex-1 flex-col min-h-0">
<div className="relative flex-1 min-h-0" ref={canvasRef}>
<NetworkCanvas
nodes={nodes}
edges={edges}
onNodesChange={handleNodesChange}
onEdgesChange={handleEdgesChange}
onConnect={onConnect}
onNodeSelect={setSelectedNodeId}
onEdgeSelect={setSelectedEdgeId}
onDrop={onDrop}
onDragOver={onDragOver}
onDragLeave={onDragLeave}
isDragOver={isDragOver}
onNodeContextMenu={handleNodeContextMenu}
onPaneContextMenu={handlePaneContextMenu}
onPaneClick={closeContextMenu}
/>
{nodes.length === 0 && !loading && (
<CanvasEmptyPrompt onGenerate={handleAIGenerate} />
)}
</div>
{nodes.length > 0 && (
<AIAssistPanel
onGenerate={handleAIGenerate}
getExistingBounds={getExistingBounds}
hasNodes={nodes.length > 0}
/>
)}
</div>
<PropertiesPanel
selectedNode={selectedNode}
selectedEdge={selectedEdge}
onNodeUpdate={handleNodeUpdate}
onEdgeUpdate={handleEdgeUpdate}
onEdgeTypeChange={handleEdgeTypeChange}
onBringToFront={handleBringToFront}
onSendToBack={handleSendToBack}
onDeleteNode={handleDeleteNode}
onDeleteEdge={handleDeleteEdge}
/>
</div>
{contextMenu && (
<ContextMenu
position={contextMenu.position}
actions={
contextMenu.type === 'node'
? getNodeMenuActions({
onCopy: copyNodes,
onDuplicate: duplicateNodes,
onBringToFront: () => { if (contextMenu.nodeId) handleBringToFront(contextMenu.nodeId) },
onSendToBack: () => { if (contextMenu.nodeId) handleSendToBack(contextMenu.nodeId) },
onDelete: () => {
const nodeId = contextMenu.nodeId
setContextMenu(null)
if (nodeId) setPendingDeleteNodeId(nodeId)
else deleteSelected()
},
})
: getCanvasMenuActions({
onPaste: pasteNodes,
onSelectAll: selectAll,
onFitView: () => fitView({ padding: 0.2 }),
hasClipboard: hasClipboard(),
})
}
onClose={closeContextMenu}
/>
)}
{pendingDeleteNodeId && (
<div className="pointer-events-none absolute inset-x-0 bottom-4 z-50 flex justify-center">
<div className="pointer-events-auto flex items-center gap-3 rounded-lg border border-default bg-card px-4 py-2.5 shadow-lg">
<span className="text-xs text-muted-foreground">Delete this device?</span>
<button
onClick={() => setPendingDeleteNodeId(null)}
className="rounded border border-default px-3 py-1 text-xs text-primary hover:bg-elevated"
>
Cancel
</button>
<button
onClick={() => { handleDeleteNode(pendingDeleteNodeId); setPendingDeleteNodeId(null) }}
className="rounded bg-red-500/20 px-3 py-1 text-xs font-medium text-red-400 hover:bg-red-500/30"
>
Delete
</button>
</div>
</div>
)}
</div>
)
}
export default function DiagramEditor() {
return (
<ReactFlowProvider>
<DiagramEditorInner />
</ReactFlowProvider>
)
}

View File

@@ -0,0 +1,323 @@
import { useState, useEffect, useCallback, useMemo, useRef } from 'react'
import { useNavigate } from 'react-router-dom'
import { Plus, Search, Network, MoreHorizontal, Upload, ChevronDown } from 'lucide-react'
import { cn } from '@/lib/utils'
import { networkDiagramsApi } from '@/api'
import { toast } from '@/lib/toast'
import { CATEGORY_COLORS } from '@/components/network/nodes/deviceRegistry'
import type { NetworkDiagramListItem, DiagramImportData } from '@/types'
const OTHER_COLOR = '#4f5666'
function TopologyBar({ categoryCounts, nodeCount }: { categoryCounts: Record<string, number>; nodeCount: number }) {
if (nodeCount === 0) return null
const sorted = Object.entries(categoryCounts).sort(([, a], [, b]) => b - a)
const tooltipLabel = sorted.map(([cat, count]) => `${count} ${cat}`).join(' · ')
return (
<div className="group/bar relative flex h-2 w-full overflow-hidden rounded-full" title={tooltipLabel}>
{sorted.map(([cat, count]) => (
<div
key={cat}
style={{
width: `${(count / nodeCount) * 100}%`,
backgroundColor: CATEGORY_COLORS[cat] ?? OTHER_COLOR,
}}
/>
))}
</div>
)
}
export default function NetworkDiagramsPage() {
const navigate = useNavigate()
const [diagrams, setDiagrams] = useState<NetworkDiagramListItem[]>([])
const [loading, setLoading] = useState(true)
const [search, setSearch] = useState('')
const [clientFilter, setClientFilter] = useState<string | null>(null)
const [clientOptions, setClientOptions] = useState<string[]>([])
const [clientDropdownOpen, setClientDropdownOpen] = useState(false)
const [clientSearch, setClientSearch] = useState('')
const [menuOpenId, setMenuOpenId] = useState<string | null>(null)
const [confirmArchiveId, setConfirmArchiveId] = useState<string | null>(null)
const clientDropdownRef = useRef<HTMLDivElement>(null)
useEffect(() => {
if (!clientDropdownOpen) return
const handleClick = (e: MouseEvent) => {
if (clientDropdownRef.current && !clientDropdownRef.current.contains(e.target as Node)) {
setClientDropdownOpen(false)
}
}
document.addEventListener('mousedown', handleClick)
return () => document.removeEventListener('mousedown', handleClick)
}, [clientDropdownOpen])
const loadDiagrams = useCallback(async () => {
try {
const params: Record<string, string> = {}
if (clientFilter) params.client_name = clientFilter
if (search) params.search = search
const data = await networkDiagramsApi.list(params)
setDiagrams(data)
} catch {
toast.error('Failed to load diagrams')
} finally {
setLoading(false)
}
}, [clientFilter, search])
const loadClients = useCallback(async () => {
try {
const clients = await networkDiagramsApi.listClients()
setClientOptions(clients)
} catch { /* ignore */ }
}, [])
useEffect(() => { loadDiagrams() }, [loadDiagrams])
useEffect(() => { loadClients() }, [loadClients])
const filteredClients = useMemo(() => {
if (!clientSearch) return clientOptions
const lower = clientSearch.toLowerCase()
return clientOptions.filter(c => c.toLowerCase().includes(lower))
}, [clientOptions, clientSearch])
const handleDuplicate = useCallback(async (id: string) => {
try {
const dup = await networkDiagramsApi.duplicate(id)
toast.success(`Created: ${dup.name}`)
loadDiagrams()
} catch {
toast.error('Failed to duplicate')
}
setMenuOpenId(null)
}, [loadDiagrams])
const handleArchive = useCallback(async (id: string) => {
try {
await networkDiagramsApi.archive(id)
toast.success('Diagram archived')
loadDiagrams()
} catch {
toast.error('Failed to archive')
}
setMenuOpenId(null)
setConfirmArchiveId(null)
}, [loadDiagrams])
const handleImport = useCallback(async () => {
const input = document.createElement('input')
input.type = 'file'
input.accept = '.json'
input.onchange = async (e) => {
const file = (e.target as HTMLInputElement).files?.[0]
if (!file) return
try {
const text = await file.text()
const data = JSON.parse(text) as DiagramImportData
const result = await networkDiagramsApi.importJson(data)
if (result.warnings.length > 0) {
toast.warning(`Imported with warnings: ${result.warnings.join(', ')}`)
} else {
toast.success('Diagram imported')
}
navigate(`/network-diagrams/${result.diagram.id}`)
} catch {
toast.error('Failed to import — check that the file is a valid diagram JSON')
}
}
input.click()
}, [navigate])
const formatDate = (dateStr: string) => {
return new Date(dateStr).toLocaleDateString('en-US', { month: 'short', day: 'numeric', year: 'numeric' })
}
return (
<div className="mx-auto max-w-6xl px-6 py-8">
<div className="mb-6 flex items-start justify-between">
<div>
<h1 className="font-heading text-2xl font-bold text-heading">Network Maps</h1>
<p className="mt-1 text-sm text-muted-foreground">Visual network topology documentation for your clients</p>
</div>
<div className="flex gap-2">
<button
onClick={handleImport}
className="flex items-center gap-1.5 rounded border border-default px-3 py-2 text-sm text-primary hover:border-hover"
>
<Upload size={14} />
Import
</button>
<button
onClick={() => navigate('/network-diagrams/new')}
className="flex items-center gap-1.5 rounded bg-accent px-4 py-2 text-sm font-medium text-white hover:bg-accent/90"
>
<Plus size={14} />
New Diagram
</button>
</div>
</div>
<div className="mb-6 flex gap-3">
<div className="relative flex-1">
<Search size={14} className="absolute left-3 top-1/2 -translate-y-1/2 text-muted-foreground" />
<input
type="text"
placeholder="Search diagrams..."
value={search}
onChange={e => setSearch(e.target.value)}
className="w-full rounded border border-default bg-input pl-9 pr-3 py-2 text-sm text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none"
/>
</div>
<div className="relative w-48" ref={clientDropdownRef}>
<button
onClick={() => setClientDropdownOpen(prev => !prev)}
className="flex w-full items-center justify-between rounded border border-default bg-input px-3 py-2 text-sm text-primary"
>
<span className={clientFilter ? 'text-primary' : 'text-muted-foreground'}>
{clientFilter || 'All clients'}
</span>
<ChevronDown size={14} className="shrink-0 text-muted-foreground" />
</button>
{clientDropdownOpen && (
<div className="absolute left-0 top-full z-50 mt-1 w-full rounded border border-default bg-card shadow-lg">
<div className="p-2">
<input
type="text"
placeholder="Search clients..."
value={clientSearch}
onChange={e => setClientSearch(e.target.value)}
className="w-full rounded border border-default bg-input px-2 py-1 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none"
autoFocus
/>
</div>
<div className="max-h-48 overflow-y-auto">
<button
onClick={() => { setClientFilter(null); setClientDropdownOpen(false); setClientSearch('') }}
className={cn('w-full px-3 py-1.5 text-left text-xs hover:bg-elevated', !clientFilter && 'text-accent')}
>
All clients
</button>
{filteredClients.map(c => (
<button
key={c}
onClick={() => { setClientFilter(c); setClientDropdownOpen(false); setClientSearch('') }}
className={cn('w-full px-3 py-1.5 text-left text-xs text-primary hover:bg-elevated', clientFilter === c && 'text-accent')}
>
{c}
</button>
))}
</div>
</div>
)}
</div>
</div>
{loading && (
<div className="grid grid-cols-1 gap-4 md:grid-cols-2 lg:grid-cols-3">
{[1, 2, 3].map(i => (
<div key={i} className="h-40 animate-pulse rounded-lg border border-default bg-card" />
))}
</div>
)}
{!loading && diagrams.length === 0 && (
<div className="flex flex-col items-center justify-center py-20">
<Network size={48} className="mb-4 text-muted-foreground" />
<h2 className="font-heading text-lg font-semibold text-heading">No network maps yet</h2>
<p className="mt-1 text-sm text-muted-foreground">Create your first network diagram to get started</p>
<button
onClick={() => navigate('/network-diagrams/new')}
className="mt-4 rounded bg-accent px-4 py-2 text-sm font-medium text-white hover:bg-accent/90"
>
Create First Diagram
</button>
</div>
)}
{!loading && diagrams.length > 0 && (
<div className="grid grid-cols-1 gap-4 md:grid-cols-2 lg:grid-cols-3">
{diagrams.map(d => (
<div
key={d.id}
onClick={() => navigate(`/network-diagrams/${d.id}`)}
className="group relative cursor-pointer rounded-lg border border-default bg-card p-4 hover:border-hover"
>
<div className="mb-2 flex items-start justify-between">
<h3 className="font-heading text-sm font-semibold text-heading">{d.name}</h3>
<button
onClick={e => { e.stopPropagation(); setMenuOpenId(menuOpenId === d.id ? null : d.id) }}
className="rounded p-1 text-muted-foreground opacity-0 hover:bg-elevated group-hover:opacity-100"
>
<MoreHorizontal size={14} />
</button>
</div>
{d.client_name && (
<span className="mb-2 inline-block rounded-full bg-elevated px-2 py-0.5 text-[10px] text-muted-foreground">
{d.client_name}
</span>
)}
{d.description && (
<p className="mb-3 line-clamp-2 text-xs text-muted-foreground">{d.description}</p>
)}
{d.node_count > 0 && (
<div className="mb-2">
<TopologyBar categoryCounts={d.category_counts} nodeCount={d.node_count} />
</div>
)}
<div className="flex items-center justify-between text-[10px] text-muted-foreground">
<span>{d.node_count} device{d.node_count !== 1 ? 's' : ''}</span>
<span>{formatDate(d.created_at)}</span>
</div>
{menuOpenId === d.id && (
<div className="absolute right-2 top-10 z-50 w-44 rounded border border-default bg-card py-1 shadow-lg">
{confirmArchiveId === d.id ? (
<>
<p className="px-3 py-1.5 text-[10px] text-muted-foreground">Archive this diagram?</p>
<div className="flex gap-1 px-2 pb-1.5">
<button
onClick={e => { e.stopPropagation(); setConfirmArchiveId(null) }}
className="flex-1 rounded border border-default px-2 py-1 text-[10px] text-primary hover:bg-elevated"
>
Cancel
</button>
<button
onClick={e => { e.stopPropagation(); handleArchive(d.id) }}
className="flex-1 rounded bg-red-500/20 px-2 py-1 text-[10px] font-medium text-red-400 hover:bg-red-500/30"
>
Archive
</button>
</div>
</>
) : (
<>
<button
onClick={e => { e.stopPropagation(); navigate(`/network-diagrams/${d.id}`) }}
className="w-full px-3 py-1.5 text-left text-xs text-primary hover:bg-elevated"
>
Open
</button>
<button
onClick={e => { e.stopPropagation(); handleDuplicate(d.id) }}
className="w-full px-3 py-1.5 text-left text-xs text-primary hover:bg-elevated"
>
Duplicate
</button>
<button
onClick={e => { e.stopPropagation(); setConfirmArchiveId(d.id) }}
className="w-full px-3 py-1.5 text-left text-xs text-red-400 hover:bg-elevated"
>
Archive
</button>
</>
)}
</div>
)}
</div>
))}
</div>
)}
</div>
)
}

View File

@@ -60,6 +60,8 @@ const DevBranchingPage = lazyWithRetry(() => import('@/pages/DevBranchingPage'))
const GuidesHubPage = lazyWithRetry(() => import('@/pages/GuidesHubPage'))
const GuideDetailPage = lazyWithRetry(() => import('@/pages/GuideDetailPage'))
const AccountSettingsPage = lazyWithRetry(() => import('@/pages/AccountSettingsPage'))
const NetworkDiagramsPage = lazyWithRetry(() => import('@/pages/NetworkDiagrams'))
const DiagramEditorPage = lazyWithRetry(() => import('@/pages/NetworkDiagrams/DiagramEditor'))
// Admin pages
const AdminLayout = lazyWithRetry(() => import('@/components/admin/AdminLayout'))
const AdminDashboardPage = lazyWithRetry(() => import('@/pages/admin/DashboardPage'))
@@ -195,6 +197,9 @@ export const router = sentryCreateBrowserRouter([
{ path: 'scripts', element: page(ScriptLibraryPage) },
{ path: 'scripts/manage', element: page(ScriptManagePage) },
{ path: 'script-builder', element: page(ScriptBuilderPage) },
{ path: 'network-diagrams', element: page(NetworkDiagramsPage) },
{ path: 'network-diagrams/new', element: page(DiagramEditorPage) },
{ path: 'network-diagrams/:id', element: page(DiagramEditorPage) },
{ path: 'kb-accelerator', element: page(KBAcceleratorPage) },
{ path: 'assistant', element: page(AssistantChatPage) },
{ path: 'assistant/:sessionId', element: page(AssistantChatPage) },

View File

@@ -98,3 +98,4 @@ export * from './script-builder'
export * from './integrations'
export * from './notification'
export type * from './public-templates'
export * from './network-diagram'

View File

@@ -0,0 +1,134 @@
export interface DeviceProperties {
hostname: string | null
ip: string | null
subnet: string | null
vendor: string | null
model: string | null
role: string | null
vlan: string | null
notes: string | null
status: 'unknown' | 'online' | 'offline' | 'degraded'
}
export interface DiagramNode {
id: string
type: string
label: string
position: { x: number; y: number }
properties: DeviceProperties
nodeType?: string
style?: { width?: number; height?: number } | null
}
export interface DiagramEdge {
id: string
source: string
target: string
label: string | null
connectionType: string
speed: string | null
notes: string | null
routing?: string | null
}
export interface DeviceTypeResponse {
id: string
slug: string
label: string
category: string
is_system: boolean
account_id: string
sort_order: number
created_at: string
}
export interface DeviceTypeCreate {
slug: string
label: string
category: string
sort_order?: number
}
export interface NetworkDiagramResponse {
id: string
account_id: string
name: string
client_name: string | null
asset_name: string | null
description: string | null
nodes: DiagramNode[]
edges: DiagramEdge[]
thumbnail_url: string | null
is_archived: boolean
created_by: string | null
created_at: string
updated_at: string
}
export interface NetworkDiagramListItem {
id: string
name: string
client_name: string | null
description: string | null
node_count: number
category_counts: Record<string, number>
created_by: string | null
created_at: string
updated_at: string
}
export interface NetworkDiagramCreate {
name: string
client_name?: string | null
asset_name?: string | null
description?: string | null
nodes?: DiagramNode[]
edges?: DiagramEdge[]
}
export interface NetworkDiagramUpdate {
name?: string
client_name?: string | null
asset_name?: string | null
description?: string | null
nodes?: DiagramNode[]
edges?: DiagramEdge[]
}
export interface AIGenerateRequest {
description: string
client_name?: string | null
mode: 'replace' | 'merge'
existingBounds?: { minX: number; maxX: number; minY: number; maxY: number } | null
}
export interface AIGenerateResponse {
nodes: DiagramNode[]
edges: DiagramEdge[]
suggestedName: string | null
notes: string | null
}
export interface DiagramImportData {
schemaVersion: number
name: string
client_name?: string | null
description?: string | null
nodes: DiagramNode[]
edges: DiagramEdge[]
}
export interface DiagramImportResponse {
diagram: NetworkDiagramResponse
warnings: string[]
}
export interface DiagramExportResponse {
schemaVersion: number
name: string
client_name: string | null
description: string | null
nodes: DiagramNode[]
edges: DiagramEdge[]
exportedAt: string
}