diff --git a/backend/app/services/branch_manager.py b/backend/app/services/branch_manager.py new file mode 100644 index 00000000..8dba3fa4 --- /dev/null +++ b/backend/app/services/branch_manager.py @@ -0,0 +1,238 @@ +"""Branch lifecycle management for conversational branching.""" +import uuid +import logging +from datetime import datetime, timezone +from typing import Any +from uuid import UUID + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.ai_session import AISession +from app.models.ai_session_step import AISessionStep +from app.models.session_branch import SessionBranch +from app.models.fork_point import ForkPoint + +logger = logging.getLogger(__name__) + + +class BranchManager: + """Branch lifecycle management.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def create_root_branch(self, session_id: UUID) -> SessionBranch: + """Create the root branch, copy conversation_messages, set is_branching=True.""" + result = await self.db.execute( + select(AISession).where(AISession.id == session_id) + ) + session = result.scalar_one_or_none() + if not session: + raise ValueError(f"Session {session_id} not found") + + root = SessionBranch( + id=uuid.uuid4(), + session_id=session_id, + parent_branch_id=None, + branch_order=1, + label="Root", + status="active", + conversation_messages=list(session.conversation_messages or []), + ) + self.db.add(root) + + session.is_branching = True + session.active_branch_id = root.id + + await self.db.flush() + return root + + async def create_fork( + self, + session_id: UUID, + parent_branch_id: UUID, + trigger_step_id: UUID | None, + fork_reason: str, + options: list[dict[str, str]], + ) -> tuple[ForkPoint, list[SessionBranch]]: + """Create a fork point with N branches.""" + branch_ids = [uuid.uuid4() for _ in options] + + fork_options = [] + for i, opt in enumerate(options): + fork_options.append({ + "label": opt["label"], + "description": opt["description"], + "branch_id": str(branch_ids[i]), + "status": "untried", + }) + + fork_point = ForkPoint( + id=uuid.uuid4(), + session_id=session_id, + parent_branch_id=parent_branch_id, + trigger_step_id=trigger_step_id, + fork_reason=fork_reason, + options=fork_options, + ) + self.db.add(fork_point) + + # Get parent branch messages for context inheritance + result = await self.db.execute( + select(SessionBranch).where(SessionBranch.id == parent_branch_id) + ) + parent = result.scalar_one_or_none() + parent_messages = list(parent.conversation_messages or []) if parent else [] + + branches = [] + for i, opt in enumerate(options): + branch = SessionBranch( + id=branch_ids[i], + session_id=session_id, + parent_branch_id=parent_branch_id, + fork_point_step_id=trigger_step_id, + branch_order=i + 1, + label=opt["label"], + status="untried", + conversation_messages=parent_messages, + ) + self.db.add(branch) + branches.append(branch) + + # Mark trigger step as fork point + if trigger_step_id: + step_result = await self.db.execute( + select(AISessionStep).where(AISessionStep.id == trigger_step_id) + ) + step = step_result.scalar_one_or_none() + if step: + step.is_fork_point = True + step.fork_point_id = fork_point.id + + await self.db.flush() + return fork_point, branches + + async def switch_branch(self, session_id: UUID, target_branch_id: UUID) -> SessionBranch: + """Switch the active branch for a session.""" + result = await self.db.execute( + select(SessionBranch).where( + SessionBranch.id == target_branch_id, + SessionBranch.session_id == session_id, + ) + ) + branch = result.scalar_one_or_none() + if not branch: + raise ValueError(f"Branch {target_branch_id} not found in session {session_id}") + + session_result = await self.db.execute( + select(AISession).where(AISession.id == session_id) + ) + session = session_result.scalar_one() + session.active_branch_id = target_branch_id + + if branch.status == "untried": + branch.status = "active" + branch.status_changed_at = datetime.now(timezone.utc) + + await self.db.flush() + return branch + + async def mark_branch_status( + self, + branch_id: UUID, + status: str, + reason: str | None = None, + user_id: UUID | None = None, + ) -> SessionBranch: + """Update a branch's status.""" + result = await self.db.execute( + select(SessionBranch).where(SessionBranch.id == branch_id) + ) + branch = result.scalar_one_or_none() + if not branch: + raise ValueError(f"Branch {branch_id} not found") + + branch.status = status + branch.status_reason = reason + branch.status_changed_at = datetime.now(timezone.utc) + branch.status_changed_by = user_id + + await self.db.flush() + return branch + + async def revive_branch( + self, + branch_id: UUID, + evidence_from_branch_id: UUID, + evidence_description: str, + ) -> SessionBranch: + """Revive a dead-end branch with evidence from another branch.""" + result = await self.db.execute( + select(SessionBranch).where(SessionBranch.id == branch_id) + ) + branch = result.scalar_one_or_none() + if not branch: + raise ValueError(f"Branch {branch_id} not found") + + branch.status = "revived" + branch.status_changed_at = datetime.now(timezone.utc) + branch.evidence_from_branch_id = evidence_from_branch_id + branch.evidence_description = evidence_description + + revival_msg = { + "role": "system", + "content": f"[Branch Revived] New evidence from another branch: {evidence_description}", + } + msgs = list(branch.conversation_messages or []) + msgs.append(revival_msg) + branch.conversation_messages = msgs + + await self.db.flush() + return branch + + async def get_branch_tree(self, session_id: UUID) -> list[SessionBranch]: + """Get all branches for a session, ordered by branch_order.""" + result = await self.db.execute( + select(SessionBranch) + .where(SessionBranch.session_id == session_id) + .order_by(SessionBranch.branch_order) + ) + return list(result.scalars().all()) + + async def build_cross_branch_context(self, branch_id: UUID) -> str: + """Build cross-branch context from sibling summaries.""" + result = await self.db.execute( + select(SessionBranch).where(SessionBranch.id == branch_id) + ) + branch = result.scalar_one_or_none() + if not branch: + return "" + + siblings_result = await self.db.execute( + select(SessionBranch) + .where( + SessionBranch.session_id == branch.session_id, + SessionBranch.id != branch_id, + ) + .order_by(SessionBranch.branch_order) + ) + siblings = list(siblings_result.scalars().all()) + + if not siblings: + return "" + + priority = {"active": 0, "untried": 1, "revived": 2, "dead_end": 3, "solved": 4} + siblings.sort(key=lambda b: priority.get(b.status, 5)) + + parts = ["\n## Cross-Branch Context"] + for sib in siblings: + summary = sib.context_summary + if summary: + tried = ", ".join(summary.get("tried", [])) + concluded = summary.get("concluded", "No conclusion yet") + parts.append(f"- **{sib.label}** [{sib.status}]: Tried: {tried}. {concluded}") + else: + parts.append(f"- **{sib.label}** [{sib.status}]: No summary yet.") + + return "\n".join(parts) diff --git a/backend/tests/test_branch_manager.py b/backend/tests/test_branch_manager.py new file mode 100644 index 00000000..923ee6b1 --- /dev/null +++ b/backend/tests/test_branch_manager.py @@ -0,0 +1,218 @@ +"""Integration tests for BranchManager service.""" +import uuid +import pytest +from httpx import AsyncClient + +from app.models.ai_session import AISession +from app.models.session_branch import SessionBranch +from app.models.fork_point import ForkPoint +from app.models.ai_session_step import AISessionStep + + +@pytest.mark.asyncio +async def test_create_root_branch(client: AsyncClient, test_user, auth_headers, test_db): + """Creating a root branch sets is_branching=True and copies conversation_messages.""" + session = AISession( + user_id=test_user["user_data"]["id"], + account_id=test_user["user_data"]["account_id"], + session_type="guided", + intake_type="free_text", + intake_content={"text": "test"}, + status="active", + confidence_tier="discovery", + conversation_messages=[ + {"role": "user", "content": "test message"}, + {"role": "assistant", "content": "test response"}, + ], + ) + test_db.add(session) + await test_db.flush() + + from app.services.branch_manager import BranchManager + manager = BranchManager(test_db) + root = await manager.create_root_branch(session.id) + + assert root is not None + assert root.parent_branch_id is None + assert root.label == "Root" + assert root.status == "active" + assert root.branch_order == 1 + assert len(root.conversation_messages) == 2 + + await test_db.refresh(session) + assert session.is_branching is True + assert session.active_branch_id == root.id + + +@pytest.mark.asyncio +async def test_create_fork(client: AsyncClient, test_user, auth_headers, test_db): + """Creating a fork produces a ForkPoint + N branches.""" + session = AISession( + user_id=test_user["user_data"]["id"], + account_id=test_user["user_data"]["account_id"], + session_type="guided", + intake_type="free_text", + intake_content={"text": "test"}, + status="active", + confidence_tier="discovery", + conversation_messages=[], + ) + test_db.add(session) + await test_db.flush() + + from app.services.branch_manager import BranchManager + manager = BranchManager(test_db) + root = await manager.create_root_branch(session.id) + + step = AISessionStep( + session_id=session.id, + step_order=0, + step_type="question", + content={"text": "What's the issue?"}, + confidence_at_step=0.5, + ) + test_db.add(step) + await test_db.flush() + + fork_point, branches = await manager.create_fork( + session_id=session.id, + parent_branch_id=root.id, + trigger_step_id=step.id, + fork_reason="Two possible causes identified", + options=[ + {"label": "Network connectivity", "description": "Check network stack"}, + {"label": "DNS resolution", "description": "Check DNS config"}, + ], + ) + + assert fork_point is not None + assert len(branches) == 2 + assert branches[0].label == "Network connectivity" + assert branches[0].status == "untried" + assert branches[0].parent_branch_id == root.id + assert branches[1].label == "DNS resolution" + assert branches[1].branch_order == 2 + + await test_db.refresh(step) + assert step.is_fork_point is True + assert step.fork_point_id == fork_point.id + + +@pytest.mark.asyncio +async def test_switch_branch(client: AsyncClient, test_user, auth_headers, test_db): + """Switching branches updates active_branch_id.""" + session = AISession( + user_id=test_user["user_data"]["id"], + account_id=test_user["user_data"]["account_id"], + session_type="guided", + intake_type="free_text", + intake_content={"text": "test"}, + status="active", + confidence_tier="discovery", + conversation_messages=[], + ) + test_db.add(session) + await test_db.flush() + + from app.services.branch_manager import BranchManager + manager = BranchManager(test_db) + root = await manager.create_root_branch(session.id) + + step = AISessionStep( + session_id=session.id, step_order=0, step_type="question", + content={"text": "test"}, confidence_at_step=0.5, + ) + test_db.add(step) + await test_db.flush() + + _, branches = await manager.create_fork( + session_id=session.id, + parent_branch_id=root.id, + trigger_step_id=step.id, + fork_reason="test fork", + options=[ + {"label": "Option A", "description": "desc A"}, + {"label": "Option B", "description": "desc B"}, + ], + ) + + branch_b = branches[1] + result = await manager.switch_branch(session.id, branch_b.id) + + assert result.id == branch_b.id + await test_db.refresh(session) + assert session.active_branch_id == branch_b.id + + +@pytest.mark.asyncio +async def test_mark_branch_dead_end(client: AsyncClient, test_user, auth_headers, test_db): + """Marking a branch as dead_end updates status.""" + session = AISession( + user_id=test_user["user_data"]["id"], + account_id=test_user["user_data"]["account_id"], + session_type="guided", + intake_type="free_text", + intake_content={"text": "test"}, + status="active", + confidence_tier="discovery", + conversation_messages=[], + ) + test_db.add(session) + await test_db.flush() + + from app.services.branch_manager import BranchManager + manager = BranchManager(test_db) + root = await manager.create_root_branch(session.id) + + updated = await manager.mark_branch_status( + branch_id=root.id, + status="dead_end", + reason="Network was fine, not the cause", + user_id=test_user["user_data"]["id"], + ) + + assert updated.status == "dead_end" + assert updated.status_reason == "Network was fine, not the cause" + assert updated.status_changed_at is not None + + +@pytest.mark.asyncio +async def test_get_branch_tree(client: AsyncClient, test_user, auth_headers, test_db): + """get_branch_tree returns the full tree structure.""" + session = AISession( + user_id=test_user["user_data"]["id"], + account_id=test_user["user_data"]["account_id"], + session_type="guided", + intake_type="free_text", + intake_content={"text": "test"}, + status="active", + confidence_tier="discovery", + conversation_messages=[{"role": "user", "content": "help"}], + ) + test_db.add(session) + await test_db.flush() + + from app.services.branch_manager import BranchManager + manager = BranchManager(test_db) + root = await manager.create_root_branch(session.id) + + step = AISessionStep( + session_id=session.id, step_order=0, step_type="question", + content={"text": "test"}, confidence_at_step=0.5, + ) + test_db.add(step) + await test_db.flush() + + await manager.create_fork( + session_id=session.id, + parent_branch_id=root.id, + trigger_step_id=step.id, + fork_reason="test", + options=[ + {"label": "A", "description": "a"}, + {"label": "B", "description": "b"}, + ], + ) + + tree = await manager.get_branch_tree(session.id) + assert len(tree) == 3 # Root + 2 fork branches