Files
resolutionflow/backend/app/services/branch_manager.py
chihlasm cc77f2858d feat: add BranchManager service with integration tests
Implements branch lifecycle management for conversational branching:
create_root_branch, create_fork, switch_branch, mark_branch_status,
revive_branch, get_branch_tree, and build_cross_branch_context.
Five integration tests cover the full lifecycle from root creation
through forking, switching, dead-end marking, and tree retrieval.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-24 08:34:49 +00:00

239 lines
8.0 KiB
Python

"""Branch lifecycle management for conversational branching."""
import uuid
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ai_session import AISession
from app.models.ai_session_step import AISessionStep
from app.models.session_branch import SessionBranch
from app.models.fork_point import ForkPoint
logger = logging.getLogger(__name__)
class BranchManager:
"""Branch lifecycle management."""
def __init__(self, db: AsyncSession):
self.db = db
async def create_root_branch(self, session_id: UUID) -> SessionBranch:
"""Create the root branch, copy conversation_messages, set is_branching=True."""
result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
session = result.scalar_one_or_none()
if not session:
raise ValueError(f"Session {session_id} not found")
root = SessionBranch(
id=uuid.uuid4(),
session_id=session_id,
parent_branch_id=None,
branch_order=1,
label="Root",
status="active",
conversation_messages=list(session.conversation_messages or []),
)
self.db.add(root)
session.is_branching = True
session.active_branch_id = root.id
await self.db.flush()
return root
async def create_fork(
self,
session_id: UUID,
parent_branch_id: UUID,
trigger_step_id: UUID | None,
fork_reason: str,
options: list[dict[str, str]],
) -> tuple[ForkPoint, list[SessionBranch]]:
"""Create a fork point with N branches."""
branch_ids = [uuid.uuid4() for _ in options]
fork_options = []
for i, opt in enumerate(options):
fork_options.append({
"label": opt["label"],
"description": opt["description"],
"branch_id": str(branch_ids[i]),
"status": "untried",
})
fork_point = ForkPoint(
id=uuid.uuid4(),
session_id=session_id,
parent_branch_id=parent_branch_id,
trigger_step_id=trigger_step_id,
fork_reason=fork_reason,
options=fork_options,
)
self.db.add(fork_point)
# Get parent branch messages for context inheritance
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == parent_branch_id)
)
parent = result.scalar_one_or_none()
parent_messages = list(parent.conversation_messages or []) if parent else []
branches = []
for i, opt in enumerate(options):
branch = SessionBranch(
id=branch_ids[i],
session_id=session_id,
parent_branch_id=parent_branch_id,
fork_point_step_id=trigger_step_id,
branch_order=i + 1,
label=opt["label"],
status="untried",
conversation_messages=parent_messages,
)
self.db.add(branch)
branches.append(branch)
# Mark trigger step as fork point
if trigger_step_id:
step_result = await self.db.execute(
select(AISessionStep).where(AISessionStep.id == trigger_step_id)
)
step = step_result.scalar_one_or_none()
if step:
step.is_fork_point = True
step.fork_point_id = fork_point.id
await self.db.flush()
return fork_point, branches
async def switch_branch(self, session_id: UUID, target_branch_id: UUID) -> SessionBranch:
"""Switch the active branch for a session."""
result = await self.db.execute(
select(SessionBranch).where(
SessionBranch.id == target_branch_id,
SessionBranch.session_id == session_id,
)
)
branch = result.scalar_one_or_none()
if not branch:
raise ValueError(f"Branch {target_branch_id} not found in session {session_id}")
session_result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
session = session_result.scalar_one()
session.active_branch_id = target_branch_id
if branch.status == "untried":
branch.status = "active"
branch.status_changed_at = datetime.now(timezone.utc)
await self.db.flush()
return branch
async def mark_branch_status(
self,
branch_id: UUID,
status: str,
reason: str | None = None,
user_id: UUID | None = None,
) -> SessionBranch:
"""Update a branch's status."""
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
raise ValueError(f"Branch {branch_id} not found")
branch.status = status
branch.status_reason = reason
branch.status_changed_at = datetime.now(timezone.utc)
branch.status_changed_by = user_id
await self.db.flush()
return branch
async def revive_branch(
self,
branch_id: UUID,
evidence_from_branch_id: UUID,
evidence_description: str,
) -> SessionBranch:
"""Revive a dead-end branch with evidence from another branch."""
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
raise ValueError(f"Branch {branch_id} not found")
branch.status = "revived"
branch.status_changed_at = datetime.now(timezone.utc)
branch.evidence_from_branch_id = evidence_from_branch_id
branch.evidence_description = evidence_description
revival_msg = {
"role": "system",
"content": f"[Branch Revived] New evidence from another branch: {evidence_description}",
}
msgs = list(branch.conversation_messages or [])
msgs.append(revival_msg)
branch.conversation_messages = msgs
await self.db.flush()
return branch
async def get_branch_tree(self, session_id: UUID) -> list[SessionBranch]:
"""Get all branches for a session, ordered by branch_order."""
result = await self.db.execute(
select(SessionBranch)
.where(SessionBranch.session_id == session_id)
.order_by(SessionBranch.branch_order)
)
return list(result.scalars().all())
async def build_cross_branch_context(self, branch_id: UUID) -> str:
"""Build cross-branch context from sibling summaries."""
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
return ""
siblings_result = await self.db.execute(
select(SessionBranch)
.where(
SessionBranch.session_id == branch.session_id,
SessionBranch.id != branch_id,
)
.order_by(SessionBranch.branch_order)
)
siblings = list(siblings_result.scalars().all())
if not siblings:
return ""
priority = {"active": 0, "untried": 1, "revived": 2, "dead_end": 3, "solved": 4}
siblings.sort(key=lambda b: priority.get(b.status, 5))
parts = ["\n## Cross-Branch Context"]
for sib in siblings:
summary = sib.context_summary
if summary:
tried = ", ".join(summary.get("tried", []))
concluded = summary.get("concluded", "No conclusion yet")
parts.append(f"- **{sib.label}** [{sib.status}]: Tried: {tried}. {concluded}")
else:
parts.append(f"- **{sib.label}** [{sib.status}]: No summary yet.")
return "\n".join(parts)