"""Phase 1 migration tests — verify account_id backfill correctness. These tests create objects via ORM (which uses the updated models), then verify account_id is populated correctly. They run against a real PostgreSQL test DB (same as all other integration tests). """ import pytest import uuid from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, text from app.models.account import Account from app.models.user import User from app.models.tree import Tree from app.models.session import Session from app.models.attachment import Attachment from app.models.supporting_data import SessionSupportingData from app.models.session_resolution_output import SessionResolutionOutput from app.models.ai_session import AISession from app.core.security import get_password_hash # ── Helpers ────────────────────────────────────────────────────────────────── async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]: account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8]) db.add(account) await db.flush() user = User( email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com", name=f"User {suffix}", password_hash=get_password_hash("TestPass123!"), is_active=True, account_id=account.id, account_role="engineer", ) db.add(user) await db.flush() return account, user async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree: tree = Tree( name=f"Tree {uuid.uuid4().hex[:6]}", account_id=account.id, author_id=user.id, visibility="team", tree_type="troubleshooting", tree_structure={"id": "root", "type": "start", "children": []}, is_active=True, status="published", ) db.add(tree) await db.flush() return tree async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session: s = Session( tree_id=tree.id, user_id=user.id, account_id=account.id, tree_snapshot={}, ) db.add(s) await db.flush() return s # ── Group 1: Core sessions ──────────────────────────────────────────────────── @pytest.mark.asyncio async def test_session_account_id_matches_user(test_db: AsyncSession): """sessions.account_id must equal the user's account_id.""" account, user = await _make_account_and_user(test_db, "s1") tree = await _make_tree(test_db, account, user) session = await _make_session(test_db, account, user, tree) await test_db.commit() result = await test_db.execute(select(Session).where(Session.id == session.id)) row = result.scalar_one() assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}" @pytest.mark.asyncio async def test_attachment_account_id_matches_session(test_db: AsyncSession): """attachments.account_id must match the parent session's account_id.""" account, user = await _make_account_and_user(test_db, "att1") tree = await _make_tree(test_db, account, user) session = await _make_session(test_db, account, user, tree) attachment = Attachment( session_id=session.id, account_id=account.id, file_name="test.png", file_type="image/png", ) test_db.add(attachment) await test_db.commit() result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id)) row = result.scalar_one() assert row.account_id == account.id @pytest.mark.asyncio async def test_session_supporting_data_account_id(test_db: AsyncSession): """session_supporting_data.account_id must match parent session's account_id.""" account, user = await _make_account_and_user(test_db, "sd1") tree = await _make_tree(test_db, account, user) session = await _make_session(test_db, account, user, tree) sd = SessionSupportingData( session_id=session.id, account_id=account.id, label="Log snippet", data_type="text_snippet", content="error: connection refused", ) test_db.add(sd) await test_db.commit() result = await test_db.execute( select(SessionSupportingData).where(SessionSupportingData.id == sd.id) ) row = result.scalar_one() assert row.account_id == account.id @pytest.mark.asyncio async def test_session_resolution_output_account_id(test_db: AsyncSession): """session_resolution_outputs.account_id must match the parent ai_session's account_id. NOTE: session_resolution_outputs.session_id FK points to ai_sessions (not sessions). """ account, user = await _make_account_and_user(test_db, "sro1") ai_session = AISession( user_id=user.id, account_id=account.id, problem_summary="test resolution output", problem_domain="networking", status="active", ) test_db.add(ai_session) await test_db.flush() output = SessionResolutionOutput( session_id=ai_session.id, account_id=account.id, output_type="psa_ticket_notes", generated_content="Ticket notes content", generated_by_model="gpt-4", ) test_db.add(output) await test_db.commit() result = await test_db.execute( select(SessionResolutionOutput).where(SessionResolutionOutput.id == output.id) ) row = result.scalar_one() assert row.account_id == account.id # ── Group 2: AI & branching ─────────────────────────────────────────────────── @pytest.mark.asyncio async def test_session_branch_account_id_matches_ai_session(test_db: AsyncSession): """session_branches.account_id must match parent ai_session.account_id.""" from app.models.session_branch import SessionBranch account, user = await _make_account_and_user(test_db, "sb1") ai_session = AISession( user_id=user.id, account_id=account.id, problem_summary="test", problem_domain="networking", status="active", ) test_db.add(ai_session) await test_db.flush() branch = SessionBranch( session_id=ai_session.id, account_id=account.id, label="Branch A", branch_order=1, conversation_messages=[], ) test_db.add(branch) await test_db.commit() result = await test_db.execute( select(SessionBranch).where(SessionBranch.id == branch.id) ) row = result.scalar_one() assert row.account_id == account.id @pytest.mark.asyncio async def test_ai_suggestion_account_id_matches_user(test_db: AsyncSession): """ai_suggestions.account_id must match the creating user's account_id.""" from app.models.ai_suggestion import AISuggestion account, user = await _make_account_and_user(test_db, "ais1") tree = await _make_tree(test_db, account, user) suggestion = AISuggestion( tree_id=tree.id, user_id=user.id, account_id=account.id, action_type="add_node", changes_json={}, status="pending", ) test_db.add(suggestion) await test_db.commit() result = await test_db.execute( select(AISuggestion).where(AISuggestion.id == suggestion.id) ) row = result.scalar_one() assert row.account_id == account.id # ── Group 3: Steps & ratings ────────────────────────────────────────────────── @pytest.mark.asyncio async def test_step_rating_account_id_is_rater_account(test_db: AsyncSession): """step_ratings.account_id must be the RATER's account, not the step's account.""" from app.models.step_library import StepLibrary, StepRating account_a, user_a = await _make_account_and_user(test_db, "sr-rater") account_b, user_b = await _make_account_and_user(test_db, "sr-step-owner") # Step owned by account_b step = StepLibrary( title="A step", step_type="action", content={"text": "do something"}, created_by=user_b.id, account_id=account_b.id, visibility="public", ) test_db.add(step) await test_db.flush() # user_a (account_a) rates the step rating = StepRating( step_id=step.id, user_id=user_a.id, account_id=account_a.id, # rater's account, not step owner's was_helpful=True, is_verified_use=False, is_visible=True, ) test_db.add(rating) await test_db.commit() result = await test_db.execute(select(StepRating).where(StepRating.id == rating.id)) row = result.scalar_one() assert row.account_id == account_a.id, ( f"account_id should be rater's account ({account_a.id}), got {row.account_id}" ) @pytest.mark.asyncio async def test_step_usage_log_account_id_is_logger_account(test_db: AsyncSession): """step_usage_log.account_id must be the LOGGER's account (user who used the step).""" from app.models.step_library import StepLibrary, StepUsageLog account, user = await _make_account_and_user(test_db, "sul1") tree = await _make_tree(test_db, account, user) session = await _make_session(test_db, account, user, tree) step = StepLibrary( title="A usage step", step_type="action", content={"text": "do something"}, created_by=user.id, account_id=account.id, visibility="team", ) test_db.add(step) await test_db.flush() log = StepUsageLog( step_id=step.id, user_id=user.id, account_id=account.id, session_id=session.id, ) test_db.add(log) await test_db.commit() result = await test_db.execute(select(StepUsageLog).where(StepUsageLog.id == log.id)) row = result.scalar_one() assert row.account_id == account.id, ( f"account_id should be logger's account ({account.id}), got {row.account_id}" ) # ── Group 4: User personalization ──────────────────────────────────────────── @pytest.mark.asyncio async def test_user_folder_account_id_matches_user(test_db: AsyncSession): """user_folders.account_id must match the owning user's account_id.""" from app.models.folder import UserFolder account, user = await _make_account_and_user(test_db, "uf1") folder = UserFolder( user_id=user.id, account_id=account.id, name="My Folder", color="#6366f1", icon="folder", display_order=0, ) test_db.add(folder) await test_db.commit() result = await test_db.execute(select(UserFolder).where(UserFolder.id == folder.id)) row = result.scalar_one() assert row.account_id == account.id @pytest.mark.asyncio async def test_user_pinned_tree_account_id_matches_user(test_db: AsyncSession): """user_pinned_trees.account_id must match the pinning user's account_id.""" from app.models.user_pinned_tree import UserPinnedTree account, user = await _make_account_and_user(test_db, "pt1") tree = await _make_tree(test_db, account, user) pin = UserPinnedTree( user_id=user.id, tree_id=tree.id, account_id=account.id, display_order=0, ) test_db.add(pin) await test_db.commit() result = await test_db.execute(select(UserPinnedTree).where(UserPinnedTree.id == pin.id)) row = result.scalar_one() assert row.account_id == account.id # ── Group 5: PSA & notifications ───────────────────────────────────────────── @pytest.mark.asyncio async def test_psa_member_mapping_account_id_matches_connection(test_db: AsyncSession): """psa_member_mappings.account_id must match psa_connection's account_id.""" from app.models.psa_connection import PsaConnection from app.models.psa_member_mapping import PsaMemberMapping account, user = await _make_account_and_user(test_db, "psa1") conn = PsaConnection( account_id=account.id, provider="connectwise", display_name="Test CW", site_url="https://cw.example.com", company_id="TEST", credentials_encrypted="placeholder", ) test_db.add(conn) await test_db.flush() mapping = PsaMemberMapping( psa_connection_id=conn.id, user_id=user.id, account_id=account.id, external_member_id="cw-123", external_member_name="Test User", matched_by="manual_admin", ) test_db.add(mapping) await test_db.commit() result = await test_db.execute( select(PsaMemberMapping).where(PsaMemberMapping.id == mapping.id) ) row = result.scalar_one() assert row.account_id == account.id # ── Group 6: Maintenance ────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_maintenance_schedule_account_id_matches_tree(test_db: AsyncSession): """maintenance_schedules.account_id must match the tree's account_id.""" from app.models.maintenance_schedule import MaintenanceSchedule account, user = await _make_account_and_user(test_db, "ms1") tree = Tree( name="Maintenance Flow", account_id=account.id, author_id=user.id, visibility="team", tree_type="maintenance", tree_structure={"id": "root", "type": "start", "children": []}, is_active=True, status="published", ) test_db.add(tree) await test_db.flush() schedule = MaintenanceSchedule( tree_id=tree.id, account_id=account.id, created_by=user.id, cron_expression="0 9 * * 1", timezone="UTC", is_active=True, ) test_db.add(schedule) await test_db.commit() result = await test_db.execute( select(MaintenanceSchedule).where(MaintenanceSchedule.id == schedule.id) ) row = result.scalar_one() assert row.account_id == account.id # ── Group 7: Legacy team_id tables ─────────────────────────────────────────── @pytest.mark.asyncio async def test_script_builder_session_account_id(test_db: AsyncSession): """script_builder_sessions.account_id must match user's account_id.""" from app.models.script_builder_session import ScriptBuilderSession account, user = await _make_account_and_user(test_db, "sbs1") sbs = ScriptBuilderSession( user_id=user.id, account_id=account.id, language="powershell", ) test_db.add(sbs) await test_db.commit() result = await test_db.execute( select(ScriptBuilderSession).where(ScriptBuilderSession.id == sbs.id) ) row = result.scalar_one() assert row.account_id == account.id # ── Group 8: TargetList ──────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_target_list_account_id_from_team_admin(test_db: AsyncSession): """target_lists.account_id must be set to the team admin's account_id.""" from app.models.target_list import TargetList from app.models.team import Team account, user = await _make_account_and_user(test_db, "tl1") # Make user a team admin team = Team(name=f"Team {uuid.uuid4().hex[:6]}") test_db.add(team) await test_db.flush() user.team_id = team.id user.is_team_admin = True await test_db.flush() target_list = TargetList( account_id=account.id, created_by=user.id, name="Server Targets", targets=[{"label": "SRV-01"}], ) test_db.add(target_list) await test_db.commit() result = await test_db.execute( select(TargetList).where(TargetList.id == target_list.id) ) row = result.scalar_one() assert row.account_id == account.id # ── Group 10 (runs first): Global content tables ────────────────────────────── @pytest.mark.asyncio async def test_template_trees_table_exists_and_has_no_account_id(test_db: AsyncSession): """template_trees must exist and must NOT have an account_id column.""" result = await test_db.execute(text(""" SELECT column_name FROM information_schema.columns WHERE table_name = 'template_trees' """)) columns = {row[0] for row in result.fetchall()} assert 'id' in columns, "template_trees.id must exist" assert 'account_id' not in columns, "template_trees must not have account_id (global content)" @pytest.mark.asyncio async def test_platform_steps_table_exists_and_has_no_account_id(test_db: AsyncSession): """platform_steps must exist and must NOT have an account_id column.""" result = await test_db.execute(text(""" SELECT column_name FROM information_schema.columns WHERE table_name = 'platform_steps' """)) columns = {row[0] for row in result.fetchall()} assert 'id' in columns, "platform_steps.id must exist" assert 'account_id' not in columns, "platform_steps must not have account_id (global content)" # ── Group 9: SET NOT NULL on existing nullable columns ──────────────────────── @pytest.mark.asyncio async def test_tree_account_id_is_not_null(test_db: AsyncSession): """trees.account_id must be NOT NULL after Phase 1 — enforced at DB level.""" from sqlalchemy.exc import IntegrityError with pytest.raises(IntegrityError): test_db.add(Tree( name="Bad tree", # account_id intentionally omitted author_id=None, visibility="private", tree_type="troubleshooting", tree_structure={}, is_active=True, status="draft", )) await test_db.flush() @pytest.mark.asyncio async def test_user_account_id_is_not_null(test_db: AsyncSession): """users.account_id must be NOT NULL after Phase 1.""" from sqlalchemy.exc import IntegrityError with pytest.raises(IntegrityError): test_db.add(User( email=f"orphan-{uuid.uuid4().hex[:6]}@example.com", name="Orphan", password_hash=get_password_hash("x"), is_active=True, role="engineer", account_role="engineer", # account_id intentionally omitted )) await test_db.flush()