fix: backend code review fixes for network diagrams

- Replace legacy Optional imports with modern str | None syntax
- Type JSONB columns as Mapped[list[dict[str, Any]]]
- Escape SQL LIKE wildcards (%, _) in diagram search
- Type DiagramNode.position as Position(x, y) Pydantic model
- Wrap AI response parsing in KeyError handler for clean 422 errors
- Remove unused Optional/TYPE_CHECKING imports from schemas/models
- Extract _get_available_slugs helper to DRY duplicate queries

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-04-04 19:11:44 +00:00
parent c53ced8725
commit e042cf6186
7 changed files with 78 additions and 62 deletions

View File

@@ -2,7 +2,7 @@
> **Purpose:** Quick-reference file showing exactly where the project stands. > **Purpose:** Quick-reference file showing exactly where the project stands.
> **For Claude Code:** Read this first to understand what's done and what's next. > **For Claude Code:** Read this first to understand what's done and what's next.
> **Last Updated:** April 4, 2026 > **Last Updated:** April 4, 2026 (evening)
--- ---
@@ -215,6 +215,22 @@
- **ConcludeSessionModal copy refresh** — Forward-facing action verbs, "Close & Generate" CTA, consistent outcome labels - **ConcludeSessionModal copy refresh** — Forward-facing action verbs, "Close & Generate" CTA, consistent outcome labels
- Deleted unused FlowPilotActionBar component (227 lines dead code) - Deleted unused FlowPilotActionBar component (227 lines dead code)
### Network Diagrams (In Progress)
- Network diagram editor with React Flow (@xyflow/react v12) canvas
- Device node system: 27 device types across 7 categories (network, compute, storage, cloud, endpoint, infrastructure, security)
- Custom device type creation via DeviceToolbar
- Connection edges with 6 types (ethernet, fiber, wifi, vpn, vlan, wan) — color-coded, dashed for wireless/VPN
- Properties panel for editing device and connection details
- AI-assisted diagram generation (describe network → auto-layout)
- Auto-save every 30 seconds, manual save, JSON export
- **React Flow UI Components** — Cherry-picked and Charcoal-restyled: BaseNode (structured header/content/footer slots), BaseHandle (styled connection handles), LabeledHandle (named port labels), NodeStatusIndicator (status border effect: emerald/red/yellow), NodeTooltip (hover details via NodeToolbar), LabeledGroupNode (subnet/VLAN/site/DMZ containers), AnimatedSvgEdge (traffic flow visualization)
- Grouping category in toolbar: Subnet, VLAN, Site, DMZ drag-drop to canvas
- Traffic flow toggle on edges (switches between static and animated)
- Context menu with copy/paste/duplicate/select all shortcuts
- Drop position uses `screenToFlowPosition()` for correct placement at any zoom/pan level
- **Bug fix:** PropertiesPanel inputs now work — selection uses IDs instead of stale object snapshots
### Maintenance Flows (Hidden from UI) ### Maintenance Flows (Hidden from UI)
- Batch session launch, saved target lists - Batch session launch, saved target lists

View File

@@ -62,6 +62,14 @@ def _diagram_to_list_item(diagram: NetworkDiagram) -> NetworkDiagramListItem:
) )
async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]:
stmt = select(DeviceType.slug).where(
or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id)
)
result = await db.execute(stmt)
return {row[0] for row in result.all()}
@router.get("/clients", response_model=list[str]) @router.get("/clients", response_model=list[str])
async def list_client_names( async def list_client_names(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
@@ -102,7 +110,8 @@ async def list_diagrams(
stmt = stmt.where(NetworkDiagram.client_name == client_name) stmt = stmt.where(NetworkDiagram.client_name == client_name)
if search: if search:
search_filter = f"%{search}%" escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
search_filter = f"%{escaped}%"
stmt = stmt.where( stmt = stmt.where(
or_( or_(
NetworkDiagram.name.ilike(search_filter), NetworkDiagram.name.ilike(search_filter),
@@ -232,14 +241,7 @@ async def import_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramImportResponse: ) -> DiagramImportResponse:
available_stmt = select(DeviceType.slug).where( available_slugs = await _get_available_slugs(current_user.team_id, db)
or_(
DeviceType.is_system.is_(True),
DeviceType.team_id == current_user.team_id,
)
)
result = await db.execute(available_stmt)
available_slugs = {row[0] for row in result.all()}
warnings: list[str] = [] warnings: list[str] = []
for node in data.nodes: for node in data.nodes:
@@ -271,14 +273,8 @@ async def ai_generate_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> AIGenerateResponse: ) -> AIGenerateResponse:
stmt = select(DeviceType.slug).where( available_slugs_set = await _get_available_slugs(current_user.team_id, db)
or_( available_slugs = list(available_slugs_set)
DeviceType.is_system.is_(True),
DeviceType.team_id == current_user.team_id,
)
)
result = await db.execute(stmt)
available_slugs = [row[0] for row in result.all()]
existing_node_ids: list[str] | None = None existing_node_ids: list[str] | None = None
if data.mode == "merge" and data.existingBounds: if data.mode == "merge" and data.existingBounds:

View File

@@ -1,7 +1,6 @@
"""Device type model for network diagrams.""" """Device type model for network diagrams."""
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@@ -9,9 +8,6 @@ from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base from app.core.database import Base
if TYPE_CHECKING:
pass
class DeviceType(Base): class DeviceType(Base):
"""A device type for network diagram nodes (system or team-custom).""" """A device type for network diagram nodes (system or team-custom)."""
@@ -36,7 +32,7 @@ class DeviceType(Base):
Boolean, nullable=False, default=False, Boolean, nullable=False, default=False,
comment="True for built-in types that cannot be deleted", comment="True for built-in types that cannot be deleted",
) )
team_id: Mapped[Optional[uuid.UUID]] = mapped_column( team_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("teams.id", ondelete="CASCADE"), ForeignKey("teams.id", ondelete="CASCADE"),
nullable=True, nullable=True,

View File

@@ -1,7 +1,7 @@
"""Network diagram model.""" """Network diagram model."""
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING from typing import Any, TYPE_CHECKING
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -27,16 +27,16 @@ class NetworkDiagram(Base):
index=True, index=True,
) )
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
client_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) client_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
asset_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) description: Mapped[str | None] = mapped_column(Text, nullable=True)
nodes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'") nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
edges: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'") edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
thumbnail_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True) thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True)
is_archived: Mapped[bool] = mapped_column( is_archived: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, Boolean, nullable=False, default=False,
) )
created_by: Mapped[Optional[uuid.UUID]] = mapped_column( created_by: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("users.id"), ForeignKey("users.id"),
nullable=True, nullable=True,
@@ -50,4 +50,4 @@ class NetworkDiagram(Base):
onupdate=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc),
) )
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by]) creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by])

