fix: backend code review fixes for network diagrams

- Replace legacy Optional imports with modern str | None syntax
- Type JSONB columns as Mapped[list[dict[str, Any]]]
- Escape SQL LIKE wildcards (%, _) in diagram search
- Type DiagramNode.position as Position(x, y) Pydantic model
- Wrap AI response parsing in KeyError handler for clean 422 errors
- Remove unused Optional/TYPE_CHECKING imports from schemas/models
- Extract _get_available_slugs helper to DRY duplicate queries

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-04-04 19:11:44 +00:00
parent 2ea56f2563
commit 663a96c8a5
6 changed files with 61 additions and 61 deletions

View File

@@ -62,6 +62,14 @@ def _diagram_to_list_item(diagram: NetworkDiagram) -> NetworkDiagramListItem:
)
async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]:
stmt = select(DeviceType.slug).where(
or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id)
)
result = await db.execute(stmt)
return {row[0] for row in result.all()}
@router.get("/clients", response_model=list[str])
async def list_client_names(
db: Annotated[AsyncSession, Depends(get_db)],
@@ -102,7 +110,8 @@ async def list_diagrams(
stmt = stmt.where(NetworkDiagram.client_name == client_name)
if search:
search_filter = f"%{search}%"
escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
search_filter = f"%{escaped}%"
stmt = stmt.where(
or_(
NetworkDiagram.name.ilike(search_filter),
@@ -232,14 +241,7 @@ async def import_diagram(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramImportResponse:
available_stmt = select(DeviceType.slug).where(
or_(
DeviceType.is_system.is_(True),
DeviceType.team_id == current_user.team_id,
)
)
result = await db.execute(available_stmt)
available_slugs = {row[0] for row in result.all()}
available_slugs = await _get_available_slugs(current_user.team_id, db)
warnings: list[str] = []
for node in data.nodes:
@@ -271,14 +273,8 @@ async def ai_generate_diagram(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> AIGenerateResponse:
stmt = select(DeviceType.slug).where(
or_(
DeviceType.is_system.is_(True),
DeviceType.team_id == current_user.team_id,
)
)
result = await db.execute(stmt)
available_slugs = [row[0] for row in result.all()]
available_slugs_set = await _get_available_slugs(current_user.team_id, db)
available_slugs = list(available_slugs_set)
existing_node_ids: list[str] | None = None
if data.mode == "merge" and data.existingBounds:

View File

@@ -1,7 +1,6 @@
"""Device type model for network diagrams."""
import uuid
from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
@@ -9,9 +8,6 @@ from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
if TYPE_CHECKING:
pass
class DeviceType(Base):
"""A device type for network diagram nodes (system or team-custom)."""
@@ -36,7 +32,7 @@ class DeviceType(Base):
Boolean, nullable=False, default=False,
comment="True for built-in types that cannot be deleted",
)
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
team_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("teams.id", ondelete="CASCADE"),
nullable=True,

View File

@@ -1,7 +1,7 @@
"""Network diagram model."""
import uuid
from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -27,16 +27,16 @@ class NetworkDiagram(Base):
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
client_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
asset_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
nodes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'")
edges: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'")
thumbnail_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
client_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True)
is_archived: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False,
)
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
created_by: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id"),
nullable=True,
@@ -50,4 +50,4 @@ class NetworkDiagram(Base):
onupdate=lambda: datetime.now(timezone.utc),
)
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by])
creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by])

View File

@@ -1,6 +1,5 @@
"""Pydantic schemas for device types."""
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field

View File

@@ -1,11 +1,15 @@
"""Pydantic schemas for network diagrams."""
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field
class Position(BaseModel):
x: float
y: float
class DeviceProperties(BaseModel):
hostname: str | None = None
ip: str | None = None
@@ -22,7 +26,7 @@ class DiagramNode(BaseModel):
id: str
type: str
label: str
position: dict
position: Position
properties: DeviceProperties = Field(default_factory=DeviceProperties)

View File

@@ -10,6 +10,7 @@ from app.schemas.network_diagram import (
DiagramNode,
DiagramEdge,
DeviceProperties,
Position,
)
logger = logging.getLogger(__name__)
@@ -108,35 +109,39 @@ async def generate_diagram(
logger.error("Failed to parse AI response as JSON: %s", e)
raise ValueError("AI generated an invalid response, please try again")
nodes = []
for raw_node in data.get("nodes", []):
node_type = raw_node.get("type", "server")
if node_type not in available_slugs:
logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
node_type = "server"
try:
nodes = []
for raw_node in data.get("nodes", []):
node_type = raw_node.get("type", "server")
if node_type not in available_slugs:
logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
node_type = "server"
nodes.append(DiagramNode(
id=raw_node["id"],
type=node_type,
label=raw_node.get("label", node_type),
position=raw_node.get("position", {"x": 0, "y": 0}),
properties=DeviceProperties(**{
k: v for k, v in raw_node.get("properties", {}).items()
if k in DeviceProperties.model_fields
}),
))
nodes.append(DiagramNode(
id=raw_node["id"],
type=node_type,
label=raw_node.get("label", node_type),
position=Position(**raw_node.get("position", {"x": 0, "y": 0})),
properties=DeviceProperties(**{
k: v for k, v in raw_node.get("properties", {}).items()
if k in DeviceProperties.model_fields
}),
))
edges = []
for raw_edge in data.get("edges", []):
edges.append(DiagramEdge(
id=raw_edge["id"],
source=raw_edge["source"],
target=raw_edge["target"],
label=raw_edge.get("label"),
connectionType=raw_edge.get("connectionType", "ethernet"),
speed=raw_edge.get("speed"),
notes=raw_edge.get("notes"),
))
edges = []
for raw_edge in data.get("edges", []):
edges.append(DiagramEdge(
id=raw_edge["id"],
source=raw_edge["source"],
target=raw_edge["target"],
label=raw_edge.get("label"),
connectionType=raw_edge.get("connectionType", "ethernet"),
speed=raw_edge.get("speed"),
notes=raw_edge.get("notes"),
))
except KeyError as e:
logger.warning("AI response missing required field: %s", e)
raise ValueError(f"AI generated incomplete data (missing {e}), please try again")
return AIGenerateResponse(
nodes=nodes,