feat: add BranchManager service with integration tests

Implements branch lifecycle management for conversational branching:
create_root_branch, create_fork, switch_branch, mark_branch_status,
revive_branch, get_branch_tree, and build_cross_branch_context.
Five integration tests cover the full lifecycle from root creation
through forking, switching, dead-end marking, and tree retrieval.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-03-24 08:34:49 +00:00
parent 52fa1153c4
commit cc77f2858d
2 changed files with 456 additions and 0 deletions

View File

@@ -0,0 +1,238 @@
"""Branch lifecycle management for conversational branching."""
import uuid
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ai_session import AISession
from app.models.ai_session_step import AISessionStep
from app.models.session_branch import SessionBranch
from app.models.fork_point import ForkPoint
logger = logging.getLogger(__name__)
class BranchManager:
"""Branch lifecycle management."""
def __init__(self, db: AsyncSession):
self.db = db
async def create_root_branch(self, session_id: UUID) -> SessionBranch:
"""Create the root branch, copy conversation_messages, set is_branching=True."""
result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
session = result.scalar_one_or_none()
if not session:
raise ValueError(f"Session {session_id} not found")
root = SessionBranch(
id=uuid.uuid4(),
session_id=session_id,
parent_branch_id=None,
branch_order=1,
label="Root",
status="active",
conversation_messages=list(session.conversation_messages or []),
)
self.db.add(root)
session.is_branching = True
session.active_branch_id = root.id
await self.db.flush()
return root
async def create_fork(
self,
session_id: UUID,
parent_branch_id: UUID,
trigger_step_id: UUID | None,
fork_reason: str,
options: list[dict[str, str]],
) -> tuple[ForkPoint, list[SessionBranch]]:
"""Create a fork point with N branches."""
branch_ids = [uuid.uuid4() for _ in options]
fork_options = []
for i, opt in enumerate(options):
fork_options.append({
"label": opt["label"],
"description": opt["description"],
"branch_id": str(branch_ids[i]),
"status": "untried",
})
fork_point = ForkPoint(
id=uuid.uuid4(),
session_id=session_id,
parent_branch_id=parent_branch_id,
trigger_step_id=trigger_step_id,
fork_reason=fork_reason,
options=fork_options,
)
self.db.add(fork_point)
# Get parent branch messages for context inheritance
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == parent_branch_id)
)
parent = result.scalar_one_or_none()
parent_messages = list(parent.conversation_messages or []) if parent else []
branches = []
for i, opt in enumerate(options):
branch = SessionBranch(
id=branch_ids[i],
session_id=session_id,
parent_branch_id=parent_branch_id,
fork_point_step_id=trigger_step_id,
branch_order=i + 1,
label=opt["label"],
status="untried",
conversation_messages=parent_messages,
)
self.db.add(branch)
branches.append(branch)
# Mark trigger step as fork point
if trigger_step_id:
step_result = await self.db.execute(
select(AISessionStep).where(AISessionStep.id == trigger_step_id)
)
step = step_result.scalar_one_or_none()
if step:
step.is_fork_point = True
step.fork_point_id = fork_point.id
await self.db.flush()
return fork_point, branches
async def switch_branch(self, session_id: UUID, target_branch_id: UUID) -> SessionBranch:
"""Switch the active branch for a session."""
result = await self.db.execute(
select(SessionBranch).where(
SessionBranch.id == target_branch_id,
SessionBranch.session_id == session_id,
)
)
branch = result.scalar_one_or_none()
if not branch:
raise ValueError(f"Branch {target_branch_id} not found in session {session_id}")
session_result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
session = session_result.scalar_one()
session.active_branch_id = target_branch_id
if branch.status == "untried":
branch.status = "active"
branch.status_changed_at = datetime.now(timezone.utc)
await self.db.flush()
return branch
async def mark_branch_status(
self,
branch_id: UUID,
status: str,
reason: str | None = None,
user_id: UUID | None = None,
) -> SessionBranch:
"""Update a branch's status."""
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
raise ValueError(f"Branch {branch_id} not found")
branch.status = status
branch.status_reason = reason
branch.status_changed_at = datetime.now(timezone.utc)
branch.status_changed_by = user_id
await self.db.flush()
return branch
async def revive_branch(
self,
branch_id: UUID,
evidence_from_branch_id: UUID,
evidence_description: str,
) -> SessionBranch:
"""Revive a dead-end branch with evidence from another branch."""
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
raise ValueError(f"Branch {branch_id} not found")
branch.status = "revived"
branch.status_changed_at = datetime.now(timezone.utc)
branch.evidence_from_branch_id = evidence_from_branch_id
branch.evidence_description = evidence_description
revival_msg = {
"role": "system",
"content": f"[Branch Revived] New evidence from another branch: {evidence_description}",
}
msgs = list(branch.conversation_messages or [])
msgs.append(revival_msg)
branch.conversation_messages = msgs
await self.db.flush()
return branch
async def get_branch_tree(self, session_id: UUID) -> list[SessionBranch]:
"""Get all branches for a session, ordered by branch_order."""
result = await self.db.execute(
select(SessionBranch)
.where(SessionBranch.session_id == session_id)
.order_by(SessionBranch.branch_order)
)
return list(result.scalars().all())
async def build_cross_branch_context(self, branch_id: UUID) -> str:
"""Build cross-branch context from sibling summaries."""
result = await self.db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
return ""
siblings_result = await self.db.execute(
select(SessionBranch)
.where(
SessionBranch.session_id == branch.session_id,
SessionBranch.id != branch_id,
)
.order_by(SessionBranch.branch_order)
)
siblings = list(siblings_result.scalars().all())
if not siblings:
return ""
priority = {"active": 0, "untried": 1, "revived": 2, "dead_end": 3, "solved": 4}
siblings.sort(key=lambda b: priority.get(b.status, 5))
parts = ["\n## Cross-Branch Context"]
for sib in siblings:
summary = sib.context_summary
if summary:
tried = ", ".join(summary.get("tried", []))
concluded = summary.get("concluded", "No conclusion yet")
parts.append(f"- **{sib.label}** [{sib.status}]: Tried: {tried}. {concluded}")
else:
parts.append(f"- **{sib.label}** [{sib.status}]: No summary yet.")
return "\n".join(parts)