View File

@@ -1,6 +1,5 @@
"""Pydantic schemas for device types.""" """Pydantic schemas for device types."""
from datetime import datetime from datetime import datetime
from typing import Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@@ -1,11 +1,15 @@
"""Pydantic schemas for network diagrams.""" """Pydantic schemas for network diagrams."""
from datetime import datetime from datetime import datetime
from typing import Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class Position(BaseModel):
x: float
y: float
class DeviceProperties(BaseModel): class DeviceProperties(BaseModel):
hostname: str | None = None hostname: str | None = None
ip: str | None = None ip: str | None = None
@@ -22,7 +26,7 @@ class DiagramNode(BaseModel):
id: str id: str
type: str type: str
label: str label: str
position: dict position: Position
properties: DeviceProperties = Field(default_factory=DeviceProperties) properties: DeviceProperties = Field(default_factory=DeviceProperties)

View File

@@ -10,6 +10,7 @@ from app.schemas.network_diagram import (
DiagramNode, DiagramNode,
DiagramEdge, DiagramEdge,
DeviceProperties, DeviceProperties,
Position,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -108,35 +109,39 @@ async def generate_diagram(
logger.error("Failed to parse AI response as JSON: %s", e) logger.error("Failed to parse AI response as JSON: %s", e)
raise ValueError("AI generated an invalid response, please try again") raise ValueError("AI generated an invalid response, please try again")
nodes = [] try:
for raw_node in data.get("nodes", []): nodes = []
node_type = raw_node.get("type", "server") for raw_node in data.get("nodes", []):
if node_type not in available_slugs: node_type = raw_node.get("type", "server")
logger.warning("Unknown device type '%s', falling back to 'server'", node_type) if node_type not in available_slugs:
node_type = "server" logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
node_type = "server"
nodes.append(DiagramNode( nodes.append(DiagramNode(
id=raw_node["id"], id=raw_node["id"],
type=node_type, type=node_type,
label=raw_node.get("label", node_type), label=raw_node.get("label", node_type),
position=raw_node.get("position", {"x": 0, "y": 0}), position=Position(**raw_node.get("position", {"x": 0, "y": 0})),
properties=DeviceProperties(**{ properties=DeviceProperties(**{
k: v for k, v in raw_node.get("properties", {}).items() k: v for k, v in raw_node.get("properties", {}).items()
if k in DeviceProperties.model_fields if k in DeviceProperties.model_fields
}), }),
)) ))
edges = [] edges = []
for raw_edge in data.get("edges", []): for raw_edge in data.get("edges", []):
edges.append(DiagramEdge( edges.append(DiagramEdge(
id=raw_edge["id"], id=raw_edge["id"],
source=raw_edge["source"], source=raw_edge["source"],
target=raw_edge["target"], target=raw_edge["target"],
label=raw_edge.get("label"), label=raw_edge.get("label"),
connectionType=raw_edge.get("connectionType", "ethernet"), connectionType=raw_edge.get("connectionType", "ethernet"),
speed=raw_edge.get("speed"), speed=raw_edge.get("speed"),
notes=raw_edge.get("notes"), notes=raw_edge.get("notes"),
)) ))
except KeyError as e:
logger.warning("AI response missing required field: %s", e)
raise ValueError(f"AI generated incomplete data (missing {e}), please try again")
return AIGenerateResponse( return AIGenerateResponse(
nodes=nodes, nodes=nodes,