fix: propagate account_id through all write paths missing NOT NULL coverage

Service layer (production code):
- branch_manager: set account_id on SessionBranch (root + fork) and ForkPoint
  from session.account_id; load session in create_fork for this purpose
- handoff_manager: set account_id on SessionHandoff from session.account_id
- ai_suggestions endpoint: set account_id on AISuggestion from current_user
- steps endpoint (/feedback): set account_id on StepRating from current_user
- ratings endpoint: set account_id on StepRating from current_user

Test infrastructure:
- conftest.py: seed PLATFORM_ACCOUNT_ID (00000000-...-0001) account after
  Base.metadata.create_all so global categories and gallery items have a valid FK
- test_rls_isolation: add _ensure_rls_schema fixture that runs
  'alembic upgrade head' before module tests — previous function-scoped
  test_db fixtures drop the schema, leaving the RLS tests with no tables
- test_branding: create Account before User in helper functions
- test_admin_gallery: set account_id=PLATFORM_ACCOUNT_ID on Tree/ScriptTemplate
- test_public_templates: set account_id=PLATFORM_ACCOUNT_ID on Tree,
  ScriptTemplate, TreeCategory
- test_resolution_outputs: set account_id=session.account_id on
  SessionResolutionOutput
- test_analytics_phase5: set account_id on PsaPostLog
- test_draft_trees: replace account_id=None with PLATFORM_ACCOUNT_ID in
  migration default test (NOT NULL now enforced)
- test_maintenance_schedules: set account_id on other_tree
- test_save_session_as_tree: set account_id on all 5 Session() constructors

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-04-11 04:24:36 +00:00
parent b9fcdd5d73
commit 758cd61621
15 changed files with 85 additions and 3 deletions

View File

