fix: align network map builder with account isolation

This commit is contained in:
chihlasm
2026-04-12 05:05:27 +00:00
parent bb24078d60
commit 3c2b1dd16e
10 changed files with 207 additions and 58 deletions

View File

@@ -1,8 +1,8 @@
"""Add device_types table with system seed data. """Add account-scoped device_types table with platform seed data.
Revision ID: 073 Revision ID: 073
Revises: 072 Revises: b3c7e9f2a1d8
Create Date: 2026-04-04 Create Date: 2026-04-12
""" """
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@@ -11,10 +11,18 @@ import uuid
revision = "073" revision = "073"
down_revision = "072" down_revision = "b3c7e9f2a1d8"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
_PLATFORM_UUID = "00000000-0000-0000-0000-000000000001"
_CURRENT_ACCOUNT = (
"COALESCE("
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
"'00000000-0000-0000-0000-000000000000'"
")::uuid"
)
SYSTEM_DEVICE_TYPES = [ SYSTEM_DEVICE_TYPES = [
("router", "Router", "network", 0), ("router", "Router", "network", 0),
("switch", "Switch", "network", 1), ("switch", "Switch", "network", 1),
@@ -55,16 +63,13 @@ def upgrade() -> None:
sa.Column("label", sa.String(100), nullable=False), sa.Column("label", sa.String(100), nullable=False),
sa.Column("category", sa.String(50), nullable=False), sa.Column("category", sa.String(50), nullable=False),
sa.Column("is_system", sa.Boolean(), nullable=False, server_default=sa.text("false")), sa.Column("is_system", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.Column("team_id", UUID(as_uuid=True), sa.ForeignKey("teams.id", ondelete="CASCADE"), nullable=True), sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False),
sa.Column("sort_order", sa.Integer(), nullable=False, server_default=sa.text("0")), sa.Column("sort_order", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")), sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
) )
op.execute( op.create_unique_constraint("uq_device_types_slug_account", "device_types", ["slug", "account_id"])
"ALTER TABLE device_types ADD CONSTRAINT uq_device_types_slug_team " op.create_index("ix_device_types_account_id", "device_types", ["account_id"])
"UNIQUE NULLS NOT DISTINCT (slug, team_id)"
)
op.create_index("idx_device_types_team", "device_types", ["team_id"])
device_types_table = sa.table( device_types_table = sa.table(
"device_types", "device_types",
@@ -73,7 +78,7 @@ def upgrade() -> None:
sa.column("label", sa.String), sa.column("label", sa.String),
sa.column("category", sa.String), sa.column("category", sa.String),
sa.column("is_system", sa.Boolean), sa.column("is_system", sa.Boolean),
sa.column("team_id", UUID(as_uuid=True)), sa.column("account_id", UUID(as_uuid=True)),
sa.column("sort_order", sa.Integer), sa.column("sort_order", sa.Integer),
) )
@@ -84,12 +89,44 @@ def upgrade() -> None:
"label": label, "label": label,
"category": category, "category": category,
"is_system": True, "is_system": True,
"team_id": None, "account_id": uuid.UUID(_PLATFORM_UUID),
"sort_order": sort_order, "sort_order": sort_order,
} }
for slug, label, category, sort_order in SYSTEM_DEVICE_TYPES for slug, label, category, sort_order in SYSTEM_DEVICE_TYPES
]) ])
op.execute("ALTER TABLE device_types ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE device_types FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY device_types_select ON device_types
FOR SELECT
USING (
account_id = {_CURRENT_ACCOUNT}
OR account_id = '{_PLATFORM_UUID}'::uuid
)
""")
op.execute(f"""
CREATE POLICY device_types_insert ON device_types
FOR INSERT
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
""")
op.execute(f"""
CREATE POLICY device_types_update ON device_types
FOR UPDATE
USING (account_id = {_CURRENT_ACCOUNT})
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
""")
op.execute(f"""
CREATE POLICY device_types_delete ON device_types
FOR DELETE
USING (account_id = {_CURRENT_ACCOUNT})
""")
def downgrade() -> None: def downgrade() -> None:
op.execute("DROP POLICY IF EXISTS device_types_delete ON device_types")
op.execute("DROP POLICY IF EXISTS device_types_update ON device_types")
op.execute("DROP POLICY IF EXISTS device_types_insert ON device_types")
op.execute("DROP POLICY IF EXISTS device_types_select ON device_types")
op.execute("ALTER TABLE device_types DISABLE ROW LEVEL SECURITY")
op.drop_table("device_types") op.drop_table("device_types")

View File

@@ -2,7 +2,7 @@
Revision ID: 074 Revision ID: 074
Revises: 073 Revises: 073
Create Date: 2026-04-04 Create Date: 2026-04-12
""" """
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@@ -14,12 +14,19 @@ down_revision = "073"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
_CURRENT_ACCOUNT = (
"COALESCE("
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
"'00000000-0000-0000-0000-000000000000'"
")::uuid"
)
def upgrade() -> None: def upgrade() -> None:
op.create_table( op.create_table(
"network_diagrams", "network_diagrams",
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column("team_id", UUID(as_uuid=True), sa.ForeignKey("teams.id", ondelete="CASCADE"), nullable=False), sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False),
sa.Column("name", sa.String(255), nullable=False), sa.Column("name", sa.String(255), nullable=False),
sa.Column("client_name", sa.String(255), nullable=True), sa.Column("client_name", sa.String(255), nullable=True),
sa.Column("asset_name", sa.String(255), nullable=True), sa.Column("asset_name", sa.String(255), nullable=True),
@@ -33,9 +40,18 @@ def upgrade() -> None:
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")), sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
) )
op.create_index("idx_network_diagrams_team", "network_diagrams", ["team_id"]) op.create_index("ix_network_diagrams_account_id", "network_diagrams", ["account_id"])
op.create_index("idx_network_diagrams_client", "network_diagrams", ["team_id", "client_name"]) op.create_index("idx_network_diagrams_account_client", "network_diagrams", ["account_id", "client_name"])
op.execute("ALTER TABLE network_diagrams ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE network_diagrams FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON network_diagrams
USING (account_id = {_CURRENT_ACCOUNT})
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
""")
def downgrade() -> None: def downgrade() -> None:
op.execute("DROP POLICY IF EXISTS tenant_isolation ON network_diagrams")
op.execute("ALTER TABLE network_diagrams DISABLE ROW LEVEL SECURITY")
op.drop_table("network_diagrams") op.drop_table("network_diagrams")

