fix: align network map builder with account isolation

This commit is contained in:
chihlasm
2026-04-12 05:05:27 +00:00
parent bb24078d60
commit 3c2b1dd16e
10 changed files with 207 additions and 58 deletions

View File

@@ -1,8 +1,8 @@
"""Add device_types table with system seed data.
"""Add account-scoped device_types table with platform seed data.
Revision ID: 073
Revises: 072
Create Date: 2026-04-04
Revises: b3c7e9f2a1d8
Create Date: 2026-04-12
"""
from alembic import op
import sqlalchemy as sa
@@ -11,10 +11,18 @@ import uuid
revision = "073"
down_revision = "072"
down_revision = "b3c7e9f2a1d8"
branch_labels = None
depends_on = None
_PLATFORM_UUID = "00000000-0000-0000-0000-000000000001"
_CURRENT_ACCOUNT = (
"COALESCE("
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
"'00000000-0000-0000-0000-000000000000'"
")::uuid"
)
SYSTEM_DEVICE_TYPES = [
("router", "Router", "network", 0),
("switch", "Switch", "network", 1),
@@ -55,16 +63,13 @@ def upgrade() -> None:
sa.Column("label", sa.String(100), nullable=False),
sa.Column("category", sa.String(50), nullable=False),
sa.Column("is_system", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.Column("team_id", UUID(as_uuid=True), sa.ForeignKey("teams.id", ondelete="CASCADE"), nullable=True),
sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False),
sa.Column("sort_order", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
)
op.execute(
"ALTER TABLE device_types ADD CONSTRAINT uq_device_types_slug_team "
"UNIQUE NULLS NOT DISTINCT (slug, team_id)"
)
op.create_index("idx_device_types_team", "device_types", ["team_id"])
op.create_unique_constraint("uq_device_types_slug_account", "device_types", ["slug", "account_id"])
op.create_index("ix_device_types_account_id", "device_types", ["account_id"])
device_types_table = sa.table(
"device_types",
@@ -73,7 +78,7 @@ def upgrade() -> None:
sa.column("label", sa.String),
sa.column("category", sa.String),
sa.column("is_system", sa.Boolean),
sa.column("team_id", UUID(as_uuid=True)),
sa.column("account_id", UUID(as_uuid=True)),
sa.column("sort_order", sa.Integer),
)
@@ -84,12 +89,44 @@ def upgrade() -> None:
"label": label,
"category": category,
"is_system": True,
"team_id": None,
"account_id": uuid.UUID(_PLATFORM_UUID),
"sort_order": sort_order,
}
for slug, label, category, sort_order in SYSTEM_DEVICE_TYPES
])
op.execute("ALTER TABLE device_types ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE device_types FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY device_types_select ON device_types
FOR SELECT
USING (
account_id = {_CURRENT_ACCOUNT}
OR account_id = '{_PLATFORM_UUID}'::uuid
)
""")
op.execute(f"""
CREATE POLICY device_types_insert ON device_types
FOR INSERT
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
""")
op.execute(f"""
CREATE POLICY device_types_update ON device_types
FOR UPDATE
USING (account_id = {_CURRENT_ACCOUNT})
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
""")
op.execute(f"""
CREATE POLICY device_types_delete ON device_types
FOR DELETE
USING (account_id = {_CURRENT_ACCOUNT})
""")
def downgrade() -> None:
op.execute("DROP POLICY IF EXISTS device_types_delete ON device_types")
op.execute("DROP POLICY IF EXISTS device_types_update ON device_types")
op.execute("DROP POLICY IF EXISTS device_types_insert ON device_types")
op.execute("DROP POLICY IF EXISTS device_types_select ON device_types")
op.execute("ALTER TABLE device_types DISABLE ROW LEVEL SECURITY")
op.drop_table("device_types")

View File

