diff --git a/backend/app/api/endpoints/network_diagrams.py b/backend/app/api/endpoints/network_diagrams.py index a54478d2..93c9c29e 100644 --- a/backend/app/api/endpoints/network_diagrams.py +++ b/backend/app/api/endpoints/network_diagrams.py @@ -62,6 +62,14 @@ def _diagram_to_list_item(diagram: NetworkDiagram) -> NetworkDiagramListItem: ) +async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]: + stmt = select(DeviceType.slug).where( + or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id) + ) + result = await db.execute(stmt) + return {row[0] for row in result.all()} + + @router.get("/clients", response_model=list[str]) async def list_client_names( db: Annotated[AsyncSession, Depends(get_db)], @@ -102,7 +110,8 @@ async def list_diagrams( stmt = stmt.where(NetworkDiagram.client_name == client_name) if search: - search_filter = f"%{search}%" + escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + search_filter = f"%{escaped}%" stmt = stmt.where( or_( NetworkDiagram.name.ilike(search_filter), @@ -232,14 +241,7 @@ async def import_diagram( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> DiagramImportResponse: - available_stmt = select(DeviceType.slug).where( - or_( - DeviceType.is_system.is_(True), - DeviceType.team_id == current_user.team_id, - ) - ) - result = await db.execute(available_stmt) - available_slugs = {row[0] for row in result.all()} + available_slugs = await _get_available_slugs(current_user.team_id, db) warnings: list[str] = [] for node in data.nodes: @@ -271,14 +273,8 @@ async def ai_generate_diagram( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> AIGenerateResponse: - stmt = select(DeviceType.slug).where( - or_( - DeviceType.is_system.is_(True), - DeviceType.team_id == current_user.team_id, - ) - ) - result = await db.execute(stmt) - available_slugs = [row[0] for row in result.all()] + available_slugs_set = await _get_available_slugs(current_user.team_id, db) + available_slugs = list(available_slugs_set) existing_node_ids: list[str] | None = None if data.mode == "merge" and data.existingBounds: diff --git a/backend/app/models/device_type.py b/backend/app/models/device_type.py index cb419cfb..85da01e6 100644 --- a/backend/app/models/device_type.py +++ b/backend/app/models/device_type.py @@ -1,7 +1,6 @@ """Device type model for network diagrams.""" import uuid from datetime import datetime, timezone -from typing import Optional, TYPE_CHECKING from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey from sqlalchemy.orm import Mapped, mapped_column @@ -9,9 +8,6 @@ from sqlalchemy.dialects.postgresql import UUID from app.core.database import Base -if TYPE_CHECKING: - pass - class DeviceType(Base): """A device type for network diagram nodes (system or team-custom).""" @@ -36,7 +32,7 @@ class DeviceType(Base): Boolean, nullable=False, default=False, comment="True for built-in types that cannot be deleted", ) - team_id: Mapped[Optional[uuid.UUID]] = mapped_column( + team_id: Mapped[uuid.UUID | None] = mapped_column( UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), nullable=True, diff --git a/backend/app/models/network_diagram.py b/backend/app/models/network_diagram.py index a8ba7092..63932b5a 100644 --- a/backend/app/models/network_diagram.py +++ b/backend/app/models/network_diagram.py @@ -1,7 +1,7 @@ """Network diagram model.""" import uuid from datetime import datetime, timezone -from typing import Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -27,16 +27,16 @@ class NetworkDiagram(Base): index=True, ) name: Mapped[str] = mapped_column(String(255), nullable=False) - client_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - asset_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) - nodes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'") - edges: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'") - thumbnail_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + client_name: Mapped[str | None] = mapped_column(String(255), nullable=True) + asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'") + edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'") + thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True) is_archived: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False, ) - created_by: Mapped[Optional[uuid.UUID]] = mapped_column( + created_by: Mapped[uuid.UUID | None] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id"), nullable=True, @@ -50,4 +50,4 @@ class NetworkDiagram(Base): onupdate=lambda: datetime.now(timezone.utc), ) - creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by]) + creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by]) diff --git a/backend/app/schemas/device_type.py b/backend/app/schemas/device_type.py index 8d2b70a2..aeab8ff5 100644 --- a/backend/app/schemas/device_type.py +++ b/backend/app/schemas/device_type.py @@ -1,6 +1,5 @@ """Pydantic schemas for device types.""" from datetime import datetime -from typing import Optional from uuid import UUID from pydantic import BaseModel, Field diff --git a/backend/app/schemas/network_diagram.py b/backend/app/schemas/network_diagram.py index 2b25356a..3adc6360 100644 --- a/backend/app/schemas/network_diagram.py +++ b/backend/app/schemas/network_diagram.py @@ -1,11 +1,15 @@ """Pydantic schemas for network diagrams.""" from datetime import datetime -from typing import Optional from uuid import UUID from pydantic import BaseModel, Field +class Position(BaseModel): + x: float + y: float + + class DeviceProperties(BaseModel): hostname: str | None = None ip: str | None = None @@ -22,7 +26,7 @@ class DiagramNode(BaseModel): id: str type: str label: str - position: dict + position: Position properties: DeviceProperties = Field(default_factory=DeviceProperties) diff --git a/backend/app/services/network_diagram_ai_service.py b/backend/app/services/network_diagram_ai_service.py index 61defa50..5ac0df9c 100644 --- a/backend/app/services/network_diagram_ai_service.py +++ b/backend/app/services/network_diagram_ai_service.py @@ -10,6 +10,7 @@ from app.schemas.network_diagram import ( DiagramNode, DiagramEdge, DeviceProperties, + Position, ) logger = logging.getLogger(__name__) @@ -108,35 +109,39 @@ async def generate_diagram( logger.error("Failed to parse AI response as JSON: %s", e) raise ValueError("AI generated an invalid response, please try again") - nodes = [] - for raw_node in data.get("nodes", []): - node_type = raw_node.get("type", "server") - if node_type not in available_slugs: - logger.warning("Unknown device type '%s', falling back to 'server'", node_type) - node_type = "server" + try: + nodes = [] + for raw_node in data.get("nodes", []): + node_type = raw_node.get("type", "server") + if node_type not in available_slugs: + logger.warning("Unknown device type '%s', falling back to 'server'", node_type) + node_type = "server" - nodes.append(DiagramNode( - id=raw_node["id"], - type=node_type, - label=raw_node.get("label", node_type), - position=raw_node.get("position", {"x": 0, "y": 0}), - properties=DeviceProperties(**{ - k: v for k, v in raw_node.get("properties", {}).items() - if k in DeviceProperties.model_fields - }), - )) + nodes.append(DiagramNode( + id=raw_node["id"], + type=node_type, + label=raw_node.get("label", node_type), + position=Position(**raw_node.get("position", {"x": 0, "y": 0})), + properties=DeviceProperties(**{ + k: v for k, v in raw_node.get("properties", {}).items() + if k in DeviceProperties.model_fields + }), + )) - edges = [] - for raw_edge in data.get("edges", []): - edges.append(DiagramEdge( - id=raw_edge["id"], - source=raw_edge["source"], - target=raw_edge["target"], - label=raw_edge.get("label"), - connectionType=raw_edge.get("connectionType", "ethernet"), - speed=raw_edge.get("speed"), - notes=raw_edge.get("notes"), - )) + edges = [] + for raw_edge in data.get("edges", []): + edges.append(DiagramEdge( + id=raw_edge["id"], + source=raw_edge["source"], + target=raw_edge["target"], + label=raw_edge.get("label"), + connectionType=raw_edge.get("connectionType", "ethernet"), + speed=raw_edge.get("speed"), + notes=raw_edge.get("notes"), + )) + except KeyError as e: + logger.warning("AI response missing required field: %s", e) + raise ValueError(f"AI generated incomplete data (missing {e}), please try again") return AIGenerateResponse( nodes=nodes,