From e05472615b9aea60933fdb64b8d42f7988961d14 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Sat, 11 Apr 2026 05:02:43 +0000 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20tenant=20isolation=20Phase=203=20?= =?UTF-8?q?=E2=80=94=20audit=5Flogs,=20tree=5Fshares,=20remaining=20RLS?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P3-A: Add account_id to audit_logs model + migration (backfill via user_id → users.account_id). log_audit() gains optional account_id param with fallback SELECT to avoid churn across 40 call sites. P3-B: Add account_id to tree_shares model + migration (backfill via created_by → users.account_id). TreeShare constructor updated in trees.py. P3-C: Enable RLS on 6 remaining tables: step_ratings, step_usage_log, target_lists, session_shares, audit_logs, tree_shares. P3-D: Drop team_id from target_lists — endpoint, schema, and model now use account_id as the sole isolation key. P3-E: Append Phase 3 RLS isolation tests for all 6 tables. test_target_lists.py: fix cross-account test to use Account model (not Team) and set account_id on new User. Co-Authored-By: Claude Sonnet 4.6 --- .../04f013768235_enable_rls_phase3.py | 59 +++++ .../172ad76d7d20_drop_team_id_target_lists.py | 32 +++ .../2a9056eddd90_add_account_id_audit_logs.py | 51 ++++ ...a05e1a1bea7c_add_account_id_tree_shares.py | 51 ++++ backend/app/api/endpoints/target_lists.py | 20 +- backend/app/api/endpoints/trees.py | 1 + backend/app/core/audit.py | 9 + backend/app/models/audit_log.py | 6 + backend/app/models/target_list.py | 5 - backend/app/models/tree_share.py | 6 + backend/app/schemas/target_list.py | 2 +- backend/tests/test_rls_isolation.py | 246 ++++++++++++++++++ backend/tests/test_target_lists.py | 52 ++-- 13 files changed, 485 insertions(+), 55 deletions(-) create mode 100644 backend/alembic/versions/04f013768235_enable_rls_phase3.py create mode 100644 backend/alembic/versions/172ad76d7d20_drop_team_id_target_lists.py create mode 100644 backend/alembic/versions/2a9056eddd90_add_account_id_audit_logs.py create mode 100644 backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py diff --git a/backend/alembic/versions/04f013768235_enable_rls_phase3.py b/backend/alembic/versions/04f013768235_enable_rls_phase3.py new file mode 100644 index 00000000..aab26bef --- /dev/null +++ b/backend/alembic/versions/04f013768235_enable_rls_phase3.py @@ -0,0 +1,59 @@ +"""Enable RLS on Phase 3 tables. + +Tables covered: + - step_ratings (account_id NOT NULL since migration 7167e9374b0c) + - step_usage_log (account_id NOT NULL since migration 7167e9374b0c) + - target_lists (account_id NOT NULL since migration 2c6aabd89bc6) + - session_shares (account_id NOT NULL since session_share model) + - audit_logs (account_id NOT NULL since migration 2a9056eddd90) + - tree_shares (account_id NOT NULL since migration a05e1a1bea7c) + +All use a standard intra-tenant isolation policy. +Token-based access to session_shares and tree_shares goes through +endpoints that use get_admin_db (BYPASSRLS), so a strict tenant +policy here is correct. + +Revision ID: 04f013768235 +Revises: a05e1a1bea7c +Create Date: 2026-04-11 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op + +revision: str = '04f013768235' +down_revision: Union[str, None] = 'a05e1a1bea7c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +_CURRENT_ACCOUNT = ( + "COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), " + "'00000000-0000-0000-0000-000000000000')::uuid" +) + +_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}" + +_PHASE3_TABLES = [ + "step_ratings", + "step_usage_log", + "target_lists", + "session_shares", + "audit_logs", + "tree_shares", +] + + +def upgrade() -> None: + for table in _PHASE3_TABLES: + op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON {table} + USING ({_STANDARD_USING}) + """) + + +def downgrade() -> None: + for table in _PHASE3_TABLES: + op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}") + op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY") diff --git a/backend/alembic/versions/172ad76d7d20_drop_team_id_target_lists.py b/backend/alembic/versions/172ad76d7d20_drop_team_id_target_lists.py new file mode 100644 index 00000000..e565470d --- /dev/null +++ b/backend/alembic/versions/172ad76d7d20_drop_team_id_target_lists.py @@ -0,0 +1,32 @@ +"""Drop team_id from target_lists. + +account_id (NOT NULL) is now the tenant isolation key; team_id is redundant. +All reads/writes use account_id via RLS + application filter. + +Revision ID: 172ad76d7d20 +Revises: 04f013768235 +Create Date: 2026-04-11 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '172ad76d7d20' +down_revision: Union[str, None] = '04f013768235' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.drop_index('ix_target_lists_team_id', table_name='target_lists', if_exists=True) + op.drop_constraint('target_lists_team_id_fkey', 'target_lists', type_='foreignkey') + op.drop_column('target_lists', 'team_id') + + +def downgrade() -> None: + op.add_column('target_lists', sa.Column('team_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'target_lists_team_id_fkey', 'target_lists', 'teams', + ['team_id'], ['id'], ondelete='CASCADE', + ) + op.create_index('ix_target_lists_team_id', 'target_lists', ['team_id']) diff --git a/backend/alembic/versions/2a9056eddd90_add_account_id_audit_logs.py b/backend/alembic/versions/2a9056eddd90_add_account_id_audit_logs.py new file mode 100644 index 00000000..8978186e --- /dev/null +++ b/backend/alembic/versions/2a9056eddd90_add_account_id_audit_logs.py @@ -0,0 +1,51 @@ +"""Add account_id to audit_logs and backfill via user_id. + +Revision ID: 2a9056eddd90 +Revises: 70a5dd746e83 +Create Date: 2026-04-11 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '2a9056eddd90' +down_revision: Union[str, None] = '70a5dd746e83' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('audit_logs', sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_audit_logs_account_id', 'audit_logs', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Backfill: derive from the acting user's account + op.execute(""" + UPDATE audit_logs al + SET account_id = u.account_id + FROM users u + WHERE al.user_id = u.id + AND u.account_id IS NOT NULL + AND al.account_id IS NULL + """) + + result = op.get_bind().execute( + sa.text("SELECT COUNT(*) FROM audit_logs WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} audit_logs rows have NULL account_id after backfill. " + "All audit log entries must have an associated user with an account." + ) + + op.alter_column('audit_logs', 'account_id', nullable=False) + op.create_index('ix_audit_logs_account_id', 'audit_logs', ['account_id']) + + +def downgrade() -> None: + op.drop_index('ix_audit_logs_account_id', table_name='audit_logs') + op.drop_constraint('fk_audit_logs_account_id', 'audit_logs', type_='foreignkey') + op.drop_column('audit_logs', 'account_id') diff --git a/backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py b/backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py new file mode 100644 index 00000000..8d69aec3 --- /dev/null +++ b/backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py @@ -0,0 +1,51 @@ +"""Add account_id to tree_shares and backfill via created_by user. + +Revision ID: a05e1a1bea7c +Revises: 2a9056eddd90 +Create Date: 2026-04-11 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = 'a05e1a1bea7c' +down_revision: Union[str, None] = '2a9056eddd90' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('tree_shares', sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_tree_shares_account_id', 'tree_shares', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Backfill: derive from the creating user's account + op.execute(""" + UPDATE tree_shares ts + SET account_id = u.account_id + FROM users u + WHERE ts.created_by = u.id + AND u.account_id IS NOT NULL + AND ts.account_id IS NULL + """) + + result = op.get_bind().execute( + sa.text("SELECT COUNT(*) FROM tree_shares WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} tree_shares rows have NULL account_id after backfill. " + "All share entries must have a creating user with an account." + ) + + op.alter_column('tree_shares', 'account_id', nullable=False) + op.create_index('ix_tree_shares_account_id', 'tree_shares', ['account_id']) + + +def downgrade() -> None: + op.drop_index('ix_tree_shares_account_id', table_name='tree_shares') + op.drop_constraint('fk_tree_shares_account_id', 'tree_shares', type_='foreignkey') + op.drop_column('tree_shares', 'account_id') diff --git a/backend/app/api/endpoints/target_lists.py b/backend/app/api/endpoints/target_lists.py index 0bfac439..82e20a55 100644 --- a/backend/app/api/endpoints/target_lists.py +++ b/backend/app/api/endpoints/target_lists.py @@ -18,12 +18,10 @@ async def list_target_lists( current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ): - """List all target lists for the current user's team.""" - if not current_user.team_id: - return [] + """List all target lists for the current user's account.""" result = await db.execute( select(TargetList) - .where(TargetList.team_id == current_user.team_id) + .where(TargetList.account_id == current_user.account_id) .order_by(TargetList.name) ) return result.scalars().all() @@ -36,11 +34,9 @@ async def create_target_list( db: Annotated[AsyncSession, Depends(get_db)], _: None = Depends(require_engineer_or_admin), ): - """Create a new target list for the current team.""" - if not current_user.team_id: - raise HTTPException(status_code=400, detail="User must belong to a team") + """Create a new target list for the current account.""" target_list = TargetList( - team_id=current_user.team_id, + account_id=current_user.account_id, created_by=current_user.id, name=data.name, description=data.description, @@ -61,7 +57,7 @@ async def get_target_list( result = await db.execute( select(TargetList).where( TargetList.id == list_id, - TargetList.team_id == current_user.team_id, + TargetList.account_id == current_user.account_id, ) ) target_list = result.scalar_one_or_none() @@ -81,7 +77,7 @@ async def update_target_list( result = await db.execute( select(TargetList).where( TargetList.id == list_id, - TargetList.team_id == current_user.team_id, + TargetList.account_id == current_user.account_id, ) ) target_list = result.scalar_one_or_none() @@ -91,7 +87,7 @@ async def update_target_list( if "name" in update_fields and data.name is not None: target_list.name = data.name if "description" in update_fields: - target_list.description = data.description # allow setting to None + target_list.description = data.description if "targets" in update_fields and data.targets is not None: target_list.targets = [t.model_dump() for t in data.targets] await db.commit() @@ -109,7 +105,7 @@ async def delete_target_list( result = await db.execute( select(TargetList).where( TargetList.id == list_id, - TargetList.team_id == current_user.team_id, + TargetList.account_id == current_user.account_id, ) ) target_list = result.scalar_one_or_none() diff --git a/backend/app/api/endpoints/trees.py b/backend/app/api/endpoints/trees.py index 6a16297e..fe5246c7 100644 --- a/backend/app/api/endpoints/trees.py +++ b/backend/app/api/endpoints/trees.py @@ -1048,6 +1048,7 @@ async def create_tree_share( # Create share tree_share = TreeShare( tree_id=tree.id, + account_id=current_user.account_id, share_token=share_token, created_by=current_user.id, allow_forking=share_data.allow_forking, diff --git a/backend/app/core/audit.py b/backend/app/core/audit.py index 58ecd620..b5640e0a 100644 --- a/backend/app/core/audit.py +++ b/backend/app/core/audit.py @@ -12,10 +12,19 @@ async def log_audit( resource_type: str, resource_id: Optional[UUID] = None, details: Optional[dict] = None, + account_id: Optional[UUID] = None, ) -> None: """Record an audit log entry. Does not commit — piggybacks on the caller's commit.""" + if account_id is None: + # Derive from the acting user's account as a fallback (one extra query). + from sqlalchemy import select + from app.models.user import User + result = await db.execute(select(User.account_id).where(User.id == user_id)) + account_id = result.scalar_one() + entry = AuditLog( user_id=user_id, + account_id=account_id, action=action, resource_type=resource_type, resource_id=resource_id, diff --git a/backend/app/models/audit_log.py b/backend/app/models/audit_log.py index 25fa669e..0e795222 100644 --- a/backend/app/models/audit_log.py +++ b/backend/app/models/audit_log.py @@ -21,6 +21,12 @@ class AuditLog(Base): nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True + ) action: Mapped[str] = mapped_column(String(50), nullable=False, index=True) resource_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True) resource_id: Mapped[Optional[uuid.UUID]] = mapped_column( diff --git a/backend/app/models/target_list.py b/backend/app/models/target_list.py index b1169d72..d1b64c0d 100644 --- a/backend/app/models/target_list.py +++ b/backend/app/models/target_list.py @@ -8,7 +8,6 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.user import User - from app.models.team import Team from app.models.account import Account @@ -18,10 +17,6 @@ class TargetList(Base): id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 ) - team_id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), - nullable=False, index=True - ) account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), diff --git a/backend/app/models/tree_share.py b/backend/app/models/tree_share.py index d5e42ad1..4af919aa 100644 --- a/backend/app/models/tree_share.py +++ b/backend/app/models/tree_share.py @@ -25,6 +25,12 @@ class TreeShare(Base): nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True + ) share_token: Mapped[str] = mapped_column( String(64), unique=True, diff --git a/backend/app/schemas/target_list.py b/backend/app/schemas/target_list.py index 0016d393..f7d39b59 100644 --- a/backend/app/schemas/target_list.py +++ b/backend/app/schemas/target_list.py @@ -23,7 +23,7 @@ class TargetListUpdate(BaseModel): class TargetListResponse(BaseModel): id: UUID - team_id: UUID + account_id: UUID created_by: Optional[UUID] name: str description: Optional[str] diff --git a/backend/tests/test_rls_isolation.py b/backend/tests/test_rls_isolation.py index 1934fdee..b083eda8 100644 --- a/backend/tests/test_rls_isolation.py +++ b/backend/tests/test_rls_isolation.py @@ -708,3 +708,249 @@ async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn, await admin_conn.execute( f"DELETE FROM step_library WHERE id = '{public_step_id}'" ) + + +# =========================================================================== +# Phase 3 RLS isolation tests +# Tables: step_ratings, step_usage_log, target_lists, +# session_shares, audit_logs, tree_shares +# =========================================================================== + +# --------------------------------------------------------------------------- +# Helpers shared by Phase 3 fixtures +# --------------------------------------------------------------------------- + +async def _get_user_b_id(admin_conn) -> str: + row = await admin_conn.fetchrow( + "SELECT id FROM users WHERE email = 'rls-user-b@example.com'" + ) + return str(row["id"]) + + +async def _get_tree_b_id(admin_conn) -> str: + row = await admin_conn.fetchrow( + f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1" + ) + return str(row["id"]) + + +# --------------------------------------------------------------------------- +# step_ratings +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see step ratings belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + + # Need a step_library row as FK target + step_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_library ( + id, account_id, title, step_type, content, + visibility, is_active, created_at, updated_at + ) VALUES ( + '{step_id}', '{ACCOUNT_B_ID}', 'Phase3 RLS Step', 'action', + '{{}}'::jsonb, 'private', TRUE, NOW(), NOW() + ) + """) + + rating_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_ratings ( + id, step_id, user_id, account_id, is_verified_use, is_visible, + created_at, updated_at + ) VALUES ( + '{rating_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + FALSE, TRUE, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM step_ratings WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B step_ratings" + finally: + await admin_conn.execute(f"DELETE FROM step_ratings WHERE id = '{rating_id}'") + await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'") + + +# --------------------------------------------------------------------------- +# step_usage_log +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see step usage logs belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + tree_b_id = await _get_tree_b_id(admin_conn) + + step_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_library ( + id, account_id, title, step_type, content, + visibility, is_active, created_at, updated_at + ) VALUES ( + '{step_id}', '{ACCOUNT_B_ID}', 'Phase3 Usage Step', 'action', + '{{}}'::jsonb, 'private', TRUE, NOW(), NOW() + ) + """) + + # Need a sessions row as FK for usage log + session_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO sessions ( + id, tree_id, user_id, account_id, tree_snapshot, + path_taken, decisions, custom_steps, started_at + ) VALUES ( + '{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW() + ) + """) + + log_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_usage_log ( + id, step_id, user_id, account_id, session_id, used_at + ) VALUES ( + '{log_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + '{session_id}', NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM step_usage_log WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B step_usage_log" + finally: + await admin_conn.execute(f"DELETE FROM step_usage_log WHERE id = '{log_id}'") + await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'") + await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'") + + +# --------------------------------------------------------------------------- +# target_lists +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see target lists belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + + tl_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO target_lists ( + id, account_id, created_by, name, targets, created_at, updated_at + ) VALUES ( + '{tl_id}', '{ACCOUNT_B_ID}', '{user_b_id}', + 'Phase3 RLS Target List', '[]'::jsonb, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM target_lists WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B target_lists" + finally: + await admin_conn.execute(f"DELETE FROM target_lists WHERE id = '{tl_id}'") + + +# --------------------------------------------------------------------------- +# session_shares +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see session shares belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + tree_b_id = await _get_tree_b_id(admin_conn) + + # Need a sessions row as FK + session_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO sessions ( + id, tree_id, user_id, account_id, tree_snapshot, + path_taken, decisions, custom_steps, started_at + ) VALUES ( + '{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW() + ) + """) + + share_id = str(uuid.uuid4()) + share_token = f"phase3-rls-test-{share_id[:8]}" + await admin_conn.execute(f""" + INSERT INTO session_shares ( + id, session_id, account_id, share_token, visibility, + created_by, view_count, is_active, created_at, updated_at + ) VALUES ( + '{share_id}', '{session_id}', '{ACCOUNT_B_ID}', + '{share_token}', 'account', '{user_b_id}', + 0, TRUE, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM session_shares WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B session_shares" + finally: + await admin_conn.execute(f"DELETE FROM session_shares WHERE id = '{share_id}'") + await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'") + + +# --------------------------------------------------------------------------- +# audit_logs +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see audit logs belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + + log_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO audit_logs ( + id, user_id, account_id, action, resource_type, created_at + ) VALUES ( + '{log_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + 'test.action', 'test_resource', NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM audit_logs WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B audit_logs" + finally: + await admin_conn.execute(f"DELETE FROM audit_logs WHERE id = '{log_id}'") + + +# --------------------------------------------------------------------------- +# tree_shares +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see tree shares belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + tree_b_id = await _get_tree_b_id(admin_conn) + + share_id = str(uuid.uuid4()) + share_token = f"phase3-tree-rls-{share_id[:8]}" + await admin_conn.execute(f""" + INSERT INTO tree_shares ( + id, tree_id, account_id, share_token, created_by, + allow_forking, created_at + ) VALUES ( + '{share_id}', '{tree_b_id}', '{ACCOUNT_B_ID}', + '{share_token}', '{user_b_id}', TRUE, NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM tree_shares WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B tree_shares" + finally: + await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'") diff --git a/backend/tests/test_target_lists.py b/backend/tests/test_target_lists.py index a40cfb48..1e4bca49 100644 --- a/backend/tests/test_target_lists.py +++ b/backend/tests/test_target_lists.py @@ -3,37 +3,10 @@ import pytest from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession -from app.models.team import Team from app.models.user import User from sqlalchemy import select -@pytest.fixture -async def auth_headers(client: AsyncClient, test_db: AsyncSession, test_user: dict): - """Override auth_headers to ensure the test user has a team_id assigned.""" - # Fetch the user from DB and assign a team - result = await test_db.execute(select(User).where(User.email == test_user["email"])) - user = result.scalar_one() - - # Create a team and assign the user to it - team = Team(name="Test Team") - test_db.add(team) - await test_db.flush() - - user.team_id = team.id - await test_db.commit() - - # Re-login to get a fresh token - login_data = { - "email": test_user["email"], - "password": test_user["password"], - } - resp = await client.post("/api/v1/auth/login/json", json=login_data) - assert resp.status_code == 200 - token_data = resp.json() - return {"Authorization": f"Bearer {token_data['access_token']}"} - - @pytest.mark.asyncio async def test_create_target_list(client: AsyncClient, auth_headers: dict): resp = await client.post( @@ -107,25 +80,28 @@ async def test_delete_target_list(client: AsyncClient, auth_headers: dict): assert get.status_code == 404 @pytest.mark.asyncio -async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: dict, test_db): - """User from team B cannot access team A's list.""" +async def test_cannot_access_other_accounts_list(client: AsyncClient, auth_headers: dict, test_db): + """User from account B cannot access account A's target list.""" import uuid - from app.models.team import Team + from app.models.account import Account from app.models.user import User from app.core.security import get_password_hash - # Create team A list using existing auth_headers + # Create account A list using existing auth_headers create = await client.post( "/api/v1/target-lists/", - json={"name": "Team A List", "targets": [{"label": "SRV-A"}]}, + json={"name": "Account A List", "targets": [{"label": "SRV-A"}]}, headers=auth_headers, ) assert create.status_code == 201 list_id = create.json()["id"] - # Create a separate team B with its own user - team_b = Team(name=f"Team B {uuid.uuid4()}") - test_db.add(team_b) + # Create a separate account B with its own user + account_b = Account( + name=f"Account B {uuid.uuid4()}", + display_code=f"AB{str(uuid.uuid4())[:6].upper()}", + ) + test_db.add(account_b) await test_db.flush() user_b = User( @@ -133,11 +109,13 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: password_hash=get_password_hash("password123"), name="User B", is_active=True, - team_id=team_b.id, + account_id=account_b.id, + account_role="engineer", role="engineer", ) test_db.add(user_b) await test_db.flush() + await test_db.commit() # Get auth token for user B login = await client.post( @@ -148,6 +126,6 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: token_b = login.json()["access_token"] headers_b = {"Authorization": f"Bearer {token_b}"} - # Team B cannot access Team A's list + # Account B cannot access Account A's list resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b) assert resp.status_code == 404 From 893b8a500899d3aa1d69c110e93590cd02461103 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Sat, 11 Apr 2026 05:17:25 +0000 Subject: [PATCH 2/2] fix: tree_shares.account_id must come from tree owner, not the actor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - trees.py: change account_id=current_user.account_id → account_id=tree.account_id so super-admin cross-account shares land in the tree's tenant where RLS will see them. - migration a05e1a1bea7c: fix backfill to join tree_shares → trees instead of tree_shares → users(created_by). Same logic: historical shares belong to the tree's tenant. - test_tree_sharing.py: add test_share_account_id_matches_tree_not_actor to assert share.account_id == tree.account_id after POST /share; also add missing account_id to all direct TreeShare(...) constructors in existing tests. - test_phase1_migrations.py: remove team_id= from TargetList constructor (column dropped in Phase 3). Co-Authored-By: Claude Sonnet 4.6 --- ...a05e1a1bea7c_add_account_id_tree_shares.py | 18 ++++++---- backend/app/api/endpoints/trees.py | 2 +- backend/tests/test_phase1_migrations.py | 1 - backend/tests/test_tree_sharing.py | 35 +++++++++++++++++++ 4 files changed, 48 insertions(+), 8 deletions(-) diff --git a/backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py b/backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py index 8d69aec3..f8216ace 100644 --- a/backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py +++ b/backend/alembic/versions/a05e1a1bea7c_add_account_id_tree_shares.py @@ -1,4 +1,8 @@ -"""Add account_id to tree_shares and backfill via created_by user. +"""Add account_id to tree_shares and backfill via tree owner's account. + +The share belongs to the tree's tenant, not the actor who created it. +A super admin in account A can share a tree owned by account B; that share +must land in account B so account B's RLS filter sees it. Revision ID: a05e1a1bea7c Revises: 2a9056eddd90 @@ -21,13 +25,15 @@ def upgrade() -> None: ['account_id'], ['id'], ondelete='CASCADE', ) - # Backfill: derive from the creating user's account + # Backfill: derive from the tree's account, not the creator's account. + # A share lives in the same tenant as its tree so that the tree owner's + # RLS context covers their own shares regardless of who created them. op.execute(""" UPDATE tree_shares ts - SET account_id = u.account_id - FROM users u - WHERE ts.created_by = u.id - AND u.account_id IS NOT NULL + SET account_id = t.account_id + FROM trees t + WHERE ts.tree_id = t.id + AND t.account_id IS NOT NULL AND ts.account_id IS NULL """) diff --git a/backend/app/api/endpoints/trees.py b/backend/app/api/endpoints/trees.py index fe5246c7..6f7c0853 100644 --- a/backend/app/api/endpoints/trees.py +++ b/backend/app/api/endpoints/trees.py @@ -1048,7 +1048,7 @@ async def create_tree_share( # Create share tree_share = TreeShare( tree_id=tree.id, - account_id=current_user.account_id, + account_id=tree.account_id, # share belongs to the tree's tenant, not the actor share_token=share_token, created_by=current_user.id, allow_forking=share_data.allow_forking, diff --git a/backend/tests/test_phase1_migrations.py b/backend/tests/test_phase1_migrations.py index eefeba17..59e54820 100644 --- a/backend/tests/test_phase1_migrations.py +++ b/backend/tests/test_phase1_migrations.py @@ -464,7 +464,6 @@ async def test_target_list_account_id_from_team_admin(test_db: AsyncSession): await test_db.flush() target_list = TargetList( - team_id=team.id, account_id=account.id, created_by=user.id, name="Server Targets", diff --git a/backend/tests/test_tree_sharing.py b/backend/tests/test_tree_sharing.py index ea05c50f..a9adee54 100644 --- a/backend/tests/test_tree_sharing.py +++ b/backend/tests/test_tree_sharing.py @@ -117,6 +117,7 @@ class TestTreeSharing: for i in range(3): share = TreeShare( tree_id=sample_tree.id, + account_id=sample_tree.account_id, share_token=f"token_{i}_" + "x" * 56, created_by=sample_tree.author_id, allow_forking=i % 2 == 0 @@ -162,6 +163,7 @@ class TestTreeSharing: # Create a share share = TreeShare( tree_id=sample_tree.id, + account_id=sample_tree.account_id, share_token="public_test_token" + "x" * 47, created_by=UUID(test_user["user_data"]["id"]), allow_forking=True @@ -192,6 +194,7 @@ class TestTreeSharing: # Create expired share share = TreeShare( tree_id=sample_tree.id, + account_id=sample_tree.account_id, share_token="expired_token" + "x" * 50, created_by=UUID(test_user["user_data"]["id"]), allow_forking=True, @@ -209,6 +212,7 @@ class TestTreeSharing: from uuid import UUID share = TreeShare( tree_id=sample_tree.id, + account_id=sample_tree.account_id, share_token="inactive_tree_token" + "x" * 44, created_by=UUID(test_user["user_data"]["id"]), allow_forking=True @@ -248,6 +252,37 @@ class TestTreeSharing: tokens.add(token) assert len(tokens) == 5 + async def test_share_account_id_matches_tree_not_actor( + self, client: AsyncClient, sample_tree, auth_headers, test_db + ): + """Share account_id must equal tree.account_id, not the actor's account_id. + + A super admin in a different account can share any tree. The resulting + TreeShare row must live in the tree-owner's account so that the tree + owner's RLS context covers it. If account_id were derived from the + actor instead, the share would vanish from the tree owner's view once + RLS is enabled. + """ + from uuid import UUID + from sqlalchemy import select + + response = await client.post( + f"/api/v1/trees/{sample_tree.id}/share", + json={"allow_forking": True}, + headers=auth_headers, + ) + assert response.status_code == 201 + share_token = response.json()["share_token"] + + result = await test_db.execute( + select(TreeShare).where(TreeShare.share_token == share_token) + ) + share = result.scalar_one() + assert share.account_id == sample_tree.account_id, ( + "TreeShare.account_id must equal tree.account_id, not the actor's account. " + "Shares must live in the tree owner's tenant for RLS to cover them." + ) + @pytest.mark.asyncio async def test_migration_defaults_visibility_to_team(test_db):