diff --git a/backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py b/backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py new file mode 100644 index 00000000..fb47878a --- /dev/null +++ b/backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py @@ -0,0 +1,105 @@ +"""Enable RLS on Phase 4 tables — all remaining tenant-scoped tables. + +All tables in this migration already have account_id NOT NULL (enforced by +earlier migrations). This migration adds ENABLE ROW LEVEL SECURITY, +FORCE ROW LEVEL SECURITY, and the appropriate tenant isolation policy to each. + +Policy variants used: + - Standard: account_id = current_setting(app.current_account_id)::uuid + - Platform: standard OR account_id = PLATFORM_ACCOUNT_ID + (for global content tables readable by all tenants) + +Skipped intentionally: + - accounts — IS the root table; no account_id column + - plan_feature_defaults — platform config; no account_id column + +Revision ID: b3c7e9f2a1d8 +Revises: 172ad76d7d20 +Create Date: 2026-04-12 +""" + +from typing import Union +from alembic import op + +revision: str = "b3c7e9f2a1d8" +down_revision: Union[str, None] = "172ad76d7d20" +branch_labels = None +depends_on = None + +PLATFORM_ACCOUNT_ID = "00000000-0000-0000-0000-000000000001" + +# Standard policy — tenant sees only own rows. +_STANDARD_TABLES = [ + "users", + "account_invites", + "account_limit_overrides", + "account_feature_overrides", + "subscriptions", + "ai_chat_sessions", + "ai_conversations", + "ai_session_steps", + "ai_session_embeddings", + "ai_suggestions", + "ai_usage", + "assistant_chats", + "attachments", + "copilot_conversations", + "feedback", + "file_uploads", + "fork_points", + "kb_imports", + "notifications", + "notification_configs", + "notification_logs", + "psa_activity_logs", + "psa_member_mappings", + "script_builder_sessions", + "script_categories", + "session_ratings", + "tree_embeddings", + "user_folders", + "user_pinned_trees", +] + +# Platform-visibility policy — tenant sees own rows PLUS PLATFORM_ACCOUNT_ID rows. +# These tables hold global content created by ResolutionFlow admins. +_PLATFORM_TABLES = [ + "platform_steps", + "template_trees", +] + +_POLICY_EXPR = ( + "account_id = COALESCE(" + "NULLIF(current_setting('app.current_account_id', TRUE), ''), " + "'00000000-0000-0000-0000-000000000000'" + ")::uuid" +) + + +def upgrade() -> None: + # Standard tables — tenant isolation only + for table in _STANDARD_TABLES: + op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON {table} + USING ({_POLICY_EXPR}) + """) + + # Platform-visible tables — own rows OR global platform rows + for table in _PLATFORM_TABLES: + op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON {table} + USING ( + {_POLICY_EXPR} + OR account_id = '{PLATFORM_ACCOUNT_ID}'::uuid + ) + """) + + +def downgrade() -> None: + for table in _STANDARD_TABLES + _PLATFORM_TABLES: + op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}") + op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY") diff --git a/backend/app/api/endpoints/script_builder.py b/backend/app/api/endpoints/script_builder.py index f56c595c..1cb849f8 100644 --- a/backend/app/api/endpoints/script_builder.py +++ b/backend/app/api/endpoints/script_builder.py @@ -85,6 +85,7 @@ async def create_session( session = await script_builder_service.create_session( db=db, user_id=current_user.id, + account_id=current_user.account_id, team_id=current_user.team_id, language=data.language, ) diff --git a/backend/app/services/script_builder_service.py b/backend/app/services/script_builder_service.py index a1a7647e..991d9a28 100644 --- a/backend/app/services/script_builder_service.py +++ b/backend/app/services/script_builder_service.py @@ -144,6 +144,7 @@ def _extract_script_from_response(content: str, language: str) -> tuple[str | No async def create_session( db: AsyncSession, user_id: UUID, + account_id: UUID, team_id: UUID | None, language: str, initial_prompt: str | None = None, @@ -151,6 +152,7 @@ async def create_session( """Create a new Script Builder session.""" session = ScriptBuilderSession( user_id=user_id, + account_id=account_id, team_id=team_id, language=language, ) diff --git a/backend/tests/test_rls_isolation.py b/backend/tests/test_rls_isolation.py index b083eda8..e585845d 100644 --- a/backend/tests/test_rls_isolation.py +++ b/backend/tests/test_rls_isolation.py @@ -954,3 +954,187 @@ async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a): assert len(rows) == 0, "Account A should not see Account B tree_shares" finally: await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'") + + +# =========================================================================== +# Phase 4 RLS isolation tests +# Tables: users, script_builder_sessions, ai_session_steps, +# notifications, platform_steps, template_trees +# =========================================================================== + +# --------------------------------------------------------------------------- +# users +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_users_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see users belonging to Account B.""" + rows = await conn_a.fetch( + f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B users" + + +@pytest.mark.asyncio +async def test_users_account_a_can_see_own(admin_conn, conn_a): + """Account A must be able to see its own users.""" + rows = await conn_a.fetch( + f"SELECT id FROM users WHERE account_id = '{ACCOUNT_A_ID}'" + ) + assert len(rows) > 0, "Account A should see its own users" + + +# --------------------------------------------------------------------------- +# script_builder_sessions +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_script_builder_sessions_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see script builder sessions belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + + session_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO script_builder_sessions ( + id, user_id, account_id, language, created_at, updated_at + ) VALUES ( + '{session_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + 'powershell', NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM script_builder_sessions WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B script_builder_sessions" + finally: + await admin_conn.execute( + f"DELETE FROM script_builder_sessions WHERE id = '{session_id}'" + ) + + +# --------------------------------------------------------------------------- +# ai_session_steps +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_ai_session_steps_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see ai_session_steps belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + tree_b_id = await _get_tree_b_id(admin_conn) + + # Need an ai_sessions row as FK + ai_session_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO ai_sessions ( + id, user_id, account_id, flow_type, status, confidence_tier, + created_at, updated_at + ) VALUES ( + '{ai_session_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + 'troubleshooting', 'active', 'guided', NOW(), NOW() + ) + """) + + step_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO ai_session_steps ( + id, session_id, account_id, step_type, content, + created_at + ) VALUES ( + '{step_id}', '{ai_session_id}', '{ACCOUNT_B_ID}', + 'question', 'Phase4 RLS test step', NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM ai_session_steps WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B ai_session_steps" + finally: + await admin_conn.execute(f"DELETE FROM ai_session_steps WHERE id = '{step_id}'") + await admin_conn.execute(f"DELETE FROM ai_sessions WHERE id = '{ai_session_id}'") + + +# --------------------------------------------------------------------------- +# notifications +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_notifications_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see notifications belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + + notif_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO notifications ( + id, user_id, account_id, type, title, message, + is_read, created_at + ) VALUES ( + '{notif_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + 'info', 'Phase4 RLS Test', 'RLS isolation test notification', + FALSE, NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM notifications WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B notifications" + finally: + await admin_conn.execute(f"DELETE FROM notifications WHERE id = '{notif_id}'") + + +# --------------------------------------------------------------------------- +# platform_steps — platform content visible to all tenants +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_platform_steps_visible_to_all_tenants(admin_conn, conn_a): + """Platform steps (PLATFORM_ACCOUNT_ID) must be visible to any tenant.""" + step_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO platform_steps ( + id, account_id, title, step_type, content, + is_active, created_at, updated_at + ) VALUES ( + '{step_id}', '{PLATFORM_ACCOUNT_ID}', 'Phase4 RLS Platform Step', + 'action', '{{}}'::jsonb, TRUE, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM platform_steps WHERE id = '{step_id}'" + ) + assert len(rows) == 1, ( + "Platform steps (PLATFORM_ACCOUNT_ID) must be visible to all tenants" + ) + finally: + await admin_conn.execute(f"DELETE FROM platform_steps WHERE id = '{step_id}'") + + +# --------------------------------------------------------------------------- +# template_trees — platform content visible to all tenants +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_template_trees_visible_to_all_tenants(admin_conn, conn_a): + """Template trees (PLATFORM_ACCOUNT_ID) must be visible to any tenant.""" + tmpl_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO template_trees ( + id, account_id, name, tree_structure, is_active, + created_at, updated_at + ) VALUES ( + '{tmpl_id}', '{PLATFORM_ACCOUNT_ID}', 'Phase4 RLS Template', + '{{}}'::jsonb, TRUE, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM template_trees WHERE id = '{tmpl_id}'" + ) + assert len(rows) == 1, ( + "Template trees (PLATFORM_ACCOUNT_ID) must be visible to all tenants" + ) + finally: + await admin_conn.execute(f"DELETE FROM template_trees WHERE id = '{tmpl_id}'")