diff --git a/backend/app/api/endpoints/ai_suggestions.py b/backend/app/api/endpoints/ai_suggestions.py index e5f0919f..f3dd7f53 100644 --- a/backend/app/api/endpoints/ai_suggestions.py +++ b/backend/app/api/endpoints/ai_suggestions.py @@ -43,6 +43,7 @@ async def create_suggestion( suggestion = AISuggestion( tree_id=data.tree_id, user_id=current_user.id, + account_id=current_user.account_id, session_id=data.session_id, action_type=data.action_type, target_node_id=data.target_node_id, diff --git a/backend/app/api/endpoints/ratings.py b/backend/app/api/endpoints/ratings.py index 23f07054..3f0dc2f9 100644 --- a/backend/app/api/endpoints/ratings.py +++ b/backend/app/api/endpoints/ratings.py @@ -91,6 +91,7 @@ async def submit_step_feedback( new_rating = StepRating( step_id=step_id, user_id=current_user.id, + account_id=current_user.account_id, session_id=session_uuid, was_helpful=data.was_helpful, # rating is nullable now — thumbs-only mode diff --git a/backend/app/api/endpoints/steps.py b/backend/app/api/endpoints/steps.py index 992c1725..0ec2af71 100644 --- a/backend/app/api/endpoints/steps.py +++ b/backend/app/api/endpoints/steps.py @@ -460,6 +460,7 @@ async def rate_step( rating = StepRating( step_id=step_id, user_id=current_user.id, + account_id=current_user.account_id, rating=rating_data.rating, was_helpful=rating_data.was_helpful, review_text=rating_data.review_text, diff --git a/backend/app/services/branch_manager.py b/backend/app/services/branch_manager.py index 8dba3fa4..ac6d1609 100644 --- a/backend/app/services/branch_manager.py +++ b/backend/app/services/branch_manager.py @@ -34,6 +34,7 @@ class BranchManager: root = SessionBranch( id=uuid.uuid4(), session_id=session_id, + account_id=session.account_id, parent_branch_id=None, branch_order=1, label="Root", @@ -68,9 +69,17 @@ class BranchManager: "status": "untried", }) + # Load session to get account_id for FK constraints + session_result = await self.db.execute( + select(AISession).where(AISession.id == session_id) + ) + session = session_result.scalar_one_or_none() + account_id = session.account_id if session else None + fork_point = ForkPoint( id=uuid.uuid4(), session_id=session_id, + account_id=account_id, parent_branch_id=parent_branch_id, trigger_step_id=trigger_step_id, fork_reason=fork_reason, @@ -90,6 +99,7 @@ class BranchManager: branch = SessionBranch( id=branch_ids[i], session_id=session_id, + account_id=account_id, parent_branch_id=parent_branch_id, fork_point_step_id=trigger_step_id, branch_order=i + 1, diff --git a/backend/app/services/handoff_manager.py b/backend/app/services/handoff_manager.py index 8751e8b4..c79461ba 100644 --- a/backend/app/services/handoff_manager.py +++ b/backend/app/services/handoff_manager.py @@ -56,6 +56,7 @@ class HandoffManager: handoff = SessionHandoff( session_id=session_id, + account_id=session.account_id, handed_off_by=user_id, intent=intent, source_branch_id=session.active_branch_id, diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index b2c10429..9c1f60e6 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -75,6 +75,19 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]: ('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]') """)) + # Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by + # global categories, gallery items, and other platform-owned content. + await conn.execute(sa.text(""" + INSERT INTO accounts (id, name, display_code, created_at, updated_at) + VALUES ( + '00000000-0000-0000-0000-000000000001', + 'ResolutionFlow System', + 'RF-SYS-1', + NOW(), NOW() + ) + ON CONFLICT (id) DO NOTHING + """)) + # Create async session maker async_session_maker = async_sessionmaker( engine, diff --git a/backend/tests/test_admin_gallery.py b/backend/tests/test_admin_gallery.py index e611950a..5d93651b 100644 --- a/backend/tests/test_admin_gallery.py +++ b/backend/tests/test_admin_gallery.py @@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.models.tree import Tree from app.models.script_template import ScriptTemplate, ScriptCategory +_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001") # --------------------------------------------------------------------------- # Helpers @@ -22,6 +23,7 @@ async def _create_tree(db: AsyncSession, admin_user_id: str) -> Tree: name="Gallery Test Flow", tree_type="troubleshooting", visibility="public", + account_id=_PLATFORM_ACCOUNT_ID, is_gallery_featured=False, gallery_sort_order=0, tree_structure={ @@ -53,6 +55,7 @@ async def _create_script(db: AsyncSession, admin_user_id: str) -> ScriptTemplate script = ScriptTemplate( id=uuid.uuid4(), category_id=category.id, + account_id=_PLATFORM_ACCOUNT_ID, name="Gallery Test Script", slug=f"gallery-test-script-{uuid.uuid4().hex[:6]}", script_body="Write-Host 'Test'", diff --git a/backend/tests/test_analytics_phase5.py b/backend/tests/test_analytics_phase5.py index 6992e0a8..a2e3dd3c 100644 --- a/backend/tests/test_analytics_phase5.py +++ b/backend/tests/test_analytics_phase5.py @@ -594,6 +594,7 @@ class TestPsaMetrics: post_log = PsaPostLog( id=uuid.uuid4(), ai_session_id=push_session_id, + account_id=account_id, ticket_id="TICKET-123", note_type="internal", content_posted="Session summary", diff --git a/backend/tests/test_branding.py b/backend/tests/test_branding.py index c94c696a..d7d23f11 100644 --- a/backend/tests/test_branding.py +++ b/backend/tests/test_branding.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.core.security import get_password_hash +from app.models.account import Account from app.models.team import Team from app.models.user import User @@ -23,6 +24,8 @@ async def _create_team_with_admin( team_name: str = "Branding Test Team", ) -> tuple[dict, str, Team]: """Create a team + team admin user. Returns (auth_headers, team_id_str, team).""" + account = Account(name=team_name, display_code=uuid.uuid4().hex[:8].upper()) + test_db.add(account) team = Team(name=team_name) test_db.add(team) await test_db.flush() @@ -36,6 +39,8 @@ async def _create_team_with_admin( team_id=team.id, is_team_admin=True, role="engineer", + account_id=account.id, + account_role="engineer", ) test_db.add(user) await test_db.commit() @@ -58,6 +63,15 @@ async def _create_team_member( is_team_admin: bool = False, ) -> dict: """Create a regular team member. Returns auth_headers.""" + # Look up the account associated with this team via an existing member + from sqlalchemy import select as _select + from app.models.user import User as _User + result = await test_db.execute( + _select(_User).where(_User.team_id == team.id).limit(1) + ) + team_member = result.scalar_one_or_none() + member_account_id = team_member.account_id if team_member else None + email = f"member_{uuid.uuid4().hex[:8]}@test.com" user = User( email=email, @@ -67,6 +81,8 @@ async def _create_team_member( team_id=team.id, is_team_admin=is_team_admin, role="engineer", + account_id=member_account_id, + account_role="engineer", ) test_db.add(user) await test_db.commit() diff --git a/backend/tests/test_draft_trees.py b/backend/tests/test_draft_trees.py index 97aae49a..45538f78 100644 --- a/backend/tests/test_draft_trees.py +++ b/backend/tests/test_draft_trees.py @@ -334,12 +334,13 @@ class TestDraftTreesAPI: """Test that migration defaults existing trees to published status.""" # Create a tree without specifying status (relies on DB default) from uuid import UUID, uuid4 + _platform_id = UUID("00000000-0000-0000-0000-000000000001") tree = Tree( name="Legacy Tree", description="Created before status field", tree_structure={"id": "root", "type": "solution", "title": "Fix"}, author_id=None, - account_id=None + account_id=_platform_id, ) test_db.add(tree) await test_db.commit() diff --git a/backend/tests/test_maintenance_schedules.py b/backend/tests/test_maintenance_schedules.py index 42a7fd58..2ba700e7 100644 --- a/backend/tests/test_maintenance_schedules.py +++ b/backend/tests/test_maintenance_schedules.py @@ -127,10 +127,12 @@ async def test_cannot_schedule_other_teams_tree(client: AsyncClient, auth_header test_db.add(other_team) await test_db.flush() + from uuid import UUID as _UUID other_tree = Tree( name="Other Team Tree", tree_type="maintenance", team_id=other_team.id, + account_id=_UUID("00000000-0000-0000-0000-000000000001"), tree_structure={ "steps": [ {"id": "s1", "type": "procedure_step", "title": "Step", diff --git a/backend/tests/test_public_templates.py b/backend/tests/test_public_templates.py index d1d972f0..04542020 100644 --- a/backend/tests/test_public_templates.py +++ b/backend/tests/test_public_templates.py @@ -11,6 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.models.script_template import ScriptCategory, ScriptTemplate from app.models.tree import Tree +_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001") + # --------------------------------------------------------------------------- # Helpers @@ -41,6 +43,7 @@ async def _create_featured_tree(db: AsyncSession, name: str = "Featured Flow", f description="A featured flow for the gallery", tree_type="troubleshooting", tree_structure=_make_tree_structure(4), + account_id=_PLATFORM_ACCOUNT_ID, is_gallery_featured=featured, is_active=True, usage_count=42, @@ -74,6 +77,7 @@ async def _create_featured_script( ) -> ScriptTemplate: script = ScriptTemplate( category_id=category.id, + account_id=_PLATFORM_ACCOUNT_ID, name=name, slug=name.lower().replace(" ", "-"), description="A gallery-featured script", @@ -312,7 +316,7 @@ class TestCategoriesEndpoint: from app.models.category import TreeCategory # Create a category and a featured tree in that category - cat = TreeCategory(name="Networking", slug="networking", is_active=True) + cat = TreeCategory(name="Networking", slug="networking", is_active=True, account_id=_PLATFORM_ACCOUNT_ID) test_db.add(cat) await test_db.commit() await test_db.refresh(cat) @@ -321,6 +325,7 @@ class TestCategoriesEndpoint: name="Router Diagnostics", tree_type="troubleshooting", tree_structure=_make_tree_structure(2), + account_id=_PLATFORM_ACCOUNT_ID, is_gallery_featured=True, is_active=True, usage_count=5, diff --git a/backend/tests/test_resolution_outputs.py b/backend/tests/test_resolution_outputs.py index a852ebca..2a853252 100644 --- a/backend/tests/test_resolution_outputs.py +++ b/backend/tests/test_resolution_outputs.py @@ -62,6 +62,7 @@ async def test_edit_output(client: AsyncClient, test_user, auth_headers, test_db output = SessionResolutionOutput( session_id=session.id, + account_id=session.account_id, output_type="psa_ticket_notes", generated_content="Original notes", status="draft", diff --git a/backend/tests/test_rls_isolation.py b/backend/tests/test_rls_isolation.py index 9b608bdc..ee8136af 100644 --- a/backend/tests/test_rls_isolation.py +++ b/backend/tests/test_rls_isolation.py @@ -16,7 +16,10 @@ Run with: The test DB is patherly_test (matches conftest.py default). """ import os +import subprocess +import sys import uuid +from pathlib import Path import asyncpg import pytest @@ -37,7 +40,25 @@ ACCOUNT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" # --------------------------------------------------------------------------- @pytest.fixture(scope="module") -async def admin_conn(): +def _ensure_rls_schema(): + """Re-apply Alembic migrations before the module runs. + + Function-scoped test_db fixtures in other modules drop and recreate the + public schema using Base.metadata.create_all, which does not enable RLS + or create DB roles. This fixture re-runs 'alembic upgrade head' so that + the full migration-managed schema (including RLS policies) is in place. + """ + backend_dir = Path(__file__).parent.parent + subprocess.run( + [sys.executable, "-m", "alembic", "upgrade", "head"], + cwd=backend_dir, + check=True, + capture_output=True, + ) + + +@pytest.fixture(scope="module") +async def admin_conn(_ensure_rls_schema): """Superuser asyncpg connection for fixture setup and teardown.""" conn = await asyncpg.connect(_ADMIN_DSN) yield conn diff --git a/backend/tests/test_save_session_as_tree.py b/backend/tests/test_save_session_as_tree.py index bf011674..f137862c 100644 --- a/backend/tests/test_save_session_as_tree.py +++ b/backend/tests/test_save_session_as_tree.py @@ -155,6 +155,7 @@ class TestSaveSessionAsTreeAPI: session = Session( tree_id=tree.id, user_id=UUID(test_user["user_data"]["id"]), + account_id=UUID(test_user["user_data"]["account_id"]), tree_snapshot=tree.tree_structure, path_taken=["root"], decisions=[{"node_id": "root", "timestamp": datetime.now(timezone.utc).isoformat()}], @@ -199,6 +200,7 @@ class TestSaveSessionAsTreeAPI: session = Session( tree_id=tree.id, user_id=UUID(test_user["user_data"]["id"]), + account_id=UUID(test_user["user_data"]["account_id"]), tree_snapshot=tree.tree_structure, path_taken=["root"], decisions=[], @@ -239,6 +241,7 @@ class TestSaveSessionAsTreeAPI: session = Session( tree_id=tree.id, user_id=UUID(test_user["user_data"]["id"]), + account_id=UUID(test_user["user_data"]["account_id"]), tree_snapshot=tree.tree_structure, path_taken=["root"], decisions=[], @@ -279,6 +282,7 @@ class TestSaveSessionAsTreeAPI: session = Session( tree_id=tree.id, user_id=UUID(test_user["user_data"]["id"]), + account_id=UUID(test_user["user_data"]["account_id"]), tree_snapshot=tree.tree_structure, path_taken=["root"], decisions=[], @@ -352,6 +356,7 @@ class TestSaveSessionAsTreeAPI: session = Session( tree_id=tree.id, user_id=other_user.id, + account_id=UUID(test_user["user_data"]["account_id"]), tree_snapshot=tree.tree_structure, path_taken=["root"], decisions=[],