feat: add branch API endpoints with integration tests
Six REST endpoints for branch lifecycle management (list, fork, update status, switch, revive, branch-message) with BranchManager + BranchAwarePromptBuilder integration. Registered session_branches router in api/router.py. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
274
backend/app/api/endpoints/session_branches.py
Normal file
274
backend/app/api/endpoints/session_branches.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Branch management endpoints for conversational branching.
|
||||
|
||||
GET /ai-sessions/{id}/branches — List all branches (tree)
|
||||
POST /ai-sessions/{id}/branches/fork — Create fork with N branches
|
||||
PATCH /ai-sessions/{id}/branches/{bid} — Update branch status
|
||||
POST /ai-sessions/{id}/branches/{bid}/switch — Switch active branch
|
||||
POST /ai-sessions/{id}/branches/{bid}/revive — Revive dead-end branch
|
||||
POST /ai-sessions/{id}/branches/{bid}/message — Send message on branch
|
||||
"""
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_active_user, get_db
|
||||
from app.models.user import User
|
||||
from app.models.ai_session import AISession
|
||||
from app.models.session_branch import SessionBranch
|
||||
from app.services.branch_manager import BranchManager
|
||||
from app.services.branch_aware_prompt_builder import BranchAwarePromptBuilder
|
||||
from app.services.assistant_chat_service import _call_ai
|
||||
from app.schemas.session_branch import (
|
||||
BranchTreeResponse,
|
||||
BranchResponse,
|
||||
BranchUpdate,
|
||||
ForkCreateRequest,
|
||||
ForkPointResponse,
|
||||
BranchSwitchResponse,
|
||||
ReviveRequest,
|
||||
BranchMessageRequest,
|
||||
BranchMessageResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ai-sessions/{session_id}/branches", tags=["session-branches"])
|
||||
|
||||
|
||||
async def _get_user_session(
|
||||
session_id: UUID, user: User, db: AsyncSession
|
||||
) -> AISession:
|
||||
"""Fetch session owned by user, or raise 404."""
|
||||
result = await db.execute(
|
||||
select(AISession).where(
|
||||
AISession.id == session_id,
|
||||
AISession.user_id == user.id,
|
||||
)
|
||||
)
|
||||
session = result.scalar_one_or_none()
|
||||
if not session:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
||||
return session
|
||||
|
||||
|
||||
@router.get("", response_model=BranchTreeResponse)
|
||||
async def list_branches(
|
||||
session_id: UUID,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> BranchTreeResponse:
|
||||
"""Get branch tree for a session."""
|
||||
session = await _get_user_session(session_id, current_user, db)
|
||||
manager = BranchManager(db)
|
||||
branches = await manager.get_branch_tree(session_id)
|
||||
|
||||
branch_responses = []
|
||||
for b in branches:
|
||||
branch_responses.append(BranchResponse(
|
||||
id=b.id,
|
||||
session_id=b.session_id,
|
||||
parent_branch_id=b.parent_branch_id,
|
||||
fork_point_step_id=b.fork_point_step_id,
|
||||
branch_order=b.branch_order,
|
||||
label=b.label,
|
||||
status=b.status,
|
||||
status_reason=b.status_reason,
|
||||
status_changed_at=b.status_changed_at,
|
||||
context_summary=b.context_summary,
|
||||
evidence_from_branch_id=b.evidence_from_branch_id,
|
||||
evidence_description=b.evidence_description,
|
||||
created_at=b.created_at,
|
||||
updated_at=b.updated_at,
|
||||
))
|
||||
|
||||
return BranchTreeResponse(
|
||||
branches=branch_responses,
|
||||
active_branch_id=session.active_branch_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/fork", response_model=ForkPointResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_fork(
|
||||
session_id: UUID,
|
||||
body: ForkCreateRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> ForkPointResponse:
|
||||
"""Create a fork point with N branches."""
|
||||
session = await _get_user_session(session_id, current_user, db)
|
||||
|
||||
if session.status not in ("active", "paused"):
|
||||
raise HTTPException(status_code=400, detail=f"Cannot fork a {session.status} session")
|
||||
|
||||
manager = BranchManager(db)
|
||||
|
||||
# Ensure branching is initialized
|
||||
if not session.is_branching:
|
||||
await manager.create_root_branch(session_id)
|
||||
await db.refresh(session)
|
||||
|
||||
# Use the active branch as parent
|
||||
parent_branch_id = session.active_branch_id
|
||||
if not parent_branch_id:
|
||||
raise HTTPException(status_code=400, detail="No active branch to fork from")
|
||||
|
||||
options = [{"label": o.label, "description": o.description} for o in body.options]
|
||||
|
||||
fork_point, branches = await manager.create_fork(
|
||||
session_id=session_id,
|
||||
parent_branch_id=parent_branch_id,
|
||||
trigger_step_id=None,
|
||||
fork_reason=body.fork_reason,
|
||||
options=options,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
return ForkPointResponse.model_validate(fork_point)
|
||||
|
||||
|
||||
@router.patch("/{branch_id}", response_model=BranchResponse)
|
||||
async def update_branch_status(
|
||||
session_id: UUID,
|
||||
branch_id: UUID,
|
||||
body: BranchUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> BranchResponse:
|
||||
"""Update a branch's status."""
|
||||
await _get_user_session(session_id, current_user, db)
|
||||
manager = BranchManager(db)
|
||||
|
||||
try:
|
||||
branch = await manager.mark_branch_status(
|
||||
branch_id=branch_id,
|
||||
status=body.status,
|
||||
reason=body.status_reason,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
await db.commit()
|
||||
return BranchResponse.model_validate(branch)
|
||||
|
||||
|
||||
@router.post("/{branch_id}/switch", response_model=BranchSwitchResponse)
|
||||
async def switch_branch(
|
||||
session_id: UUID,
|
||||
branch_id: UUID,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> BranchSwitchResponse:
|
||||
"""Switch the active branch."""
|
||||
await _get_user_session(session_id, current_user, db)
|
||||
manager = BranchManager(db)
|
||||
|
||||
try:
|
||||
branch = await manager.switch_branch(session_id, branch_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
await db.commit()
|
||||
return BranchSwitchResponse(
|
||||
active_branch_id=branch.id,
|
||||
branch=BranchResponse.model_validate(branch),
|
||||
conversation_messages=branch.conversation_messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{branch_id}/revive", response_model=BranchResponse)
|
||||
async def revive_branch(
|
||||
session_id: UUID,
|
||||
branch_id: UUID,
|
||||
body: ReviveRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> BranchResponse:
|
||||
"""Revive a dead-end branch with new evidence."""
|
||||
await _get_user_session(session_id, current_user, db)
|
||||
manager = BranchManager(db)
|
||||
|
||||
try:
|
||||
branch = await manager.revive_branch(
|
||||
branch_id=branch_id,
|
||||
evidence_from_branch_id=body.evidence_from_branch_id,
|
||||
evidence_description=body.evidence_description,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
await db.commit()
|
||||
return BranchResponse.model_validate(branch)
|
||||
|
||||
|
||||
@router.post("/{branch_id}/message", response_model=BranchMessageResponse)
|
||||
async def send_branch_message(
|
||||
session_id: UUID,
|
||||
branch_id: UUID,
|
||||
body: BranchMessageRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> BranchMessageResponse:
|
||||
"""Send a message on a specific branch."""
|
||||
session = await _get_user_session(session_id, current_user, db)
|
||||
|
||||
if session.status not in ("active", "paused"):
|
||||
raise HTTPException(status_code=400, detail=f"Cannot message a {session.status} session")
|
||||
|
||||
manager = BranchManager(db)
|
||||
|
||||
# Switch to branch if not already active
|
||||
if session.active_branch_id != branch_id:
|
||||
await manager.switch_branch(session_id, branch_id)
|
||||
await db.refresh(session)
|
||||
|
||||
# Get branch
|
||||
result = await db.execute(
|
||||
select(SessionBranch).where(SessionBranch.id == branch_id)
|
||||
)
|
||||
branch = result.scalar_one_or_none()
|
||||
if not branch:
|
||||
raise HTTPException(status_code=404, detail="Branch not found")
|
||||
|
||||
# Build cross-branch context
|
||||
sibling_ctx = await manager.build_cross_branch_context(branch_id)
|
||||
|
||||
# Build prompt
|
||||
builder = BranchAwarePromptBuilder()
|
||||
session_context = f"Problem: {session.problem_summary or 'Unknown'}. Domain: {session.problem_domain or 'Unknown'}."
|
||||
prompt_args = builder.build(
|
||||
branch_messages=branch.conversation_messages,
|
||||
sibling_summaries=sibling_ctx,
|
||||
session_context=session_context,
|
||||
attachments=[],
|
||||
new_message=body.message,
|
||||
revival_context=branch.evidence_description if branch.status == "revived" else None,
|
||||
)
|
||||
|
||||
# Call AI
|
||||
ai_content, input_tokens, output_tokens = await _call_ai(**prompt_args)
|
||||
|
||||
# Update branch conversation
|
||||
msgs = list(branch.conversation_messages or [])
|
||||
msgs.append({"role": "user", "content": body.message})
|
||||
msgs.append({"role": "assistant", "content": ai_content})
|
||||
branch.conversation_messages = msgs
|
||||
|
||||
# Update session token counts
|
||||
session.total_input_tokens += input_tokens
|
||||
session.total_output_tokens += output_tokens
|
||||
|
||||
# Resume if paused
|
||||
if session.status == "paused":
|
||||
session.status = "active"
|
||||
|
||||
await db.commit()
|
||||
|
||||
return BranchMessageResponse(
|
||||
content=ai_content,
|
||||
branch_id=branch_id,
|
||||
)
|
||||
@@ -30,6 +30,7 @@ from app.api.endpoints import admin_gallery
|
||||
from app.api.endpoints import uploads
|
||||
from app.api.endpoints import script_builder
|
||||
from app.api.endpoints import beta_feedback
|
||||
from app.api.endpoints import session_branches
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -85,3 +86,4 @@ api_router.include_router(admin_gallery.router)
|
||||
api_router.include_router(uploads.router)
|
||||
api_router.include_router(script_builder.router)
|
||||
api_router.include_router(beta_feedback.router)
|
||||
api_router.include_router(session_branches.router)
|
||||
|
||||
118
backend/tests/test_session_branches_api.py
Normal file
118
backend/tests/test_session_branches_api.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""API endpoint tests for session branches."""
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models.ai_session import AISession
|
||||
from app.models.ai_session_step import AISessionStep
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_branches_empty(client: AsyncClient, test_user, auth_headers, test_db):
|
||||
"""GET /ai-sessions/{id}/branches returns empty for non-branching session."""
|
||||
session = AISession(
|
||||
user_id=test_user["user_data"]["id"],
|
||||
account_id=test_user["user_data"]["account_id"],
|
||||
session_type="guided",
|
||||
intake_type="free_text",
|
||||
intake_content={"text": "test"},
|
||||
status="active",
|
||||
confidence_tier="discovery",
|
||||
conversation_messages=[],
|
||||
)
|
||||
test_db.add(session)
|
||||
await test_db.commit()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/v1/ai-sessions/{session.id}/branches",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["branches"] == []
|
||||
assert data["active_branch_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_fork(client: AsyncClient, test_user, auth_headers, test_db):
|
||||
"""POST /ai-sessions/{id}/branches/fork creates branches."""
|
||||
session = AISession(
|
||||
user_id=test_user["user_data"]["id"],
|
||||
account_id=test_user["user_data"]["account_id"],
|
||||
session_type="guided",
|
||||
intake_type="free_text",
|
||||
intake_content={"text": "test"},
|
||||
status="active",
|
||||
confidence_tier="discovery",
|
||||
conversation_messages=[{"role": "user", "content": "help"}],
|
||||
)
|
||||
test_db.add(session)
|
||||
await test_db.flush()
|
||||
|
||||
step = AISessionStep(
|
||||
session_id=session.id, step_order=0, step_type="question",
|
||||
content={"text": "test"}, confidence_at_step=0.5,
|
||||
)
|
||||
test_db.add(step)
|
||||
await test_db.commit()
|
||||
|
||||
resp = await client.post(
|
||||
f"/api/v1/ai-sessions/{session.id}/branches/fork",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"fork_reason": "Two possible causes",
|
||||
"options": [
|
||||
{"label": "Network issue", "description": "Check connectivity"},
|
||||
{"label": "DNS problem", "description": "Check DNS"},
|
||||
],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert len(data["options"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_branch(client: AsyncClient, test_user, auth_headers, test_db):
|
||||
"""POST /ai-sessions/{id}/branches/{bid}/switch changes active branch."""
|
||||
session = AISession(
|
||||
user_id=test_user["user_data"]["id"],
|
||||
account_id=test_user["user_data"]["account_id"],
|
||||
session_type="guided",
|
||||
intake_type="free_text",
|
||||
intake_content={"text": "test"},
|
||||
status="active",
|
||||
confidence_tier="discovery",
|
||||
conversation_messages=[{"role": "user", "content": "help"}],
|
||||
)
|
||||
test_db.add(session)
|
||||
await test_db.flush()
|
||||
|
||||
step = AISessionStep(
|
||||
session_id=session.id, step_order=0, step_type="question",
|
||||
content={"text": "test"}, confidence_at_step=0.5,
|
||||
)
|
||||
test_db.add(step)
|
||||
await test_db.commit()
|
||||
|
||||
# Create fork first
|
||||
fork_resp = await client.post(
|
||||
f"/api/v1/ai-sessions/{session.id}/branches/fork",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"fork_reason": "test",
|
||||
"options": [
|
||||
{"label": "A", "description": "a"},
|
||||
{"label": "B", "description": "b"},
|
||||
],
|
||||
},
|
||||
)
|
||||
fork_data = fork_resp.json()
|
||||
branch_b_id = fork_data["options"][1]["branch_id"]
|
||||
|
||||
# Switch to branch B
|
||||
resp = await client.post(
|
||||
f"/api/v1/ai-sessions/{session.id}/branches/{branch_b_id}/switch",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["active_branch_id"] == branch_b_id
|
||||
Reference in New Issue
Block a user