fix: backend code review fixes for network diagrams
- Replace legacy Optional imports with modern str | None syntax - Type JSONB columns as Mapped[list[dict[str, Any]]] - Escape SQL LIKE wildcards (%, _) in diagram search - Type DiagramNode.position as Position(x, y) Pydantic model - Wrap AI response parsing in KeyError handler for clean 422 errors - Remove unused Optional/TYPE_CHECKING imports from schemas/models - Extract _get_available_slugs helper to DRY duplicate queries Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -62,6 +62,14 @@ def _diagram_to_list_item(diagram: NetworkDiagram) -> NetworkDiagramListItem:
|
||||
)
|
||||
|
||||
|
||||
async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]:
|
||||
stmt = select(DeviceType.slug).where(
|
||||
or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return {row[0] for row in result.all()}
|
||||
|
||||
|
||||
@router.get("/clients", response_model=list[str])
|
||||
async def list_client_names(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
@@ -102,7 +110,8 @@ async def list_diagrams(
|
||||
stmt = stmt.where(NetworkDiagram.client_name == client_name)
|
||||
|
||||
if search:
|
||||
search_filter = f"%{search}%"
|
||||
escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
search_filter = f"%{escaped}%"
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
NetworkDiagram.name.ilike(search_filter),
|
||||
@@ -232,14 +241,7 @@ async def import_diagram(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> DiagramImportResponse:
|
||||
available_stmt = select(DeviceType.slug).where(
|
||||
or_(
|
||||
DeviceType.is_system.is_(True),
|
||||
DeviceType.team_id == current_user.team_id,
|
||||
)
|
||||
)
|
||||
result = await db.execute(available_stmt)
|
||||
available_slugs = {row[0] for row in result.all()}
|
||||
available_slugs = await _get_available_slugs(current_user.team_id, db)
|
||||
|
||||
warnings: list[str] = []
|
||||
for node in data.nodes:
|
||||
@@ -271,14 +273,8 @@ async def ai_generate_diagram(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> AIGenerateResponse:
|
||||
stmt = select(DeviceType.slug).where(
|
||||
or_(
|
||||
DeviceType.is_system.is_(True),
|
||||
DeviceType.team_id == current_user.team_id,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
available_slugs = [row[0] for row in result.all()]
|
||||
available_slugs_set = await _get_available_slugs(current_user.team_id, db)
|
||||
available_slugs = list(available_slugs_set)
|
||||
|
||||
existing_node_ids: list[str] | None = None
|
||||
if data.mode == "merge" and data.existingBounds:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Device type model for network diagrams."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
@@ -9,9 +8,6 @@ from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class DeviceType(Base):
|
||||
"""A device type for network diagram nodes (system or team-custom)."""
|
||||
@@ -36,7 +32,7 @@ class DeviceType(Base):
|
||||
Boolean, nullable=False, default=False,
|
||||
comment="True for built-in types that cannot be deleted",
|
||||
)
|
||||
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
team_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("teams.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Network diagram model."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
@@ -27,16 +27,16 @@ class NetworkDiagram(Base):
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
asset_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
nodes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
||||
edges: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
||||
thumbnail_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
client_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
||||
edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
||||
thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
is_archived: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False,
|
||||
)
|
||||
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
@@ -50,4 +50,4 @@ class NetworkDiagram(Base):
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by])
|
||||
creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by])
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Pydantic schemas for device types."""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""Pydantic schemas for network diagrams."""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
x: float
|
||||
y: float
|
||||
|
||||
|
||||
class DeviceProperties(BaseModel):
|
||||
hostname: str | None = None
|
||||
ip: str | None = None
|
||||
@@ -22,7 +26,7 @@ class DiagramNode(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
label: str
|
||||
position: dict
|
||||
position: Position
|
||||
properties: DeviceProperties = Field(default_factory=DeviceProperties)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.schemas.network_diagram import (
|
||||
DiagramNode,
|
||||
DiagramEdge,
|
||||
DeviceProperties,
|
||||
Position,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -108,35 +109,39 @@ async def generate_diagram(
|
||||
logger.error("Failed to parse AI response as JSON: %s", e)
|
||||
raise ValueError("AI generated an invalid response, please try again")
|
||||
|
||||
nodes = []
|
||||
for raw_node in data.get("nodes", []):
|
||||
node_type = raw_node.get("type", "server")
|
||||
if node_type not in available_slugs:
|
||||
logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
|
||||
node_type = "server"
|
||||
try:
|
||||
nodes = []
|
||||
for raw_node in data.get("nodes", []):
|
||||
node_type = raw_node.get("type", "server")
|
||||
if node_type not in available_slugs:
|
||||
logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
|
||||
node_type = "server"
|
||||
|
||||
nodes.append(DiagramNode(
|
||||
id=raw_node["id"],
|
||||
type=node_type,
|
||||
label=raw_node.get("label", node_type),
|
||||
position=raw_node.get("position", {"x": 0, "y": 0}),
|
||||
properties=DeviceProperties(**{
|
||||
k: v for k, v in raw_node.get("properties", {}).items()
|
||||
if k in DeviceProperties.model_fields
|
||||
}),
|
||||
))
|
||||
nodes.append(DiagramNode(
|
||||
id=raw_node["id"],
|
||||
type=node_type,
|
||||
label=raw_node.get("label", node_type),
|
||||
position=Position(**raw_node.get("position", {"x": 0, "y": 0})),
|
||||
properties=DeviceProperties(**{
|
||||
k: v for k, v in raw_node.get("properties", {}).items()
|
||||
if k in DeviceProperties.model_fields
|
||||
}),
|
||||
))
|
||||
|
||||
edges = []
|
||||
for raw_edge in data.get("edges", []):
|
||||
edges.append(DiagramEdge(
|
||||
id=raw_edge["id"],
|
||||
source=raw_edge["source"],
|
||||
target=raw_edge["target"],
|
||||
label=raw_edge.get("label"),
|
||||
connectionType=raw_edge.get("connectionType", "ethernet"),
|
||||
speed=raw_edge.get("speed"),
|
||||
notes=raw_edge.get("notes"),
|
||||
))
|
||||
edges = []
|
||||
for raw_edge in data.get("edges", []):
|
||||
edges.append(DiagramEdge(
|
||||
id=raw_edge["id"],
|
||||
source=raw_edge["source"],
|
||||
target=raw_edge["target"],
|
||||
label=raw_edge.get("label"),
|
||||
connectionType=raw_edge.get("connectionType", "ethernet"),
|
||||
speed=raw_edge.get("speed"),
|
||||
notes=raw_edge.get("notes"),
|
||||
))
|
||||
except KeyError as e:
|
||||
logger.warning("AI response missing required field: %s", e)
|
||||
raise ValueError(f"AI generated incomplete data (missing {e}), please try again")
|
||||
|
||||
return AIGenerateResponse(
|
||||
nodes=nodes,
|
||||
|
||||
Reference in New Issue
Block a user