@@ -2,7 +2,7 @@
Revision ID: 074
Revises: 073
Create Date: 2026-04-04
Create Date: 2026-04-12
"""
from alembic import op
import sqlalchemy as sa
@@ -14,12 +14,19 @@ down_revision = "073"
branch_labels = None
depends_on = None
_CURRENT_ACCOUNT = (
"COALESCE("
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
"'00000000-0000-0000-0000-000000000000'"
")::uuid"
)
def upgrade() -> None:
op.create_table(
"network_diagrams",
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column("team_id", UUID(as_uuid=True), sa.ForeignKey("teams.id", ondelete="CASCADE"), nullable=False),
sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("client_name", sa.String(255), nullable=True),
sa.Column("asset_name", sa.String(255), nullable=True),
@@ -33,9 +40,18 @@ def upgrade() -> None:
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
)
op.create_index("idx_network_diagrams_team", "network_diagrams", ["team_id"])
op.create_index("idx_network_diagrams_client", "network_diagrams", ["team_id", "client_name"])
op.create_index("ix_network_diagrams_account_id", "network_diagrams", ["account_id"])
op.create_index("idx_network_diagrams_account_client", "network_diagrams", ["account_id", "client_name"])
op.execute("ALTER TABLE network_diagrams ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE network_diagrams FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON network_diagrams
USING (account_id = {_CURRENT_ACCOUNT})
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
""")
def downgrade() -> None:
op.execute("DROP POLICY IF EXISTS tenant_isolation ON network_diagrams")
op.execute("ALTER TABLE network_diagrams DISABLE ROW LEVEL SECURITY")
op.drop_table("network_diagrams")

View File

