Files
resolutionflow/backend/tests/test_phase1_migrations.py
chihlasm 893b8a5008 fix: tree_shares.account_id must come from tree owner, not the actor
- trees.py: change account_id=current_user.account_id →
  account_id=tree.account_id so super-admin cross-account shares land in
  the tree's tenant where RLS will see them.

- migration a05e1a1bea7c: fix backfill to join tree_shares → trees instead
  of tree_shares → users(created_by). Same logic: historical shares belong
  to the tree's tenant.

- test_tree_sharing.py: add test_share_account_id_matches_tree_not_actor
  to assert share.account_id == tree.account_id after POST /share; also
  add missing account_id to all direct TreeShare(...) constructors in
  existing tests.

- test_phase1_migrations.py: remove team_id= from TargetList constructor
  (column dropped in Phase 3).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 07:02:35 +00:00

545 lines
19 KiB
Python

"""Phase 1 migration tests — verify account_id backfill correctness.
These tests create objects via ORM (which uses the updated models),
then verify account_id is populated correctly. They run against a
real PostgreSQL test DB (same as all other integration tests).
"""
import pytest
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text
from app.models.account import Account
from app.models.user import User
from app.models.tree import Tree
from app.models.session import Session
from app.models.attachment import Attachment
from app.models.supporting_data import SessionSupportingData
from app.models.session_resolution_output import SessionResolutionOutput
from app.models.ai_session import AISession
from app.core.security import get_password_hash
# ── Helpers ──────────────────────────────────────────────────────────────────
async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]:
account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8])
db.add(account)
await db.flush()
user = User(
email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com",
name=f"User {suffix}",
password_hash=get_password_hash("TestPass123!"),
is_active=True,
account_id=account.id,
account_role="engineer",
)
db.add(user)
await db.flush()
return account, user
async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree:
tree = Tree(
name=f"Tree {uuid.uuid4().hex[:6]}",
account_id=account.id,
author_id=user.id,
visibility="team",
tree_type="troubleshooting",
tree_structure={"id": "root", "type": "start", "children": []},
is_active=True,
status="published",
)
db.add(tree)
await db.flush()
return tree
async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session:
s = Session(
tree_id=tree.id,
user_id=user.id,
account_id=account.id,
tree_snapshot={},
)
db.add(s)
await db.flush()
return s
# ── Group 1: Core sessions ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_session_account_id_matches_user(test_db: AsyncSession):
"""sessions.account_id must equal the user's account_id."""
account, user = await _make_account_and_user(test_db, "s1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
await test_db.commit()
result = await test_db.execute(select(Session).where(Session.id == session.id))
row = result.scalar_one()
assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}"
@pytest.mark.asyncio
async def test_attachment_account_id_matches_session(test_db: AsyncSession):
"""attachments.account_id must match the parent session's account_id."""
account, user = await _make_account_and_user(test_db, "att1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
attachment = Attachment(
session_id=session.id,
account_id=account.id,
file_name="test.png",
file_type="image/png",
)
test_db.add(attachment)
await test_db.commit()
result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id))
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_session_supporting_data_account_id(test_db: AsyncSession):
"""session_supporting_data.account_id must match parent session's account_id."""
account, user = await _make_account_and_user(test_db, "sd1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
sd = SessionSupportingData(
session_id=session.id,
account_id=account.id,
label="Log snippet",
data_type="text_snippet",
content="error: connection refused",
)
test_db.add(sd)
await test_db.commit()
result = await test_db.execute(
select(SessionSupportingData).where(SessionSupportingData.id == sd.id)
)
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_session_resolution_output_account_id(test_db: AsyncSession):
"""session_resolution_outputs.account_id must match the parent ai_session's account_id.
NOTE: session_resolution_outputs.session_id FK points to ai_sessions (not sessions).
"""
account, user = await _make_account_and_user(test_db, "sro1")
ai_session = AISession(
user_id=user.id,
account_id=account.id,
problem_summary="test resolution output",
problem_domain="networking",
status="active",
)
test_db.add(ai_session)
await test_db.flush()
output = SessionResolutionOutput(
session_id=ai_session.id,
account_id=account.id,
output_type="psa_ticket_notes",
generated_content="Ticket notes content",
generated_by_model="gpt-4",
)
test_db.add(output)
await test_db.commit()
result = await test_db.execute(
select(SessionResolutionOutput).where(SessionResolutionOutput.id == output.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 2: AI & branching ───────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_session_branch_account_id_matches_ai_session(test_db: AsyncSession):
"""session_branches.account_id must match parent ai_session.account_id."""
from app.models.session_branch import SessionBranch
account, user = await _make_account_and_user(test_db, "sb1")
ai_session = AISession(
user_id=user.id,
account_id=account.id,
problem_summary="test",
problem_domain="networking",
status="active",
)
test_db.add(ai_session)
await test_db.flush()
branch = SessionBranch(
session_id=ai_session.id,
account_id=account.id,
label="Branch A",
branch_order=1,
conversation_messages=[],
)
test_db.add(branch)
await test_db.commit()
result = await test_db.execute(
select(SessionBranch).where(SessionBranch.id == branch.id)
)
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_ai_suggestion_account_id_matches_user(test_db: AsyncSession):
"""ai_suggestions.account_id must match the creating user's account_id."""
from app.models.ai_suggestion import AISuggestion
account, user = await _make_account_and_user(test_db, "ais1")
tree = await _make_tree(test_db, account, user)
suggestion = AISuggestion(
tree_id=tree.id,
user_id=user.id,
account_id=account.id,
action_type="add_node",
changes_json={},
status="pending",
)
test_db.add(suggestion)
await test_db.commit()
result = await test_db.execute(
select(AISuggestion).where(AISuggestion.id == suggestion.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 3: Steps & ratings ──────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_step_rating_account_id_is_rater_account(test_db: AsyncSession):
"""step_ratings.account_id must be the RATER's account, not the step's account."""
from app.models.step_library import StepLibrary, StepRating
account_a, user_a = await _make_account_and_user(test_db, "sr-rater")
account_b, user_b = await _make_account_and_user(test_db, "sr-step-owner")
# Step owned by account_b
step = StepLibrary(
title="A step",
step_type="action",
content={"text": "do something"},
created_by=user_b.id,
account_id=account_b.id,
visibility="public",
)
test_db.add(step)
await test_db.flush()
# user_a (account_a) rates the step
rating = StepRating(
step_id=step.id,
user_id=user_a.id,
account_id=account_a.id, # rater's account, not step owner's
was_helpful=True,
is_verified_use=False,
is_visible=True,
)
test_db.add(rating)
await test_db.commit()
result = await test_db.execute(select(StepRating).where(StepRating.id == rating.id))
row = result.scalar_one()
assert row.account_id == account_a.id, (
f"account_id should be rater's account ({account_a.id}), got {row.account_id}"
)
@pytest.mark.asyncio
async def test_step_usage_log_account_id_is_logger_account(test_db: AsyncSession):
"""step_usage_log.account_id must be the LOGGER's account (user who used the step)."""
from app.models.step_library import StepLibrary, StepUsageLog
account, user = await _make_account_and_user(test_db, "sul1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
step = StepLibrary(
title="A usage step",
step_type="action",
content={"text": "do something"},
created_by=user.id,
account_id=account.id,
visibility="team",
)
test_db.add(step)
await test_db.flush()
log = StepUsageLog(
step_id=step.id,
user_id=user.id,
account_id=account.id,
session_id=session.id,
)
test_db.add(log)
await test_db.commit()
result = await test_db.execute(select(StepUsageLog).where(StepUsageLog.id == log.id))
row = result.scalar_one()
assert row.account_id == account.id, (
f"account_id should be logger's account ({account.id}), got {row.account_id}"
)
# ── Group 4: User personalization ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_user_folder_account_id_matches_user(test_db: AsyncSession):
"""user_folders.account_id must match the owning user's account_id."""
from app.models.folder import UserFolder
account, user = await _make_account_and_user(test_db, "uf1")
folder = UserFolder(
user_id=user.id,
account_id=account.id,
name="My Folder",
color="#6366f1",
icon="folder",
display_order=0,
)
test_db.add(folder)
await test_db.commit()
result = await test_db.execute(select(UserFolder).where(UserFolder.id == folder.id))
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_user_pinned_tree_account_id_matches_user(test_db: AsyncSession):
"""user_pinned_trees.account_id must match the pinning user's account_id."""
from app.models.user_pinned_tree import UserPinnedTree
account, user = await _make_account_and_user(test_db, "pt1")
tree = await _make_tree(test_db, account, user)
pin = UserPinnedTree(
user_id=user.id,
tree_id=tree.id,
account_id=account.id,
display_order=0,
)
test_db.add(pin)
await test_db.commit()
result = await test_db.execute(select(UserPinnedTree).where(UserPinnedTree.id == pin.id))
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 5: PSA & notifications ─────────────────────────────────────────────
@pytest.mark.asyncio
async def test_psa_member_mapping_account_id_matches_connection(test_db: AsyncSession):
"""psa_member_mappings.account_id must match psa_connection's account_id."""
from app.models.psa_connection import PsaConnection
from app.models.psa_member_mapping import PsaMemberMapping
account, user = await _make_account_and_user(test_db, "psa1")
conn = PsaConnection(
account_id=account.id,
provider="connectwise",
display_name="Test CW",
site_url="https://cw.example.com",
company_id="TEST",
credentials_encrypted="placeholder",
)
test_db.add(conn)
await test_db.flush()
mapping = PsaMemberMapping(
psa_connection_id=conn.id,
user_id=user.id,
account_id=account.id,
external_member_id="cw-123",
external_member_name="Test User",
matched_by="manual_admin",
)
test_db.add(mapping)
await test_db.commit()
result = await test_db.execute(
select(PsaMemberMapping).where(PsaMemberMapping.id == mapping.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 6: Maintenance ──────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_maintenance_schedule_account_id_matches_tree(test_db: AsyncSession):
"""maintenance_schedules.account_id must match the tree's account_id."""
from app.models.maintenance_schedule import MaintenanceSchedule
account, user = await _make_account_and_user(test_db, "ms1")
tree = Tree(
name="Maintenance Flow",
account_id=account.id,
author_id=user.id,
visibility="team",
tree_type="maintenance",
tree_structure={"id": "root", "type": "start", "children": []},
is_active=True,
status="published",
)
test_db.add(tree)
await test_db.flush()
schedule = MaintenanceSchedule(
tree_id=tree.id,
account_id=account.id,
created_by=user.id,
cron_expression="0 9 * * 1",
timezone="UTC",
is_active=True,
)
test_db.add(schedule)
await test_db.commit()
result = await test_db.execute(
select(MaintenanceSchedule).where(MaintenanceSchedule.id == schedule.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 7: Legacy team_id tables ───────────────────────────────────────────
@pytest.mark.asyncio
async def test_script_builder_session_account_id(test_db: AsyncSession):
"""script_builder_sessions.account_id must match user's account_id."""
from app.models.script_builder_session import ScriptBuilderSession
account, user = await _make_account_and_user(test_db, "sbs1")
sbs = ScriptBuilderSession(
user_id=user.id,
account_id=account.id,
language="powershell",
)
test_db.add(sbs)
await test_db.commit()
result = await test_db.execute(
select(ScriptBuilderSession).where(ScriptBuilderSession.id == sbs.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 8: TargetList ────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_target_list_account_id_from_team_admin(test_db: AsyncSession):
"""target_lists.account_id must be set to the team admin's account_id."""
from app.models.target_list import TargetList
from app.models.team import Team
account, user = await _make_account_and_user(test_db, "tl1")
# Make user a team admin
team = Team(name=f"Team {uuid.uuid4().hex[:6]}")
test_db.add(team)
await test_db.flush()
user.team_id = team.id
user.is_team_admin = True
await test_db.flush()
target_list = TargetList(
account_id=account.id,
created_by=user.id,
name="Server Targets",
targets=[{"label": "SRV-01"}],
)
test_db.add(target_list)
await test_db.commit()
result = await test_db.execute(
select(TargetList).where(TargetList.id == target_list.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 10 (runs first): Global content tables ──────────────────────────────
@pytest.mark.asyncio
async def test_template_trees_table_exists_and_has_no_account_id(test_db: AsyncSession):
"""template_trees must exist and must NOT have an account_id column."""
result = await test_db.execute(text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'template_trees'
"""))
columns = {row[0] for row in result.fetchall()}
assert 'id' in columns, "template_trees.id must exist"
assert 'account_id' not in columns, "template_trees must not have account_id (global content)"
@pytest.mark.asyncio
async def test_platform_steps_table_exists_and_has_no_account_id(test_db: AsyncSession):
"""platform_steps must exist and must NOT have an account_id column."""
result = await test_db.execute(text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'platform_steps'
"""))
columns = {row[0] for row in result.fetchall()}
assert 'id' in columns, "platform_steps.id must exist"
assert 'account_id' not in columns, "platform_steps must not have account_id (global content)"
# ── Group 9: SET NOT NULL on existing nullable columns ────────────────────────
@pytest.mark.asyncio
async def test_tree_account_id_is_not_null(test_db: AsyncSession):
"""trees.account_id must be NOT NULL after Phase 1 — enforced at DB level."""
from sqlalchemy.exc import IntegrityError
with pytest.raises(IntegrityError):
test_db.add(Tree(
name="Bad tree",
# account_id intentionally omitted
author_id=None,
visibility="private",
tree_type="troubleshooting",
tree_structure={},
is_active=True,
status="draft",
))
await test_db.flush()
@pytest.mark.asyncio
async def test_user_account_id_is_not_null(test_db: AsyncSession):
"""users.account_id must be NOT NULL after Phase 1."""
from sqlalchemy.exc import IntegrityError
with pytest.raises(IntegrityError):
test_db.add(User(
email=f"orphan-{uuid.uuid4().hex[:6]}@example.com",
name="Orphan",
password_hash=get_password_hash("x"),
is_active=True,
role="engineer",
account_role="engineer",
# account_id intentionally omitted
))
await test_db.flush()