Implements branch lifecycle management for conversational branching: create_root_branch, create_fork, switch_branch, mark_branch_status, revive_branch, get_branch_tree, and build_cross_branch_context. Five integration tests cover the full lifecycle from root creation through forking, switching, dead-end marking, and tree retrieval. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
239 lines
8.0 KiB
Python
239 lines
8.0 KiB
Python
"""Branch lifecycle management for conversational branching."""
|
|
import uuid
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.ai_session import AISession
|
|
from app.models.ai_session_step import AISessionStep
|
|
from app.models.session_branch import SessionBranch
|
|
from app.models.fork_point import ForkPoint
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BranchManager:
|
|
"""Branch lifecycle management."""
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def create_root_branch(self, session_id: UUID) -> SessionBranch:
|
|
"""Create the root branch, copy conversation_messages, set is_branching=True."""
|
|
result = await self.db.execute(
|
|
select(AISession).where(AISession.id == session_id)
|
|
)
|
|
session = result.scalar_one_or_none()
|
|
if not session:
|
|
raise ValueError(f"Session {session_id} not found")
|
|
|
|
root = SessionBranch(
|
|
id=uuid.uuid4(),
|
|
session_id=session_id,
|
|
parent_branch_id=None,
|
|
branch_order=1,
|
|
label="Root",
|
|
status="active",
|
|
conversation_messages=list(session.conversation_messages or []),
|
|
)
|
|
self.db.add(root)
|
|
|
|
session.is_branching = True
|
|
session.active_branch_id = root.id
|
|
|
|
await self.db.flush()
|
|
return root
|
|
|
|
async def create_fork(
|
|
self,
|
|
session_id: UUID,
|
|
parent_branch_id: UUID,
|
|
trigger_step_id: UUID | None,
|
|
fork_reason: str,
|
|
options: list[dict[str, str]],
|
|
) -> tuple[ForkPoint, list[SessionBranch]]:
|
|
"""Create a fork point with N branches."""
|
|
branch_ids = [uuid.uuid4() for _ in options]
|
|
|
|
fork_options = []
|
|
for i, opt in enumerate(options):
|
|
fork_options.append({
|
|
"label": opt["label"],
|
|
"description": opt["description"],
|
|
"branch_id": str(branch_ids[i]),
|
|
"status": "untried",
|
|
})
|
|
|
|
fork_point = ForkPoint(
|
|
id=uuid.uuid4(),
|
|
session_id=session_id,
|
|
parent_branch_id=parent_branch_id,
|
|
trigger_step_id=trigger_step_id,
|
|
fork_reason=fork_reason,
|
|
options=fork_options,
|
|
)
|
|
self.db.add(fork_point)
|
|
|
|
# Get parent branch messages for context inheritance
|
|
result = await self.db.execute(
|
|
select(SessionBranch).where(SessionBranch.id == parent_branch_id)
|
|
)
|
|
parent = result.scalar_one_or_none()
|
|
parent_messages = list(parent.conversation_messages or []) if parent else []
|
|
|
|
branches = []
|
|
for i, opt in enumerate(options):
|
|
branch = SessionBranch(
|
|
id=branch_ids[i],
|
|
session_id=session_id,
|
|
parent_branch_id=parent_branch_id,
|
|
fork_point_step_id=trigger_step_id,
|
|
branch_order=i + 1,
|
|
label=opt["label"],
|
|
status="untried",
|
|
conversation_messages=parent_messages,
|
|
)
|
|
self.db.add(branch)
|
|
branches.append(branch)
|
|
|
|
# Mark trigger step as fork point
|
|
if trigger_step_id:
|
|
step_result = await self.db.execute(
|
|
select(AISessionStep).where(AISessionStep.id == trigger_step_id)
|
|
)
|
|
step = step_result.scalar_one_or_none()
|
|
if step:
|
|
step.is_fork_point = True
|
|
step.fork_point_id = fork_point.id
|
|
|
|
await self.db.flush()
|
|
return fork_point, branches
|
|
|
|
async def switch_branch(self, session_id: UUID, target_branch_id: UUID) -> SessionBranch:
|
|
"""Switch the active branch for a session."""
|
|
result = await self.db.execute(
|
|
select(SessionBranch).where(
|
|
SessionBranch.id == target_branch_id,
|
|
SessionBranch.session_id == session_id,
|
|
)
|
|
)
|
|
branch = result.scalar_one_or_none()
|
|
if not branch:
|
|
raise ValueError(f"Branch {target_branch_id} not found in session {session_id}")
|
|
|
|
session_result = await self.db.execute(
|
|
select(AISession).where(AISession.id == session_id)
|
|
)
|
|
session = session_result.scalar_one()
|
|
session.active_branch_id = target_branch_id
|
|
|
|
if branch.status == "untried":
|
|
branch.status = "active"
|
|
branch.status_changed_at = datetime.now(timezone.utc)
|
|
|
|
await self.db.flush()
|
|
return branch
|
|
|
|
async def mark_branch_status(
|
|
self,
|
|
branch_id: UUID,
|
|
status: str,
|
|
reason: str | None = None,
|
|
user_id: UUID | None = None,
|
|
) -> SessionBranch:
|
|
"""Update a branch's status."""
|
|
result = await self.db.execute(
|
|
select(SessionBranch).where(SessionBranch.id == branch_id)
|
|
)
|
|
branch = result.scalar_one_or_none()
|
|
if not branch:
|
|
raise ValueError(f"Branch {branch_id} not found")
|
|
|
|
branch.status = status
|
|
branch.status_reason = reason
|
|
branch.status_changed_at = datetime.now(timezone.utc)
|
|
branch.status_changed_by = user_id
|
|
|
|
await self.db.flush()
|
|
return branch
|
|
|
|
async def revive_branch(
|
|
self,
|
|
branch_id: UUID,
|
|
evidence_from_branch_id: UUID,
|
|
evidence_description: str,
|
|
) -> SessionBranch:
|
|
"""Revive a dead-end branch with evidence from another branch."""
|
|
result = await self.db.execute(
|
|
select(SessionBranch).where(SessionBranch.id == branch_id)
|
|
)
|
|
branch = result.scalar_one_or_none()
|
|
if not branch:
|
|
raise ValueError(f"Branch {branch_id} not found")
|
|
|
|
branch.status = "revived"
|
|
branch.status_changed_at = datetime.now(timezone.utc)
|
|
branch.evidence_from_branch_id = evidence_from_branch_id
|
|
branch.evidence_description = evidence_description
|
|
|
|
revival_msg = {
|
|
"role": "system",
|
|
"content": f"[Branch Revived] New evidence from another branch: {evidence_description}",
|
|
}
|
|
msgs = list(branch.conversation_messages or [])
|
|
msgs.append(revival_msg)
|
|
branch.conversation_messages = msgs
|
|
|
|
await self.db.flush()
|
|
return branch
|
|
|
|
async def get_branch_tree(self, session_id: UUID) -> list[SessionBranch]:
|
|
"""Get all branches for a session, ordered by branch_order."""
|
|
result = await self.db.execute(
|
|
select(SessionBranch)
|
|
.where(SessionBranch.session_id == session_id)
|
|
.order_by(SessionBranch.branch_order)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def build_cross_branch_context(self, branch_id: UUID) -> str:
|
|
"""Build cross-branch context from sibling summaries."""
|
|
result = await self.db.execute(
|
|
select(SessionBranch).where(SessionBranch.id == branch_id)
|
|
)
|
|
branch = result.scalar_one_or_none()
|
|
if not branch:
|
|
return ""
|
|
|
|
siblings_result = await self.db.execute(
|
|
select(SessionBranch)
|
|
.where(
|
|
SessionBranch.session_id == branch.session_id,
|
|
SessionBranch.id != branch_id,
|
|
)
|
|
.order_by(SessionBranch.branch_order)
|
|
)
|
|
siblings = list(siblings_result.scalars().all())
|
|
|
|
if not siblings:
|
|
return ""
|
|
|
|
priority = {"active": 0, "untried": 1, "revived": 2, "dead_end": 3, "solved": 4}
|
|
siblings.sort(key=lambda b: priority.get(b.status, 5))
|
|
|
|
parts = ["\n## Cross-Branch Context"]
|
|
for sib in siblings:
|
|
summary = sib.context_summary
|
|
if summary:
|
|
tried = ", ".join(summary.get("tried", []))
|
|
concluded = summary.get("concluded", "No conclusion yet")
|
|
parts.append(f"- **{sib.label}** [{sib.status}]: Tried: {tried}. {concluded}")
|
|
else:
|
|
parts.append(f"- **{sib.label}** [{sib.status}]: No summary yet.")
|
|
|
|
return "\n".join(parts)
|