fix: backend code review fixes for network diagrams

- Replace legacy Optional imports with modern str | None syntax
- Type JSONB columns as Mapped[list[dict[str, Any]]]
- Escape SQL LIKE wildcards (%, _) in diagram search
- Type DiagramNode.position as Position(x, y) Pydantic model
- Wrap AI response parsing in KeyError handler for clean 422 errors
- Remove unused Optional/TYPE_CHECKING imports from schemas/models
- Extract _get_available_slugs helper to DRY duplicate queries

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-04-04 19:11:44 +00:00
parent 2ea56f2563
commit 663a96c8a5
6 changed files with 61 additions and 61 deletions

View File

@@ -62,6 +62,14 @@ def _diagram_to_list_item(diagram: NetworkDiagram) -> NetworkDiagramListItem:
) )
async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]:
stmt = select(DeviceType.slug).where(
or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id)
)
result = await db.execute(stmt)
return {row[0] for row in result.all()}
@router.get("/clients", response_model=list[str]) @router.get("/clients", response_model=list[str])
async def list_client_names( async def list_client_names(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
@@ -102,7 +110,8 @@ async def list_diagrams(
stmt = stmt.where(NetworkDiagram.client_name == client_name) stmt = stmt.where(NetworkDiagram.client_name == client_name)
if search: if search:
search_filter = f"%{search}%" escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
search_filter = f"%{escaped}%"
stmt = stmt.where( stmt = stmt.where(
or_( or_(
NetworkDiagram.name.ilike(search_filter), NetworkDiagram.name.ilike(search_filter),
@@ -232,14 +241,7 @@ async def import_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramImportResponse: ) -> DiagramImportResponse:
available_stmt = select(DeviceType.slug).where( available_slugs = await _get_available_slugs(current_user.team_id, db)
or_(
DeviceType.is_system.is_(True),
DeviceType.team_id == current_user.team_id,
)
)
result = await db.execute(available_stmt)
available_slugs = {row[0] for row in result.all()}
warnings: list[str] = [] warnings: list[str] = []
for node in data.nodes: for node in data.nodes:
@@ -271,14 +273,8 @@ async def ai_generate_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> AIGenerateResponse: ) -> AIGenerateResponse:
stmt = select(DeviceType.slug).where( available_slugs_set = await _get_available_slugs(current_user.team_id, db)
or_( available_slugs = list(available_slugs_set)
DeviceType.is_system.is_(True),
DeviceType.team_id == current_user.team_id,
)
)
result = await db.execute(stmt)
available_slugs = [row[0] for row in result.all()]
existing_node_ids: list[str] | None = None existing_node_ids: list[str] | None = None
if data.mode == "merge" and data.existingBounds: if data.mode == "merge" and data.existingBounds:

View File

@@ -1,7 +1,6 @@
"""Device type model for network diagrams.""" """Device type model for network diagrams."""
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@@ -9,9 +8,6 @@ from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base from app.core.database import Base
if TYPE_CHECKING:
pass
class DeviceType(Base): class DeviceType(Base):
"""A device type for network diagram nodes (system or team-custom).""" """A device type for network diagram nodes (system or team-custom)."""
@@ -36,7 +32,7 @@ class DeviceType(Base):
Boolean, nullable=False, default=False, Boolean, nullable=False, default=False,
comment="True for built-in types that cannot be deleted", comment="True for built-in types that cannot be deleted",
) )
team_id: Mapped[Optional[uuid.UUID]] = mapped_column( team_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("teams.id", ondelete="CASCADE"), ForeignKey("teams.id", ondelete="CASCADE"),
nullable=True, nullable=True,

View File

@@ -1,7 +1,7 @@
"""Network diagram model.""" """Network diagram model."""
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING from typing import Any, TYPE_CHECKING
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -27,16 +27,16 @@ class NetworkDiagram(Base):
index=True, index=True,
) )
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
client_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) client_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
asset_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) description: Mapped[str | None] = mapped_column(Text, nullable=True)
nodes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'") nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
edges: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'") edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
thumbnail_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True) thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True)
is_archived: Mapped[bool] = mapped_column( is_archived: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, Boolean, nullable=False, default=False,
) )
created_by: Mapped[Optional[uuid.UUID]] = mapped_column( created_by: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("users.id"), ForeignKey("users.id"),
nullable=True, nullable=True,
@@ -50,4 +50,4 @@ class NetworkDiagram(Base):
onupdate=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc),
) )
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by]) creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by])

