"""Integration tests for BranchManager service.""" import uuid import pytest from httpx import AsyncClient from app.models.ai_session import AISession from app.models.session_branch import SessionBranch from app.models.fork_point import ForkPoint from app.models.ai_session_step import AISessionStep @pytest.mark.asyncio async def test_create_root_branch(client: AsyncClient, test_user, auth_headers, test_db): """Creating a root branch sets is_branching=True and copies conversation_messages.""" session = AISession( user_id=test_user["user_data"]["id"], account_id=test_user["user_data"]["account_id"], session_type="guided", intake_type="free_text", intake_content={"text": "test"}, status="active", confidence_tier="discovery", conversation_messages=[ {"role": "user", "content": "test message"}, {"role": "assistant", "content": "test response"}, ], ) test_db.add(session) await test_db.flush() from app.services.branch_manager import BranchManager manager = BranchManager(test_db) root = await manager.create_root_branch(session.id) assert root is not None assert root.parent_branch_id is None assert root.label == "Root" assert root.status == "active" assert root.branch_order == 1 assert len(root.conversation_messages) == 2 await test_db.refresh(session) assert session.is_branching is True assert session.active_branch_id == root.id @pytest.mark.asyncio async def test_create_fork(client: AsyncClient, test_user, auth_headers, test_db): """Creating a fork produces a ForkPoint + N branches.""" session = AISession( user_id=test_user["user_data"]["id"], account_id=test_user["user_data"]["account_id"], session_type="guided", intake_type="free_text", intake_content={"text": "test"}, status="active", confidence_tier="discovery", conversation_messages=[], ) test_db.add(session) await test_db.flush() from app.services.branch_manager import BranchManager manager = BranchManager(test_db) root = await manager.create_root_branch(session.id) step = AISessionStep( session_id=session.id, step_order=0, step_type="question", content={"text": "What's the issue?"}, confidence_at_step=0.5, ) test_db.add(step) await test_db.flush() fork_point, branches = await manager.create_fork( session_id=session.id, parent_branch_id=root.id, trigger_step_id=step.id, fork_reason="Two possible causes identified", options=[ {"label": "Network connectivity", "description": "Check network stack"}, {"label": "DNS resolution", "description": "Check DNS config"}, ], ) assert fork_point is not None assert len(branches) == 2 assert branches[0].label == "Network connectivity" assert branches[0].status == "untried" assert branches[0].parent_branch_id == root.id assert branches[1].label == "DNS resolution" assert branches[1].branch_order == 2 await test_db.refresh(step) assert step.is_fork_point is True assert step.fork_point_id == fork_point.id @pytest.mark.asyncio async def test_switch_branch(client: AsyncClient, test_user, auth_headers, test_db): """Switching branches updates active_branch_id.""" session = AISession( user_id=test_user["user_data"]["id"], account_id=test_user["user_data"]["account_id"], session_type="guided", intake_type="free_text", intake_content={"text": "test"}, status="active", confidence_tier="discovery", conversation_messages=[], ) test_db.add(session) await test_db.flush() from app.services.branch_manager import BranchManager manager = BranchManager(test_db) root = await manager.create_root_branch(session.id) step = AISessionStep( session_id=session.id, step_order=0, step_type="question", content={"text": "test"}, confidence_at_step=0.5, ) test_db.add(step) await test_db.flush() _, branches = await manager.create_fork( session_id=session.id, parent_branch_id=root.id, trigger_step_id=step.id, fork_reason="test fork", options=[ {"label": "Option A", "description": "desc A"}, {"label": "Option B", "description": "desc B"}, ], ) branch_b = branches[1] result = await manager.switch_branch(session.id, branch_b.id) assert result.id == branch_b.id await test_db.refresh(session) assert session.active_branch_id == branch_b.id @pytest.mark.asyncio async def test_mark_branch_dead_end(client: AsyncClient, test_user, auth_headers, test_db): """Marking a branch as dead_end updates status.""" session = AISession( user_id=test_user["user_data"]["id"], account_id=test_user["user_data"]["account_id"], session_type="guided", intake_type="free_text", intake_content={"text": "test"}, status="active", confidence_tier="discovery", conversation_messages=[], ) test_db.add(session) await test_db.flush() from app.services.branch_manager import BranchManager manager = BranchManager(test_db) root = await manager.create_root_branch(session.id) updated = await manager.mark_branch_status( branch_id=root.id, status="dead_end", reason="Network was fine, not the cause", user_id=test_user["user_data"]["id"], ) assert updated.status == "dead_end" assert updated.status_reason == "Network was fine, not the cause" assert updated.status_changed_at is not None @pytest.mark.asyncio async def test_get_branch_tree(client: AsyncClient, test_user, auth_headers, test_db): """get_branch_tree returns the full tree structure.""" session = AISession( user_id=test_user["user_data"]["id"], account_id=test_user["user_data"]["account_id"], session_type="guided", intake_type="free_text", intake_content={"text": "test"}, status="active", confidence_tier="discovery", conversation_messages=[{"role": "user", "content": "help"}], ) test_db.add(session) await test_db.flush() from app.services.branch_manager import BranchManager manager = BranchManager(test_db) root = await manager.create_root_branch(session.id) step = AISessionStep( session_id=session.id, step_order=0, step_type="question", content={"text": "test"}, confidence_at_step=0.5, ) test_db.add(step) await test_db.flush() await manager.create_fork( session_id=session.id, parent_branch_id=root.id, trigger_step_id=step.id, fork_reason="test", options=[ {"label": "A", "description": "a"}, {"label": "B", "description": "b"}, ], ) tree = await manager.get_branch_tree(session.id) assert len(tree) == 3 # Root + 2 fork branches