View File

@@ -0,0 +1,218 @@
"""Integration tests for BranchManager service."""
import uuid
import pytest
from httpx import AsyncClient
from app.models.ai_session import AISession
from app.models.session_branch import SessionBranch
from app.models.fork_point import ForkPoint
from app.models.ai_session_step import AISessionStep
@pytest.mark.asyncio
async def test_create_root_branch(client: AsyncClient, test_user, auth_headers, test_db):
"""Creating a root branch sets is_branching=True and copies conversation_messages."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[
{"role": "user", "content": "test message"},
{"role": "assistant", "content": "test response"},
],
)
test_db.add(session)
await test_db.flush()
from app.services.branch_manager import BranchManager
manager = BranchManager(test_db)
root = await manager.create_root_branch(session.id)
assert root is not None
assert root.parent_branch_id is None
assert root.label == "Root"
assert root.status == "active"
assert root.branch_order == 1
assert len(root.conversation_messages) == 2
await test_db.refresh(session)
assert session.is_branching is True
assert session.active_branch_id == root.id
@pytest.mark.asyncio
async def test_create_fork(client: AsyncClient, test_user, auth_headers, test_db):
"""Creating a fork produces a ForkPoint + N branches."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[],
)
test_db.add(session)
await test_db.flush()
from app.services.branch_manager import BranchManager
manager = BranchManager(test_db)
root = await manager.create_root_branch(session.id)
step = AISessionStep(
session_id=session.id,
step_order=0,
step_type="question",
content={"text": "What's the issue?"},
confidence_at_step=0.5,
)
test_db.add(step)
await test_db.flush()
fork_point, branches = await manager.create_fork(
session_id=session.id,
parent_branch_id=root.id,
trigger_step_id=step.id,
fork_reason="Two possible causes identified",
options=[
{"label": "Network connectivity", "description": "Check network stack"},
{"label": "DNS resolution", "description": "Check DNS config"},
],
)
assert fork_point is not None
assert len(branches) == 2
assert branches[0].label == "Network connectivity"
assert branches[0].status == "untried"
assert branches[0].parent_branch_id == root.id
assert branches[1].label == "DNS resolution"
assert branches[1].branch_order == 2
await test_db.refresh(step)
assert step.is_fork_point is True
assert step.fork_point_id == fork_point.id
@pytest.mark.asyncio
async def test_switch_branch(client: AsyncClient, test_user, auth_headers, test_db):
"""Switching branches updates active_branch_id."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[],
)
test_db.add(session)
await test_db.flush()
from app.services.branch_manager import BranchManager
manager = BranchManager(test_db)
root = await manager.create_root_branch(session.id)
step = AISessionStep(
session_id=session.id, step_order=0, step_type="question",
content={"text": "test"}, confidence_at_step=0.5,
)
test_db.add(step)
await test_db.flush()
_, branches = await manager.create_fork(
session_id=session.id,
parent_branch_id=root.id,
trigger_step_id=step.id,
fork_reason="test fork",
options=[
{"label": "Option A", "description": "desc A"},
{"label": "Option B", "description": "desc B"},
],
)
branch_b = branches[1]
result = await manager.switch_branch(session.id, branch_b.id)
assert result.id == branch_b.id
await test_db.refresh(session)
assert session.active_branch_id == branch_b.id
@pytest.mark.asyncio
async def test_mark_branch_dead_end(client: AsyncClient, test_user, auth_headers, test_db):
"""Marking a branch as dead_end updates status."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[],
)
test_db.add(session)
await test_db.flush()
from app.services.branch_manager import BranchManager
manager = BranchManager(test_db)
root = await manager.create_root_branch(session.id)
updated = await manager.mark_branch_status(
branch_id=root.id,
status="dead_end",
reason="Network was fine, not the cause",
user_id=test_user["user_data"]["id"],
)
assert updated.status == "dead_end"
assert updated.status_reason == "Network was fine, not the cause"
assert updated.status_changed_at is not None
@pytest.mark.asyncio
async def test_get_branch_tree(client: AsyncClient, test_user, auth_headers, test_db):
"""get_branch_tree returns the full tree structure."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[{"role": "user", "content": "help"}],
)
test_db.add(session)
await test_db.flush()
from app.services.branch_manager import BranchManager
manager = BranchManager(test_db)
root = await manager.create_root_branch(session.id)
step = AISessionStep(
session_id=session.id, step_order=0, step_type="question",
content={"text": "test"}, confidence_at_step=0.5,
)
test_db.add(step)
await test_db.flush()
await manager.create_fork(
session_id=session.id,
parent_branch_id=root.id,
trigger_step_id=step.id,
fork_reason="test",
options=[
{"label": "A", "description": "a"},
{"label": "B", "description": "b"},
],
)
tree = await manager.get_branch_tree(session.id)
assert len(tree) == 3 # Root + 2 fork branches