- trees.py: change account_id=current_user.account_id → account_id=tree.account_id so super-admin cross-account shares land in the tree's tenant where RLS will see them. - migration a05e1a1bea7c: fix backfill to join tree_shares → trees instead of tree_shares → users(created_by). Same logic: historical shares belong to the tree's tenant. - test_tree_sharing.py: add test_share_account_id_matches_tree_not_actor to assert share.account_id == tree.account_id after POST /share; also add missing account_id to all direct TreeShare(...) constructors in existing tests. - test_phase1_migrations.py: remove team_id= from TargetList constructor (column dropped in Phase 3). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
545 lines
19 KiB
Python
545 lines
19 KiB
Python
"""Phase 1 migration tests — verify account_id backfill correctness.
|
|
|
|
These tests create objects via ORM (which uses the updated models),
|
|
then verify account_id is populated correctly. They run against a
|
|
real PostgreSQL test DB (same as all other integration tests).
|
|
"""
|
|
import pytest
|
|
import uuid
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, text
|
|
|
|
from app.models.account import Account
|
|
from app.models.user import User
|
|
from app.models.tree import Tree
|
|
from app.models.session import Session
|
|
from app.models.attachment import Attachment
|
|
from app.models.supporting_data import SessionSupportingData
|
|
from app.models.session_resolution_output import SessionResolutionOutput
|
|
from app.models.ai_session import AISession
|
|
from app.core.security import get_password_hash
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]:
|
|
account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8])
|
|
db.add(account)
|
|
await db.flush()
|
|
user = User(
|
|
email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com",
|
|
name=f"User {suffix}",
|
|
password_hash=get_password_hash("TestPass123!"),
|
|
is_active=True,
|
|
account_id=account.id,
|
|
account_role="engineer",
|
|
)
|
|
db.add(user)
|
|
await db.flush()
|
|
return account, user
|
|
|
|
|
|
async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree:
|
|
tree = Tree(
|
|
name=f"Tree {uuid.uuid4().hex[:6]}",
|
|
account_id=account.id,
|
|
author_id=user.id,
|
|
visibility="team",
|
|
tree_type="troubleshooting",
|
|
tree_structure={"id": "root", "type": "start", "children": []},
|
|
is_active=True,
|
|
status="published",
|
|
)
|
|
db.add(tree)
|
|
await db.flush()
|
|
return tree
|
|
|
|
|
|
async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session:
|
|
s = Session(
|
|
tree_id=tree.id,
|
|
user_id=user.id,
|
|
account_id=account.id,
|
|
tree_snapshot={},
|
|
)
|
|
db.add(s)
|
|
await db.flush()
|
|
return s
|
|
|
|
|
|
# ── Group 1: Core sessions ────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_account_id_matches_user(test_db: AsyncSession):
|
|
"""sessions.account_id must equal the user's account_id."""
|
|
account, user = await _make_account_and_user(test_db, "s1")
|
|
tree = await _make_tree(test_db, account, user)
|
|
session = await _make_session(test_db, account, user, tree)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(select(Session).where(Session.id == session.id))
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_attachment_account_id_matches_session(test_db: AsyncSession):
|
|
"""attachments.account_id must match the parent session's account_id."""
|
|
account, user = await _make_account_and_user(test_db, "att1")
|
|
tree = await _make_tree(test_db, account, user)
|
|
session = await _make_session(test_db, account, user, tree)
|
|
|
|
attachment = Attachment(
|
|
session_id=session.id,
|
|
account_id=account.id,
|
|
file_name="test.png",
|
|
file_type="image/png",
|
|
)
|
|
test_db.add(attachment)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id))
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_supporting_data_account_id(test_db: AsyncSession):
|
|
"""session_supporting_data.account_id must match parent session's account_id."""
|
|
account, user = await _make_account_and_user(test_db, "sd1")
|
|
tree = await _make_tree(test_db, account, user)
|
|
session = await _make_session(test_db, account, user, tree)
|
|
|
|
sd = SessionSupportingData(
|
|
session_id=session.id,
|
|
account_id=account.id,
|
|
label="Log snippet",
|
|
data_type="text_snippet",
|
|
content="error: connection refused",
|
|
)
|
|
test_db.add(sd)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(
|
|
select(SessionSupportingData).where(SessionSupportingData.id == sd.id)
|
|
)
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_resolution_output_account_id(test_db: AsyncSession):
|
|
"""session_resolution_outputs.account_id must match the parent ai_session's account_id.
|
|
|
|
NOTE: session_resolution_outputs.session_id FK points to ai_sessions (not sessions).
|
|
"""
|
|
account, user = await _make_account_and_user(test_db, "sro1")
|
|
|
|
ai_session = AISession(
|
|
user_id=user.id,
|
|
account_id=account.id,
|
|
problem_summary="test resolution output",
|
|
problem_domain="networking",
|
|
status="active",
|
|
)
|
|
test_db.add(ai_session)
|
|
await test_db.flush()
|
|
|
|
output = SessionResolutionOutput(
|
|
session_id=ai_session.id,
|
|
account_id=account.id,
|
|
output_type="psa_ticket_notes",
|
|
generated_content="Ticket notes content",
|
|
generated_by_model="gpt-4",
|
|
)
|
|
test_db.add(output)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(
|
|
select(SessionResolutionOutput).where(SessionResolutionOutput.id == output.id)
|
|
)
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
# ── Group 2: AI & branching ───────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_branch_account_id_matches_ai_session(test_db: AsyncSession):
|
|
"""session_branches.account_id must match parent ai_session.account_id."""
|
|
from app.models.session_branch import SessionBranch
|
|
|
|
account, user = await _make_account_and_user(test_db, "sb1")
|
|
ai_session = AISession(
|
|
user_id=user.id,
|
|
account_id=account.id,
|
|
problem_summary="test",
|
|
problem_domain="networking",
|
|
status="active",
|
|
)
|
|
test_db.add(ai_session)
|
|
await test_db.flush()
|
|
|
|
branch = SessionBranch(
|
|
session_id=ai_session.id,
|
|
account_id=account.id,
|
|
label="Branch A",
|
|
branch_order=1,
|
|
conversation_messages=[],
|
|
)
|
|
test_db.add(branch)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(
|
|
select(SessionBranch).where(SessionBranch.id == branch.id)
|
|
)
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ai_suggestion_account_id_matches_user(test_db: AsyncSession):
|
|
"""ai_suggestions.account_id must match the creating user's account_id."""
|
|
from app.models.ai_suggestion import AISuggestion
|
|
|
|
account, user = await _make_account_and_user(test_db, "ais1")
|
|
tree = await _make_tree(test_db, account, user)
|
|
|
|
suggestion = AISuggestion(
|
|
tree_id=tree.id,
|
|
user_id=user.id,
|
|
account_id=account.id,
|
|
action_type="add_node",
|
|
changes_json={},
|
|
status="pending",
|
|
)
|
|
test_db.add(suggestion)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(
|
|
select(AISuggestion).where(AISuggestion.id == suggestion.id)
|
|
)
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
# ── Group 3: Steps & ratings ──────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_step_rating_account_id_is_rater_account(test_db: AsyncSession):
|
|
"""step_ratings.account_id must be the RATER's account, not the step's account."""
|
|
from app.models.step_library import StepLibrary, StepRating
|
|
|
|
account_a, user_a = await _make_account_and_user(test_db, "sr-rater")
|
|
account_b, user_b = await _make_account_and_user(test_db, "sr-step-owner")
|
|
|
|
# Step owned by account_b
|
|
step = StepLibrary(
|
|
title="A step",
|
|
step_type="action",
|
|
content={"text": "do something"},
|
|
created_by=user_b.id,
|
|
account_id=account_b.id,
|
|
visibility="public",
|
|
)
|
|
test_db.add(step)
|
|
await test_db.flush()
|
|
|
|
# user_a (account_a) rates the step
|
|
rating = StepRating(
|
|
step_id=step.id,
|
|
user_id=user_a.id,
|
|
account_id=account_a.id, # rater's account, not step owner's
|
|
was_helpful=True,
|
|
is_verified_use=False,
|
|
is_visible=True,
|
|
)
|
|
test_db.add(rating)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(select(StepRating).where(StepRating.id == rating.id))
|
|
row = result.scalar_one()
|
|
assert row.account_id == account_a.id, (
|
|
f"account_id should be rater's account ({account_a.id}), got {row.account_id}"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_step_usage_log_account_id_is_logger_account(test_db: AsyncSession):
|
|
"""step_usage_log.account_id must be the LOGGER's account (user who used the step)."""
|
|
from app.models.step_library import StepLibrary, StepUsageLog
|
|
|
|
account, user = await _make_account_and_user(test_db, "sul1")
|
|
tree = await _make_tree(test_db, account, user)
|
|
session = await _make_session(test_db, account, user, tree)
|
|
|
|
step = StepLibrary(
|
|
title="A usage step",
|
|
step_type="action",
|
|
content={"text": "do something"},
|
|
created_by=user.id,
|
|
account_id=account.id,
|
|
visibility="team",
|
|
)
|
|
test_db.add(step)
|
|
await test_db.flush()
|
|
|
|
log = StepUsageLog(
|
|
step_id=step.id,
|
|
user_id=user.id,
|
|
account_id=account.id,
|
|
session_id=session.id,
|
|
)
|
|
test_db.add(log)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(select(StepUsageLog).where(StepUsageLog.id == log.id))
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id, (
|
|
f"account_id should be logger's account ({account.id}), got {row.account_id}"
|
|
)
|
|
|
|
|
|
# ── Group 4: User personalization ────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_folder_account_id_matches_user(test_db: AsyncSession):
|
|
"""user_folders.account_id must match the owning user's account_id."""
|
|
from app.models.folder import UserFolder
|
|
|
|
account, user = await _make_account_and_user(test_db, "uf1")
|
|
folder = UserFolder(
|
|
user_id=user.id,
|
|
account_id=account.id,
|
|
name="My Folder",
|
|
color="#6366f1",
|
|
icon="folder",
|
|
display_order=0,
|
|
)
|
|
test_db.add(folder)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(select(UserFolder).where(UserFolder.id == folder.id))
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_pinned_tree_account_id_matches_user(test_db: AsyncSession):
|
|
"""user_pinned_trees.account_id must match the pinning user's account_id."""
|
|
from app.models.user_pinned_tree import UserPinnedTree
|
|
|
|
account, user = await _make_account_and_user(test_db, "pt1")
|
|
tree = await _make_tree(test_db, account, user)
|
|
pin = UserPinnedTree(
|
|
user_id=user.id,
|
|
tree_id=tree.id,
|
|
account_id=account.id,
|
|
display_order=0,
|
|
)
|
|
test_db.add(pin)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(select(UserPinnedTree).where(UserPinnedTree.id == pin.id))
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
# ── Group 5: PSA & notifications ─────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_psa_member_mapping_account_id_matches_connection(test_db: AsyncSession):
|
|
"""psa_member_mappings.account_id must match psa_connection's account_id."""
|
|
from app.models.psa_connection import PsaConnection
|
|
from app.models.psa_member_mapping import PsaMemberMapping
|
|
|
|
account, user = await _make_account_and_user(test_db, "psa1")
|
|
conn = PsaConnection(
|
|
account_id=account.id,
|
|
provider="connectwise",
|
|
display_name="Test CW",
|
|
site_url="https://cw.example.com",
|
|
company_id="TEST",
|
|
credentials_encrypted="placeholder",
|
|
)
|
|
test_db.add(conn)
|
|
await test_db.flush()
|
|
|
|
mapping = PsaMemberMapping(
|
|
psa_connection_id=conn.id,
|
|
user_id=user.id,
|
|
account_id=account.id,
|
|
external_member_id="cw-123",
|
|
external_member_name="Test User",
|
|
matched_by="manual_admin",
|
|
)
|
|
test_db.add(mapping)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(
|
|
select(PsaMemberMapping).where(PsaMemberMapping.id == mapping.id)
|
|
)
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
# ── Group 6: Maintenance ──────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_maintenance_schedule_account_id_matches_tree(test_db: AsyncSession):
|
|
"""maintenance_schedules.account_id must match the tree's account_id."""
|
|
from app.models.maintenance_schedule import MaintenanceSchedule
|
|
|
|
account, user = await _make_account_and_user(test_db, "ms1")
|
|
tree = Tree(
|
|
name="Maintenance Flow",
|
|
account_id=account.id,
|
|
author_id=user.id,
|
|
visibility="team",
|
|
tree_type="maintenance",
|
|
tree_structure={"id": "root", "type": "start", "children": []},
|
|
is_active=True,
|
|
status="published",
|
|
)
|
|
test_db.add(tree)
|
|
await test_db.flush()
|
|
|
|
schedule = MaintenanceSchedule(
|
|
tree_id=tree.id,
|
|
account_id=account.id,
|
|
created_by=user.id,
|
|
cron_expression="0 9 * * 1",
|
|
timezone="UTC",
|
|
is_active=True,
|
|
)
|
|
test_db.add(schedule)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(
|
|
select(MaintenanceSchedule).where(MaintenanceSchedule.id == schedule.id)
|
|
)
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
# ── Group 7: Legacy team_id tables ───────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_script_builder_session_account_id(test_db: AsyncSession):
|
|
"""script_builder_sessions.account_id must match user's account_id."""
|
|
from app.models.script_builder_session import ScriptBuilderSession
|
|
|
|
account, user = await _make_account_and_user(test_db, "sbs1")
|
|
sbs = ScriptBuilderSession(
|
|
user_id=user.id,
|
|
account_id=account.id,
|
|
language="powershell",
|
|
)
|
|
test_db.add(sbs)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(
|
|
select(ScriptBuilderSession).where(ScriptBuilderSession.id == sbs.id)
|
|
)
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
# ── Group 8: TargetList ────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_target_list_account_id_from_team_admin(test_db: AsyncSession):
|
|
"""target_lists.account_id must be set to the team admin's account_id."""
|
|
from app.models.target_list import TargetList
|
|
from app.models.team import Team
|
|
|
|
account, user = await _make_account_and_user(test_db, "tl1")
|
|
# Make user a team admin
|
|
team = Team(name=f"Team {uuid.uuid4().hex[:6]}")
|
|
test_db.add(team)
|
|
await test_db.flush()
|
|
|
|
user.team_id = team.id
|
|
user.is_team_admin = True
|
|
await test_db.flush()
|
|
|
|
target_list = TargetList(
|
|
account_id=account.id,
|
|
created_by=user.id,
|
|
name="Server Targets",
|
|
targets=[{"label": "SRV-01"}],
|
|
)
|
|
test_db.add(target_list)
|
|
await test_db.commit()
|
|
|
|
result = await test_db.execute(
|
|
select(TargetList).where(TargetList.id == target_list.id)
|
|
)
|
|
row = result.scalar_one()
|
|
assert row.account_id == account.id
|
|
|
|
|
|
# ── Group 10 (runs first): Global content tables ──────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_template_trees_table_exists_and_has_no_account_id(test_db: AsyncSession):
|
|
"""template_trees must exist and must NOT have an account_id column."""
|
|
result = await test_db.execute(text("""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_name = 'template_trees'
|
|
"""))
|
|
columns = {row[0] for row in result.fetchall()}
|
|
assert 'id' in columns, "template_trees.id must exist"
|
|
assert 'account_id' not in columns, "template_trees must not have account_id (global content)"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_platform_steps_table_exists_and_has_no_account_id(test_db: AsyncSession):
|
|
"""platform_steps must exist and must NOT have an account_id column."""
|
|
result = await test_db.execute(text("""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_name = 'platform_steps'
|
|
"""))
|
|
columns = {row[0] for row in result.fetchall()}
|
|
assert 'id' in columns, "platform_steps.id must exist"
|
|
assert 'account_id' not in columns, "platform_steps must not have account_id (global content)"
|
|
|
|
|
|
# ── Group 9: SET NOT NULL on existing nullable columns ────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tree_account_id_is_not_null(test_db: AsyncSession):
|
|
"""trees.account_id must be NOT NULL after Phase 1 — enforced at DB level."""
|
|
from sqlalchemy.exc import IntegrityError
|
|
with pytest.raises(IntegrityError):
|
|
test_db.add(Tree(
|
|
name="Bad tree",
|
|
# account_id intentionally omitted
|
|
author_id=None,
|
|
visibility="private",
|
|
tree_type="troubleshooting",
|
|
tree_structure={},
|
|
is_active=True,
|
|
status="draft",
|
|
))
|
|
await test_db.flush()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_account_id_is_not_null(test_db: AsyncSession):
|
|
"""users.account_id must be NOT NULL after Phase 1."""
|
|
from sqlalchemy.exc import IntegrityError
|
|
with pytest.raises(IntegrityError):
|
|
test_db.add(User(
|
|
email=f"orphan-{uuid.uuid4().hex[:6]}@example.com",
|
|
name="Orphan",
|
|
password_hash=get_password_hash("x"),
|
|
is_active=True,
|
|
role="engineer",
|
|
account_role="engineer",
|
|
# account_id intentionally omitted
|
|
))
|
|
await test_db.flush()
|