View File

@@ -1,6 +1,5 @@
"""Pydantic schemas for device types.""" """Pydantic schemas for device types."""
from datetime import datetime from datetime import datetime
from typing import Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@@ -1,11 +1,15 @@
"""Pydantic schemas for network diagrams.""" """Pydantic schemas for network diagrams."""
from datetime import datetime from datetime import datetime
from typing import Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class Position(BaseModel):
x: float
y: float
class DeviceProperties(BaseModel): class DeviceProperties(BaseModel):
hostname: str | None = None hostname: str | None = None
ip: str | None = None ip: str | None = None
@@ -22,7 +26,7 @@ class DiagramNode(BaseModel):
id: str id: str
type: str type: str
label: str label: str
position: dict position: Position
properties: DeviceProperties = Field(default_factory=DeviceProperties) properties: DeviceProperties = Field(default_factory=DeviceProperties)

View File

@@ -10,6 +10,7 @@ from app.schemas.network_diagram import (
DiagramNode, DiagramNode,
DiagramEdge, DiagramEdge,
DeviceProperties, DeviceProperties,
Position,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -108,35 +109,39 @@ async def generate_diagram(
logger.error("Failed to parse AI response as JSON: %s", e) logger.error("Failed to parse AI response as JSON: %s", e)
raise ValueError("AI generated an invalid response, please try again") raise ValueError("AI generated an invalid response, please try again")
nodes = [] try:
for raw_node in data.get("nodes", []): nodes = []
node_type = raw_node.get("type", "server") for raw_node in data.get("nodes", []):
if node_type not in available_slugs: node_type = raw_node.get("type", "server")
logger.warning("Unknown device type '%s', falling back to 'server'", node_type) if node_type not in available_slugs:
node_type = "server" logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
node_type = "server"
nodes.append(DiagramNode( nodes.append(DiagramNode(
id=raw_node["id"], id=raw_node["id"],
type=node_type, type=node_type,
label=raw_node.get("label", node_type), label=raw_node.get("label", node_type),
position=raw_node.get("position", {"x": 0, "y": 0}), position=Position(**raw_node.get("position", {"x": 0, "y": 0})),
properties=DeviceProperties(**{ properties=DeviceProperties(**{
k: v for k, v in raw_node.get("properties", {}).items() k: v for k, v in raw_node.get("properties", {}).items()
if k in DeviceProperties.model_fields if k in DeviceProperties.model_fields
}), }),
)) ))
edges = [] edges = []
for raw_edge in data.get("edges", []): for raw_edge in data.get("edges", []):
edges.append(DiagramEdge( edges.append(DiagramEdge(
id=raw_edge["id"], id=raw_edge["id"],
source=raw_edge["source"], source=raw_edge["source"],
target=raw_edge["target"], target=raw_edge["target"],
label=raw_edge.get("label"), label=raw_edge.get("label"),
connectionType=raw_edge.get("connectionType", "ethernet"), connectionType=raw_edge.get("connectionType", "ethernet"),
speed=raw_edge.get("speed"), speed=raw_edge.get("speed"),
notes=raw_edge.get("notes"), notes=raw_edge.get("notes"),
)) ))
except KeyError as e:
logger.warning("AI response missing required field: %s", e)
raise ValueError(f"AI generated incomplete data (missing {e}), please try again")
return AIGenerateResponse( return AIGenerateResponse(
nodes=nodes, nodes=nodes,