diff --git a/backend/alembic/versions/073_add_device_types_table.py b/backend/alembic/versions/073_add_device_types_table.py index fb1b86e0..70b61e5d 100644 --- a/backend/alembic/versions/073_add_device_types_table.py +++ b/backend/alembic/versions/073_add_device_types_table.py @@ -1,8 +1,8 @@ -"""Add device_types table with system seed data. +"""Add account-scoped device_types table with platform seed data. Revision ID: 073 -Revises: 072 -Create Date: 2026-04-04 +Revises: b3c7e9f2a1d8 +Create Date: 2026-04-12 """ from alembic import op import sqlalchemy as sa @@ -11,10 +11,18 @@ import uuid revision = "073" -down_revision = "072" +down_revision = "b3c7e9f2a1d8" branch_labels = None depends_on = None +_PLATFORM_UUID = "00000000-0000-0000-0000-000000000001" +_CURRENT_ACCOUNT = ( + "COALESCE(" + "NULLIF(current_setting('app.current_account_id', TRUE), ''), " + "'00000000-0000-0000-0000-000000000000'" + ")::uuid" +) + SYSTEM_DEVICE_TYPES = [ ("router", "Router", "network", 0), ("switch", "Switch", "network", 1), @@ -55,16 +63,13 @@ def upgrade() -> None: sa.Column("label", sa.String(100), nullable=False), sa.Column("category", sa.String(50), nullable=False), sa.Column("is_system", sa.Boolean(), nullable=False, server_default=sa.text("false")), - sa.Column("team_id", UUID(as_uuid=True), sa.ForeignKey("teams.id", ondelete="CASCADE"), nullable=True), + sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False), sa.Column("sort_order", sa.Integer(), nullable=False, server_default=sa.text("0")), sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")), ) - op.execute( - "ALTER TABLE device_types ADD CONSTRAINT uq_device_types_slug_team " - "UNIQUE NULLS NOT DISTINCT (slug, team_id)" - ) - op.create_index("idx_device_types_team", "device_types", ["team_id"]) + op.create_unique_constraint("uq_device_types_slug_account", "device_types", ["slug", "account_id"]) + op.create_index("ix_device_types_account_id", "device_types", ["account_id"]) device_types_table = sa.table( "device_types", @@ -73,7 +78,7 @@ def upgrade() -> None: sa.column("label", sa.String), sa.column("category", sa.String), sa.column("is_system", sa.Boolean), - sa.column("team_id", UUID(as_uuid=True)), + sa.column("account_id", UUID(as_uuid=True)), sa.column("sort_order", sa.Integer), ) @@ -84,12 +89,44 @@ def upgrade() -> None: "label": label, "category": category, "is_system": True, - "team_id": None, + "account_id": uuid.UUID(_PLATFORM_UUID), "sort_order": sort_order, } for slug, label, category, sort_order in SYSTEM_DEVICE_TYPES ]) + op.execute("ALTER TABLE device_types ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE device_types FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY device_types_select ON device_types + FOR SELECT + USING ( + account_id = {_CURRENT_ACCOUNT} + OR account_id = '{_PLATFORM_UUID}'::uuid + ) + """) + op.execute(f""" + CREATE POLICY device_types_insert ON device_types + FOR INSERT + WITH CHECK (account_id = {_CURRENT_ACCOUNT}) + """) + op.execute(f""" + CREATE POLICY device_types_update ON device_types + FOR UPDATE + USING (account_id = {_CURRENT_ACCOUNT}) + WITH CHECK (account_id = {_CURRENT_ACCOUNT}) + """) + op.execute(f""" + CREATE POLICY device_types_delete ON device_types + FOR DELETE + USING (account_id = {_CURRENT_ACCOUNT}) + """) + def downgrade() -> None: + op.execute("DROP POLICY IF EXISTS device_types_delete ON device_types") + op.execute("DROP POLICY IF EXISTS device_types_update ON device_types") + op.execute("DROP POLICY IF EXISTS device_types_insert ON device_types") + op.execute("DROP POLICY IF EXISTS device_types_select ON device_types") + op.execute("ALTER TABLE device_types DISABLE ROW LEVEL SECURITY") op.drop_table("device_types") diff --git a/backend/alembic/versions/074_add_network_diagrams_table.py b/backend/alembic/versions/074_add_network_diagrams_table.py index a95bd89b..1cd7397e 100644 --- a/backend/alembic/versions/074_add_network_diagrams_table.py +++ b/backend/alembic/versions/074_add_network_diagrams_table.py @@ -2,7 +2,7 @@ Revision ID: 074 Revises: 073 -Create Date: 2026-04-04 +Create Date: 2026-04-12 """ from alembic import op import sqlalchemy as sa @@ -14,12 +14,19 @@ down_revision = "073" branch_labels = None depends_on = None +_CURRENT_ACCOUNT = ( + "COALESCE(" + "NULLIF(current_setting('app.current_account_id', TRUE), ''), " + "'00000000-0000-0000-0000-000000000000'" + ")::uuid" +) + def upgrade() -> None: op.create_table( "network_diagrams", sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), - sa.Column("team_id", UUID(as_uuid=True), sa.ForeignKey("teams.id", ondelete="CASCADE"), nullable=False), + sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False), sa.Column("name", sa.String(255), nullable=False), sa.Column("client_name", sa.String(255), nullable=True), sa.Column("asset_name", sa.String(255), nullable=True), @@ -33,9 +40,18 @@ def upgrade() -> None: sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")), ) - op.create_index("idx_network_diagrams_team", "network_diagrams", ["team_id"]) - op.create_index("idx_network_diagrams_client", "network_diagrams", ["team_id", "client_name"]) + op.create_index("ix_network_diagrams_account_id", "network_diagrams", ["account_id"]) + op.create_index("idx_network_diagrams_account_client", "network_diagrams", ["account_id", "client_name"]) + op.execute("ALTER TABLE network_diagrams ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE network_diagrams FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON network_diagrams + USING (account_id = {_CURRENT_ACCOUNT}) + WITH CHECK (account_id = {_CURRENT_ACCOUNT}) + """) def downgrade() -> None: + op.execute("DROP POLICY IF EXISTS tenant_isolation ON network_diagrams") + op.execute("ALTER TABLE network_diagrams DISABLE ROW LEVEL SECURITY") op.drop_table("network_diagrams") diff --git a/backend/app/api/endpoints/device_types.py b/backend/app/api/endpoints/device_types.py index 7daa409c..330c0e9f 100644 --- a/backend/app/api/endpoints/device_types.py +++ b/backend/app/api/endpoints/device_types.py @@ -15,6 +15,7 @@ from app.schemas.device_type import ( DeviceTypeUpdate, DeviceTypeResponse, ) +from app.core.service_account import PLATFORM_ACCOUNT_ID router = APIRouter(prefix="/device-types", tags=["device-types"]) @@ -28,8 +29,8 @@ async def list_device_types( select(DeviceType) .where( or_( - DeviceType.is_system.is_(True), - DeviceType.team_id == current_user.team_id, + DeviceType.account_id == PLATFORM_ACCOUNT_ID, + DeviceType.account_id == current_user.account_id, ) ) .order_by(DeviceType.category, DeviceType.sort_order, DeviceType.label) @@ -48,16 +49,16 @@ async def create_device_type( existing = await db.execute( select(DeviceType).where( DeviceType.slug == data.slug, - DeviceType.team_id == current_user.team_id, + DeviceType.account_id == current_user.account_id, ) ) if existing.scalar_one_or_none(): - raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your team") + raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your account") system_existing = await db.execute( select(DeviceType).where( DeviceType.slug == data.slug, - DeviceType.is_system.is_(True), + DeviceType.account_id == PLATFORM_ACCOUNT_ID, ) ) if system_existing.scalar_one_or_none(): @@ -68,7 +69,7 @@ async def create_device_type( label=data.label, category=data.category, is_system=False, - team_id=current_user.team_id, + account_id=current_user.account_id, sort_order=data.sort_order, ) db.add(device_type) @@ -89,7 +90,7 @@ async def update_device_type( raise HTTPException(status_code=404, detail="Device type not found") if device_type.is_system: raise HTTPException(status_code=403, detail="Cannot modify system device types") - if device_type.team_id != current_user.team_id: + if device_type.account_id != current_user.account_id: raise HTTPException(status_code=404, detail="Device type not found") update_data = data.model_dump(exclude_unset=True) @@ -112,7 +113,7 @@ async def delete_device_type( raise HTTPException(status_code=404, detail="Device type not found") if device_type.is_system: raise HTTPException(status_code=403, detail="Cannot delete system device types") - if device_type.team_id != current_user.team_id: + if device_type.account_id != current_user.account_id: raise HTTPException(status_code=404, detail="Device type not found") await db.delete(device_type) diff --git a/backend/app/api/endpoints/network_diagrams.py b/backend/app/api/endpoints/network_diagrams.py index c6841545..e00ecf7a 100644 --- a/backend/app/api/endpoints/network_diagrams.py +++ b/backend/app/api/endpoints/network_diagrams.py @@ -13,6 +13,7 @@ from app.api.deps import get_current_active_user from app.models.user import User from app.models.device_type import DeviceType from app.models.network_diagram import NetworkDiagram +from app.core.service_account import PLATFORM_ACCOUNT_ID from app.schemas.network_diagram import ( NetworkDiagramCreate, NetworkDiagramUpdate, @@ -49,11 +50,11 @@ router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"]) async def _get_diagram_or_404( diagram_id: UUID, - team_id: UUID, + account_id: UUID, db: AsyncSession, ) -> NetworkDiagram: diagram = await db.get(NetworkDiagram, diagram_id) - if not diagram or diagram.team_id != team_id or diagram.is_archived: + if not diagram or diagram.account_id != account_id or diagram.is_archived: raise HTTPException(status_code=404, detail="Diagram not found") return diagram @@ -88,9 +89,12 @@ def _diagram_to_list_item( ) -async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]: +async def _get_available_slugs(account_id: UUID, db: AsyncSession) -> set[str]: stmt = select(DeviceType.slug).where( - or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id) + or_( + DeviceType.account_id == PLATFORM_ACCOUNT_ID, + DeviceType.account_id == account_id, + ) ) result = await db.execute(stmt) return {row[0] for row in result.all()} @@ -104,7 +108,7 @@ async def list_client_names( stmt = ( select(NetworkDiagram.client_name) .where( - NetworkDiagram.team_id == current_user.team_id, + NetworkDiagram.account_id == current_user.account_id, NetworkDiagram.is_archived.is_(False), NetworkDiagram.client_name.isnot(None), NetworkDiagram.client_name != "", @@ -126,7 +130,7 @@ async def list_diagrams( stmt = ( select(NetworkDiagram) .where( - NetworkDiagram.team_id == current_user.team_id, + NetworkDiagram.account_id == current_user.account_id, NetworkDiagram.is_archived.is_(False), ) .order_by(NetworkDiagram.updated_at.desc()) @@ -148,7 +152,7 @@ async def list_diagrams( # Single query for custom device types so category_counts is accurate dt_stmt = select(DeviceType.slug, DeviceType.category).where( DeviceType.is_system.is_(False), - DeviceType.team_id == current_user.team_id, + DeviceType.account_id == current_user.account_id, ) dt_result = await db.execute(dt_stmt) custom_slug_category = {row[0]: row[1] for row in dt_result.all()} @@ -164,13 +168,8 @@ async def create_diagram( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: - if current_user.team_id is None: - raise HTTPException( - status_code=422, - detail="Network Diagrams require a team account. Assign your account to a team first.", - ) diagram = NetworkDiagram( - team_id=current_user.team_id, + account_id=current_user.account_id, name=data.name, client_name=data.client_name, asset_name=data.asset_name, @@ -191,7 +190,7 @@ async def get_diagram( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: - diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) + diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) return _diagram_to_response(diagram) @@ -202,7 +201,7 @@ async def update_diagram( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: - diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) + diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) update_data = data.model_dump(exclude_unset=True) if "nodes" in update_data and update_data["nodes"] is not None: @@ -225,7 +224,7 @@ async def archive_diagram( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> None: - diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) + diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) diagram.is_archived = True diagram.updated_at = datetime.now(timezone.utc) await db.commit() @@ -237,9 +236,9 @@ async def duplicate_diagram( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> NetworkDiagramResponse: - source = await _get_diagram_or_404(diagram_id, current_user.team_id, db) + source = await _get_diagram_or_404(diagram_id, current_user.account_id, db) copy = NetworkDiagram( - team_id=current_user.team_id, + account_id=current_user.account_id, name=f"Copy of {source.name}", client_name=source.client_name, asset_name=source.asset_name, @@ -260,7 +259,7 @@ async def export_diagram( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> DiagramExportResponse: - diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) + diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db) nodes = [DiagramNode(**n) for n in (diagram.nodes or [])] edges = [DiagramEdge(**e) for e in (diagram.edges or [])] return DiagramExportResponse( @@ -280,7 +279,7 @@ async def import_diagram( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> DiagramImportResponse: - available_slugs = await _get_available_slugs(current_user.team_id, db) + available_slugs = await _get_available_slugs(current_user.account_id, db) warnings: list[str] = [] for node in data.nodes: @@ -288,7 +287,7 @@ async def import_diagram( warnings.append(f"Unknown device type '{node.type}' — will render with default icon") diagram = NetworkDiagram( - team_id=current_user.team_id, + account_id=current_user.account_id, name=data.name, client_name=data.client_name, description=data.description, @@ -312,7 +311,7 @@ async def ai_generate_diagram( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> AIGenerateResponse: - available_slugs_set = await _get_available_slugs(current_user.team_id, db) + available_slugs_set = await _get_available_slugs(current_user.account_id, db) available_slugs = list(available_slugs_set) existing_node_ids: list[str] | None = None diff --git a/backend/app/models/device_type.py b/backend/app/models/device_type.py index 85da01e6..d2f9c756 100644 --- a/backend/app/models/device_type.py +++ b/backend/app/models/device_type.py @@ -10,7 +10,7 @@ from app.core.database import Base class DeviceType(Base): - """A device type for network diagram nodes (system or team-custom).""" + """A device type for network diagram nodes (platform or account-custom).""" __tablename__ = "device_types" id: Mapped[uuid.UUID] = mapped_column( @@ -32,11 +32,11 @@ class DeviceType(Base): Boolean, nullable=False, default=False, comment="True for built-in types that cannot be deleted", ) - team_id: Mapped[uuid.UUID | None] = mapped_column( + account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), - ForeignKey("teams.id", ondelete="CASCADE"), - nullable=True, - comment="NULL for system types, set for team-custom types", + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + comment="Platform account for system types, tenant account for custom types", ) sort_order: Mapped[int] = mapped_column( Integer, nullable=False, default=0, diff --git a/backend/app/models/network_diagram.py b/backend/app/models/network_diagram.py index 63932b5a..347216da 100644 --- a/backend/app/models/network_diagram.py +++ b/backend/app/models/network_diagram.py @@ -14,15 +14,15 @@ if TYPE_CHECKING: class NetworkDiagram(Base): - """A network topology diagram, team-scoped.""" + """A network topology diagram scoped to one account.""" __tablename__ = "network_diagrams" id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 ) - team_id: Mapped[uuid.UUID] = mapped_column( + account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), - ForeignKey("teams.id", ondelete="CASCADE"), + ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True, ) diff --git a/backend/app/schemas/device_type.py b/backend/app/schemas/device_type.py index aeab8ff5..665fd8a2 100644 --- a/backend/app/schemas/device_type.py +++ b/backend/app/schemas/device_type.py @@ -30,7 +30,7 @@ class DeviceTypeResponse(BaseModel): label: str category: str is_system: bool - team_id: UUID | None = None + account_id: UUID sort_order: int created_at: datetime diff --git a/backend/app/schemas/network_diagram.py b/backend/app/schemas/network_diagram.py index 24b98264..e31d5283 100644 --- a/backend/app/schemas/network_diagram.py +++ b/backend/app/schemas/network_diagram.py @@ -61,7 +61,7 @@ class NetworkDiagramUpdate(BaseModel): class NetworkDiagramResponse(BaseModel): id: UUID - team_id: UUID + account_id: UUID name: str client_name: str | None = None asset_name: str | None = None diff --git a/backend/tests/test_network_diagrams.py b/backend/tests/test_network_diagrams.py new file mode 100644 index 00000000..6b6cb93c --- /dev/null +++ b/backend/tests/test_network_diagrams.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import uuid + +import pytest +from sqlalchemy import select + +from app.models.device_type import DeviceType +from app.models.user import User +from app.core.service_account import PLATFORM_ACCOUNT_ID + + +async def _login_headers(client, email: str, password: str) -> dict[str, str]: + response = await client.post( + "/api/v1/auth/login/json", + json={"email": email, "password": password}, + ) + assert response.status_code == 200 + token = response.json()["access_token"] + return {"Authorization": f"Bearer {token}"} + + +@pytest.mark.asyncio +async def test_device_types_include_platform_and_account_custom(client, test_db, auth_headers, test_user): + result = await test_db.execute(select(User).where(User.email == test_user["email"])) + user = result.scalar_one() + + test_db.add( + DeviceType( + id=uuid.uuid4(), + slug="platform-router", + label="Platform Router", + category="network", + is_system=True, + account_id=PLATFORM_ACCOUNT_ID, + sort_order=0, + ) + ) + await test_db.commit() + + create_response = await client.post( + "/api/v1/device-types/", + json={ + "slug": "tenant-appliance", + "label": "Tenant Appliance", + "category": "network", + "sort_order": 3, + }, + headers=auth_headers, + ) + assert create_response.status_code == 201 + assert create_response.json()["account_id"] == str(user.account_id) + + list_response = await client.get("/api/v1/device-types/", headers=auth_headers) + assert list_response.status_code == 200 + payload = list_response.json() + slugs = {item["slug"] for item in payload} + + assert "platform-router" in slugs + assert "tenant-appliance" in slugs + + +@pytest.mark.asyncio +async def test_network_diagrams_are_account_scoped(client, test_db, auth_headers, test_user): + other_user = { + "email": "other-network@example.com", + "password": "TestPassword123!", + "name": "Other Network User", + } + register_response = await client.post("/api/v1/auth/register", json=other_user) + assert register_response.status_code in (200, 201) + other_headers = await _login_headers(client, other_user["email"], other_user["password"]) + + owner_result = await test_db.execute(select(User).where(User.email == test_user["email"])) + owner = owner_result.scalar_one() + + create_response = await client.post( + "/api/v1/network-diagrams/", + json={ + "name": "HQ Core", + "client_name": "Acme", + "description": "Primary topology", + "nodes": [], + "edges": [], + }, + headers=auth_headers, + ) + assert create_response.status_code == 201 + diagram = create_response.json() + assert diagram["account_id"] == str(owner.account_id) + + own_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=auth_headers) + assert own_get.status_code == 200 + + other_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=other_headers) + assert other_get.status_code == 404 diff --git a/frontend/src/types/network-diagram.ts b/frontend/src/types/network-diagram.ts index 77d92dbe..878984b7 100644 --- a/frontend/src/types/network-diagram.ts +++ b/frontend/src/types/network-diagram.ts @@ -37,7 +37,7 @@ export interface DeviceTypeResponse { label: string category: string is_system: boolean - team_id: string | null + account_id: string sort_order: number created_at: string } @@ -51,7 +51,7 @@ export interface DeviceTypeCreate { export interface NetworkDiagramResponse { id: string - team_id: string + account_id: string name: string client_name: string | null asset_name: string | null