diff --git a/backend/alembic/versions/04f013768235_enable_rls_phase3.py b/backend/alembic/versions/04f013768235_enable_rls_phase3.py new file mode 100644 index 00000000..aab26bef --- /dev/null +++ b/backend/alembic/versions/04f013768235_enable_rls_phase3.py @@ -0,0 +1,59 @@ +"""Enable RLS on Phase 3 tables. + +Tables covered: + - step_ratings (account_id NOT NULL since migration 7167e9374b0c) + - step_usage_log (account_id NOT NULL since migration 7167e9374b0c) + - target_lists (account_id NOT NULL since migration 2c6aabd89bc6) + - session_shares (account_id NOT NULL since session_share model) + - audit_logs (account_id NOT NULL since migration 2a9056eddd90) + - tree_shares (account_id NOT NULL since migration a05e1a1bea7c) + +All use a standard intra-tenant isolation policy. +Token-based access to session_shares and tree_shares goes through +endpoints that use get_admin_db (BYPASSRLS), so a strict tenant +policy here is correct. + +Revision ID: 04f013768235 +Revises: a05e1a1bea7c +Create Date: 2026-04-11 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op + +revision: str = '04f013768235' +down_revision: Union[str, None] = 'a05e1a1bea7c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +_CURRENT_ACCOUNT = ( + "COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), " + "'00000000-0000-0000-0000-000000000000')::uuid" +) + +_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}" + +_PHASE3_TABLES = [ + "step_ratings", + "step_usage_log", + "target_lists", + "session_shares", + "audit_logs", + "tree_shares", +] + + +def upgrade() -> None: + for table in _PHASE3_TABLES: + op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON {table} + USING ({_STANDARD_USING}) + """) + + +def downgrade() -> None: + for table in _PHASE3_TABLES: + op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}") + op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY") diff --git a/backend/alembic/versions/172ad76d7d20_drop_team_id_target_lists.py b/backend/alembic/versions/172ad76d7d20_drop_team_id_target_lists.py new file mode 100644 index 00000000..e565470d --- /dev/null +++ b/backend/alembic/versions/172ad76d7d20_drop_team_id_target_lists.py @@ -0,0 +1,32 @@ +"""Drop team_id from target_lists. + +account_id (NOT NULL) is now the tenant isolation key; team_id is redundant. +All reads/writes use account_id via RLS + application filter. + +Revision ID: 172ad76d7d20 +Revises: 04f013768235 +Create Date: 2026-04-11 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '172ad76d7d20' +down_revision: Union[str, None] = '04f013768235' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.drop_index('ix_target_lists_team_id', table_name='target_lists', if_exists=True) + op.drop_constraint('target_lists_team_id_fkey', 'target_lists', type_='foreignkey') + op.drop_column('target_lists', 'team_id') + + +def downgrade() -> None: + op.add_column('target_lists', sa.Column('team_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'target_lists_team_id_fkey', 'target_lists', 'teams', + ['team_id'], ['id'], ondelete='CASCADE', + ) + op.create_index('ix_target_lists_team_id', 'target_lists', ['team_id']) diff --git a/backend/alembic/versions/2a9056eddd90_add_account_id_audit_logs.py b/backend/alembic/versions/2a9056eddd90_add_account_id_audit_logs.py new file mode 100644 index 00000000..8978186e --- /dev/null +++ b/backend/alembic/versions/2a9056eddd90_add_account_id_audit_logs.py @@ -0,0 +1,51 @@ +"""Add account_id to audit_logs and backfill via user_id. + +Revision ID: 2a9056eddd90 +Revises: 70a5dd746e83 +Create Date: 2026-04-11 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '2a9056eddd90' +down_revision: Union[str, None] = '70a5dd746e83' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('audit_logs', sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_audit_logs_account_id', 'audit_logs', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Backfill: derive from the acting user's account + op.execute(""" + UPDATE audit_logs al + SET account_id = u.account_id + FROM users u + WHERE al.user_id = u.id + AND u.account_id IS NOT NULL + AND al.account_id IS NULL + """) + + result = op.get_bind().execute( + sa.text("SELECT COUNT(*) FROM audit_logs WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} audit_logs rows have NULL account_id after backfill. " + "All audit log entries must have an associated user with an account." + ) + + op.alter_column('audit_logs', 'account_id', nullable=False) + op.create_index('ix_audit_logs_account_id', 'audit_logs', ['account_id']) + + +def downgrade() -> None: + op.drop_index('ix_audit_logs_account_id', table_name='audit_logs') + op.drop_constraint('fk_audit_logs_account_id', 'audit_logs', type_='foreignkey') + op.drop_column('audit_logs', 'account_id') diff --git a/backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py b/backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py new file mode 100644 index 00000000..8d69aec3 --- /dev/null +++ b/backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py @@ -0,0 +1,51 @@ +"""Add account_id to tree_shares and backfill via created_by user. + +Revision ID: a05e1a1bea7c +Revises: 2a9056eddd90 +Create Date: 2026-04-11 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = 'a05e1a1bea7c' +down_revision: Union[str, None] = '2a9056eddd90' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('tree_shares', sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_tree_shares_account_id', 'tree_shares', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Backfill: derive from the creating user's account + op.execute(""" + UPDATE tree_shares ts + SET account_id = u.account_id + FROM users u + WHERE ts.created_by = u.id + AND u.account_id IS NOT NULL + AND ts.account_id IS NULL + """) + + result = op.get_bind().execute( + sa.text("SELECT COUNT(*) FROM tree_shares WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} tree_shares rows have NULL account_id after backfill. " + "All share entries must have a creating user with an account." + ) + + op.alter_column('tree_shares', 'account_id', nullable=False) + op.create_index('ix_tree_shares_account_id', 'tree_shares', ['account_id']) + + +def downgrade() -> None: + op.drop_index('ix_tree_shares_account_id', table_name='tree_shares') + op.drop_constraint('fk_tree_shares_account_id', 'tree_shares', type_='foreignkey') + op.drop_column('tree_shares', 'account_id') diff --git a/backend/app/api/endpoints/target_lists.py b/backend/app/api/endpoints/target_lists.py index 0bfac439..82e20a55 100644 --- a/backend/app/api/endpoints/target_lists.py +++ b/backend/app/api/endpoints/target_lists.py @@ -18,12 +18,10 @@ async def list_target_lists( current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ): - """List all target lists for the current user's team.""" - if not current_user.team_id: - return [] + """List all target lists for the current user's account.""" result = await db.execute( select(TargetList) - .where(TargetList.team_id == current_user.team_id) + .where(TargetList.account_id == current_user.account_id) .order_by(TargetList.name) ) return result.scalars().all() @@ -36,11 +34,9 @@ async def create_target_list( db: Annotated[AsyncSession, Depends(get_db)], _: None = Depends(require_engineer_or_admin), ): - """Create a new target list for the current team.""" - if not current_user.team_id: - raise HTTPException(status_code=400, detail="User must belong to a team") + """Create a new target list for the current account.""" target_list = TargetList( - team_id=current_user.team_id, + account_id=current_user.account_id, created_by=current_user.id, name=data.name, description=data.description, @@ -61,7 +57,7 @@ async def get_target_list( result = await db.execute( select(TargetList).where( TargetList.id == list_id, - TargetList.team_id == current_user.team_id, + TargetList.account_id == current_user.account_id, ) ) target_list = result.scalar_one_or_none() @@ -81,7 +77,7 @@ async def update_target_list( result = await db.execute( select(TargetList).where( TargetList.id == list_id, - TargetList.team_id == current_user.team_id, + TargetList.account_id == current_user.account_id, ) ) target_list = result.scalar_one_or_none() @@ -91,7 +87,7 @@ async def update_target_list( if "name" in update_fields and data.name is not None: target_list.name = data.name if "description" in update_fields: - target_list.description = data.description # allow setting to None + target_list.description = data.description if "targets" in update_fields and data.targets is not None: target_list.targets = [t.model_dump() for t in data.targets] await db.commit() @@ -109,7 +105,7 @@ async def delete_target_list( result = await db.execute( select(TargetList).where( TargetList.id == list_id, - TargetList.team_id == current_user.team_id, + TargetList.account_id == current_user.account_id, ) ) target_list = result.scalar_one_or_none() diff --git a/backend/app/api/endpoints/trees.py b/backend/app/api/endpoints/trees.py index 6a16297e..fe5246c7 100644 --- a/backend/app/api/endpoints/trees.py +++ b/backend/app/api/endpoints/trees.py @@ -1048,6 +1048,7 @@ async def create_tree_share( # Create share tree_share = TreeShare( tree_id=tree.id, + account_id=current_user.account_id, share_token=share_token, created_by=current_user.id, allow_forking=share_data.allow_forking, diff --git a/backend/app/core/audit.py b/backend/app/core/audit.py index 58ecd620..b5640e0a 100644 --- a/backend/app/core/audit.py +++ b/backend/app/core/audit.py @@ -12,10 +12,19 @@ async def log_audit( resource_type: str, resource_id: Optional[UUID] = None, details: Optional[dict] = None, + account_id: Optional[UUID] = None, ) -> None: """Record an audit log entry. Does not commit — piggybacks on the caller's commit.""" + if account_id is None: + # Derive from the acting user's account as a fallback (one extra query). + from sqlalchemy import select + from app.models.user import User + result = await db.execute(select(User.account_id).where(User.id == user_id)) + account_id = result.scalar_one() + entry = AuditLog( user_id=user_id, + account_id=account_id, action=action, resource_type=resource_type, resource_id=resource_id, diff --git a/backend/app/models/audit_log.py b/backend/app/models/audit_log.py index 25fa669e..0e795222 100644 --- a/backend/app/models/audit_log.py +++ b/backend/app/models/audit_log.py @@ -21,6 +21,12 @@ class AuditLog(Base): nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True + ) action: Mapped[str] = mapped_column(String(50), nullable=False, index=True) resource_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True) resource_id: Mapped[Optional[uuid.UUID]] = mapped_column( diff --git a/backend/app/models/target_list.py b/backend/app/models/target_list.py index b1169d72..d1b64c0d 100644 --- a/backend/app/models/target_list.py +++ b/backend/app/models/target_list.py @@ -8,7 +8,6 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.user import User - from app.models.team import Team from app.models.account import Account @@ -18,10 +17,6 @@ class TargetList(Base): id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 ) - team_id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), - nullable=False, index=True - ) account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), diff --git a/backend/app/models/tree_share.py b/backend/app/models/tree_share.py index d5e42ad1..4af919aa 100644 --- a/backend/app/models/tree_share.py +++ b/backend/app/models/tree_share.py @@ -25,6 +25,12 @@ class TreeShare(Base): nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True + ) share_token: Mapped[str] = mapped_column( String(64), unique=True, diff --git a/backend/app/schemas/target_list.py b/backend/app/schemas/target_list.py index 0016d393..f7d39b59 100644 --- a/backend/app/schemas/target_list.py +++ b/backend/app/schemas/target_list.py @@ -23,7 +23,7 @@ class TargetListUpdate(BaseModel): class TargetListResponse(BaseModel): id: UUID - team_id: UUID + account_id: UUID created_by: Optional[UUID] name: str description: Optional[str] diff --git a/backend/tests/test_rls_isolation.py b/backend/tests/test_rls_isolation.py index 1934fdee..b083eda8 100644 --- a/backend/tests/test_rls_isolation.py +++ b/backend/tests/test_rls_isolation.py @@ -708,3 +708,249 @@ async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn, await admin_conn.execute( f"DELETE FROM step_library WHERE id = '{public_step_id}'" ) + + +# =========================================================================== +# Phase 3 RLS isolation tests +# Tables: step_ratings, step_usage_log, target_lists, +# session_shares, audit_logs, tree_shares +# =========================================================================== + +# --------------------------------------------------------------------------- +# Helpers shared by Phase 3 fixtures +# --------------------------------------------------------------------------- + +async def _get_user_b_id(admin_conn) -> str: + row = await admin_conn.fetchrow( + "SELECT id FROM users WHERE email = 'rls-user-b@example.com'" + ) + return str(row["id"]) + + +async def _get_tree_b_id(admin_conn) -> str: + row = await admin_conn.fetchrow( + f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1" + ) + return str(row["id"]) + + +# --------------------------------------------------------------------------- +# step_ratings +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see step ratings belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + + # Need a step_library row as FK target + step_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_library ( + id, account_id, title, step_type, content, + visibility, is_active, created_at, updated_at + ) VALUES ( + '{step_id}', '{ACCOUNT_B_ID}', 'Phase3 RLS Step', 'action', + '{{}}'::jsonb, 'private', TRUE, NOW(), NOW() + ) + """) + + rating_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_ratings ( + id, step_id, user_id, account_id, is_verified_use, is_visible, + created_at, updated_at + ) VALUES ( + '{rating_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + FALSE, TRUE, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM step_ratings WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B step_ratings" + finally: + await admin_conn.execute(f"DELETE FROM step_ratings WHERE id = '{rating_id}'") + await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'") + + +# --------------------------------------------------------------------------- +# step_usage_log +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see step usage logs belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + tree_b_id = await _get_tree_b_id(admin_conn) + + step_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_library ( + id, account_id, title, step_type, content, + visibility, is_active, created_at, updated_at + ) VALUES ( + '{step_id}', '{ACCOUNT_B_ID}', 'Phase3 Usage Step', 'action', + '{{}}'::jsonb, 'private', TRUE, NOW(), NOW() + ) + """) + + # Need a sessions row as FK for usage log + session_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO sessions ( + id, tree_id, user_id, account_id, tree_snapshot, + path_taken, decisions, custom_steps, started_at + ) VALUES ( + '{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW() + ) + """) + + log_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_usage_log ( + id, step_id, user_id, account_id, session_id, used_at + ) VALUES ( + '{log_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + '{session_id}', NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM step_usage_log WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B step_usage_log" + finally: + await admin_conn.execute(f"DELETE FROM step_usage_log WHERE id = '{log_id}'") + await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'") + await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'") + + +# --------------------------------------------------------------------------- +# target_lists +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see target lists belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + + tl_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO target_lists ( + id, account_id, created_by, name, targets, created_at, updated_at + ) VALUES ( + '{tl_id}', '{ACCOUNT_B_ID}', '{user_b_id}', + 'Phase3 RLS Target List', '[]'::jsonb, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM target_lists WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B target_lists" + finally: + await admin_conn.execute(f"DELETE FROM target_lists WHERE id = '{tl_id}'") + + +# --------------------------------------------------------------------------- +# session_shares +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see session shares belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + tree_b_id = await _get_tree_b_id(admin_conn) + + # Need a sessions row as FK + session_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO sessions ( + id, tree_id, user_id, account_id, tree_snapshot, + path_taken, decisions, custom_steps, started_at + ) VALUES ( + '{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW() + ) + """) + + share_id = str(uuid.uuid4()) + share_token = f"phase3-rls-test-{share_id[:8]}" + await admin_conn.execute(f""" + INSERT INTO session_shares ( + id, session_id, account_id, share_token, visibility, + created_by, view_count, is_active, created_at, updated_at + ) VALUES ( + '{share_id}', '{session_id}', '{ACCOUNT_B_ID}', + '{share_token}', 'account', '{user_b_id}', + 0, TRUE, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM session_shares WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B session_shares" + finally: + await admin_conn.execute(f"DELETE FROM session_shares WHERE id = '{share_id}'") + await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'") + + +# --------------------------------------------------------------------------- +# audit_logs +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see audit logs belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + + log_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO audit_logs ( + id, user_id, account_id, action, resource_type, created_at + ) VALUES ( + '{log_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + 'test.action', 'test_resource', NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM audit_logs WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B audit_logs" + finally: + await admin_conn.execute(f"DELETE FROM audit_logs WHERE id = '{log_id}'") + + +# --------------------------------------------------------------------------- +# tree_shares +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see tree shares belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + tree_b_id = await _get_tree_b_id(admin_conn) + + share_id = str(uuid.uuid4()) + share_token = f"phase3-tree-rls-{share_id[:8]}" + await admin_conn.execute(f""" + INSERT INTO tree_shares ( + id, tree_id, account_id, share_token, created_by, + allow_forking, created_at + ) VALUES ( + '{share_id}', '{tree_b_id}', '{ACCOUNT_B_ID}', + '{share_token}', '{user_b_id}', TRUE, NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM tree_shares WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B tree_shares" + finally: + await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'") diff --git a/backend/tests/test_target_lists.py b/backend/tests/test_target_lists.py index a40cfb48..1e4bca49 100644 --- a/backend/tests/test_target_lists.py +++ b/backend/tests/test_target_lists.py @@ -3,37 +3,10 @@ import pytest from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession -from app.models.team import Team from app.models.user import User from sqlalchemy import select -@pytest.fixture -async def auth_headers(client: AsyncClient, test_db: AsyncSession, test_user: dict): - """Override auth_headers to ensure the test user has a team_id assigned.""" - # Fetch the user from DB and assign a team - result = await test_db.execute(select(User).where(User.email == test_user["email"])) - user = result.scalar_one() - - # Create a team and assign the user to it - team = Team(name="Test Team") - test_db.add(team) - await test_db.flush() - - user.team_id = team.id - await test_db.commit() - - # Re-login to get a fresh token - login_data = { - "email": test_user["email"], - "password": test_user["password"], - } - resp = await client.post("/api/v1/auth/login/json", json=login_data) - assert resp.status_code == 200 - token_data = resp.json() - return {"Authorization": f"Bearer {token_data['access_token']}"} - - @pytest.mark.asyncio async def test_create_target_list(client: AsyncClient, auth_headers: dict): resp = await client.post( @@ -107,25 +80,28 @@ async def test_delete_target_list(client: AsyncClient, auth_headers: dict): assert get.status_code == 404 @pytest.mark.asyncio -async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: dict, test_db): - """User from team B cannot access team A's list.""" +async def test_cannot_access_other_accounts_list(client: AsyncClient, auth_headers: dict, test_db): + """User from account B cannot access account A's target list.""" import uuid - from app.models.team import Team + from app.models.account import Account from app.models.user import User from app.core.security import get_password_hash - # Create team A list using existing auth_headers + # Create account A list using existing auth_headers create = await client.post( "/api/v1/target-lists/", - json={"name": "Team A List", "targets": [{"label": "SRV-A"}]}, + json={"name": "Account A List", "targets": [{"label": "SRV-A"}]}, headers=auth_headers, ) assert create.status_code == 201 list_id = create.json()["id"] - # Create a separate team B with its own user - team_b = Team(name=f"Team B {uuid.uuid4()}") - test_db.add(team_b) + # Create a separate account B with its own user + account_b = Account( + name=f"Account B {uuid.uuid4()}", + display_code=f"AB{str(uuid.uuid4())[:6].upper()}", + ) + test_db.add(account_b) await test_db.flush() user_b = User( @@ -133,11 +109,13 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: password_hash=get_password_hash("password123"), name="User B", is_active=True, - team_id=team_b.id, + account_id=account_b.id, + account_role="engineer", role="engineer", ) test_db.add(user_b) await test_db.flush() + await test_db.commit() # Get auth token for user B login = await client.post( @@ -148,6 +126,6 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: token_b = login.json()["access_token"] headers_b = {"Authorization": f"Bearer {token_b}"} - # Team B cannot access Team A's list + # Account B cannot access Account A's list resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b) assert resp.status_code == 404