feat: add branch API endpoints with integration tests

Six REST endpoints for branch lifecycle management (list, fork, update
status, switch, revive, branch-message) with BranchManager + BranchAwarePromptBuilder
integration. Registered session_branches router in api/router.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-03-24 08:41:17 +00:00
parent d8312c24a5
commit 395f157578
3 changed files with 394 additions and 0 deletions

View File

@@ -0,0 +1,274 @@
"""Branch management endpoints for conversational branching.
GET /ai-sessions/{id}/branches — List all branches (tree)
POST /ai-sessions/{id}/branches/fork — Create fork with N branches
PATCH /ai-sessions/{id}/branches/{bid} — Update branch status
POST /ai-sessions/{id}/branches/{bid}/switch — Switch active branch
POST /ai-sessions/{id}/branches/{bid}/revive — Revive dead-end branch
POST /ai-sessions/{id}/branches/{bid}/message — Send message on branch
"""
import logging
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_active_user, get_db
from app.models.user import User
from app.models.ai_session import AISession
from app.models.session_branch import SessionBranch
from app.services.branch_manager import BranchManager
from app.services.branch_aware_prompt_builder import BranchAwarePromptBuilder
from app.services.assistant_chat_service import _call_ai
from app.schemas.session_branch import (
BranchTreeResponse,
BranchResponse,
BranchUpdate,
ForkCreateRequest,
ForkPointResponse,
BranchSwitchResponse,
ReviveRequest,
BranchMessageRequest,
BranchMessageResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ai-sessions/{session_id}/branches", tags=["session-branches"])
async def _get_user_session(
session_id: UUID, user: User, db: AsyncSession
) -> AISession:
"""Fetch session owned by user, or raise 404."""
result = await db.execute(
select(AISession).where(
AISession.id == session_id,
AISession.user_id == user.id,
)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
return session
@router.get("", response_model=BranchTreeResponse)
async def list_branches(
session_id: UUID,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> BranchTreeResponse:
"""Get branch tree for a session."""
session = await _get_user_session(session_id, current_user, db)
manager = BranchManager(db)
branches = await manager.get_branch_tree(session_id)
branch_responses = []
for b in branches:
branch_responses.append(BranchResponse(
id=b.id,
session_id=b.session_id,
parent_branch_id=b.parent_branch_id,
fork_point_step_id=b.fork_point_step_id,
branch_order=b.branch_order,
label=b.label,
status=b.status,
status_reason=b.status_reason,
status_changed_at=b.status_changed_at,
context_summary=b.context_summary,
evidence_from_branch_id=b.evidence_from_branch_id,
evidence_description=b.evidence_description,
created_at=b.created_at,
updated_at=b.updated_at,
))
return BranchTreeResponse(
branches=branch_responses,
active_branch_id=session.active_branch_id,
)
@router.post("/fork", response_model=ForkPointResponse, status_code=status.HTTP_201_CREATED)
async def create_fork(
session_id: UUID,
body: ForkCreateRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> ForkPointResponse:
"""Create a fork point with N branches."""
session = await _get_user_session(session_id, current_user, db)
if session.status not in ("active", "paused"):
raise HTTPException(status_code=400, detail=f"Cannot fork a {session.status} session")
manager = BranchManager(db)
# Ensure branching is initialized
if not session.is_branching:
await manager.create_root_branch(session_id)
await db.refresh(session)
# Use the active branch as parent
parent_branch_id = session.active_branch_id
if not parent_branch_id:
raise HTTPException(status_code=400, detail="No active branch to fork from")
options = [{"label": o.label, "description": o.description} for o in body.options]
fork_point, branches = await manager.create_fork(
session_id=session_id,
parent_branch_id=parent_branch_id,
trigger_step_id=None,
fork_reason=body.fork_reason,
options=options,
)
await db.commit()
return ForkPointResponse.model_validate(fork_point)
@router.patch("/{branch_id}", response_model=BranchResponse)
async def update_branch_status(
session_id: UUID,
branch_id: UUID,
body: BranchUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> BranchResponse:
"""Update a branch's status."""
await _get_user_session(session_id, current_user, db)
manager = BranchManager(db)
try:
branch = await manager.mark_branch_status(
branch_id=branch_id,
status=body.status,
reason=body.status_reason,
user_id=current_user.id,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
await db.commit()
return BranchResponse.model_validate(branch)
@router.post("/{branch_id}/switch", response_model=BranchSwitchResponse)
async def switch_branch(
session_id: UUID,
branch_id: UUID,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> BranchSwitchResponse:
"""Switch the active branch."""
await _get_user_session(session_id, current_user, db)
manager = BranchManager(db)
try:
branch = await manager.switch_branch(session_id, branch_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
await db.commit()
return BranchSwitchResponse(
active_branch_id=branch.id,
branch=BranchResponse.model_validate(branch),
conversation_messages=branch.conversation_messages,
)
@router.post("/{branch_id}/revive", response_model=BranchResponse)
async def revive_branch(
session_id: UUID,
branch_id: UUID,
body: ReviveRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> BranchResponse:
"""Revive a dead-end branch with new evidence."""
await _get_user_session(session_id, current_user, db)
manager = BranchManager(db)
try:
branch = await manager.revive_branch(
branch_id=branch_id,
evidence_from_branch_id=body.evidence_from_branch_id,
evidence_description=body.evidence_description,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
await db.commit()
return BranchResponse.model_validate(branch)
@router.post("/{branch_id}/message", response_model=BranchMessageResponse)
async def send_branch_message(
session_id: UUID,
branch_id: UUID,
body: BranchMessageRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> BranchMessageResponse:
"""Send a message on a specific branch."""
session = await _get_user_session(session_id, current_user, db)
if session.status not in ("active", "paused"):
raise HTTPException(status_code=400, detail=f"Cannot message a {session.status} session")
manager = BranchManager(db)
# Switch to branch if not already active
if session.active_branch_id != branch_id:
await manager.switch_branch(session_id, branch_id)
await db.refresh(session)
# Get branch
result = await db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
raise HTTPException(status_code=404, detail="Branch not found")
# Build cross-branch context
sibling_ctx = await manager.build_cross_branch_context(branch_id)
# Build prompt
builder = BranchAwarePromptBuilder()
session_context = f"Problem: {session.problem_summary or 'Unknown'}. Domain: {session.problem_domain or 'Unknown'}."
prompt_args = builder.build(
branch_messages=branch.conversation_messages,
sibling_summaries=sibling_ctx,
session_context=session_context,
attachments=[],
new_message=body.message,
revival_context=branch.evidence_description if branch.status == "revived" else None,
)
# Call AI
ai_content, input_tokens, output_tokens = await _call_ai(**prompt_args)
# Update branch conversation
msgs = list(branch.conversation_messages or [])
msgs.append({"role": "user", "content": body.message})
msgs.append({"role": "assistant", "content": ai_content})
branch.conversation_messages = msgs
# Update session token counts
session.total_input_tokens += input_tokens
session.total_output_tokens += output_tokens
# Resume if paused
if session.status == "paused":
session.status = "active"
await db.commit()
return BranchMessageResponse(
content=ai_content,
branch_id=branch_id,
)

View File

@@ -30,6 +30,7 @@ from app.api.endpoints import admin_gallery
from app.api.endpoints import uploads
from app.api.endpoints import script_builder
from app.api.endpoints import beta_feedback
from app.api.endpoints import session_branches
api_router = APIRouter()
@@ -85,3 +86,4 @@ api_router.include_router(admin_gallery.router)
api_router.include_router(uploads.router)
api_router.include_router(script_builder.router)
api_router.include_router(beta_feedback.router)
api_router.include_router(session_branches.router)

View File

@@ -0,0 +1,118 @@
"""API endpoint tests for session branches."""
import pytest
from httpx import AsyncClient
from app.models.ai_session import AISession
from app.models.ai_session_step import AISessionStep
@pytest.mark.asyncio
async def test_list_branches_empty(client: AsyncClient, test_user, auth_headers, test_db):
"""GET /ai-sessions/{id}/branches returns empty for non-branching session."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[],
)
test_db.add(session)
await test_db.commit()
resp = await client.get(
f"/api/v1/ai-sessions/{session.id}/branches",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["branches"] == []
assert data["active_branch_id"] is None
@pytest.mark.asyncio
async def test_create_fork(client: AsyncClient, test_user, auth_headers, test_db):
"""POST /ai-sessions/{id}/branches/fork creates branches."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[{"role": "user", "content": "help"}],
)
test_db.add(session)
await test_db.flush()
step = AISessionStep(
session_id=session.id, step_order=0, step_type="question",
content={"text": "test"}, confidence_at_step=0.5,
)
test_db.add(step)
await test_db.commit()
resp = await client.post(
f"/api/v1/ai-sessions/{session.id}/branches/fork",
headers=auth_headers,
json={
"fork_reason": "Two possible causes",
"options": [
{"label": "Network issue", "description": "Check connectivity"},
{"label": "DNS problem", "description": "Check DNS"},
],
},
)
assert resp.status_code == 201
data = resp.json()
assert len(data["options"]) == 2
@pytest.mark.asyncio
async def test_switch_branch(client: AsyncClient, test_user, auth_headers, test_db):
"""POST /ai-sessions/{id}/branches/{bid}/switch changes active branch."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[{"role": "user", "content": "help"}],
)
test_db.add(session)
await test_db.flush()
step = AISessionStep(
session_id=session.id, step_order=0, step_type="question",
content={"text": "test"}, confidence_at_step=0.5,
)
test_db.add(step)
await test_db.commit()
# Create fork first
fork_resp = await client.post(
f"/api/v1/ai-sessions/{session.id}/branches/fork",
headers=auth_headers,
json={
"fork_reason": "test",
"options": [
{"label": "A", "description": "a"},
{"label": "B", "description": "b"},
],
},
)
fork_data = fork_resp.json()
branch_b_id = fork_data["options"][1]["branch_id"]
# Switch to branch B
resp = await client.post(
f"/api/v1/ai-sessions/{session.id}/branches/{branch_b_id}/switch",
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["active_branch_id"] == branch_b_id