Files
resolutionflow/backend/app/services/branch_manager.py
chihlasm 758cd61621 fix: propagate account_id through all write paths missing NOT NULL coverage
Service layer (production code):
- branch_manager: set account_id on SessionBranch (root + fork) and ForkPoint
  from session.account_id; load session in create_fork for this purpose
- handoff_manager: set account_id on SessionHandoff from session.account_id
- ai_suggestions endpoint: set account_id on AISuggestion from current_user
- steps endpoint (/feedback): set account_id on StepRating from current_user
- ratings endpoint: set account_id on StepRating from current_user

Test infrastructure:
- conftest.py: seed PLATFORM_ACCOUNT_ID (00000000-...-0001) account after
  Base.metadata.create_all so global categories and gallery items have a valid FK
- test_rls_isolation: add _ensure_rls_schema fixture that runs
  'alembic upgrade head' before module tests — previous function-scoped
  test_db fixtures drop the schema, leaving the RLS tests with no tables
- test_branding: create Account before User in helper functions
- test_admin_gallery: set account_id=PLATFORM_ACCOUNT_ID on Tree/ScriptTemplate
- test_public_templates: set account_id=PLATFORM_ACCOUNT_ID on Tree,
  ScriptTemplate, TreeCategory
- test_resolution_outputs: set account_id=session.account_id on
  SessionResolutionOutput
- test_analytics_phase5: set account_id on PsaPostLog
- test_draft_trees: replace account_id=None with PLATFORM_ACCOUNT_ID in
  migration default test (NOT NULL now enforced)
- test_maintenance_schedules: set account_id on other_tree
- test_save_session_as_tree: set account_id on all 5 Session() constructors

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 04:24:36 +00:00

249 lines
8.4 KiB
Python

"""Branch lifecycle management for conversational branching."""
import uuid
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ai_session import AISession
from app.models.ai_session_step import AISessionStep
from app.models.session_branch import SessionBranch
from app.models.fork_point import ForkPoint
logger = logging.getLogger(__name__)
class BranchManager:
"""Branch lifecycle management."""
def __init__(self, db: AsyncSession):
self.db = db
async def create_root_branch(self, session_id: UUID) -> SessionBranch:
"""Create the root branch, copy conversation_messages, set is_branching=True."""
result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
session = result.scalar_one_or_none()
if not session:
raise ValueError(f"Session {session_id} not found")
root = SessionBranch(
id=uuid.uuid4(),
session_id=session_id,
account_id=session.account_id,
parent_branch_id=None,
branch_order=1,
label="Root",
status="active",
conversation_messages=list(session.conversation_messages or []),
)
self.db.add(root)
session.is_branching = True
session.active_branch_id = root.id
await self.db.flush()
return root
async def create_fork(
self,
session_id: UUID,
parent_branch_id: UUID,
trigger_step_id: UUID | None,
fork_reason: str,
options: list[dict[str, str]],
) -> tuple[ForkPoint, list[SessionBranch]]:
"""Create a fork point with N branches."""
branch_ids = [uuid.uuid4() for _ in options]
fork_options = []
for i, opt in enumerate(options):
fork_options.append({
"label": opt["label"],
"description": opt["description"],
"branch_id": str(branch_ids[i]),
"status": "untried",
})
# Load session to get account_id for FK constraints
session_result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
session = session_result.scalar_one_or_none()
account_id = session.account_id if session else None
fork_point = ForkPoint(
id=uuid.uuid4(),
session_id=session_id,
account_id=account_id,
parent_branch_id=parent_branch_id,
trigger_step_id=trigger_step_id,
fork_reason=fork_reason,
options=fork_options,
)
self.db.add(fork_point)
# Get parent branch messages for context inheritance
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == parent_branch_id)
)
parent = result.scalar_one_or_none()
parent_messages = list(parent.conversation_messages or []) if parent else []
branches = []
for i, opt in enumerate(options):
branch = SessionBranch(
id=branch_ids[i],
session_id=session_id,
account_id=account_id,
parent_branch_id=parent_branch_id,
fork_point_step_id=trigger_step_id,
branch_order=i + 1,
label=opt["label"],
status="untried",
conversation_messages=parent_messages,
)
self.db.add(branch)
branches.append(branch)
# Mark trigger step as fork point
if trigger_step_id:
step_result = await self.db.execute(
select(AISessionStep).where(AISessionStep.id == trigger_step_id)
)
step = step_result.scalar_one_or_none()
if step:
step.is_fork_point = True
step.fork_point_id = fork_point.id
await self.db.flush()
return fork_point, branches
async def switch_branch(self, session_id: UUID, target_branch_id: UUID) -> SessionBranch:
"""Switch the active branch for a session."""
result = await self.db.execute(
select(SessionBranch).where(
SessionBranch.id == target_branch_id,
SessionBranch.session_id == session_id,
)
)
branch = result.scalar_one_or_none()
if not branch:
raise ValueError(f"Branch {target_branch_id} not found in session {session_id}")
session_result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
session = session_result.scalar_one()
session.active_branch_id = target_branch_id
if branch.status == "untried":
branch.status = "active"
branch.status_changed_at = datetime.now(timezone.utc)
await self.db.flush()
return branch
async def mark_branch_status(
self,
branch_id: UUID,
status: str,
reason: str | None = None,
user_id: UUID | None = None,
) -> SessionBranch:
"""Update a branch's status."""
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
raise ValueError(f"Branch {branch_id} not found")
branch.status = status
branch.status_reason = reason
branch.status_changed_at = datetime.now(timezone.utc)
branch.status_changed_by = user_id
await self.db.flush()
return branch
async def revive_branch(
self,
branch_id: UUID,
evidence_from_branch_id: UUID,
evidence_description: str,
) -> SessionBranch:
"""Revive a dead-end branch with evidence from another branch."""
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
raise ValueError(f"Branch {branch_id} not found")
branch.status = "revived"
branch.status_changed_at = datetime.now(timezone.utc)
branch.evidence_from_branch_id = evidence_from_branch_id
branch.evidence_description = evidence_description
revival_msg = {
"role": "system",
"content": f"[Branch Revived] New evidence from another branch: {evidence_description}",
}
msgs = list(branch.conversation_messages or [])
msgs.append(revival_msg)
branch.conversation_messages = msgs
await self.db.flush()
return branch
async def get_branch_tree(self, session_id: UUID) -> list[SessionBranch]:
"""Get all branches for a session, ordered by branch_order."""
result = await self.db.execute(
select(SessionBranch)
.where(SessionBranch.session_id == session_id)
.order_by(SessionBranch.branch_order)
)
return list(result.scalars().all())
async def build_cross_branch_context(self, branch_id: UUID) -> str:
"""Build cross-branch context from sibling summaries."""
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
return ""
siblings_result = await self.db.execute(
select(SessionBranch)
.where(
SessionBranch.session_id == branch.session_id,
SessionBranch.id != branch_id,
)
.order_by(SessionBranch.branch_order)
)
siblings = list(siblings_result.scalars().all())
if not siblings:
return ""
priority = {"active": 0, "untried": 1, "revived": 2, "dead_end": 3, "solved": 4}
siblings.sort(key=lambda b: priority.get(b.status, 5))
parts = ["\n## Cross-Branch Context"]
for sib in siblings:
summary = sib.context_summary
if summary:
tried = ", ".join(summary.get("tried", []))
concluded = summary.get("concluded", "No conclusion yet")
parts.append(f"- **{sib.label}** [{sib.status}]: Tried: {tried}. {concluded}")
else:
parts.append(f"- **{sib.label}** [{sib.status}]: No summary yet.")
return "\n".join(parts)