fix: backend code review fixes for network diagrams
- Replace legacy Optional imports with modern str | None syntax - Type JSONB columns as Mapped[list[dict[str, Any]]] - Escape SQL LIKE wildcards (%, _) in diagram search - Type DiagramNode.position as Position(x, y) Pydantic model - Wrap AI response parsing in KeyError handler for clean 422 errors - Remove unused Optional/TYPE_CHECKING imports from schemas/models - Extract _get_available_slugs helper to DRY duplicate queries Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -62,6 +62,14 @@ def _diagram_to_list_item(diagram: NetworkDiagram) -> NetworkDiagramListItem:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]:
|
||||||
|
stmt = select(DeviceType.slug).where(
|
||||||
|
or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return {row[0] for row in result.all()}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/clients", response_model=list[str])
|
@router.get("/clients", response_model=list[str])
|
||||||
async def list_client_names(
|
async def list_client_names(
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
@@ -102,7 +110,8 @@ async def list_diagrams(
|
|||||||
stmt = stmt.where(NetworkDiagram.client_name == client_name)
|
stmt = stmt.where(NetworkDiagram.client_name == client_name)
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
search_filter = f"%{search}%"
|
escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||||
|
search_filter = f"%{escaped}%"
|
||||||
stmt = stmt.where(
|
stmt = stmt.where(
|
||||||
or_(
|
or_(
|
||||||
NetworkDiagram.name.ilike(search_filter),
|
NetworkDiagram.name.ilike(search_filter),
|
||||||
@@ -232,14 +241,7 @@ async def import_diagram(
|
|||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
) -> DiagramImportResponse:
|
) -> DiagramImportResponse:
|
||||||
available_stmt = select(DeviceType.slug).where(
|
available_slugs = await _get_available_slugs(current_user.team_id, db)
|
||||||
or_(
|
|
||||||
DeviceType.is_system.is_(True),
|
|
||||||
DeviceType.team_id == current_user.team_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await db.execute(available_stmt)
|
|
||||||
available_slugs = {row[0] for row in result.all()}
|
|
||||||
|
|
||||||
warnings: list[str] = []
|
warnings: list[str] = []
|
||||||
for node in data.nodes:
|
for node in data.nodes:
|
||||||
@@ -271,14 +273,8 @@ async def ai_generate_diagram(
|
|||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
) -> AIGenerateResponse:
|
) -> AIGenerateResponse:
|
||||||
stmt = select(DeviceType.slug).where(
|
available_slugs_set = await _get_available_slugs(current_user.team_id, db)
|
||||||
or_(
|
available_slugs = list(available_slugs_set)
|
||||||
DeviceType.is_system.is_(True),
|
|
||||||
DeviceType.team_id == current_user.team_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await db.execute(stmt)
|
|
||||||
available_slugs = [row[0] for row in result.all()]
|
|
||||||
|
|
||||||
existing_node_ids: list[str] | None = None
|
existing_node_ids: list[str] | None = None
|
||||||
if data.mode == "merge" and data.existingBounds:
|
if data.mode == "merge" and data.existingBounds:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Device type model for network diagrams."""
|
"""Device type model for network diagrams."""
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional, TYPE_CHECKING
|
|
||||||
|
|
||||||
from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey
|
from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
@@ -9,9 +8,6 @@ from sqlalchemy.dialects.postgresql import UUID
|
|||||||
|
|
||||||
from app.core.database import Base
|
from app.core.database import Base
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceType(Base):
|
class DeviceType(Base):
|
||||||
"""A device type for network diagram nodes (system or team-custom)."""
|
"""A device type for network diagram nodes (system or team-custom)."""
|
||||||
@@ -36,7 +32,7 @@ class DeviceType(Base):
|
|||||||
Boolean, nullable=False, default=False,
|
Boolean, nullable=False, default=False,
|
||||||
comment="True for built-in types that cannot be deleted",
|
comment="True for built-in types that cannot be deleted",
|
||||||
)
|
)
|
||||||
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
team_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("teams.id", ondelete="CASCADE"),
|
ForeignKey("teams.id", ondelete="CASCADE"),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Network diagram model."""
|
"""Network diagram model."""
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
|
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
@@ -27,16 +27,16 @@ class NetworkDiagram(Base):
|
|||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
client_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
client_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
asset_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
nodes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
||||||
edges: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
||||||
thumbnail_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
is_archived: Mapped[bool] = mapped_column(
|
is_archived: Mapped[bool] = mapped_column(
|
||||||
Boolean, nullable=False, default=False,
|
Boolean, nullable=False, default=False,
|
||||||
)
|
)
|
||||||
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
|
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("users.id"),
|
ForeignKey("users.id"),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
@@ -50,4 +50,4 @@ class NetworkDiagram(Base):
|
|||||||
onupdate=lambda: datetime.now(timezone.utc),
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by])
|
creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by])
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Pydantic schemas for device types."""
|
"""Pydantic schemas for device types."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
"""Pydantic schemas for network diagrams."""
|
"""Pydantic schemas for network diagrams."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Position(BaseModel):
|
||||||
|
x: float
|
||||||
|
y: float
|
||||||
|
|
||||||
|
|
||||||
class DeviceProperties(BaseModel):
|
class DeviceProperties(BaseModel):
|
||||||
hostname: str | None = None
|
hostname: str | None = None
|
||||||
ip: str | None = None
|
ip: str | None = None
|
||||||
@@ -22,7 +26,7 @@ class DiagramNode(BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
type: str
|
type: str
|
||||||
label: str
|
label: str
|
||||||
position: dict
|
position: Position
|
||||||
properties: DeviceProperties = Field(default_factory=DeviceProperties)
|
properties: DeviceProperties = Field(default_factory=DeviceProperties)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from app.schemas.network_diagram import (
|
|||||||
DiagramNode,
|
DiagramNode,
|
||||||
DiagramEdge,
|
DiagramEdge,
|
||||||
DeviceProperties,
|
DeviceProperties,
|
||||||
|
Position,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -108,35 +109,39 @@ async def generate_diagram(
|
|||||||
logger.error("Failed to parse AI response as JSON: %s", e)
|
logger.error("Failed to parse AI response as JSON: %s", e)
|
||||||
raise ValueError("AI generated an invalid response, please try again")
|
raise ValueError("AI generated an invalid response, please try again")
|
||||||
|
|
||||||
nodes = []
|
try:
|
||||||
for raw_node in data.get("nodes", []):
|
nodes = []
|
||||||
node_type = raw_node.get("type", "server")
|
for raw_node in data.get("nodes", []):
|
||||||
if node_type not in available_slugs:
|
node_type = raw_node.get("type", "server")
|
||||||
logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
|
if node_type not in available_slugs:
|
||||||
node_type = "server"
|
logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
|
||||||
|
node_type = "server"
|
||||||
|
|
||||||
nodes.append(DiagramNode(
|
nodes.append(DiagramNode(
|
||||||
id=raw_node["id"],
|
id=raw_node["id"],
|
||||||
type=node_type,
|
type=node_type,
|
||||||
label=raw_node.get("label", node_type),
|
label=raw_node.get("label", node_type),
|
||||||
position=raw_node.get("position", {"x": 0, "y": 0}),
|
position=Position(**raw_node.get("position", {"x": 0, "y": 0})),
|
||||||
properties=DeviceProperties(**{
|
properties=DeviceProperties(**{
|
||||||
k: v for k, v in raw_node.get("properties", {}).items()
|
k: v for k, v in raw_node.get("properties", {}).items()
|
||||||
if k in DeviceProperties.model_fields
|
if k in DeviceProperties.model_fields
|
||||||
}),
|
}),
|
||||||
))
|
))
|
||||||
|
|
||||||
edges = []
|
edges = []
|
||||||
for raw_edge in data.get("edges", []):
|
for raw_edge in data.get("edges", []):
|
||||||
edges.append(DiagramEdge(
|
edges.append(DiagramEdge(
|
||||||
id=raw_edge["id"],
|
id=raw_edge["id"],
|
||||||
source=raw_edge["source"],
|
source=raw_edge["source"],
|
||||||
target=raw_edge["target"],
|
target=raw_edge["target"],
|
||||||
label=raw_edge.get("label"),
|
label=raw_edge.get("label"),
|
||||||
connectionType=raw_edge.get("connectionType", "ethernet"),
|
connectionType=raw_edge.get("connectionType", "ethernet"),
|
||||||
speed=raw_edge.get("speed"),
|
speed=raw_edge.get("speed"),
|
||||||
notes=raw_edge.get("notes"),
|
notes=raw_edge.get("notes"),
|
||||||
))
|
))
|
||||||
|
except KeyError as e:
|
||||||
|
logger.warning("AI response missing required field: %s", e)
|
||||||
|
raise ValueError(f"AI generated incomplete data (missing {e}), please try again")
|
||||||
|
|
||||||
return AIGenerateResponse(
|
return AIGenerateResponse(
|
||||||
nodes=nodes,
|
nodes=nodes,
|
||||||
|
|||||||
Reference in New Issue
Block a user