diff --git a/backend/tests/test_rls_isolation.py b/backend/tests/test_rls_isolation.py index 520582fe..9b608bdc 100644 --- a/backend/tests/test_rls_isolation.py +++ b/backend/tests/test_rls_isolation.py @@ -292,6 +292,11 @@ async def session_row_ids(admin_conn): f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1" ) + assert tree_a is not None, f"No tree found for ACCOUNT_A ({ACCOUNT_A_ID}) — seed_rls_test_data must run first" + assert tree_b is not None, f"No tree found for ACCOUNT_B ({ACCOUNT_B_ID}) — seed_rls_test_data must run first" + assert user_a is not None, f"No user found for ACCOUNT_A ({ACCOUNT_A_ID}) — seed_rls_test_data must run first" + assert user_b is not None, f"No user found for ACCOUNT_B ({ACCOUNT_B_ID}) — seed_rls_test_data must run first" + tree_a_id = str(tree_a["id"]) tree_b_id = str(tree_b["id"]) user_a_id = str(user_a["id"]) @@ -329,20 +334,157 @@ async def session_row_ids(admin_conn): NOW(), NOW()) """) - yield { - "session_a": session_a_id, - "session_b": session_b_id, - "ai_session_a": ai_session_a_id, - "ai_session_b": ai_session_b_id, - } + # ------------------------------------------------------------------------- + # Seed Account B rows for every "cannot-see" table that would otherwise be + # empty. Without these, isolation tests pass vacuously even when RLS is off. + # ------------------------------------------------------------------------- - # Cleanup - await admin_conn.execute( - f"DELETE FROM sessions WHERE id IN ('{session_a_id}', '{session_b_id}')" - ) - await admin_conn.execute( - f"DELETE FROM ai_sessions WHERE id IN ('{ai_session_a_id}', '{ai_session_b_id}')" - ) + # session_branches (FK: ai_sessions.id) + branch_b_row = await admin_conn.fetchrow(""" + INSERT INTO session_branches ( + id, session_id, account_id, branch_order, label, status, + conversation_messages, created_at, updated_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, 1, 'test-branch', 'active', + '[]'::jsonb, NOW(), NOW() + ) RETURNING id + """, ai_session_b_id, ACCOUNT_B_ID) + branch_b_id = str(branch_b_row["id"]) + + # session_supporting_data (FK: sessions.id) + supporting_data_b_row = await admin_conn.fetchrow(""" + INSERT INTO session_supporting_data ( + id, session_id, account_id, label, data_type, content, + sort_order, created_at, updated_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, 'test-data', 'text_snippet', + 'test content', 0, NOW(), NOW() + ) RETURNING id + """, session_b_id, ACCOUNT_B_ID) + supporting_data_b_id = str(supporting_data_b_row["id"]) + + # session_resolution_outputs (FK: ai_sessions.id) + resolution_output_b_row = await admin_conn.fetchrow(""" + INSERT INTO session_resolution_outputs ( + id, session_id, account_id, output_type, generated_content, + status, generated_by_model, created_at, updated_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, 'psa_ticket_notes', + 'test content', 'draft', 'test-model', NOW(), NOW() + ) RETURNING id + """, ai_session_b_id, ACCOUNT_B_ID) + resolution_output_b_id = str(resolution_output_b_row["id"]) + + # session_handoffs (FK: ai_sessions.id, users.id) + handoff_b_row = await admin_conn.fetchrow(""" + INSERT INTO session_handoffs ( + id, session_id, account_id, handed_off_by, intent, snapshot, + priority, psa_note_pushed, notification_sent, created_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, $3::uuid, 'park', + '{}'::jsonb, 'normal', false, false, NOW() + ) RETURNING id + """, ai_session_b_id, ACCOUNT_B_ID, user_b_id) + handoff_b_id = str(handoff_b_row["id"]) + + # maintenance_schedules (FK: trees.id) + maintenance_b_row = await admin_conn.fetchrow(""" + INSERT INTO maintenance_schedules ( + id, tree_id, account_id, cron_expression, timezone, + created_at, updated_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, '0 9 * * 1', 'UTC', + NOW(), NOW() + ) RETURNING id + """, tree_b_id, ACCOUNT_B_ID) + maintenance_b_id = str(maintenance_b_row["id"]) + + # psa_post_log (FK: ai_sessions.id, users.id) + psa_log_b_row = await admin_conn.fetchrow(""" + INSERT INTO psa_post_log ( + id, ai_session_id, account_id, ticket_id, note_type, + content_posted, status, posted_by, posted_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, 'TEST-0001', 'internal', + 'test note', 'success', $3::uuid, NOW() + ) RETURNING id + """, ai_session_b_id, ACCOUNT_B_ID, user_b_id) + psa_log_b_id = str(psa_log_b_row["id"]) + + # script_templates requires a script_categories row — insert a temporary one + script_category_b_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO script_categories (id, name, slug, sort_order, is_active, created_at, updated_at) + VALUES ('{script_category_b_id}', 'RLS Test Category', 'rls-test-category-{script_category_b_id[:8]}', + 0, true, NOW(), NOW()) + """) + + script_template_b_row = await admin_conn.fetchrow(f""" + INSERT INTO script_templates ( + id, category_id, account_id, name, slug, script_body, + complexity, is_active, created_at, updated_at + ) VALUES ( + gen_random_uuid(), '{script_category_b_id}'::uuid, $1::uuid, + 'RLS Test Template', 'rls-test-template-b-' || gen_random_uuid()::text, + 'Write-Host "test"', 'beginner', true, NOW(), NOW() + ) RETURNING id + """, ACCOUNT_B_ID) + script_template_b_id = str(script_template_b_row["id"]) + + # script_generations (FK: script_templates.id, users.id) + script_gen_b_row = await admin_conn.fetchrow(""" + INSERT INTO script_generations ( + id, template_id, user_id, account_id, parameters_used, + generated_script, created_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, $3::uuid, '{}'::jsonb, + 'test script', NOW() + ) RETURNING id + """, script_template_b_id, user_b_id, ACCOUNT_B_ID) + script_gen_b_id = str(script_gen_b_row["id"]) + + try: + yield { + "session_a": session_a_id, + "session_b": session_b_id, + "ai_session_a": ai_session_a_id, + "ai_session_b": ai_session_b_id, + } + finally: + # Cleanup in reverse FK order (children before parents) + await admin_conn.execute( + f"DELETE FROM script_generations WHERE id = '{script_gen_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM session_branches WHERE id = '{branch_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM session_supporting_data WHERE id = '{supporting_data_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM session_resolution_outputs WHERE id = '{resolution_output_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM session_handoffs WHERE id = '{handoff_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM maintenance_schedules WHERE id = '{maintenance_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM psa_post_log WHERE id = '{psa_log_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM script_templates WHERE id = '{script_template_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM script_categories WHERE id = '{script_category_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM sessions WHERE id IN ('{session_a_id}', '{session_b_id}')" + ) + await admin_conn.execute( + f"DELETE FROM ai_sessions WHERE id IN ('{ai_session_a_id}', '{ai_session_b_id}')" + ) # --------------------------------------------------------------------------- @@ -399,7 +541,7 @@ async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids): # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_session_branches_account_a_cannot_see_account_b(conn_a): +async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM session_branches WHERE account_id = '{ACCOUNT_B_ID}'" ) @@ -411,7 +553,7 @@ async def test_session_branches_account_a_cannot_see_account_b(conn_a): # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a): +async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM session_supporting_data WHERE account_id = '{ACCOUNT_B_ID}'" ) @@ -423,7 +565,7 @@ async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a): # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a): +async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM session_resolution_outputs WHERE account_id = '{ACCOUNT_B_ID}'" ) @@ -435,7 +577,7 @@ async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a) # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_session_handoffs_account_a_cannot_see_account_b(conn_a): +async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM session_handoffs WHERE account_id = '{ACCOUNT_B_ID}'" ) @@ -447,7 +589,7 @@ async def test_session_handoffs_account_a_cannot_see_account_b(conn_a): # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_script_templates_account_a_cannot_see_account_b(conn_a): +async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM script_templates WHERE account_id = '{ACCOUNT_B_ID}'" ) @@ -459,7 +601,7 @@ async def test_script_templates_account_a_cannot_see_account_b(conn_a): # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_script_generations_account_a_cannot_see_account_b(conn_a): +async def test_script_generations_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM script_generations WHERE account_id = '{ACCOUNT_B_ID}'" ) @@ -471,7 +613,7 @@ async def test_script_generations_account_a_cannot_see_account_b(conn_a): # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a): +async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM maintenance_schedules WHERE account_id = '{ACCOUNT_B_ID}'" ) @@ -483,7 +625,7 @@ async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a): # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_psa_post_log_account_a_cannot_see_account_b(conn_a): +async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM psa_post_log WHERE account_id = '{ACCOUNT_B_ID}'" ) @@ -495,13 +637,28 @@ async def test_psa_post_log_account_a_cannot_see_account_b(conn_a): # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_step_library_account_a_cannot_see_account_b_private_steps(conn_a): +async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_conn, conn_a): """Private/non-public steps owned by Account B must not be visible to Account A.""" - rows = await conn_a.fetch( - f"SELECT id FROM step_library " - f"WHERE account_id = '{ACCOUNT_B_ID}' AND visibility != 'public'" - ) - assert len(rows) == 0, "Account A should not see Account B's private step_library rows" + private_step_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_library ( + id, account_id, title, step_type, content, + visibility, is_active, created_at, updated_at + ) VALUES ( + '{private_step_id}', '{ACCOUNT_B_ID}', 'RLS Private Step', 'action', + '{{}}'::jsonb, 'private', TRUE, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM step_library " + f"WHERE id = '{private_step_id}' AND visibility != 'public'" + ) + assert len(rows) == 0, "Account A should not see Account B's private step_library rows" + finally: + await admin_conn.execute( + f"DELETE FROM step_library WHERE id = '{private_step_id}'" + ) @pytest.mark.asyncio