@@ -15,6 +15,7 @@ from app.schemas.device_type import (
DeviceTypeUpdate,
DeviceTypeResponse,
)
from app.core.service_account import PLATFORM_ACCOUNT_ID
router = APIRouter(prefix="/device-types", tags=["device-types"])
@@ -28,8 +29,8 @@ async def list_device_types(
select(DeviceType)
.where(
or_(
DeviceType.is_system.is_(True),
DeviceType.team_id == current_user.team_id,
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
DeviceType.account_id == current_user.account_id,
)
)
.order_by(DeviceType.category, DeviceType.sort_order, DeviceType.label)
@@ -48,16 +49,16 @@ async def create_device_type(
existing = await db.execute(
select(DeviceType).where(
DeviceType.slug == data.slug,
DeviceType.team_id == current_user.team_id,
DeviceType.account_id == current_user.account_id,
)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your team")
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your account")
system_existing = await db.execute(
select(DeviceType).where(
DeviceType.slug == data.slug,
DeviceType.is_system.is_(True),
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
)
)
if system_existing.scalar_one_or_none():
@@ -68,7 +69,7 @@ async def create_device_type(
label=data.label,
category=data.category,
is_system=False,
team_id=current_user.team_id,
account_id=current_user.account_id,
sort_order=data.sort_order,
)
db.add(device_type)
@@ -89,7 +90,7 @@ async def update_device_type(
raise HTTPException(status_code=404, detail="Device type not found")
if device_type.is_system:
raise HTTPException(status_code=403, detail="Cannot modify system device types")
if device_type.team_id != current_user.team_id:
if device_type.account_id != current_user.account_id:
raise HTTPException(status_code=404, detail="Device type not found")
update_data = data.model_dump(exclude_unset=True)
@@ -112,7 +113,7 @@ async def delete_device_type(
raise HTTPException(status_code=404, detail="Device type not found")
if device_type.is_system:
raise HTTPException(status_code=403, detail="Cannot delete system device types")
if device_type.team_id != current_user.team_id:
if device_type.account_id != current_user.account_id:
raise HTTPException(status_code=404, detail="Device type not found")
await db.delete(device_type)

View File

@@ -13,6 +13,7 @@ from app.api.deps import get_current_active_user
from app.models.user import User
from app.models.device_type import DeviceType
from app.models.network_diagram import NetworkDiagram
from app.core.service_account import PLATFORM_ACCOUNT_ID
from app.schemas.network_diagram import (
NetworkDiagramCreate,
NetworkDiagramUpdate,
@@ -49,11 +50,11 @@ router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"])
async def _get_diagram_or_404(
diagram_id: UUID,
team_id: UUID,
account_id: UUID,
db: AsyncSession,
) -> NetworkDiagram:
diagram = await db.get(NetworkDiagram, diagram_id)
if not diagram or diagram.team_id != team_id or diagram.is_archived:
if not diagram or diagram.account_id != account_id or diagram.is_archived:
raise HTTPException(status_code=404, detail="Diagram not found")
return diagram
@@ -88,9 +89,12 @@ def _diagram_to_list_item(
)
async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]:
async def _get_available_slugs(account_id: UUID, db: AsyncSession) -> set[str]:
stmt = select(DeviceType.slug).where(
or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id)
or_(
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
DeviceType.account_id == account_id,
)
)
result = await db.execute(stmt)
return {row[0] for row in result.all()}
@@ -104,7 +108,7 @@ async def list_client_names(
stmt = (
select(NetworkDiagram.client_name)
.where(
NetworkDiagram.team_id == current_user.team_id,
NetworkDiagram.account_id == current_user.account_id,
NetworkDiagram.is_archived.is_(False),
NetworkDiagram.client_name.isnot(None),
NetworkDiagram.client_name != "",
@@ -126,7 +130,7 @@ async def list_diagrams(
stmt = (
select(NetworkDiagram)
.where(
NetworkDiagram.team_id == current_user.team_id,
NetworkDiagram.account_id == current_user.account_id,
NetworkDiagram.is_archived.is_(False),
)
.order_by(NetworkDiagram.updated_at.desc())
@@ -148,7 +152,7 @@ async def list_diagrams(
# Single query for custom device types so category_counts is accurate
dt_stmt = select(DeviceType.slug, DeviceType.category).where(
DeviceType.is_system.is_(False),
DeviceType.team_id == current_user.team_id,
DeviceType.account_id == current_user.account_id,
)
dt_result = await db.execute(dt_stmt)
custom_slug_category = {row[0]: row[1] for row in dt_result.all()}
@@ -164,13 +168,8 @@ async def create_diagram(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
if current_user.team_id is None:
raise HTTPException(
status_code=422,
detail="Network Diagrams require a team account. Assign your account to a team first.",
)
diagram = NetworkDiagram(
team_id=current_user.team_id,
account_id=current_user.account_id,
name=data.name,
client_name=data.client_name,
asset_name=data.asset_name,
@@ -191,7 +190,7 @@ async def get_diagram(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
return _diagram_to_response(diagram)
@@ -202,7 +201,7 @@ async def update_diagram(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
update_data = data.model_dump(exclude_unset=True)
if "nodes" in update_data and update_data["nodes"] is not None:
@@ -225,7 +224,7 @@ async def archive_diagram(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> None:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
diagram.is_archived = True
diagram.updated_at = datetime.now(timezone.utc)
await db.commit()
@@ -237,9 +236,9 @@ async def duplicate_diagram(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
source = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
source = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
copy = NetworkDiagram(
team_id=current_user.team_id,
account_id=current_user.account_id,
name=f"Copy of {source.name}",
client_name=source.client_name,
asset_name=source.asset_name,
@@ -260,7 +259,7 @@ async def export_diagram(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramExportResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
nodes = [DiagramNode(**n) for n in (diagram.nodes or [])]
edges = [DiagramEdge(**e) for e in (diagram.edges or [])]
return DiagramExportResponse(
@@ -280,7 +279,7 @@ async def import_diagram(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramImportResponse:
available_slugs = await _get_available_slugs(current_user.team_id, db)
available_slugs = await _get_available_slugs(current_user.account_id, db)
warnings: list[str] = []
for node in data.nodes:
@@ -288,7 +287,7 @@ async def import_diagram(
warnings.append(f"Unknown device type '{node.type}' — will render with default icon")
diagram = NetworkDiagram(
team_id=current_user.team_id,
account_id=current_user.account_id,
name=data.name,
client_name=data.client_name,
description=data.description,
@@ -312,7 +311,7 @@ async def ai_generate_diagram(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> AIGenerateResponse:
available_slugs_set = await _get_available_slugs(current_user.team_id, db)
available_slugs_set = await _get_available_slugs(current_user.account_id, db)
available_slugs = list(available_slugs_set)
existing_node_ids: list[str] | None = None

View File

@@ -10,7 +10,7 @@ from app.core.database import Base
class DeviceType(Base):
"""A device type for network diagram nodes (system or team-custom)."""
"""A device type for network diagram nodes (platform or account-custom)."""
__tablename__ = "device_types"
id: Mapped[uuid.UUID] = mapped_column(
@@ -32,11 +32,11 @@ class DeviceType(Base):
Boolean, nullable=False, default=False,
comment="True for built-in types that cannot be deleted",
)
team_id: Mapped[uuid.UUID | None] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("teams.id", ondelete="CASCADE"),
nullable=True,
comment="NULL for system types, set for team-custom types",
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
comment="Platform account for system types, tenant account for custom types",
)
sort_order: Mapped[int] = mapped_column(
Integer, nullable=False, default=0,

View File

@@ -14,15 +14,15 @@ if TYPE_CHECKING:
class NetworkDiagram(Base):
"""A network topology diagram, team-scoped."""
"""A network topology diagram scoped to one account."""
__tablename__ = "network_diagrams"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
team_id: Mapped[uuid.UUID] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("teams.id", ondelete="CASCADE"),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)

View File

@@ -30,7 +30,7 @@ class DeviceTypeResponse(BaseModel):
label: str
category: str
is_system: bool
team_id: UUID | None = None
account_id: UUID
sort_order: int
created_at: datetime

View File

@@ -61,7 +61,7 @@ class NetworkDiagramUpdate(BaseModel):
class NetworkDiagramResponse(BaseModel):
id: UUID
team_id: UUID
account_id: UUID
name: str
client_name: str | None = None
asset_name: str | None = None

View File

@@ -0,0 +1,96 @@
from __future__ import annotations
import uuid
import pytest
from sqlalchemy import select
from app.models.device_type import DeviceType
from app.models.user import User
from app.core.service_account import PLATFORM_ACCOUNT_ID
async def _login_headers(client, email: str, password: str) -> dict[str, str]:
response = await client.post(
"/api/v1/auth/login/json",
json={"email": email, "password": password},
)
assert response.status_code == 200
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
@pytest.mark.asyncio
async def test_device_types_include_platform_and_account_custom(client, test_db, auth_headers, test_user):
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
user = result.scalar_one()
test_db.add(
DeviceType(
id=uuid.uuid4(),
slug="platform-router",
label="Platform Router",
category="network",
is_system=True,
account_id=PLATFORM_ACCOUNT_ID,
sort_order=0,
)
)
await test_db.commit()
create_response = await client.post(
"/api/v1/device-types/",
json={
"slug": "tenant-appliance",
"label": "Tenant Appliance",
"category": "network",
"sort_order": 3,
},
headers=auth_headers,
)
assert create_response.status_code == 201
assert create_response.json()["account_id"] == str(user.account_id)
list_response = await client.get("/api/v1/device-types/", headers=auth_headers)
assert list_response.status_code == 200
payload = list_response.json()
slugs = {item["slug"] for item in payload}
assert "platform-router" in slugs
assert "tenant-appliance" in slugs
@pytest.mark.asyncio
async def test_network_diagrams_are_account_scoped(client, test_db, auth_headers, test_user):
other_user = {
"email": "other-network@example.com",
"password": "TestPassword123!",
"name": "Other Network User",
}
register_response = await client.post("/api/v1/auth/register", json=other_user)
assert register_response.status_code in (200, 201)
other_headers = await _login_headers(client, other_user["email"], other_user["password"])
owner_result = await test_db.execute(select(User).where(User.email == test_user["email"]))
owner = owner_result.scalar_one()
create_response = await client.post(
"/api/v1/network-diagrams/",
json={
"name": "HQ Core",
"client_name": "Acme",
"description": "Primary topology",
"nodes": [],
"edges": [],
},
headers=auth_headers,
)
assert create_response.status_code == 201
diagram = create_response.json()
assert diagram["account_id"] == str(owner.account_id)
own_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=auth_headers)
assert own_get.status_code == 200
other_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=other_headers)
assert other_get.status_code == 404

View File

@@ -37,7 +37,7 @@ export interface DeviceTypeResponse {
label: string
category: string
is_system: boolean
team_id: string | null
account_id: string
sort_order: number
created_at: string
}
@@ -51,7 +51,7 @@ export interface DeviceTypeCreate {
export interface NetworkDiagramResponse {
id: string
team_id: string
account_id: string
name: string
client_name: string | null
asset_name: string | null