"""Branch lifecycle management for conversational branching.""" import uuid import logging from datetime import datetime, timezone from typing import Any from uuid import UUID from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.ai_session import AISession from app.models.ai_session_step import AISessionStep from app.models.session_branch import SessionBranch from app.models.fork_point import ForkPoint logger = logging.getLogger(__name__) class BranchManager: """Branch lifecycle management.""" def __init__(self, db: AsyncSession): self.db = db async def create_root_branch(self, session_id: UUID) -> SessionBranch: """Create the root branch, copy conversation_messages, set is_branching=True.""" result = await self.db.execute( select(AISession).where(AISession.id == session_id) ) session = result.scalar_one_or_none() if not session: raise ValueError(f"Session {session_id} not found") root = SessionBranch( id=uuid.uuid4(), session_id=session_id, account_id=session.account_id, parent_branch_id=None, branch_order=1, label="Root", status="active", conversation_messages=list(session.conversation_messages or []), ) self.db.add(root) session.is_branching = True session.active_branch_id = root.id await self.db.flush() return root async def create_fork( self, session_id: UUID, parent_branch_id: UUID, trigger_step_id: UUID | None, fork_reason: str, options: list[dict[str, str]], ) -> tuple[ForkPoint, list[SessionBranch]]: """Create a fork point with N branches.""" branch_ids = [uuid.uuid4() for _ in options] fork_options = [] for i, opt in enumerate(options): fork_options.append({ "label": opt["label"], "description": opt["description"], "branch_id": str(branch_ids[i]), "status": "untried", }) # Load session to get account_id for FK constraints session_result = await self.db.execute( select(AISession).where(AISession.id == session_id) ) session = session_result.scalar_one_or_none() account_id = session.account_id if session else None fork_point = ForkPoint( id=uuid.uuid4(), session_id=session_id, account_id=account_id, parent_branch_id=parent_branch_id, trigger_step_id=trigger_step_id, fork_reason=fork_reason, options=fork_options, ) self.db.add(fork_point) # Get parent branch messages for context inheritance result = await self.db.execute( select(SessionBranch).where(SessionBranch.id == parent_branch_id) ) parent = result.scalar_one_or_none() parent_messages = list(parent.conversation_messages or []) if parent else [] branches = [] for i, opt in enumerate(options): branch = SessionBranch( id=branch_ids[i], session_id=session_id, account_id=account_id, parent_branch_id=parent_branch_id, fork_point_step_id=trigger_step_id, branch_order=i + 1, label=opt["label"], status="untried", conversation_messages=parent_messages, ) self.db.add(branch) branches.append(branch) # Mark trigger step as fork point if trigger_step_id: step_result = await self.db.execute( select(AISessionStep).where(AISessionStep.id == trigger_step_id) ) step = step_result.scalar_one_or_none() if step: step.is_fork_point = True step.fork_point_id = fork_point.id await self.db.flush() return fork_point, branches async def switch_branch(self, session_id: UUID, target_branch_id: UUID) -> SessionBranch: """Switch the active branch for a session.""" result = await self.db.execute( select(SessionBranch).where( SessionBranch.id == target_branch_id, SessionBranch.session_id == session_id, ) ) branch = result.scalar_one_or_none() if not branch: raise ValueError(f"Branch {target_branch_id} not found in session {session_id}") session_result = await self.db.execute( select(AISession).where(AISession.id == session_id) ) session = session_result.scalar_one() session.active_branch_id = target_branch_id if branch.status == "untried": branch.status = "active" branch.status_changed_at = datetime.now(timezone.utc) await self.db.flush() return branch async def mark_branch_status( self, branch_id: UUID, status: str, reason: str | None = None, user_id: UUID | None = None, ) -> SessionBranch: """Update a branch's status.""" result = await self.db.execute( select(SessionBranch).where(SessionBranch.id == branch_id) ) branch = result.scalar_one_or_none() if not branch: raise ValueError(f"Branch {branch_id} not found") branch.status = status branch.status_reason = reason branch.status_changed_at = datetime.now(timezone.utc) branch.status_changed_by = user_id await self.db.flush() return branch async def revive_branch( self, branch_id: UUID, evidence_from_branch_id: UUID, evidence_description: str, ) -> SessionBranch: """Revive a dead-end branch with evidence from another branch.""" result = await self.db.execute( select(SessionBranch).where(SessionBranch.id == branch_id) ) branch = result.scalar_one_or_none() if not branch: raise ValueError(f"Branch {branch_id} not found") branch.status = "revived" branch.status_changed_at = datetime.now(timezone.utc) branch.evidence_from_branch_id = evidence_from_branch_id branch.evidence_description = evidence_description revival_msg = { "role": "system", "content": f"[Branch Revived] New evidence from another branch: {evidence_description}", } msgs = list(branch.conversation_messages or []) msgs.append(revival_msg) branch.conversation_messages = msgs await self.db.flush() return branch async def get_branch_tree(self, session_id: UUID) -> list[SessionBranch]: """Get all branches for a session, ordered by branch_order.""" result = await self.db.execute( select(SessionBranch) .where(SessionBranch.session_id == session_id) .order_by(SessionBranch.branch_order) ) return list(result.scalars().all()) async def build_cross_branch_context(self, branch_id: UUID) -> str: """Build cross-branch context from sibling summaries.""" result = await self.db.execute( select(SessionBranch).where(SessionBranch.id == branch_id) ) branch = result.scalar_one_or_none() if not branch: return "" siblings_result = await self.db.execute( select(SessionBranch) .where( SessionBranch.session_id == branch.session_id, SessionBranch.id != branch_id, ) .order_by(SessionBranch.branch_order) ) siblings = list(siblings_result.scalars().all()) if not siblings: return "" priority = {"active": 0, "untried": 1, "revived": 2, "dead_end": 3, "solved": 4} siblings.sort(key=lambda b: priority.get(b.status, 5)) parts = ["\n## Cross-Branch Context"] for sib in siblings: summary = sib.context_summary if summary: tried = ", ".join(summary.get("tried", [])) concluded = summary.get("concluded", "No conclusion yet") parts.append(f"- **{sib.label}** [{sib.status}]: Tried: {tried}. {concluded}") else: parts.append(f"- **{sib.label}** [{sib.status}]: No summary yet.") return "\n".join(parts)