diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 18bf85ec..37f62d1f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,8 @@ jobs: SECRET_KEY: ci-test-secret-key-not-for-production DEBUG: "true" APP_NAME: ResolutionFlow + TEST_DB_NAME: resolutionflow_test + DB_APP_ROLE_PASSWORD: app_secret_ci steps: - uses: actions/checkout@v5 @@ -47,6 +49,9 @@ jobs: - name: Install dependencies run: pip install -r backend/requirements.txt -r backend/requirements-dev.txt + - name: Run Alembic migrations + run: cd backend && alembic upgrade head + - name: Check tenant filter enforcement run: cd backend && python scripts/check_tenant_filters.py # Warn mode only (exits 0). Switch to --fail after Phase 1 backlog clears. diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 7f9503b4..a3b662d4 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -29,13 +29,37 @@ from app.models.session_branch import SessionBranch # noqa: F401 from app.models.fork_point import ForkPoint # noqa: F401 from app.models.session_handoff import SessionHandoff # noqa: F401 from app.models.session_resolution_output import SessionResolutionOutput # noqa: F401 + from app.core.config import settings + +def _alembic_sync_url() -> str: + """Return a psycopg2-compatible sync URL for Alembic. + + Priority order: + 1. DATABASE_URL_SYNC — in Railway this is set as a reference variable + (${{pgvector.DATABASE_URL}}) that resolves to the correct postgres + superuser credentials for the current environment (production, PR preview, + etc.). This always works even on fresh databases before any custom roles + have been created, because it uses the postgres superuser. + 2. ADMIN_DATABASE_URL (resolutionflow_admin, BYPASSRLS) converted to a sync + driver — fallback for local dev where DATABASE_URL_SYNC may not be set. + """ + if settings.DATABASE_URL_SYNC: + return settings.DATABASE_URL_SYNC + + admin_url = settings.ADMIN_DATABASE_URL + if admin_url and "+asyncpg" in admin_url: + return admin_url.replace("postgresql+asyncpg://", "postgresql://") + + return settings.DATABASE_URL_SYNC + + # this is the Alembic Config object config = context.config # Override sqlalchemy.url with the sync version for migrations -config.set_main_option("sqlalchemy.url", settings.DATABASE_URL_SYNC) +config.set_main_option("sqlalchemy.url", _alembic_sync_url()) # Interpret the config file for Python logging. if config.config_file_name is not None: @@ -86,7 +110,7 @@ def run_migrations_online() -> None: from sqlalchemy import create_engine connectable = create_engine( - settings.DATABASE_URL_SYNC, + _alembic_sync_url(), poolclass=pool.NullPool, ) diff --git a/backend/alembic/versions/70a5dd746e83_enable_rls_phase2.py b/backend/alembic/versions/70a5dd746e83_enable_rls_phase2.py new file mode 100644 index 00000000..a70f1555 --- /dev/null +++ b/backend/alembic/versions/70a5dd746e83_enable_rls_phase2.py @@ -0,0 +1,90 @@ +"""Enable RLS on Phase 2 session and supporting tables. + +10 tables use a standard tenant-only policy. +step_library uses a visibility-aware policy — public steps visible to all tenants. + +NOTE: session_messages does not exist in this codebase (removed from plan). +script_generations is the correct table name (not script_template_generations). +sessions and ai_sessions are two separate tables, both in scope. + +Prerequisites: +- Phase 1 migration must have run (resolutionflow_app role exists, Phase 1 tables have RLS) +- NOT NULL write-path bugs fixed (P2-A commits b641ac6) +- shares.py cross-tenant session fix deployed (P2-B commit ac2b193) + +Revision ID: 70a5dd746e83 +Revises: c5f48b9890f9 +Create Date: 2026-04-10 06:54:49.431817 + +""" +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '70a5dd746e83' +down_revision: Union[str, None] = 'c5f48b9890f9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +_NULL_UUID = "00000000-0000-0000-0000-000000000000" +_CURRENT_ACCOUNT = ( + f"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), " + f"'{_NULL_UUID}')::uuid" +) + +# Standard tenant-only policy — account_id must match the current tenant. +# When no tenant context is set, COALESCE returns the nil UUID so zero rows +# are visible (fail-closed). +_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}" + +# Visibility-aware policy for step_library — public steps (visibility='public') +# must be visible to ALL tenants regardless of account_id. This covers the +# visibility='public' arm of build_step_visibility_filter() in app/core/filters.py. +# The created_by arm (private steps visible to their author) is covered +# transitively: private steps share account_id with their creator, so the +# account_id match handles it. This relies on account_id NOT NULL on step_library. +_STEP_LIBRARY_USING = f"account_id = {_CURRENT_ACCOUNT} OR visibility = 'public'" + +# Standard tables: strict tenant isolation, no cross-tenant visibility. +_STANDARD_TABLES = [ + "sessions", + "ai_sessions", + "session_branches", + "session_supporting_data", + "session_resolution_outputs", + "session_handoffs", + "script_templates", + "script_generations", + "maintenance_schedules", + "psa_post_log", +] + + +def upgrade() -> None: + # ── Standard tenant-isolation tables ──────────────────────────────────── + for table in _STANDARD_TABLES: + op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON {table} + USING ({_STANDARD_USING}) + """) + + # ── step_library ──────────────────────────────────────────────────────── + # Public steps (visibility='public') must be readable by all tenants so + # the Solutions Library browsing experience works without tenant context. + # Private/team steps remain tenant-scoped. + op.execute("ALTER TABLE step_library ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE step_library FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON step_library + USING ({_STEP_LIBRARY_USING}) + """) + + +def downgrade() -> None: + for table in _STANDARD_TABLES + ["step_library"]: + op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}") + op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY") diff --git a/backend/app/api/endpoints/ai_suggestions.py b/backend/app/api/endpoints/ai_suggestions.py index e5f0919f..f3dd7f53 100644 --- a/backend/app/api/endpoints/ai_suggestions.py +++ b/backend/app/api/endpoints/ai_suggestions.py @@ -43,6 +43,7 @@ async def create_suggestion( suggestion = AISuggestion( tree_id=data.tree_id, user_id=current_user.id, + account_id=current_user.account_id, session_id=data.session_id, action_type=data.action_type, target_node_id=data.target_node_id, diff --git a/backend/app/api/endpoints/maintenance_schedules.py b/backend/app/api/endpoints/maintenance_schedules.py index 506da0e3..d43980a9 100644 --- a/backend/app/api/endpoints/maintenance_schedules.py +++ b/backend/app/api/endpoints/maintenance_schedules.py @@ -69,6 +69,7 @@ async def create_schedule( schedule = MaintenanceSchedule( tree_id=data.tree_id, + account_id=current_user.account_id, created_by=current_user.id, cron_expression=data.cron_expression, timezone=data.timezone, diff --git a/backend/app/api/endpoints/ratings.py b/backend/app/api/endpoints/ratings.py index 23f07054..3f0dc2f9 100644 --- a/backend/app/api/endpoints/ratings.py +++ b/backend/app/api/endpoints/ratings.py @@ -91,6 +91,7 @@ async def submit_step_feedback( new_rating = StepRating( step_id=step_id, user_id=current_user.id, + account_id=current_user.account_id, session_id=session_uuid, was_helpful=data.was_helpful, # rating is nullable now — thumbs-only mode diff --git a/backend/app/api/endpoints/sessions.py b/backend/app/api/endpoints/sessions.py index 3ad51f5d..80543347 100644 --- a/backend/app/api/endpoints/sessions.py +++ b/backend/app/api/endpoints/sessions.py @@ -196,6 +196,7 @@ async def start_session( new_session = Session( tree_id=tree.id, user_id=current_user.id, + account_id=current_user.account_id, tree_snapshot=tree_snapshot, path_taken=[], decisions=[], @@ -693,6 +694,7 @@ async def prepare_session( new_session = Session( tree_id=tree.id, user_id=data.assigned_to_id or current_user.id, + account_id=current_user.account_id, tree_snapshot=tree_snapshot, path_taken=[], decisions=[], @@ -770,6 +772,7 @@ async def batch_launch_sessions( session = Session( tree_id=tree.id, user_id=current_user.id, + account_id=current_user.account_id, tree_snapshot=tree_snapshot, path_taken=[], decisions=[], @@ -1102,6 +1105,7 @@ async def psa_post_to_ticket( # Log to audit trail log_entry = PsaPostLog( session_id=session.id, + account_id=session.account_id, psa_connection_id=psa_connection.id if psa_connection else None, ticket_id=session.psa_ticket_id, note_type=data.note_type, diff --git a/backend/app/api/endpoints/shares.py b/backend/app/api/endpoints/shares.py index 3d67207d..ca04dadf 100644 --- a/backend/app/api/endpoints/shares.py +++ b/backend/app/api/endpoints/shares.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import joinedload from sqlalchemy.exc import IntegrityError from app.core.database import get_db +from app.core.admin_database import get_admin_db from app.models.session import Session from app.models.session_share import SessionShare, SessionShareView from app.models.user import User @@ -210,7 +211,7 @@ async def _get_optional_user(request: Request, db: AsyncSession) -> Optional[Use async def access_share( share_token: str, request: Request, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], ): """Access a shared session via share token. diff --git a/backend/app/api/endpoints/steps.py b/backend/app/api/endpoints/steps.py index 992c1725..0ec2af71 100644 --- a/backend/app/api/endpoints/steps.py +++ b/backend/app/api/endpoints/steps.py @@ -460,6 +460,7 @@ async def rate_step( rating = StepRating( step_id=step_id, user_id=current_user.id, + account_id=current_user.account_id, rating=rating_data.rating, was_helpful=rating_data.was_helpful, review_text=rating_data.review_text, diff --git a/backend/app/api/endpoints/supporting_data.py b/backend/app/api/endpoints/supporting_data.py index b7d0a33a..ae8c5d79 100644 --- a/backend/app/api/endpoints/supporting_data.py +++ b/backend/app/api/endpoints/supporting_data.py @@ -103,6 +103,7 @@ async def create_supporting_data( item = SessionSupportingData( session_id=session_id, + account_id=session.account_id, label=data.label, data_type=data.data_type, content=data.content, diff --git a/backend/app/core/admin_database.py b/backend/app/core/admin_database.py index 1e84a132..9d845c34 100644 --- a/backend/app/core/admin_database.py +++ b/backend/app/core/admin_database.py @@ -2,8 +2,10 @@ """ Admin database engine — connects as resolutionflow_admin (BYPASSRLS). -Use ONLY for /admin/* endpoints and internal tooling. -Never use this engine from user-facing endpoints. +Use ONLY where explicit application-level access control makes database-layer +tenant filtering unnecessary: /admin/* endpoints, internal tooling, and public +endpoints that enforce their own authorization before returning data (e.g. +share access via opaque token + visibility check). """ from collections.abc import AsyncGenerator @@ -25,7 +27,7 @@ _admin_session_factory = async_sessionmaker( async def get_admin_db() -> AsyncGenerator[AsyncSession, None]: - """Yield an admin DB session (BYPASSRLS). Use only on /admin/* endpoints.""" + """Yield an admin DB session (BYPASSRLS). See module docstring for approved use cases.""" async with _admin_session_factory() as session: try: yield session diff --git a/backend/app/services/branch_manager.py b/backend/app/services/branch_manager.py index 8dba3fa4..ac6d1609 100644 --- a/backend/app/services/branch_manager.py +++ b/backend/app/services/branch_manager.py @@ -34,6 +34,7 @@ class BranchManager: root = SessionBranch( id=uuid.uuid4(), session_id=session_id, + account_id=session.account_id, parent_branch_id=None, branch_order=1, label="Root", @@ -68,9 +69,17 @@ class BranchManager: "status": "untried", }) + # Load session to get account_id for FK constraints + session_result = await self.db.execute( + select(AISession).where(AISession.id == session_id) + ) + session = session_result.scalar_one_or_none() + account_id = session.account_id if session else None + fork_point = ForkPoint( id=uuid.uuid4(), session_id=session_id, + account_id=account_id, parent_branch_id=parent_branch_id, trigger_step_id=trigger_step_id, fork_reason=fork_reason, @@ -90,6 +99,7 @@ class BranchManager: branch = SessionBranch( id=branch_ids[i], session_id=session_id, + account_id=account_id, parent_branch_id=parent_branch_id, fork_point_step_id=trigger_step_id, branch_order=i + 1, diff --git a/backend/app/services/handoff_manager.py b/backend/app/services/handoff_manager.py index 8751e8b4..c79461ba 100644 --- a/backend/app/services/handoff_manager.py +++ b/backend/app/services/handoff_manager.py @@ -56,6 +56,7 @@ class HandoffManager: handoff = SessionHandoff( session_id=session_id, + account_id=session.account_id, handed_off_by=user_id, intent=intent, source_branch_id=session.active_branch_id, diff --git a/backend/app/services/psa_documentation_service.py b/backend/app/services/psa_documentation_service.py index 17a62587..558f21c5 100644 --- a/backend/app/services/psa_documentation_service.py +++ b/backend/app/services/psa_documentation_service.py @@ -371,6 +371,7 @@ async def push_documentation( # Log success log_entry = PsaPostLog( id=uuid.uuid4(), + account_id=session.account_id, ai_session_id=session.id, psa_connection_id=session.psa_connection_id, ticket_id=session.psa_ticket_id, @@ -394,6 +395,7 @@ async def push_documentation( # Log failure with retry scheduling log_entry = PsaPostLog( id=uuid.uuid4(), + account_id=session.account_id, ai_session_id=session.id, psa_connection_id=session.psa_connection_id, ticket_id=session.psa_ticket_id, diff --git a/backend/app/services/resolution_output_generator.py b/backend/app/services/resolution_output_generator.py index 1b317d5c..022f658e 100644 --- a/backend/app/services/resolution_output_generator.py +++ b/backend/app/services/resolution_output_generator.py @@ -45,6 +45,7 @@ class ResolutionOutputGenerator: output = SessionResolutionOutput( session_id=session_id, + account_id=session.account_id, output_type=output_type, generated_content=content, status="draft", diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index b2c10429..9c1f60e6 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -75,6 +75,19 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]: ('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]') """)) + # Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by + # global categories, gallery items, and other platform-owned content. + await conn.execute(sa.text(""" + INSERT INTO accounts (id, name, display_code, created_at, updated_at) + VALUES ( + '00000000-0000-0000-0000-000000000001', + 'ResolutionFlow System', + 'RF-SYS-1', + NOW(), NOW() + ) + ON CONFLICT (id) DO NOTHING + """)) + # Create async session maker async_session_maker = async_sessionmaker( engine, diff --git a/backend/tests/test_admin_gallery.py b/backend/tests/test_admin_gallery.py index e611950a..5d93651b 100644 --- a/backend/tests/test_admin_gallery.py +++ b/backend/tests/test_admin_gallery.py @@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.models.tree import Tree from app.models.script_template import ScriptTemplate, ScriptCategory +_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001") # --------------------------------------------------------------------------- # Helpers @@ -22,6 +23,7 @@ async def _create_tree(db: AsyncSession, admin_user_id: str) -> Tree: name="Gallery Test Flow", tree_type="troubleshooting", visibility="public", + account_id=_PLATFORM_ACCOUNT_ID, is_gallery_featured=False, gallery_sort_order=0, tree_structure={ @@ -53,6 +55,7 @@ async def _create_script(db: AsyncSession, admin_user_id: str) -> ScriptTemplate script = ScriptTemplate( id=uuid.uuid4(), category_id=category.id, + account_id=_PLATFORM_ACCOUNT_ID, name="Gallery Test Script", slug=f"gallery-test-script-{uuid.uuid4().hex[:6]}", script_body="Write-Host 'Test'", diff --git a/backend/tests/test_analytics_phase5.py b/backend/tests/test_analytics_phase5.py index 6992e0a8..a2e3dd3c 100644 --- a/backend/tests/test_analytics_phase5.py +++ b/backend/tests/test_analytics_phase5.py @@ -594,6 +594,7 @@ class TestPsaMetrics: post_log = PsaPostLog( id=uuid.uuid4(), ai_session_id=push_session_id, + account_id=account_id, ticket_id="TICKET-123", note_type="internal", content_posted="Session summary", diff --git a/backend/tests/test_branding.py b/backend/tests/test_branding.py index c94c696a..d7d23f11 100644 --- a/backend/tests/test_branding.py +++ b/backend/tests/test_branding.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.core.security import get_password_hash +from app.models.account import Account from app.models.team import Team from app.models.user import User @@ -23,6 +24,8 @@ async def _create_team_with_admin( team_name: str = "Branding Test Team", ) -> tuple[dict, str, Team]: """Create a team + team admin user. Returns (auth_headers, team_id_str, team).""" + account = Account(name=team_name, display_code=uuid.uuid4().hex[:8].upper()) + test_db.add(account) team = Team(name=team_name) test_db.add(team) await test_db.flush() @@ -36,6 +39,8 @@ async def _create_team_with_admin( team_id=team.id, is_team_admin=True, role="engineer", + account_id=account.id, + account_role="engineer", ) test_db.add(user) await test_db.commit() @@ -58,6 +63,15 @@ async def _create_team_member( is_team_admin: bool = False, ) -> dict: """Create a regular team member. Returns auth_headers.""" + # Look up the account associated with this team via an existing member + from sqlalchemy import select as _select + from app.models.user import User as _User + result = await test_db.execute( + _select(_User).where(_User.team_id == team.id).limit(1) + ) + team_member = result.scalar_one_or_none() + member_account_id = team_member.account_id if team_member else None + email = f"member_{uuid.uuid4().hex[:8]}@test.com" user = User( email=email, @@ -67,6 +81,8 @@ async def _create_team_member( team_id=team.id, is_team_admin=is_team_admin, role="engineer", + account_id=member_account_id, + account_role="engineer", ) test_db.add(user) await test_db.commit() diff --git a/backend/tests/test_draft_trees.py b/backend/tests/test_draft_trees.py index 97aae49a..45538f78 100644 --- a/backend/tests/test_draft_trees.py +++ b/backend/tests/test_draft_trees.py @@ -334,12 +334,13 @@ class TestDraftTreesAPI: """Test that migration defaults existing trees to published status.""" # Create a tree without specifying status (relies on DB default) from uuid import UUID, uuid4 + _platform_id = UUID("00000000-0000-0000-0000-000000000001") tree = Tree( name="Legacy Tree", description="Created before status field", tree_structure={"id": "root", "type": "solution", "title": "Fix"}, author_id=None, - account_id=None + account_id=_platform_id, ) test_db.add(tree) await test_db.commit() diff --git a/backend/tests/test_maintenance_schedules.py b/backend/tests/test_maintenance_schedules.py index 42a7fd58..2ba700e7 100644 --- a/backend/tests/test_maintenance_schedules.py +++ b/backend/tests/test_maintenance_schedules.py @@ -127,10 +127,12 @@ async def test_cannot_schedule_other_teams_tree(client: AsyncClient, auth_header test_db.add(other_team) await test_db.flush() + from uuid import UUID as _UUID other_tree = Tree( name="Other Team Tree", tree_type="maintenance", team_id=other_team.id, + account_id=_UUID("00000000-0000-0000-0000-000000000001"), tree_structure={ "steps": [ {"id": "s1", "type": "procedure_step", "title": "Step", diff --git a/backend/tests/test_public_templates.py b/backend/tests/test_public_templates.py index d1d972f0..04542020 100644 --- a/backend/tests/test_public_templates.py +++ b/backend/tests/test_public_templates.py @@ -11,6 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.models.script_template import ScriptCategory, ScriptTemplate from app.models.tree import Tree +_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001") + # --------------------------------------------------------------------------- # Helpers @@ -41,6 +43,7 @@ async def _create_featured_tree(db: AsyncSession, name: str = "Featured Flow", f description="A featured flow for the gallery", tree_type="troubleshooting", tree_structure=_make_tree_structure(4), + account_id=_PLATFORM_ACCOUNT_ID, is_gallery_featured=featured, is_active=True, usage_count=42, @@ -74,6 +77,7 @@ async def _create_featured_script( ) -> ScriptTemplate: script = ScriptTemplate( category_id=category.id, + account_id=_PLATFORM_ACCOUNT_ID, name=name, slug=name.lower().replace(" ", "-"), description="A gallery-featured script", @@ -312,7 +316,7 @@ class TestCategoriesEndpoint: from app.models.category import TreeCategory # Create a category and a featured tree in that category - cat = TreeCategory(name="Networking", slug="networking", is_active=True) + cat = TreeCategory(name="Networking", slug="networking", is_active=True, account_id=_PLATFORM_ACCOUNT_ID) test_db.add(cat) await test_db.commit() await test_db.refresh(cat) @@ -321,6 +325,7 @@ class TestCategoriesEndpoint: name="Router Diagnostics", tree_type="troubleshooting", tree_structure=_make_tree_structure(2), + account_id=_PLATFORM_ACCOUNT_ID, is_gallery_featured=True, is_active=True, usage_count=5, diff --git a/backend/tests/test_resolution_outputs.py b/backend/tests/test_resolution_outputs.py index a852ebca..2a853252 100644 --- a/backend/tests/test_resolution_outputs.py +++ b/backend/tests/test_resolution_outputs.py @@ -62,6 +62,7 @@ async def test_edit_output(client: AsyncClient, test_user, auth_headers, test_db output = SessionResolutionOutput( session_id=session.id, + account_id=session.account_id, output_type="psa_ticket_notes", generated_content="Original notes", status="draft", diff --git a/backend/tests/test_rls_isolation.py b/backend/tests/test_rls_isolation.py index 5d6572e2..1934fdee 100644 --- a/backend/tests/test_rls_isolation.py +++ b/backend/tests/test_rls_isolation.py @@ -16,7 +16,10 @@ Run with: The test DB is patherly_test (matches conftest.py default). """ import os +import subprocess +import sys import uuid +from pathlib import Path import asyncpg import pytest @@ -37,7 +40,25 @@ ACCOUNT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" # --------------------------------------------------------------------------- @pytest.fixture(scope="module") -async def admin_conn(): +def _ensure_rls_schema(): + """Re-apply Alembic migrations before the module runs. + + Function-scoped test_db fixtures in other modules drop and recreate the + public schema using Base.metadata.create_all, which does not enable RLS + or create DB roles. This fixture re-runs 'alembic upgrade head' so that + the full migration-managed schema (including RLS policies) is in place. + """ + backend_dir = Path(__file__).parent.parent + subprocess.run( + [sys.executable, "-m", "alembic", "upgrade", "head"], + cwd=backend_dir, + check=True, + capture_output=True, + ) + + +@pytest.fixture(scope="module") +async def admin_conn(_ensure_rls_schema): """Superuser asyncpg connection for fixture setup and teardown.""" conn = await asyncpg.connect(_ADMIN_DSN) yield conn @@ -264,3 +285,426 @@ async def test_flow_proposals_account_a_cannot_see_account_b(conn_a): f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'" ) assert len(rows) == 0 + + +# --------------------------------------------------------------------------- +# Phase 2 fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="module") +async def session_row_ids(admin_conn): + """ + Insert one `sessions` row and one `ai_sessions` row for each of + ACCOUNT_A and ACCOUNT_B using the superuser connection (BYPASSRLS). + Returns a dict with the inserted IDs for use in tests. + Cleans up on exit. + """ + # Resolve a valid tree_id and user_id for each account + tree_a = await admin_conn.fetchrow( + f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}' LIMIT 1" + ) + tree_b = await admin_conn.fetchrow( + f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1" + ) + user_a = await admin_conn.fetchrow( + f"SELECT id FROM users WHERE account_id = '{ACCOUNT_A_ID}' LIMIT 1" + ) + user_b = await admin_conn.fetchrow( + f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1" + ) + + assert tree_a is not None, f"No tree found for ACCOUNT_A ({ACCOUNT_A_ID}) — seed_rls_test_data must run first" + assert tree_b is not None, f"No tree found for ACCOUNT_B ({ACCOUNT_B_ID}) — seed_rls_test_data must run first" + assert user_a is not None, f"No user found for ACCOUNT_A ({ACCOUNT_A_ID}) — seed_rls_test_data must run first" + assert user_b is not None, f"No user found for ACCOUNT_B ({ACCOUNT_B_ID}) — seed_rls_test_data must run first" + + tree_a_id = str(tree_a["id"]) + tree_b_id = str(tree_b["id"]) + user_a_id = str(user_a["id"]) + user_b_id = str(user_b["id"]) + + session_a_id = str(uuid.uuid4()) + session_b_id = str(uuid.uuid4()) + ai_session_a_id = str(uuid.uuid4()) + ai_session_b_id = str(uuid.uuid4()) + + # Insert sessions rows (sessions uses started_at not created_at) + await admin_conn.execute(f""" + INSERT INTO sessions ( + id, tree_id, user_id, account_id, tree_snapshot, + path_taken, decisions, custom_steps, started_at + ) VALUES + ('{session_a_id}', '{tree_a_id}', '{user_a_id}', '{ACCOUNT_A_ID}', + '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()), + ('{session_b_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()) + """) + + # Insert ai_sessions rows + # confidence_tier valid values: 'guided' | 'exploring' | 'discovery' + await admin_conn.execute(f""" + INSERT INTO ai_sessions ( + id, user_id, account_id, session_type, intake_type, + intake_content, status, confidence_tier, confidence_score, + created_at, updated_at + ) VALUES + ('{ai_session_a_id}', '{user_a_id}', '{ACCOUNT_A_ID}', + 'guided', 'free_text', '{{}}'::jsonb, 'active', 'guided', 0.0, + NOW(), NOW()), + ('{ai_session_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + 'guided', 'free_text', '{{}}'::jsonb, 'active', 'guided', 0.0, + NOW(), NOW()) + """) + + # ------------------------------------------------------------------------- + # Seed Account B rows for every "cannot-see" table that would otherwise be + # empty. Without these, isolation tests pass vacuously even when RLS is off. + # ------------------------------------------------------------------------- + + # session_branches (FK: ai_sessions.id) + branch_b_row = await admin_conn.fetchrow(""" + INSERT INTO session_branches ( + id, session_id, account_id, branch_order, label, status, + conversation_messages, created_at, updated_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, 1, 'test-branch', 'active', + '[]'::jsonb, NOW(), NOW() + ) RETURNING id + """, ai_session_b_id, ACCOUNT_B_ID) + branch_b_id = str(branch_b_row["id"]) + + # session_supporting_data (FK: sessions.id) + supporting_data_b_row = await admin_conn.fetchrow(""" + INSERT INTO session_supporting_data ( + id, session_id, account_id, label, data_type, content, + sort_order, created_at, updated_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, 'test-data', 'text_snippet', + 'test content', 0, NOW(), NOW() + ) RETURNING id + """, session_b_id, ACCOUNT_B_ID) + supporting_data_b_id = str(supporting_data_b_row["id"]) + + # session_resolution_outputs (FK: ai_sessions.id) + resolution_output_b_row = await admin_conn.fetchrow(""" + INSERT INTO session_resolution_outputs ( + id, session_id, account_id, output_type, generated_content, + status, generated_by_model, created_at, updated_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, 'psa_ticket_notes', + 'test content', 'draft', 'test-model', NOW(), NOW() + ) RETURNING id + """, ai_session_b_id, ACCOUNT_B_ID) + resolution_output_b_id = str(resolution_output_b_row["id"]) + + # session_handoffs (FK: ai_sessions.id, users.id) + handoff_b_row = await admin_conn.fetchrow(""" + INSERT INTO session_handoffs ( + id, session_id, account_id, handed_off_by, intent, snapshot, + priority, psa_note_pushed, notification_sent, created_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, $3::uuid, 'park', + '{}'::jsonb, 'normal', false, false, NOW() + ) RETURNING id + """, ai_session_b_id, ACCOUNT_B_ID, user_b_id) + handoff_b_id = str(handoff_b_row["id"]) + + # maintenance_schedules (FK: trees.id) + maintenance_b_row = await admin_conn.fetchrow(""" + INSERT INTO maintenance_schedules ( + id, tree_id, account_id, cron_expression, timezone, + created_at, updated_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, '0 9 * * 1', 'UTC', + NOW(), NOW() + ) RETURNING id + """, tree_b_id, ACCOUNT_B_ID) + maintenance_b_id = str(maintenance_b_row["id"]) + + # psa_post_log (FK: ai_sessions.id, users.id) + psa_log_b_row = await admin_conn.fetchrow(""" + INSERT INTO psa_post_log ( + id, ai_session_id, account_id, ticket_id, note_type, + content_posted, status, posted_by, posted_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, 'TEST-0001', 'internal', + 'test note', 'success', $3::uuid, NOW() + ) RETURNING id + """, ai_session_b_id, ACCOUNT_B_ID, user_b_id) + psa_log_b_id = str(psa_log_b_row["id"]) + + # script_templates requires a script_categories row — insert a temporary one + script_category_b_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO script_categories (id, name, slug, sort_order, is_active, created_at, updated_at) + VALUES ('{script_category_b_id}', 'RLS Test Category', 'rls-test-category-{script_category_b_id[:8]}', + 0, true, NOW(), NOW()) + """) + + script_template_b_row = await admin_conn.fetchrow(f""" + INSERT INTO script_templates ( + id, category_id, account_id, name, slug, script_body, + complexity, is_active, created_at, updated_at + ) VALUES ( + gen_random_uuid(), '{script_category_b_id}'::uuid, $1::uuid, + 'RLS Test Template', 'rls-test-template-b-' || gen_random_uuid()::text, + 'Write-Host "test"', 'beginner', true, NOW(), NOW() + ) RETURNING id + """, ACCOUNT_B_ID) + script_template_b_id = str(script_template_b_row["id"]) + + # script_generations (FK: script_templates.id, users.id) + script_gen_b_row = await admin_conn.fetchrow(""" + INSERT INTO script_generations ( + id, template_id, user_id, account_id, parameters_used, + generated_script, created_at + ) VALUES ( + gen_random_uuid(), $1::uuid, $2::uuid, $3::uuid, '{}'::jsonb, + 'test script', NOW() + ) RETURNING id + """, script_template_b_id, user_b_id, ACCOUNT_B_ID) + script_gen_b_id = str(script_gen_b_row["id"]) + + try: + yield { + "session_a": session_a_id, + "session_b": session_b_id, + "ai_session_a": ai_session_a_id, + "ai_session_b": ai_session_b_id, + } + finally: + # Cleanup in reverse FK order (children before parents) + await admin_conn.execute( + f"DELETE FROM script_generations WHERE id = '{script_gen_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM session_branches WHERE id = '{branch_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM session_supporting_data WHERE id = '{supporting_data_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM session_resolution_outputs WHERE id = '{resolution_output_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM session_handoffs WHERE id = '{handoff_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM maintenance_schedules WHERE id = '{maintenance_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM psa_post_log WHERE id = '{psa_log_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM script_templates WHERE id = '{script_template_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM script_categories WHERE id = '{script_category_b_id}'" + ) + await admin_conn.execute( + f"DELETE FROM sessions WHERE id IN ('{session_a_id}', '{session_b_id}')" + ) + await admin_conn.execute( + f"DELETE FROM ai_sessions WHERE id IN ('{ai_session_a_id}', '{ai_session_b_id}')" + ) + + +# --------------------------------------------------------------------------- +# sessions +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_b']}'" + ) + assert len(rows) == 0, "Account A should not see Account B sessions" + + +@pytest.mark.asyncio +async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_a']}'" + ) + assert len(rows) == 1, "Account A should see its own sessions" + + +@pytest.mark.asyncio +async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids): + rows = await conn_no_context.fetch( + f"SELECT id FROM sessions WHERE id IN " + f"('{session_row_ids['session_a']}', '{session_row_ids['session_b']}')" + ) + assert len(rows) == 0, "No-context connection should see no sessions" + + +# --------------------------------------------------------------------------- +# ai_sessions +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_b']}'" + ) + assert len(rows) == 0, "Account A should not see Account B ai_sessions" + + +@pytest.mark.asyncio +async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_a']}'" + ) + assert len(rows) == 1, "Account A should see its own ai_sessions" + + +# --------------------------------------------------------------------------- +# session_branches +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM session_branches WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B session_branches" + + +# --------------------------------------------------------------------------- +# session_supporting_data +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM session_supporting_data WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B session_supporting_data" + + +# --------------------------------------------------------------------------- +# session_resolution_outputs +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM session_resolution_outputs WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B session_resolution_outputs" + + +# --------------------------------------------------------------------------- +# session_handoffs +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM session_handoffs WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B session_handoffs" + + +# --------------------------------------------------------------------------- +# script_templates +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM script_templates WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B script_templates" + + +# --------------------------------------------------------------------------- +# script_generations +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_script_generations_account_a_cannot_see_account_b(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM script_generations WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B script_generations" + + +# --------------------------------------------------------------------------- +# maintenance_schedules +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM maintenance_schedules WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B maintenance_schedules" + + +# --------------------------------------------------------------------------- +# psa_post_log +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_ids): + rows = await conn_a.fetch( + f"SELECT id FROM psa_post_log WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B psa_post_log" + + +# --------------------------------------------------------------------------- +# step_library — visibility-aware policy +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_conn, conn_a): + """Private/non-public steps owned by Account B must not be visible to Account A.""" + private_step_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_library ( + id, account_id, title, step_type, content, + visibility, is_active, created_at, updated_at + ) VALUES ( + '{private_step_id}', '{ACCOUNT_B_ID}', 'RLS Private Step', 'action', + '{{}}'::jsonb, 'private', TRUE, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM step_library " + f"WHERE id = '{private_step_id}' AND visibility != 'public'" + ) + assert len(rows) == 0, "Account A should not see Account B's private step_library rows" + finally: + await admin_conn.execute( + f"DELETE FROM step_library WHERE id = '{private_step_id}'" + ) + + +@pytest.mark.asyncio +async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn, conn_a): + """Public steps owned by Account B MUST be visible to Account A (cross-tenant visibility).""" + public_step_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO step_library ( + id, account_id, title, step_type, content, + visibility, is_active, created_at, updated_at + ) VALUES ( + '{public_step_id}', '{ACCOUNT_B_ID}', 'RLS Public Step', 'action', + '{{}}'::jsonb, 'public', TRUE, NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM step_library WHERE id = '{public_step_id}'" + ) + assert len(rows) == 1, ( + "Account A should see public steps owned by Account B " + "(cross-tenant public visibility policy)" + ) + finally: + await admin_conn.execute( + f"DELETE FROM step_library WHERE id = '{public_step_id}'" + ) diff --git a/backend/tests/test_save_session_as_tree.py b/backend/tests/test_save_session_as_tree.py index bf011674..f137862c 100644 --- a/backend/tests/test_save_session_as_tree.py +++ b/backend/tests/test_save_session_as_tree.py @@ -155,6 +155,7 @@ class TestSaveSessionAsTreeAPI: session = Session( tree_id=tree.id, user_id=UUID(test_user["user_data"]["id"]), + account_id=UUID(test_user["user_data"]["account_id"]), tree_snapshot=tree.tree_structure, path_taken=["root"], decisions=[{"node_id": "root", "timestamp": datetime.now(timezone.utc).isoformat()}], @@ -199,6 +200,7 @@ class TestSaveSessionAsTreeAPI: session = Session( tree_id=tree.id, user_id=UUID(test_user["user_data"]["id"]), + account_id=UUID(test_user["user_data"]["account_id"]), tree_snapshot=tree.tree_structure, path_taken=["root"], decisions=[], @@ -239,6 +241,7 @@ class TestSaveSessionAsTreeAPI: session = Session( tree_id=tree.id, user_id=UUID(test_user["user_data"]["id"]), + account_id=UUID(test_user["user_data"]["account_id"]), tree_snapshot=tree.tree_structure, path_taken=["root"], decisions=[], @@ -279,6 +282,7 @@ class TestSaveSessionAsTreeAPI: session = Session( tree_id=tree.id, user_id=UUID(test_user["user_data"]["id"]), + account_id=UUID(test_user["user_data"]["account_id"]), tree_snapshot=tree.tree_structure, path_taken=["root"], decisions=[], @@ -352,6 +356,7 @@ class TestSaveSessionAsTreeAPI: session = Session( tree_id=tree.id, user_id=other_user.id, + account_id=UUID(test_user["user_data"]["account_id"]), tree_snapshot=tree.tree_structure, path_taken=["root"], decisions=[], diff --git a/frontend/src/components/assistant/TaskLane.tsx b/frontend/src/components/assistant/TaskLane.tsx index 81d36a23..c3d9458e 100644 --- a/frontend/src/components/assistant/TaskLane.tsx +++ b/frontend/src/components/assistant/TaskLane.tsx @@ -57,6 +57,7 @@ function loadTaskState(sessionId: string): TaskResponse[] | null { } catch { return null } } +// eslint-disable-next-line react-refresh/only-export-components export function clearTaskState(sessionId: string) { try { sessionStorage.removeItem(`${TASK_LANE_STORAGE_KEY}:${sessionId}`) } catch { /* ignore */ } } diff --git a/frontend/src/components/dashboard/TeamSummary.tsx b/frontend/src/components/dashboard/TeamSummary.tsx index b0557de4..b7b19836 100644 --- a/frontend/src/components/dashboard/TeamSummary.tsx +++ b/frontend/src/components/dashboard/TeamSummary.tsx @@ -9,10 +9,10 @@ export function TeamSummary() { const { isAccountOwner } = usePermissions() const navigate = useNavigate() const [escalationCount, setEscalationCount] = useState(0) - const [loading, setLoading] = useState(true) + const [loading, setLoading] = useState(!!isAccountOwner) useEffect(() => { - if (!isAccountOwner) { setLoading(false); return } + if (!isAccountOwner) return aiSessionsApi.getEscalationQueue() .then((esc) => setEscalationCount(esc.length)) .catch(() => {}) diff --git a/frontend/src/components/script-editor/ScriptBodyEditor.tsx b/frontend/src/components/script-editor/ScriptBodyEditor.tsx index 03a651ba..c967f424 100644 --- a/frontend/src/components/script-editor/ScriptBodyEditor.tsx +++ b/frontend/src/components/script-editor/ScriptBodyEditor.tsx @@ -1,4 +1,4 @@ -import { useCallback, useRef } from 'react' +import { useCallback, useEffect, useRef } from 'react' import Editor, { type BeforeMount } from '@monaco-editor/react' import { resolutionFlowTheme, THEME_ID } from '@/components/tree-editor/code-mode/resolutionFlowTheme' import { Spinner } from '@/components/common/Spinner' @@ -11,7 +11,9 @@ interface Props { export function ScriptBodyEditor({ value, onChange, disabled }: Props) { const lastValueRef = useRef(value) - lastValueRef.current = value + useEffect(() => { + lastValueRef.current = value + }, [value]) const handleBeforeMount: BeforeMount = useCallback((monaco) => { // Register our dark theme if not already defined