Files
resolutionflow/backend/app/api/endpoints/session_branches.py
chihlasm 395f157578 feat: add branch API endpoints with integration tests
Six REST endpoints for branch lifecycle management (list, fork, update
status, switch, revive, branch-message) with BranchManager + BranchAwarePromptBuilder
integration. Registered session_branches router in api/router.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-24 08:41:17 +00:00

275 lines
9.2 KiB
Python

"""Branch management endpoints for conversational branching.
GET /ai-sessions/{id}/branches — List all branches (tree)
POST /ai-sessions/{id}/branches/fork — Create fork with N branches
PATCH /ai-sessions/{id}/branches/{bid} — Update branch status
POST /ai-sessions/{id}/branches/{bid}/switch — Switch active branch
POST /ai-sessions/{id}/branches/{bid}/revive — Revive dead-end branch
POST /ai-sessions/{id}/branches/{bid}/message — Send message on branch
"""
import logging
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_active_user, get_db
from app.models.user import User
from app.models.ai_session import AISession
from app.models.session_branch import SessionBranch
from app.services.branch_manager import BranchManager
from app.services.branch_aware_prompt_builder import BranchAwarePromptBuilder
from app.services.assistant_chat_service import _call_ai
from app.schemas.session_branch import (
BranchTreeResponse,
BranchResponse,
BranchUpdate,
ForkCreateRequest,
ForkPointResponse,
BranchSwitchResponse,
ReviveRequest,
BranchMessageRequest,
BranchMessageResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ai-sessions/{session_id}/branches", tags=["session-branches"])
async def _get_user_session(
session_id: UUID, user: User, db: AsyncSession
) -> AISession:
"""Fetch session owned by user, or raise 404."""
result = await db.execute(
select(AISession).where(
AISession.id == session_id,
AISession.user_id == user.id,
)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
return session
@router.get("", response_model=BranchTreeResponse)
async def list_branches(
session_id: UUID,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> BranchTreeResponse:
"""Get branch tree for a session."""
session = await _get_user_session(session_id, current_user, db)
manager = BranchManager(db)
branches = await manager.get_branch_tree(session_id)
branch_responses = []
for b in branches:
branch_responses.append(BranchResponse(
id=b.id,
session_id=b.session_id,
parent_branch_id=b.parent_branch_id,
fork_point_step_id=b.fork_point_step_id,
branch_order=b.branch_order,
label=b.label,
status=b.status,
status_reason=b.status_reason,
status_changed_at=b.status_changed_at,
context_summary=b.context_summary,
evidence_from_branch_id=b.evidence_from_branch_id,
evidence_description=b.evidence_description,
created_at=b.created_at,
updated_at=b.updated_at,
))
return BranchTreeResponse(
branches=branch_responses,
active_branch_id=session.active_branch_id,
)
@router.post("/fork", response_model=ForkPointResponse, status_code=status.HTTP_201_CREATED)
async def create_fork(
session_id: UUID,
body: ForkCreateRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> ForkPointResponse:
"""Create a fork point with N branches."""
session = await _get_user_session(session_id, current_user, db)
if session.status not in ("active", "paused"):
raise HTTPException(status_code=400, detail=f"Cannot fork a {session.status} session")
manager = BranchManager(db)
# Ensure branching is initialized
if not session.is_branching:
await manager.create_root_branch(session_id)
await db.refresh(session)
# Use the active branch as parent
parent_branch_id = session.active_branch_id
if not parent_branch_id:
raise HTTPException(status_code=400, detail="No active branch to fork from")
options = [{"label": o.label, "description": o.description} for o in body.options]
fork_point, branches = await manager.create_fork(
session_id=session_id,
parent_branch_id=parent_branch_id,
trigger_step_id=None,
fork_reason=body.fork_reason,
options=options,
)
await db.commit()
return ForkPointResponse.model_validate(fork_point)
@router.patch("/{branch_id}", response_model=BranchResponse)
async def update_branch_status(
session_id: UUID,
branch_id: UUID,
body: BranchUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> BranchResponse:
"""Update a branch's status."""
await _get_user_session(session_id, current_user, db)
manager = BranchManager(db)
try:
branch = await manager.mark_branch_status(
branch_id=branch_id,
status=body.status,
reason=body.status_reason,
user_id=current_user.id,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
await db.commit()
return BranchResponse.model_validate(branch)
@router.post("/{branch_id}/switch", response_model=BranchSwitchResponse)
async def switch_branch(
session_id: UUID,
branch_id: UUID,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> BranchSwitchResponse:
"""Switch the active branch."""
await _get_user_session(session_id, current_user, db)
manager = BranchManager(db)
try:
branch = await manager.switch_branch(session_id, branch_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
await db.commit()
return BranchSwitchResponse(
active_branch_id=branch.id,
branch=BranchResponse.model_validate(branch),
conversation_messages=branch.conversation_messages,
)
@router.post("/{branch_id}/revive", response_model=BranchResponse)
async def revive_branch(
session_id: UUID,
branch_id: UUID,
body: ReviveRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> BranchResponse:
"""Revive a dead-end branch with new evidence."""
await _get_user_session(session_id, current_user, db)
manager = BranchManager(db)
try:
branch = await manager.revive_branch(
branch_id=branch_id,
evidence_from_branch_id=body.evidence_from_branch_id,
evidence_description=body.evidence_description,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
await db.commit()
return BranchResponse.model_validate(branch)
@router.post("/{branch_id}/message", response_model=BranchMessageResponse)
async def send_branch_message(
session_id: UUID,
branch_id: UUID,
body: BranchMessageRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> BranchMessageResponse:
"""Send a message on a specific branch."""
session = await _get_user_session(session_id, current_user, db)
if session.status not in ("active", "paused"):
raise HTTPException(status_code=400, detail=f"Cannot message a {session.status} session")
manager = BranchManager(db)
# Switch to branch if not already active
if session.active_branch_id != branch_id:
await manager.switch_branch(session_id, branch_id)
await db.refresh(session)
# Get branch
result = await db.execute(
select(SessionBranch).where(SessionBranch.id == branch_id)
)
branch = result.scalar_one_or_none()
if not branch:
raise HTTPException(status_code=404, detail="Branch not found")
# Build cross-branch context
sibling_ctx = await manager.build_cross_branch_context(branch_id)
# Build prompt
builder = BranchAwarePromptBuilder()
session_context = f"Problem: {session.problem_summary or 'Unknown'}. Domain: {session.problem_domain or 'Unknown'}."
prompt_args = builder.build(
branch_messages=branch.conversation_messages,
sibling_summaries=sibling_ctx,
session_context=session_context,
attachments=[],
new_message=body.message,
revival_context=branch.evidence_description if branch.status == "revived" else None,
)
# Call AI
ai_content, input_tokens, output_tokens = await _call_ai(**prompt_args)
# Update branch conversation
msgs = list(branch.conversation_messages or [])
msgs.append({"role": "user", "content": body.message})
msgs.append({"role": "assistant", "content": ai_content})
branch.conversation_messages = msgs
# Update session token counts
session.total_input_tokens += input_tokens
session.total_output_tokens += output_tokens
# Resume if paused
if session.status == "paused":
session.status = "active"
await db.commit()
return BranchMessageResponse(
content=ai_content,
branch_id=branch_id,
)