View File

@@ -15,6 +15,7 @@ from app.schemas.device_type import (
DeviceTypeUpdate, DeviceTypeUpdate,
DeviceTypeResponse, DeviceTypeResponse,
) )
from app.core.service_account import PLATFORM_ACCOUNT_ID
router = APIRouter(prefix="/device-types", tags=["device-types"]) router = APIRouter(prefix="/device-types", tags=["device-types"])
@@ -28,8 +29,8 @@ async def list_device_types(
select(DeviceType) select(DeviceType)
.where( .where(
or_( or_(
DeviceType.is_system.is_(True), DeviceType.account_id == PLATFORM_ACCOUNT_ID,
DeviceType.team_id == current_user.team_id, DeviceType.account_id == current_user.account_id,
) )
) )
.order_by(DeviceType.category, DeviceType.sort_order, DeviceType.label) .order_by(DeviceType.category, DeviceType.sort_order, DeviceType.label)
@@ -48,16 +49,16 @@ async def create_device_type(
existing = await db.execute( existing = await db.execute(
select(DeviceType).where( select(DeviceType).where(
DeviceType.slug == data.slug, DeviceType.slug == data.slug,
DeviceType.team_id == current_user.team_id, DeviceType.account_id == current_user.account_id,
) )
) )
if existing.scalar_one_or_none(): if existing.scalar_one_or_none():
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your team") raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your account")
system_existing = await db.execute( system_existing = await db.execute(
select(DeviceType).where( select(DeviceType).where(
DeviceType.slug == data.slug, DeviceType.slug == data.slug,
DeviceType.is_system.is_(True), DeviceType.account_id == PLATFORM_ACCOUNT_ID,
) )
) )
if system_existing.scalar_one_or_none(): if system_existing.scalar_one_or_none():
@@ -68,7 +69,7 @@ async def create_device_type(
label=data.label, label=data.label,
category=data.category, category=data.category,
is_system=False, is_system=False,
team_id=current_user.team_id, account_id=current_user.account_id,
sort_order=data.sort_order, sort_order=data.sort_order,
) )
db.add(device_type) db.add(device_type)
@@ -89,7 +90,7 @@ async def update_device_type(
raise HTTPException(status_code=404, detail="Device type not found") raise HTTPException(status_code=404, detail="Device type not found")
if device_type.is_system: if device_type.is_system:
raise HTTPException(status_code=403, detail="Cannot modify system device types") raise HTTPException(status_code=403, detail="Cannot modify system device types")
if device_type.team_id != current_user.team_id: if device_type.account_id != current_user.account_id:
raise HTTPException(status_code=404, detail="Device type not found") raise HTTPException(status_code=404, detail="Device type not found")
update_data = data.model_dump(exclude_unset=True) update_data = data.model_dump(exclude_unset=True)
@@ -112,7 +113,7 @@ async def delete_device_type(
raise HTTPException(status_code=404, detail="Device type not found") raise HTTPException(status_code=404, detail="Device type not found")
if device_type.is_system: if device_type.is_system:
raise HTTPException(status_code=403, detail="Cannot delete system device types") raise HTTPException(status_code=403, detail="Cannot delete system device types")
if device_type.team_id != current_user.team_id: if device_type.account_id != current_user.account_id:
raise HTTPException(status_code=404, detail="Device type not found") raise HTTPException(status_code=404, detail="Device type not found")
await db.delete(device_type) await db.delete(device_type)

View File

@@ -13,6 +13,7 @@ from app.api.deps import get_current_active_user
from app.models.user import User from app.models.user import User
from app.models.device_type import DeviceType from app.models.device_type import DeviceType
from app.models.network_diagram import NetworkDiagram from app.models.network_diagram import NetworkDiagram
from app.core.service_account import PLATFORM_ACCOUNT_ID
from app.schemas.network_diagram import ( from app.schemas.network_diagram import (
NetworkDiagramCreate, NetworkDiagramCreate,
NetworkDiagramUpdate, NetworkDiagramUpdate,
@@ -49,11 +50,11 @@ router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"])
async def _get_diagram_or_404( async def _get_diagram_or_404(
diagram_id: UUID, diagram_id: UUID,
team_id: UUID, account_id: UUID,
db: AsyncSession, db: AsyncSession,
) -> NetworkDiagram: ) -> NetworkDiagram:
diagram = await db.get(NetworkDiagram, diagram_id) diagram = await db.get(NetworkDiagram, diagram_id)
if not diagram or diagram.team_id != team_id or diagram.is_archived: if not diagram or diagram.account_id != account_id or diagram.is_archived:
raise HTTPException(status_code=404, detail="Diagram not found") raise HTTPException(status_code=404, detail="Diagram not found")
return diagram return diagram
@@ -88,9 +89,12 @@ def _diagram_to_list_item(
) )
async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]: async def _get_available_slugs(account_id: UUID, db: AsyncSession) -> set[str]:
stmt = select(DeviceType.slug).where( stmt = select(DeviceType.slug).where(
or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id) or_(
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
DeviceType.account_id == account_id,
)
) )
result = await db.execute(stmt) result = await db.execute(stmt)
return {row[0] for row in result.all()} return {row[0] for row in result.all()}
@@ -104,7 +108,7 @@ async def list_client_names(
stmt = ( stmt = (
select(NetworkDiagram.client_name) select(NetworkDiagram.client_name)
.where( .where(
NetworkDiagram.team_id == current_user.team_id, NetworkDiagram.account_id == current_user.account_id,
NetworkDiagram.is_archived.is_(False), NetworkDiagram.is_archived.is_(False),
NetworkDiagram.client_name.isnot(None), NetworkDiagram.client_name.isnot(None),
NetworkDiagram.client_name != "", NetworkDiagram.client_name != "",
@@ -126,7 +130,7 @@ async def list_diagrams(
stmt = ( stmt = (
select(NetworkDiagram) select(NetworkDiagram)
.where( .where(
NetworkDiagram.team_id == current_user.team_id, NetworkDiagram.account_id == current_user.account_id,
NetworkDiagram.is_archived.is_(False), NetworkDiagram.is_archived.is_(False),
) )
.order_by(NetworkDiagram.updated_at.desc()) .order_by(NetworkDiagram.updated_at.desc())
@@ -148,7 +152,7 @@ async def list_diagrams(
# Single query for custom device types so category_counts is accurate # Single query for custom device types so category_counts is accurate
dt_stmt = select(DeviceType.slug, DeviceType.category).where( dt_stmt = select(DeviceType.slug, DeviceType.category).where(
DeviceType.is_system.is_(False), DeviceType.is_system.is_(False),
DeviceType.team_id == current_user.team_id, DeviceType.account_id == current_user.account_id,
) )
dt_result = await db.execute(dt_stmt) dt_result = await db.execute(dt_stmt)
custom_slug_category = {row[0]: row[1] for row in dt_result.all()} custom_slug_category = {row[0]: row[1] for row in dt_result.all()}
@@ -164,13 +168,8 @@ async def create_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse: ) -> NetworkDiagramResponse:
if current_user.team_id is None:
raise HTTPException(
status_code=422,
detail="Network Diagrams require a team account. Assign your account to a team first.",
)
diagram = NetworkDiagram( diagram = NetworkDiagram(
team_id=current_user.team_id, account_id=current_user.account_id,
name=data.name, name=data.name,
client_name=data.client_name, client_name=data.client_name,
asset_name=data.asset_name, asset_name=data.asset_name,
@@ -191,7 +190,7 @@ async def get_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse: ) -> NetworkDiagramResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
return _diagram_to_response(diagram) return _diagram_to_response(diagram)
@@ -202,7 +201,7 @@ async def update_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse: ) -> NetworkDiagramResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
update_data = data.model_dump(exclude_unset=True) update_data = data.model_dump(exclude_unset=True)
if "nodes" in update_data and update_data["nodes"] is not None: if "nodes" in update_data and update_data["nodes"] is not None:
@@ -225,7 +224,7 @@ async def archive_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> None: ) -> None:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
diagram.is_archived = True diagram.is_archived = True
diagram.updated_at = datetime.now(timezone.utc) diagram.updated_at = datetime.now(timezone.utc)
await db.commit() await db.commit()
@@ -237,9 +236,9 @@ async def duplicate_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse: ) -> NetworkDiagramResponse:
source = await _get_diagram_or_404(diagram_id, current_user.team_id, db) source = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
copy = NetworkDiagram( copy = NetworkDiagram(
team_id=current_user.team_id, account_id=current_user.account_id,
name=f"Copy of {source.name}", name=f"Copy of {source.name}",
client_name=source.client_name, client_name=source.client_name,
asset_name=source.asset_name, asset_name=source.asset_name,
@@ -260,7 +259,7 @@ async def export_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramExportResponse: ) -> DiagramExportResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db) diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
nodes = [DiagramNode(**n) for n in (diagram.nodes or [])] nodes = [DiagramNode(**n) for n in (diagram.nodes or [])]
edges = [DiagramEdge(**e) for e in (diagram.edges or [])] edges = [DiagramEdge(**e) for e in (diagram.edges or [])]
return DiagramExportResponse( return DiagramExportResponse(
@@ -280,7 +279,7 @@ async def import_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramImportResponse: ) -> DiagramImportResponse:
available_slugs = await _get_available_slugs(current_user.team_id, db) available_slugs = await _get_available_slugs(current_user.account_id, db)
warnings: list[str] = [] warnings: list[str] = []
for node in data.nodes: for node in data.nodes:
@@ -288,7 +287,7 @@ async def import_diagram(
warnings.append(f"Unknown device type '{node.type}' — will render with default icon") warnings.append(f"Unknown device type '{node.type}' — will render with default icon")
diagram = NetworkDiagram( diagram = NetworkDiagram(
team_id=current_user.team_id, account_id=current_user.account_id,
name=data.name, name=data.name,
client_name=data.client_name, client_name=data.client_name,
description=data.description, description=data.description,
@@ -312,7 +311,7 @@ async def ai_generate_diagram(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
) -> AIGenerateResponse: ) -> AIGenerateResponse:
available_slugs_set = await _get_available_slugs(current_user.team_id, db) available_slugs_set = await _get_available_slugs(current_user.account_id, db)
available_slugs = list(available_slugs_set) available_slugs = list(available_slugs_set)
existing_node_ids: list[str] | None = None existing_node_ids: list[str] | None = None

View File

@@ -10,7 +10,7 @@ from app.core.database import Base
class DeviceType(Base): class DeviceType(Base):
"""A device type for network diagram nodes (system or team-custom).""" """A device type for network diagram nodes (platform or account-custom)."""
__tablename__ = "device_types" __tablename__ = "device_types"
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
@@ -32,11 +32,11 @@ class DeviceType(Base):
Boolean, nullable=False, default=False, Boolean, nullable=False, default=False,
comment="True for built-in types that cannot be deleted", comment="True for built-in types that cannot be deleted",
) )
team_id: Mapped[uuid.UUID | None] = mapped_column( account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("teams.id", ondelete="CASCADE"), ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True, nullable=False,
comment="NULL for system types, set for team-custom types", comment="Platform account for system types, tenant account for custom types",
) )
sort_order: Mapped[int] = mapped_column( sort_order: Mapped[int] = mapped_column(
Integer, nullable=False, default=0, Integer, nullable=False, default=0,

View File

@@ -14,15 +14,15 @@ if TYPE_CHECKING:
class NetworkDiagram(Base): class NetworkDiagram(Base):
"""A network topology diagram, team-scoped.""" """A network topology diagram scoped to one account."""
__tablename__ = "network_diagrams" __tablename__ = "network_diagrams"
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
) )
team_id: Mapped[uuid.UUID] = mapped_column( account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("teams.id", ondelete="CASCADE"), ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False, nullable=False,
index=True, index=True,
) )

View File

@@ -30,7 +30,7 @@ class DeviceTypeResponse(BaseModel):
label: str label: str
category: str category: str
is_system: bool is_system: bool
team_id: UUID | None = None account_id: UUID
sort_order: int sort_order: int
created_at: datetime created_at: datetime

View File

@@ -61,7 +61,7 @@ class NetworkDiagramUpdate(BaseModel):
class NetworkDiagramResponse(BaseModel): class NetworkDiagramResponse(BaseModel):
id: UUID id: UUID
team_id: UUID account_id: UUID
name: str name: str
client_name: str | None = None client_name: str | None = None
asset_name: str | None = None asset_name: str | None = None

View File

@@ -0,0 +1,96 @@
from __future__ import annotations
import uuid
import pytest
from sqlalchemy import select
from app.models.device_type import DeviceType
from app.models.user import User
from app.core.service_account import PLATFORM_ACCOUNT_ID
async def _login_headers(client, email: str, password: str) -> dict[str, str]:
response = await client.post(
"/api/v1/auth/login/json",
json={"email": email, "password": password},
)
assert response.status_code == 200
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
@pytest.mark.asyncio
async def test_device_types_include_platform_and_account_custom(client, test_db, auth_headers, test_user):
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
user = result.scalar_one()
test_db.add(
DeviceType(
id=uuid.uuid4(),
slug="platform-router",
label="Platform Router",
category="network",
is_system=True,
account_id=PLATFORM_ACCOUNT_ID,
sort_order=0,
)
)
await test_db.commit()
create_response = await client.post(
"/api/v1/device-types/",
json={
"slug": "tenant-appliance",
"label": "Tenant Appliance",
"category": "network",
"sort_order": 3,
},
headers=auth_headers,
)
assert create_response.status_code == 201
assert create_response.json()["account_id"] == str(user.account_id)
list_response = await client.get("/api/v1/device-types/", headers=auth_headers)
assert list_response.status_code == 200
payload = list_response.json()
slugs = {item["slug"] for item in payload}
assert "platform-router" in slugs
assert "tenant-appliance" in slugs
@pytest.mark.asyncio
async def test_network_diagrams_are_account_scoped(client, test_db, auth_headers, test_user):
other_user = {
"email": "other-network@example.com",
"password": "TestPassword123!",
"name": "Other Network User",
}
register_response = await client.post("/api/v1/auth/register", json=other_user)
assert register_response.status_code in (200, 201)
other_headers = await _login_headers(client, other_user["email"], other_user["password"])
owner_result = await test_db.execute(select(User).where(User.email == test_user["email"]))
owner = owner_result.scalar_one()
create_response = await client.post(
"/api/v1/network-diagrams/",
json={
"name": "HQ Core",
"client_name": "Acme",
"description": "Primary topology",
"nodes": [],
"edges": [],
},
headers=auth_headers,
)
assert create_response.status_code == 201
diagram = create_response.json()
assert diagram["account_id"] == str(owner.account_id)
own_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=auth_headers)
assert own_get.status_code == 200
other_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=other_headers)
assert other_get.status_code == 404

View File

@@ -37,7 +37,7 @@ export interface DeviceTypeResponse {
label: string label: string
category: string category: string
is_system: boolean is_system: boolean
team_id: string | null account_id: string
sort_order: number sort_order: number
created_at: string created_at: string
} }
@@ -51,7 +51,7 @@ export interface DeviceTypeCreate {
export interface NetworkDiagramResponse { export interface NetworkDiagramResponse {
id: string id: string
team_id: string account_id: string
name: string name: string
client_name: string | null client_name: string | null
asset_name: string | null asset_name: string | null