feat: add BranchManager service with integration tests
Implements branch lifecycle management for conversational branching: create_root_branch, create_fork, switch_branch, mark_branch_status, revive_branch, get_branch_tree, and build_cross_branch_context. Five integration tests cover the full lifecycle from root creation through forking, switching, dead-end marking, and tree retrieval. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
238
backend/app/services/branch_manager.py
Normal file
238
backend/app/services/branch_manager.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Branch lifecycle management for conversational branching."""
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.ai_session import AISession
|
||||
from app.models.ai_session_step import AISessionStep
|
||||
from app.models.session_branch import SessionBranch
|
||||
from app.models.fork_point import ForkPoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BranchManager:
|
||||
"""Branch lifecycle management."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def create_root_branch(self, session_id: UUID) -> SessionBranch:
|
||||
"""Create the root branch, copy conversation_messages, set is_branching=True."""
|
||||
result = await self.db.execute(
|
||||
select(AISession).where(AISession.id == session_id)
|
||||
)
|
||||
session = result.scalar_one_or_none()
|
||||
if not session:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
root = SessionBranch(
|
||||
id=uuid.uuid4(),
|
||||
session_id=session_id,
|
||||
parent_branch_id=None,
|
||||
branch_order=1,
|
||||
label="Root",
|
||||
status="active",
|
||||
conversation_messages=list(session.conversation_messages or []),
|
||||
)
|
||||
self.db.add(root)
|
||||
|
||||
session.is_branching = True
|
||||
session.active_branch_id = root.id
|
||||
|
||||
await self.db.flush()
|
||||
return root
|
||||
|
||||
async def create_fork(
|
||||
self,
|
||||
session_id: UUID,
|
||||
parent_branch_id: UUID,
|
||||
trigger_step_id: UUID | None,
|
||||
fork_reason: str,
|
||||
options: list[dict[str, str]],
|
||||
) -> tuple[ForkPoint, list[SessionBranch]]:
|
||||
"""Create a fork point with N branches."""
|
||||
branch_ids = [uuid.uuid4() for _ in options]
|
||||
|
||||
fork_options = []
|
||||
for i, opt in enumerate(options):
|
||||
fork_options.append({
|
||||
"label": opt["label"],
|
||||
"description": opt["description"],
|
||||
"branch_id": str(branch_ids[i]),
|
||||
"status": "untried",
|
||||
})
|
||||
|
||||
fork_point = ForkPoint(
|
||||
id=uuid.uuid4(),
|
||||
session_id=session_id,
|
||||
parent_branch_id=parent_branch_id,
|
||||
trigger_step_id=trigger_step_id,
|
||||
fork_reason=fork_reason,
|
||||
options=fork_options,
|
||||
)
|
||||
self.db.add(fork_point)
|
||||
|
||||
# Get parent branch messages for context inheritance
|
||||
result = await self.db.execute(
|
||||
select(SessionBranch).where(SessionBranch.id == parent_branch_id)
|
||||
)
|
||||
parent = result.scalar_one_or_none()
|
||||
parent_messages = list(parent.conversation_messages or []) if parent else []
|
||||
|
||||
branches = []
|
||||
for i, opt in enumerate(options):
|
||||
branch = SessionBranch(
|
||||
id=branch_ids[i],
|
||||
session_id=session_id,
|
||||
parent_branch_id=parent_branch_id,
|
||||
fork_point_step_id=trigger_step_id,
|
||||
branch_order=i + 1,
|
||||
label=opt["label"],
|
||||
status="untried",
|
||||
conversation_messages=parent_messages,
|
||||
)
|
||||
self.db.add(branch)
|
||||
branches.append(branch)
|
||||
|
||||
# Mark trigger step as fork point
|
||||
if trigger_step_id:
|
||||
step_result = await self.db.execute(
|
||||
select(AISessionStep).where(AISessionStep.id == trigger_step_id)
|
||||
)
|
||||
step = step_result.scalar_one_or_none()
|
||||
if step:
|
||||
step.is_fork_point = True
|
||||
step.fork_point_id = fork_point.id
|
||||
|
||||
await self.db.flush()
|
||||
return fork_point, branches
|
||||
|
||||
async def switch_branch(self, session_id: UUID, target_branch_id: UUID) -> SessionBranch:
|
||||
"""Switch the active branch for a session."""
|
||||
result = await self.db.execute(
|
||||
select(SessionBranch).where(
|
||||
SessionBranch.id == target_branch_id,
|
||||
SessionBranch.session_id == session_id,
|
||||
)
|
||||
)
|
||||
branch = result.scalar_one_or_none()
|
||||
if not branch:
|
||||
raise ValueError(f"Branch {target_branch_id} not found in session {session_id}")
|
||||
|
||||
session_result = await self.db.execute(
|
||||
select(AISession).where(AISession.id == session_id)
|
||||
)
|
||||
session = session_result.scalar_one()
|
||||
session.active_branch_id = target_branch_id
|
||||
|
||||
if branch.status == "untried":
|
||||
branch.status = "active"
|
||||
branch.status_changed_at = datetime.now(timezone.utc)
|
||||
|
||||
await self.db.flush()
|
||||
return branch
|
||||
|
||||
async def mark_branch_status(
|
||||
self,
|
||||
branch_id: UUID,
|
||||
status: str,
|
||||
reason: str | None = None,
|
||||
user_id: UUID | None = None,
|
||||
) -> SessionBranch:
|
||||
"""Update a branch's status."""
|
||||
result = await self.db.execute(
|
||||
select(SessionBranch).where(SessionBranch.id == branch_id)
|
||||
)
|
||||
branch = result.scalar_one_or_none()
|
||||
if not branch:
|
||||
raise ValueError(f"Branch {branch_id} not found")
|
||||
|
||||
branch.status = status
|
||||
branch.status_reason = reason
|
||||
branch.status_changed_at = datetime.now(timezone.utc)
|
||||
branch.status_changed_by = user_id
|
||||
|
||||
await self.db.flush()
|
||||
return branch
|
||||
|
||||
async def revive_branch(
|
||||
self,
|
||||
branch_id: UUID,
|
||||
evidence_from_branch_id: UUID,
|
||||
evidence_description: str,
|
||||
) -> SessionBranch:
|
||||
"""Revive a dead-end branch with evidence from another branch."""
|
||||
result = await self.db.execute(
|
||||
select(SessionBranch).where(SessionBranch.id == branch_id)
|
||||
)
|
||||
branch = result.scalar_one_or_none()
|
||||
if not branch:
|
||||
raise ValueError(f"Branch {branch_id} not found")
|
||||
|
||||
branch.status = "revived"
|
||||
branch.status_changed_at = datetime.now(timezone.utc)
|
||||
branch.evidence_from_branch_id = evidence_from_branch_id
|
||||
branch.evidence_description = evidence_description
|
||||
|
||||
revival_msg = {
|
||||
"role": "system",
|
||||
"content": f"[Branch Revived] New evidence from another branch: {evidence_description}",
|
||||
}
|
||||
msgs = list(branch.conversation_messages or [])
|
||||
msgs.append(revival_msg)
|
||||
branch.conversation_messages = msgs
|
||||
|
||||
await self.db.flush()
|
||||
return branch
|
||||
|
||||
async def get_branch_tree(self, session_id: UUID) -> list[SessionBranch]:
|
||||
"""Get all branches for a session, ordered by branch_order."""
|
||||
result = await self.db.execute(
|
||||
select(SessionBranch)
|
||||
.where(SessionBranch.session_id == session_id)
|
||||
.order_by(SessionBranch.branch_order)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def build_cross_branch_context(self, branch_id: UUID) -> str:
|
||||
"""Build cross-branch context from sibling summaries."""
|
||||
result = await self.db.execute(
|
||||
select(SessionBranch).where(SessionBranch.id == branch_id)
|
||||
)
|
||||
branch = result.scalar_one_or_none()
|
||||
if not branch:
|
||||
return ""
|
||||
|
||||
siblings_result = await self.db.execute(
|
||||
select(SessionBranch)
|
||||
.where(
|
||||
SessionBranch.session_id == branch.session_id,
|
||||
SessionBranch.id != branch_id,
|
||||
)
|
||||
.order_by(SessionBranch.branch_order)
|
||||
)
|
||||
siblings = list(siblings_result.scalars().all())
|
||||
|
||||
if not siblings:
|
||||
return ""
|
||||
|
||||
priority = {"active": 0, "untried": 1, "revived": 2, "dead_end": 3, "solved": 4}
|
||||
siblings.sort(key=lambda b: priority.get(b.status, 5))
|
||||
|
||||
parts = ["\n## Cross-Branch Context"]
|
||||
for sib in siblings:
|
||||
summary = sib.context_summary
|
||||
if summary:
|
||||
tried = ", ".join(summary.get("tried", []))
|
||||
concluded = summary.get("concluded", "No conclusion yet")
|
||||
parts.append(f"- **{sib.label}** [{sib.status}]: Tried: {tried}. {concluded}")
|
||||
else:
|
||||
parts.append(f"- **{sib.label}** [{sib.status}]: No summary yet.")
|
||||
|
||||
return "\n".join(parts)
|
||||
218
backend/tests/test_branch_manager.py
Normal file
218
backend/tests/test_branch_manager.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Integration tests for BranchManager service."""
|
||||
import uuid
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models.ai_session import AISession
|
||||
from app.models.session_branch import SessionBranch
|
||||
from app.models.fork_point import ForkPoint
|
||||
from app.models.ai_session_step import AISessionStep
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_root_branch(client: AsyncClient, test_user, auth_headers, test_db):
|
||||
"""Creating a root branch sets is_branching=True and copies conversation_messages."""
|
||||
session = AISession(
|
||||
user_id=test_user["user_data"]["id"],
|
||||
account_id=test_user["user_data"]["account_id"],
|
||||
session_type="guided",
|
||||
intake_type="free_text",
|
||||
intake_content={"text": "test"},
|
||||
status="active",
|
||||
confidence_tier="discovery",
|
||||
conversation_messages=[
|
||||
{"role": "user", "content": "test message"},
|
||||
{"role": "assistant", "content": "test response"},
|
||||
],
|
||||
)
|
||||
test_db.add(session)
|
||||
await test_db.flush()
|
||||
|
||||
from app.services.branch_manager import BranchManager
|
||||
manager = BranchManager(test_db)
|
||||
root = await manager.create_root_branch(session.id)
|
||||
|
||||
assert root is not None
|
||||
assert root.parent_branch_id is None
|
||||
assert root.label == "Root"
|
||||
assert root.status == "active"
|
||||
assert root.branch_order == 1
|
||||
assert len(root.conversation_messages) == 2
|
||||
|
||||
await test_db.refresh(session)
|
||||
assert session.is_branching is True
|
||||
assert session.active_branch_id == root.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_fork(client: AsyncClient, test_user, auth_headers, test_db):
|
||||
"""Creating a fork produces a ForkPoint + N branches."""
|
||||
session = AISession(
|
||||
user_id=test_user["user_data"]["id"],
|
||||
account_id=test_user["user_data"]["account_id"],
|
||||
session_type="guided",
|
||||
intake_type="free_text",
|
||||
intake_content={"text": "test"},
|
||||
status="active",
|
||||
confidence_tier="discovery",
|
||||
conversation_messages=[],
|
||||
)
|
||||
test_db.add(session)
|
||||
await test_db.flush()
|
||||
|
||||
from app.services.branch_manager import BranchManager
|
||||
manager = BranchManager(test_db)
|
||||
root = await manager.create_root_branch(session.id)
|
||||
|
||||
step = AISessionStep(
|
||||
session_id=session.id,
|
||||
step_order=0,
|
||||
step_type="question",
|
||||
content={"text": "What's the issue?"},
|
||||
confidence_at_step=0.5,
|
||||
)
|
||||
test_db.add(step)
|
||||
await test_db.flush()
|
||||
|
||||
fork_point, branches = await manager.create_fork(
|
||||
session_id=session.id,
|
||||
parent_branch_id=root.id,
|
||||
trigger_step_id=step.id,
|
||||
fork_reason="Two possible causes identified",
|
||||
options=[
|
||||
{"label": "Network connectivity", "description": "Check network stack"},
|
||||
{"label": "DNS resolution", "description": "Check DNS config"},
|
||||
],
|
||||
)
|
||||
|
||||
assert fork_point is not None
|
||||
assert len(branches) == 2
|
||||
assert branches[0].label == "Network connectivity"
|
||||
assert branches[0].status == "untried"
|
||||
assert branches[0].parent_branch_id == root.id
|
||||
assert branches[1].label == "DNS resolution"
|
||||
assert branches[1].branch_order == 2
|
||||
|
||||
await test_db.refresh(step)
|
||||
assert step.is_fork_point is True
|
||||
assert step.fork_point_id == fork_point.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_branch(client: AsyncClient, test_user, auth_headers, test_db):
|
||||
"""Switching branches updates active_branch_id."""
|
||||
session = AISession(
|
||||
user_id=test_user["user_data"]["id"],
|
||||
account_id=test_user["user_data"]["account_id"],
|
||||
session_type="guided",
|
||||
intake_type="free_text",
|
||||
intake_content={"text": "test"},
|
||||
status="active",
|
||||
confidence_tier="discovery",
|
||||
conversation_messages=[],
|
||||
)
|
||||
test_db.add(session)
|
||||
await test_db.flush()
|
||||
|
||||
from app.services.branch_manager import BranchManager
|
||||
manager = BranchManager(test_db)
|
||||
root = await manager.create_root_branch(session.id)
|
||||
|
||||
step = AISessionStep(
|
||||
session_id=session.id, step_order=0, step_type="question",
|
||||
content={"text": "test"}, confidence_at_step=0.5,
|
||||
)
|
||||
test_db.add(step)
|
||||
await test_db.flush()
|
||||
|
||||
_, branches = await manager.create_fork(
|
||||
session_id=session.id,
|
||||
parent_branch_id=root.id,
|
||||
trigger_step_id=step.id,
|
||||
fork_reason="test fork",
|
||||
options=[
|
||||
{"label": "Option A", "description": "desc A"},
|
||||
{"label": "Option B", "description": "desc B"},
|
||||
],
|
||||
)
|
||||
|
||||
branch_b = branches[1]
|
||||
result = await manager.switch_branch(session.id, branch_b.id)
|
||||
|
||||
assert result.id == branch_b.id
|
||||
await test_db.refresh(session)
|
||||
assert session.active_branch_id == branch_b.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_branch_dead_end(client: AsyncClient, test_user, auth_headers, test_db):
|
||||
"""Marking a branch as dead_end updates status."""
|
||||
session = AISession(
|
||||
user_id=test_user["user_data"]["id"],
|
||||
account_id=test_user["user_data"]["account_id"],
|
||||
session_type="guided",
|
||||
intake_type="free_text",
|
||||
intake_content={"text": "test"},
|
||||
status="active",
|
||||
confidence_tier="discovery",
|
||||
conversation_messages=[],
|
||||
)
|
||||
test_db.add(session)
|
||||
await test_db.flush()
|
||||
|
||||
from app.services.branch_manager import BranchManager
|
||||
manager = BranchManager(test_db)
|
||||
root = await manager.create_root_branch(session.id)
|
||||
|
||||
updated = await manager.mark_branch_status(
|
||||
branch_id=root.id,
|
||||
status="dead_end",
|
||||
reason="Network was fine, not the cause",
|
||||
user_id=test_user["user_data"]["id"],
|
||||
)
|
||||
|
||||
assert updated.status == "dead_end"
|
||||
assert updated.status_reason == "Network was fine, not the cause"
|
||||
assert updated.status_changed_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_branch_tree(client: AsyncClient, test_user, auth_headers, test_db):
|
||||
"""get_branch_tree returns the full tree structure."""
|
||||
session = AISession(
|
||||
user_id=test_user["user_data"]["id"],
|
||||
account_id=test_user["user_data"]["account_id"],
|
||||
session_type="guided",
|
||||
intake_type="free_text",
|
||||
intake_content={"text": "test"},
|
||||
status="active",
|
||||
confidence_tier="discovery",
|
||||
conversation_messages=[{"role": "user", "content": "help"}],
|
||||
)
|
||||
test_db.add(session)
|
||||
await test_db.flush()
|
||||
|
||||
from app.services.branch_manager import BranchManager
|
||||
manager = BranchManager(test_db)
|
||||
root = await manager.create_root_branch(session.id)
|
||||
|
||||
step = AISessionStep(
|
||||
session_id=session.id, step_order=0, step_type="question",
|
||||
content={"text": "test"}, confidence_at_step=0.5,
|
||||
)
|
||||
test_db.add(step)
|
||||
await test_db.flush()
|
||||
|
||||
await manager.create_fork(
|
||||
session_id=session.id,
|
||||
parent_branch_id=root.id,
|
||||
trigger_step_id=step.id,
|
||||
fork_reason="test",
|
||||
options=[
|
||||
{"label": "A", "description": "a"},
|
||||
{"label": "B", "description": "b"},
|
||||
],
|
||||
)
|
||||
|
||||
tree = await manager.get_branch_tree(session.id)
|
||||
assert len(tree) == 3 # Root + 2 fork branches
|
||||
Reference in New Issue
Block a user