Files
resolutionflow/backend/tests/test_phase1_migrations.py
chihlasm 2779a41b94 feat: Phase 1 Group 2 — add account_id to AI branching tables
Tables: session_branches, session_handoffs, fork_points,
        ai_session_steps, ai_suggestions
Backfill: session_id → ai_sessions.account_id (all except
ai_suggestions which uses user_id → users.account_id)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:12:18 +00:00

224 lines
7.4 KiB
Python

"""Phase 1 migration tests — verify account_id backfill correctness.
These tests create objects via ORM (which uses the updated models),
then verify account_id is populated correctly. They run against a
real PostgreSQL test DB (same as all other integration tests).
"""
import pytest
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.account import Account
from app.models.user import User
from app.models.tree import Tree
from app.models.session import Session
from app.models.attachment import Attachment
from app.models.supporting_data import SessionSupportingData
from app.models.session_resolution_output import SessionResolutionOutput
from app.models.ai_session import AISession
from app.core.security import get_password_hash
# ── Helpers ──────────────────────────────────────────────────────────────────
async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]:
account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8])
db.add(account)
await db.flush()
user = User(
email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com",
name=f"User {suffix}",
password_hash=get_password_hash("TestPass123!"),
is_active=True,
account_id=account.id,
account_role="engineer",
)
db.add(user)
await db.flush()
return account, user
async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree:
tree = Tree(
name=f"Tree {uuid.uuid4().hex[:6]}",
account_id=account.id,
author_id=user.id,
visibility="team",
tree_type="troubleshooting",
tree_structure={"id": "root", "type": "start", "children": []},
is_active=True,
status="published",
)
db.add(tree)
await db.flush()
return tree
async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session:
s = Session(
tree_id=tree.id,
user_id=user.id,
account_id=account.id,
tree_snapshot={},
)
db.add(s)
await db.flush()
return s
# ── Group 1: Core sessions ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_session_account_id_matches_user(test_db: AsyncSession):
"""sessions.account_id must equal the user's account_id."""
account, user = await _make_account_and_user(test_db, "s1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
await test_db.commit()
result = await test_db.execute(select(Session).where(Session.id == session.id))
row = result.scalar_one()
assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}"
@pytest.mark.asyncio
async def test_attachment_account_id_matches_session(test_db: AsyncSession):
"""attachments.account_id must match the parent session's account_id."""
account, user = await _make_account_and_user(test_db, "att1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
attachment = Attachment(
session_id=session.id,
account_id=account.id,
file_name="test.png",
file_type="image/png",
)
test_db.add(attachment)
await test_db.commit()
result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id))
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_session_supporting_data_account_id(test_db: AsyncSession):
"""session_supporting_data.account_id must match parent session's account_id."""
account, user = await _make_account_and_user(test_db, "sd1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
sd = SessionSupportingData(
session_id=session.id,
account_id=account.id,
label="Log snippet",
data_type="text_snippet",
content="error: connection refused",
)
test_db.add(sd)
await test_db.commit()
result = await test_db.execute(
select(SessionSupportingData).where(SessionSupportingData.id == sd.id)
)
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_session_resolution_output_account_id(test_db: AsyncSession):
"""session_resolution_outputs.account_id must match the parent ai_session's account_id.
NOTE: session_resolution_outputs.session_id FK points to ai_sessions (not sessions).
"""
account, user = await _make_account_and_user(test_db, "sro1")
ai_session = AISession(
user_id=user.id,
account_id=account.id,
problem_summary="test resolution output",
problem_domain="networking",
status="active",
)
test_db.add(ai_session)
await test_db.flush()
output = SessionResolutionOutput(
session_id=ai_session.id,
account_id=account.id,
output_type="psa_ticket_notes",
generated_content="Ticket notes content",
generated_by_model="gpt-4",
)
test_db.add(output)
await test_db.commit()
result = await test_db.execute(
select(SessionResolutionOutput).where(SessionResolutionOutput.id == output.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 2: AI & branching ───────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_session_branch_account_id_matches_ai_session(test_db: AsyncSession):
"""session_branches.account_id must match parent ai_session.account_id."""
from app.models.session_branch import SessionBranch
account, user = await _make_account_and_user(test_db, "sb1")
ai_session = AISession(
user_id=user.id,
account_id=account.id,
problem_summary="test",
problem_domain="networking",
status="active",
)
test_db.add(ai_session)
await test_db.flush()
branch = SessionBranch(
session_id=ai_session.id,
account_id=account.id,
label="Branch A",
branch_order=1,
conversation_messages=[],
)
test_db.add(branch)
await test_db.commit()
result = await test_db.execute(
select(SessionBranch).where(SessionBranch.id == branch.id)
)
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_ai_suggestion_account_id_matches_user(test_db: AsyncSession):
"""ai_suggestions.account_id must match the creating user's account_id."""
from app.models.ai_suggestion import AISuggestion
account, user = await _make_account_and_user(test_db, "ais1")
tree = await _make_tree(test_db, account, user)
suggestion = AISuggestion(
tree_id=tree.id,
user_id=user.id,
account_id=account.id,
action_type="add_node",
changes_json={},
status="pending",
)
test_db.add(suggestion)
await test_db.commit()
result = await test_db.execute(
select(AISuggestion).where(AISuggestion.id == suggestion.id)
)
row = result.scalar_one()
assert row.account_id == account.id