"""Phase 1 migration tests — verify account_id backfill correctness. These tests create objects via ORM (which uses the updated models), then verify account_id is populated correctly. They run against a real PostgreSQL test DB (same as all other integration tests). """ import pytest import uuid from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.models.account import Account from app.models.user import User from app.models.tree import Tree from app.models.session import Session from app.models.attachment import Attachment from app.models.supporting_data import SessionSupportingData from app.models.session_resolution_output import SessionResolutionOutput from app.models.ai_session import AISession from app.core.security import get_password_hash # ── Helpers ────────────────────────────────────────────────────────────────── async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]: account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8]) db.add(account) await db.flush() user = User( email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com", name=f"User {suffix}", password_hash=get_password_hash("TestPass123!"), is_active=True, account_id=account.id, account_role="engineer", ) db.add(user) await db.flush() return account, user async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree: tree = Tree( name=f"Tree {uuid.uuid4().hex[:6]}", account_id=account.id, author_id=user.id, visibility="team", tree_type="troubleshooting", tree_structure={"id": "root", "type": "start", "children": []}, is_active=True, status="published", ) db.add(tree) await db.flush() return tree async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session: s = Session( tree_id=tree.id, user_id=user.id, account_id=account.id, tree_snapshot={}, ) db.add(s) await db.flush() return s # ── Group 1: Core sessions ──────────────────────────────────────────────────── @pytest.mark.asyncio async def test_session_account_id_matches_user(test_db: AsyncSession): """sessions.account_id must equal the user's account_id.""" account, user = await _make_account_and_user(test_db, "s1") tree = await _make_tree(test_db, account, user) session = await _make_session(test_db, account, user, tree) await test_db.commit() result = await test_db.execute(select(Session).where(Session.id == session.id)) row = result.scalar_one() assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}" @pytest.mark.asyncio async def test_attachment_account_id_matches_session(test_db: AsyncSession): """attachments.account_id must match the parent session's account_id.""" account, user = await _make_account_and_user(test_db, "att1") tree = await _make_tree(test_db, account, user) session = await _make_session(test_db, account, user, tree) attachment = Attachment( session_id=session.id, account_id=account.id, file_name="test.png", file_type="image/png", ) test_db.add(attachment) await test_db.commit() result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id)) row = result.scalar_one() assert row.account_id == account.id @pytest.mark.asyncio async def test_session_supporting_data_account_id(test_db: AsyncSession): """session_supporting_data.account_id must match parent session's account_id.""" account, user = await _make_account_and_user(test_db, "sd1") tree = await _make_tree(test_db, account, user) session = await _make_session(test_db, account, user, tree) sd = SessionSupportingData( session_id=session.id, account_id=account.id, label="Log snippet", data_type="text_snippet", content="error: connection refused", ) test_db.add(sd) await test_db.commit() result = await test_db.execute( select(SessionSupportingData).where(SessionSupportingData.id == sd.id) ) row = result.scalar_one() assert row.account_id == account.id @pytest.mark.asyncio async def test_session_resolution_output_account_id(test_db: AsyncSession): """session_resolution_outputs.account_id must match the parent ai_session's account_id. NOTE: session_resolution_outputs.session_id FK points to ai_sessions (not sessions). """ account, user = await _make_account_and_user(test_db, "sro1") ai_session = AISession( user_id=user.id, account_id=account.id, problem_summary="test resolution output", problem_domain="networking", status="active", ) test_db.add(ai_session) await test_db.flush() output = SessionResolutionOutput( session_id=ai_session.id, account_id=account.id, output_type="psa_ticket_notes", generated_content="Ticket notes content", generated_by_model="gpt-4", ) test_db.add(output) await test_db.commit() result = await test_db.execute( select(SessionResolutionOutput).where(SessionResolutionOutput.id == output.id) ) row = result.scalar_one() assert row.account_id == account.id # ── Group 2: AI & branching ─────────────────────────────────────────────────── @pytest.mark.asyncio async def test_session_branch_account_id_matches_ai_session(test_db: AsyncSession): """session_branches.account_id must match parent ai_session.account_id.""" from app.models.session_branch import SessionBranch account, user = await _make_account_and_user(test_db, "sb1") ai_session = AISession( user_id=user.id, account_id=account.id, problem_summary="test", problem_domain="networking", status="active", ) test_db.add(ai_session) await test_db.flush() branch = SessionBranch( session_id=ai_session.id, account_id=account.id, label="Branch A", branch_order=1, conversation_messages=[], ) test_db.add(branch) await test_db.commit() result = await test_db.execute( select(SessionBranch).where(SessionBranch.id == branch.id) ) row = result.scalar_one() assert row.account_id == account.id @pytest.mark.asyncio async def test_ai_suggestion_account_id_matches_user(test_db: AsyncSession): """ai_suggestions.account_id must match the creating user's account_id.""" from app.models.ai_suggestion import AISuggestion account, user = await _make_account_and_user(test_db, "ais1") tree = await _make_tree(test_db, account, user) suggestion = AISuggestion( tree_id=tree.id, user_id=user.id, account_id=account.id, action_type="add_node", changes_json={}, status="pending", ) test_db.add(suggestion) await test_db.commit() result = await test_db.execute( select(AISuggestion).where(AISuggestion.id == suggestion.id) ) row = result.scalar_one() assert row.account_id == account.id # ── Group 3: Steps & ratings ────────────────────────────────────────────────── @pytest.mark.asyncio async def test_step_rating_account_id_is_rater_account(test_db: AsyncSession): """step_ratings.account_id must be the RATER's account, not the step's account.""" from app.models.step_library import StepLibrary, StepRating account_a, user_a = await _make_account_and_user(test_db, "sr-rater") account_b, user_b = await _make_account_and_user(test_db, "sr-step-owner") # Step owned by account_b step = StepLibrary( title="A step", step_type="action", content={"text": "do something"}, created_by=user_b.id, account_id=account_b.id, visibility="public", ) test_db.add(step) await test_db.flush() # user_a (account_a) rates the step rating = StepRating( step_id=step.id, user_id=user_a.id, account_id=account_a.id, # rater's account, not step owner's was_helpful=True, is_verified_use=False, is_visible=True, ) test_db.add(rating) await test_db.commit() result = await test_db.execute(select(StepRating).where(StepRating.id == rating.id)) row = result.scalar_one() assert row.account_id == account_a.id, ( f"account_id should be rater's account ({account_a.id}), got {row.account_id}" ) @pytest.mark.asyncio async def test_step_usage_log_account_id_is_logger_account(test_db: AsyncSession): """step_usage_log.account_id must be the LOGGER's account (user who used the step).""" from app.models.step_library import StepLibrary, StepUsageLog account, user = await _make_account_and_user(test_db, "sul1") tree = await _make_tree(test_db, account, user) session = await _make_session(test_db, account, user, tree) step = StepLibrary( title="A usage step", step_type="action", content={"text": "do something"}, created_by=user.id, account_id=account.id, visibility="team", ) test_db.add(step) await test_db.flush() log = StepUsageLog( step_id=step.id, user_id=user.id, account_id=account.id, session_id=session.id, ) test_db.add(log) await test_db.commit() result = await test_db.execute(select(StepUsageLog).where(StepUsageLog.id == log.id)) row = result.scalar_one() assert row.account_id == account.id, ( f"account_id should be logger's account ({account.id}), got {row.account_id}" )