220 lines
7.1 KiB
Python
220 lines
7.1 KiB
Python
"""Integration tests for BranchManager service."""
|
|
import uuid
|
|
import pytest
|
|
from httpx import AsyncClient
|
|
|
|
from app.models.ai_session import AISession
|
|
from app.models.session_branch import SessionBranch
|
|
from app.models.fork_point import ForkPoint
|
|
from app.models.ai_session_step import AISessionStep
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_root_branch(client: AsyncClient, test_user, auth_headers, test_db):
|
|
"""Creating a root branch sets is_branching=True and copies conversation_messages."""
|
|
session = AISession(
|
|
user_id=test_user["user_data"]["id"],
|
|
account_id=test_user["user_data"]["account_id"],
|
|
session_type="guided",
|
|
intake_type="free_text",
|
|
intake_content={"text": "test"},
|
|
status="active",
|
|
confidence_tier="discovery",
|
|
conversation_messages=[
|
|
{"role": "user", "content": "test message"},
|
|
{"role": "assistant", "content": "test response"},
|
|
],
|
|
)
|
|
test_db.add(session)
|
|
await test_db.flush()
|
|
|
|
from app.services.branch_manager import BranchManager
|
|
manager = BranchManager(test_db)
|
|
root = await manager.create_root_branch(session.id)
|
|
|
|
assert root is not None
|
|
assert root.parent_branch_id is None
|
|
assert root.label == "Root"
|
|
assert root.status == "active"
|
|
assert root.branch_order == 1
|
|
assert len(root.conversation_messages) == 2
|
|
|
|
await test_db.refresh(session)
|
|
assert session.is_branching is True
|
|
assert session.active_branch_id == root.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_fork(client: AsyncClient, test_user, auth_headers, test_db):
|
|
"""Creating a fork produces a ForkPoint + N branches."""
|
|
session = AISession(
|
|
user_id=test_user["user_data"]["id"],
|
|
account_id=test_user["user_data"]["account_id"],
|
|
session_type="guided",
|
|
intake_type="free_text",
|
|
intake_content={"text": "test"},
|
|
status="active",
|
|
confidence_tier="discovery",
|
|
conversation_messages=[],
|
|
)
|
|
test_db.add(session)
|
|
await test_db.flush()
|
|
|
|
from app.services.branch_manager import BranchManager
|
|
manager = BranchManager(test_db)
|
|
root = await manager.create_root_branch(session.id)
|
|
|
|
step = AISessionStep(
|
|
session_id=session.id,
|
|
account_id=session.account_id,
|
|
step_order=0,
|
|
step_type="question",
|
|
content={"text": "What's the issue?"},
|
|
confidence_at_step=0.5,
|
|
)
|
|
test_db.add(step)
|
|
await test_db.flush()
|
|
|
|
fork_point, branches = await manager.create_fork(
|
|
session_id=session.id,
|
|
parent_branch_id=root.id,
|
|
trigger_step_id=step.id,
|
|
fork_reason="Two possible causes identified",
|
|
options=[
|
|
{"label": "Network connectivity", "description": "Check network stack"},
|
|
{"label": "DNS resolution", "description": "Check DNS config"},
|
|
],
|
|
)
|
|
|
|
assert fork_point is not None
|
|
assert len(branches) == 2
|
|
assert branches[0].label == "Network connectivity"
|
|
assert branches[0].status == "untried"
|
|
assert branches[0].parent_branch_id == root.id
|
|
assert branches[1].label == "DNS resolution"
|
|
assert branches[1].branch_order == 2
|
|
|
|
await test_db.refresh(step)
|
|
assert step.is_fork_point is True
|
|
assert step.fork_point_id == fork_point.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_switch_branch(client: AsyncClient, test_user, auth_headers, test_db):
|
|
"""Switching branches updates active_branch_id."""
|
|
session = AISession(
|
|
user_id=test_user["user_data"]["id"],
|
|
account_id=test_user["user_data"]["account_id"],
|
|
session_type="guided",
|
|
intake_type="free_text",
|
|
intake_content={"text": "test"},
|
|
status="active",
|
|
confidence_tier="discovery",
|
|
conversation_messages=[],
|
|
)
|
|
test_db.add(session)
|
|
await test_db.flush()
|
|
|
|
from app.services.branch_manager import BranchManager
|
|
manager = BranchManager(test_db)
|
|
root = await manager.create_root_branch(session.id)
|
|
|
|
step = AISessionStep(
|
|
session_id=session.id, account_id=session.account_id, step_order=0, step_type="question",
|
|
content={"text": "test"}, confidence_at_step=0.5,
|
|
)
|
|
test_db.add(step)
|
|
await test_db.flush()
|
|
|
|
_, branches = await manager.create_fork(
|
|
session_id=session.id,
|
|
parent_branch_id=root.id,
|
|
trigger_step_id=step.id,
|
|
fork_reason="test fork",
|
|
options=[
|
|
{"label": "Option A", "description": "desc A"},
|
|
{"label": "Option B", "description": "desc B"},
|
|
],
|
|
)
|
|
|
|
branch_b = branches[1]
|
|
result = await manager.switch_branch(session.id, branch_b.id)
|
|
|
|
assert result.id == branch_b.id
|
|
await test_db.refresh(session)
|
|
assert session.active_branch_id == branch_b.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mark_branch_dead_end(client: AsyncClient, test_user, auth_headers, test_db):
|
|
"""Marking a branch as dead_end updates status."""
|
|
session = AISession(
|
|
user_id=test_user["user_data"]["id"],
|
|
account_id=test_user["user_data"]["account_id"],
|
|
session_type="guided",
|
|
intake_type="free_text",
|
|
intake_content={"text": "test"},
|
|
status="active",
|
|
confidence_tier="discovery",
|
|
conversation_messages=[],
|
|
)
|
|
test_db.add(session)
|
|
await test_db.flush()
|
|
|
|
from app.services.branch_manager import BranchManager
|
|
manager = BranchManager(test_db)
|
|
root = await manager.create_root_branch(session.id)
|
|
|
|
updated = await manager.mark_branch_status(
|
|
branch_id=root.id,
|
|
status="dead_end",
|
|
reason="Network was fine, not the cause",
|
|
user_id=test_user["user_data"]["id"],
|
|
)
|
|
|
|
assert updated.status == "dead_end"
|
|
assert updated.status_reason == "Network was fine, not the cause"
|
|
assert updated.status_changed_at is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_branch_tree(client: AsyncClient, test_user, auth_headers, test_db):
|
|
"""get_branch_tree returns the full tree structure."""
|
|
session = AISession(
|
|
user_id=test_user["user_data"]["id"],
|
|
account_id=test_user["user_data"]["account_id"],
|
|
session_type="guided",
|
|
intake_type="free_text",
|
|
intake_content={"text": "test"},
|
|
status="active",
|
|
confidence_tier="discovery",
|
|
conversation_messages=[{"role": "user", "content": "help"}],
|
|
)
|
|
test_db.add(session)
|
|
await test_db.flush()
|
|
|
|
from app.services.branch_manager import BranchManager
|
|
manager = BranchManager(test_db)
|
|
root = await manager.create_root_branch(session.id)
|
|
|
|
step = AISessionStep(
|
|
session_id=session.id, account_id=session.account_id, step_order=0, step_type="question",
|
|
content={"text": "test"}, confidence_at_step=0.5,
|
|
)
|
|
test_db.add(step)
|
|
await test_db.flush()
|
|
|
|
await manager.create_fork(
|
|
session_id=session.id,
|
|
parent_branch_id=root.id,
|
|
trigger_step_id=step.id,
|
|
fork_reason="test",
|
|
options=[
|
|
{"label": "A", "description": "a"},
|
|
{"label": "B", "description": "b"},
|
|
],
|
|
)
|
|
|
|
tree = await manager.get_branch_tree(session.id)
|
|
assert len(tree) == 3 # Root + 2 fork branches
|