@@ -43,6 +43,7 @@ async def create_suggestion(
suggestion = AISuggestion(
tree_id=data.tree_id,
user_id=current_user.id,
account_id=current_user.account_id,
session_id=data.session_id,
action_type=data.action_type,
target_node_id=data.target_node_id,

View File

@@ -91,6 +91,7 @@ async def submit_step_feedback(
new_rating = StepRating(
step_id=step_id,
user_id=current_user.id,
account_id=current_user.account_id,
session_id=session_uuid,
was_helpful=data.was_helpful,
# rating is nullable now — thumbs-only mode

View File

@@ -460,6 +460,7 @@ async def rate_step(
rating = StepRating(
step_id=step_id,
user_id=current_user.id,
account_id=current_user.account_id,
rating=rating_data.rating,
was_helpful=rating_data.was_helpful,
review_text=rating_data.review_text,

View File

@@ -34,6 +34,7 @@ class BranchManager:
root = SessionBranch(
id=uuid.uuid4(),
session_id=session_id,
account_id=session.account_id,
parent_branch_id=None,
branch_order=1,
label="Root",
@@ -68,9 +69,17 @@ class BranchManager:
"status": "untried",
})
# Load session to get account_id for FK constraints
session_result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
session = session_result.scalar_one_or_none()
account_id = session.account_id if session else None
fork_point = ForkPoint(
id=uuid.uuid4(),
session_id=session_id,
account_id=account_id,
parent_branch_id=parent_branch_id,
trigger_step_id=trigger_step_id,
fork_reason=fork_reason,
@@ -90,6 +99,7 @@ class BranchManager:
branch = SessionBranch(
id=branch_ids[i],
session_id=session_id,
account_id=account_id,
parent_branch_id=parent_branch_id,
fork_point_step_id=trigger_step_id,
branch_order=i + 1,

View File

@@ -56,6 +56,7 @@ class HandoffManager:
handoff = SessionHandoff(
session_id=session_id,
account_id=session.account_id,
handed_off_by=user_id,
intent=intent,
source_branch_id=session.active_branch_id,

View File

@@ -75,6 +75,19 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]:
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
"""))
# Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by
# global categories, gallery items, and other platform-owned content.
await conn.execute(sa.text("""
INSERT INTO accounts (id, name, display_code, created_at, updated_at)
VALUES (
'00000000-0000-0000-0000-000000000001',
'ResolutionFlow System',
'RF-SYS-1',
NOW(), NOW()
)
ON CONFLICT (id) DO NOTHING
"""))
# Create async session maker
async_session_maker = async_sessionmaker(
engine,

View File

@@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.models.tree import Tree
from app.models.script_template import ScriptTemplate, ScriptCategory
_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
# ---------------------------------------------------------------------------
# Helpers
@@ -22,6 +23,7 @@ async def _create_tree(db: AsyncSession, admin_user_id: str) -> Tree:
name="Gallery Test Flow",
tree_type="troubleshooting",
visibility="public",
account_id=_PLATFORM_ACCOUNT_ID,
is_gallery_featured=False,
gallery_sort_order=0,
tree_structure={
@@ -53,6 +55,7 @@ async def _create_script(db: AsyncSession, admin_user_id: str) -> ScriptTemplate
script = ScriptTemplate(
id=uuid.uuid4(),
category_id=category.id,
account_id=_PLATFORM_ACCOUNT_ID,
name="Gallery Test Script",
slug=f"gallery-test-script-{uuid.uuid4().hex[:6]}",
script_body="Write-Host 'Test'",

View File

@@ -594,6 +594,7 @@ class TestPsaMetrics:
post_log = PsaPostLog(
id=uuid.uuid4(),
ai_session_id=push_session_id,
account_id=account_id,
ticket_id="TICKET-123",
note_type="internal",
content_posted="Session summary",

View File

@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.security import get_password_hash
from app.models.account import Account
from app.models.team import Team
from app.models.user import User
@@ -23,6 +24,8 @@ async def _create_team_with_admin(
team_name: str = "Branding Test Team",
) -> tuple[dict, str, Team]:
"""Create a team + team admin user. Returns (auth_headers, team_id_str, team)."""
account = Account(name=team_name, display_code=uuid.uuid4().hex[:8].upper())
test_db.add(account)
team = Team(name=team_name)
test_db.add(team)
await test_db.flush()
@@ -36,6 +39,8 @@ async def _create_team_with_admin(
team_id=team.id,
is_team_admin=True,
role="engineer",
account_id=account.id,
account_role="engineer",
)
test_db.add(user)
await test_db.commit()
@@ -58,6 +63,15 @@ async def _create_team_member(
is_team_admin: bool = False,
) -> dict:
"""Create a regular team member. Returns auth_headers."""
# Look up the account associated with this team via an existing member
from sqlalchemy import select as _select
from app.models.user import User as _User
result = await test_db.execute(
_select(_User).where(_User.team_id == team.id).limit(1)
)
team_member = result.scalar_one_or_none()
member_account_id = team_member.account_id if team_member else None
email = f"member_{uuid.uuid4().hex[:8]}@test.com"
user = User(
email=email,
@@ -67,6 +81,8 @@ async def _create_team_member(
team_id=team.id,
is_team_admin=is_team_admin,
role="engineer",
account_id=member_account_id,
account_role="engineer",
)
test_db.add(user)
await test_db.commit()

View File

@@ -334,12 +334,13 @@ class TestDraftTreesAPI:
"""Test that migration defaults existing trees to published status."""
# Create a tree without specifying status (relies on DB default)
from uuid import UUID, uuid4
_platform_id = UUID("00000000-0000-0000-0000-000000000001")
tree = Tree(
name="Legacy Tree",
description="Created before status field",
tree_structure={"id": "root", "type": "solution", "title": "Fix"},
author_id=None,
account_id=None
account_id=_platform_id,
)
test_db.add(tree)
await test_db.commit()

View File

@@ -127,10 +127,12 @@ async def test_cannot_schedule_other_teams_tree(client: AsyncClient, auth_header
test_db.add(other_team)
await test_db.flush()
from uuid import UUID as _UUID
other_tree = Tree(
name="Other Team Tree",
tree_type="maintenance",
team_id=other_team.id,
account_id=_UUID("00000000-0000-0000-0000-000000000001"),
tree_structure={
"steps": [
{"id": "s1", "type": "procedure_step", "title": "Step",

View File

@@ -11,6 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.models.script_template import ScriptCategory, ScriptTemplate
from app.models.tree import Tree
_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
# ---------------------------------------------------------------------------
# Helpers
@@ -41,6 +43,7 @@ async def _create_featured_tree(db: AsyncSession, name: str = "Featured Flow", f
description="A featured flow for the gallery",
tree_type="troubleshooting",
tree_structure=_make_tree_structure(4),
account_id=_PLATFORM_ACCOUNT_ID,
is_gallery_featured=featured,
is_active=True,
usage_count=42,
@@ -74,6 +77,7 @@ async def _create_featured_script(
) -> ScriptTemplate:
script = ScriptTemplate(
category_id=category.id,
account_id=_PLATFORM_ACCOUNT_ID,
name=name,
slug=name.lower().replace(" ", "-"),
description="A gallery-featured script",
@@ -312,7 +316,7 @@ class TestCategoriesEndpoint:
from app.models.category import TreeCategory
# Create a category and a featured tree in that category
cat = TreeCategory(name="Networking", slug="networking", is_active=True)
cat = TreeCategory(name="Networking", slug="networking", is_active=True, account_id=_PLATFORM_ACCOUNT_ID)
test_db.add(cat)
await test_db.commit()
await test_db.refresh(cat)
@@ -321,6 +325,7 @@ class TestCategoriesEndpoint:
name="Router Diagnostics",
tree_type="troubleshooting",
tree_structure=_make_tree_structure(2),
account_id=_PLATFORM_ACCOUNT_ID,
is_gallery_featured=True,
is_active=True,
usage_count=5,

View File

@@ -62,6 +62,7 @@ async def test_edit_output(client: AsyncClient, test_user, auth_headers, test_db
output = SessionResolutionOutput(
session_id=session.id,
account_id=session.account_id,
output_type="psa_ticket_notes",
generated_content="Original notes",
status="draft",

View File

@@ -16,7 +16,10 @@ Run with:
The test DB is patherly_test (matches conftest.py default).
"""
import os
import subprocess
import sys
import uuid
from pathlib import Path
import asyncpg
import pytest
@@ -37,7 +40,25 @@ ACCOUNT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
async def admin_conn():
def _ensure_rls_schema():
"""Re-apply Alembic migrations before the module runs.
Function-scoped test_db fixtures in other modules drop and recreate the
public schema using Base.metadata.create_all, which does not enable RLS
or create DB roles. This fixture re-runs 'alembic upgrade head' so that
the full migration-managed schema (including RLS policies) is in place.
"""
backend_dir = Path(__file__).parent.parent
subprocess.run(
[sys.executable, "-m", "alembic", "upgrade", "head"],
cwd=backend_dir,
check=True,
capture_output=True,
)
@pytest.fixture(scope="module")
async def admin_conn(_ensure_rls_schema):
"""Superuser asyncpg connection for fixture setup and teardown."""
conn = await asyncpg.connect(_ADMIN_DSN)
yield conn

View File

@@ -155,6 +155,7 @@ class TestSaveSessionAsTreeAPI:
session = Session(
tree_id=tree.id,
user_id=UUID(test_user["user_data"]["id"]),
account_id=UUID(test_user["user_data"]["account_id"]),
tree_snapshot=tree.tree_structure,
path_taken=["root"],
decisions=[{"node_id": "root", "timestamp": datetime.now(timezone.utc).isoformat()}],
@@ -199,6 +200,7 @@ class TestSaveSessionAsTreeAPI:
session = Session(
tree_id=tree.id,
user_id=UUID(test_user["user_data"]["id"]),
account_id=UUID(test_user["user_data"]["account_id"]),
tree_snapshot=tree.tree_structure,
path_taken=["root"],
decisions=[],
@@ -239,6 +241,7 @@ class TestSaveSessionAsTreeAPI:
session = Session(
tree_id=tree.id,
user_id=UUID(test_user["user_data"]["id"]),
account_id=UUID(test_user["user_data"]["account_id"]),
tree_snapshot=tree.tree_structure,
path_taken=["root"],
decisions=[],
@@ -279,6 +282,7 @@ class TestSaveSessionAsTreeAPI:
session = Session(
tree_id=tree.id,
user_id=UUID(test_user["user_data"]["id"]),
account_id=UUID(test_user["user_data"]["account_id"]),
tree_snapshot=tree.tree_structure,
path_taken=["root"],
decisions=[],
@@ -352,6 +356,7 @@ class TestSaveSessionAsTreeAPI:
session = Session(
tree_id=tree.id,
user_id=other_user.id,
account_id=UUID(test_user["user_data"]["account_id"]),
tree_snapshot=tree.tree_structure,
path_taken=["root"],
decisions=[],