Six REST endpoints for branch lifecycle management (list, fork, update status, switch, revive, branch-message) with BranchManager + BranchAwarePromptBuilder integration. Registered session_branches router in api/router.py. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
275 lines
9.2 KiB
Python
275 lines
9.2 KiB
Python
"""Branch management endpoints for conversational branching.
|
|
|
|
GET /ai-sessions/{id}/branches — List all branches (tree)
|
|
POST /ai-sessions/{id}/branches/fork — Create fork with N branches
|
|
PATCH /ai-sessions/{id}/branches/{bid} — Update branch status
|
|
POST /ai-sessions/{id}/branches/{bid}/switch — Switch active branch
|
|
POST /ai-sessions/{id}/branches/{bid}/revive — Revive dead-end branch
|
|
POST /ai-sessions/{id}/branches/{bid}/message — Send message on branch
|
|
"""
|
|
import logging
|
|
from typing import Annotated
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_active_user, get_db
|
|
from app.models.user import User
|
|
from app.models.ai_session import AISession
|
|
from app.models.session_branch import SessionBranch
|
|
from app.services.branch_manager import BranchManager
|
|
from app.services.branch_aware_prompt_builder import BranchAwarePromptBuilder
|
|
from app.services.assistant_chat_service import _call_ai
|
|
from app.schemas.session_branch import (
|
|
BranchTreeResponse,
|
|
BranchResponse,
|
|
BranchUpdate,
|
|
ForkCreateRequest,
|
|
ForkPointResponse,
|
|
BranchSwitchResponse,
|
|
ReviveRequest,
|
|
BranchMessageRequest,
|
|
BranchMessageResponse,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/ai-sessions/{session_id}/branches", tags=["session-branches"])
|
|
|
|
|
|
async def _get_user_session(
|
|
session_id: UUID, user: User, db: AsyncSession
|
|
) -> AISession:
|
|
"""Fetch session owned by user, or raise 404."""
|
|
result = await db.execute(
|
|
select(AISession).where(
|
|
AISession.id == session_id,
|
|
AISession.user_id == user.id,
|
|
)
|
|
)
|
|
session = result.scalar_one_or_none()
|
|
if not session:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
|
return session
|
|
|
|
|
|
@router.get("", response_model=BranchTreeResponse)
|
|
async def list_branches(
|
|
session_id: UUID,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> BranchTreeResponse:
|
|
"""Get branch tree for a session."""
|
|
session = await _get_user_session(session_id, current_user, db)
|
|
manager = BranchManager(db)
|
|
branches = await manager.get_branch_tree(session_id)
|
|
|
|
branch_responses = []
|
|
for b in branches:
|
|
branch_responses.append(BranchResponse(
|
|
id=b.id,
|
|
session_id=b.session_id,
|
|
parent_branch_id=b.parent_branch_id,
|
|
fork_point_step_id=b.fork_point_step_id,
|
|
branch_order=b.branch_order,
|
|
label=b.label,
|
|
status=b.status,
|
|
status_reason=b.status_reason,
|
|
status_changed_at=b.status_changed_at,
|
|
context_summary=b.context_summary,
|
|
evidence_from_branch_id=b.evidence_from_branch_id,
|
|
evidence_description=b.evidence_description,
|
|
created_at=b.created_at,
|
|
updated_at=b.updated_at,
|
|
))
|
|
|
|
return BranchTreeResponse(
|
|
branches=branch_responses,
|
|
active_branch_id=session.active_branch_id,
|
|
)
|
|
|
|
|
|
@router.post("/fork", response_model=ForkPointResponse, status_code=status.HTTP_201_CREATED)
|
|
async def create_fork(
|
|
session_id: UUID,
|
|
body: ForkCreateRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> ForkPointResponse:
|
|
"""Create a fork point with N branches."""
|
|
session = await _get_user_session(session_id, current_user, db)
|
|
|
|
if session.status not in ("active", "paused"):
|
|
raise HTTPException(status_code=400, detail=f"Cannot fork a {session.status} session")
|
|
|
|
manager = BranchManager(db)
|
|
|
|
# Ensure branching is initialized
|
|
if not session.is_branching:
|
|
await manager.create_root_branch(session_id)
|
|
await db.refresh(session)
|
|
|
|
# Use the active branch as parent
|
|
parent_branch_id = session.active_branch_id
|
|
if not parent_branch_id:
|
|
raise HTTPException(status_code=400, detail="No active branch to fork from")
|
|
|
|
options = [{"label": o.label, "description": o.description} for o in body.options]
|
|
|
|
fork_point, branches = await manager.create_fork(
|
|
session_id=session_id,
|
|
parent_branch_id=parent_branch_id,
|
|
trigger_step_id=None,
|
|
fork_reason=body.fork_reason,
|
|
options=options,
|
|
)
|
|
|
|
await db.commit()
|
|
return ForkPointResponse.model_validate(fork_point)
|
|
|
|
|
|
@router.patch("/{branch_id}", response_model=BranchResponse)
|
|
async def update_branch_status(
|
|
session_id: UUID,
|
|
branch_id: UUID,
|
|
body: BranchUpdate,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> BranchResponse:
|
|
"""Update a branch's status."""
|
|
await _get_user_session(session_id, current_user, db)
|
|
manager = BranchManager(db)
|
|
|
|
try:
|
|
branch = await manager.mark_branch_status(
|
|
branch_id=branch_id,
|
|
status=body.status,
|
|
reason=body.status_reason,
|
|
user_id=current_user.id,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
await db.commit()
|
|
return BranchResponse.model_validate(branch)
|
|
|
|
|
|
@router.post("/{branch_id}/switch", response_model=BranchSwitchResponse)
|
|
async def switch_branch(
|
|
session_id: UUID,
|
|
branch_id: UUID,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> BranchSwitchResponse:
|
|
"""Switch the active branch."""
|
|
await _get_user_session(session_id, current_user, db)
|
|
manager = BranchManager(db)
|
|
|
|
try:
|
|
branch = await manager.switch_branch(session_id, branch_id)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
await db.commit()
|
|
return BranchSwitchResponse(
|
|
active_branch_id=branch.id,
|
|
branch=BranchResponse.model_validate(branch),
|
|
conversation_messages=branch.conversation_messages,
|
|
)
|
|
|
|
|
|
@router.post("/{branch_id}/revive", response_model=BranchResponse)
|
|
async def revive_branch(
|
|
session_id: UUID,
|
|
branch_id: UUID,
|
|
body: ReviveRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> BranchResponse:
|
|
"""Revive a dead-end branch with new evidence."""
|
|
await _get_user_session(session_id, current_user, db)
|
|
manager = BranchManager(db)
|
|
|
|
try:
|
|
branch = await manager.revive_branch(
|
|
branch_id=branch_id,
|
|
evidence_from_branch_id=body.evidence_from_branch_id,
|
|
evidence_description=body.evidence_description,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
await db.commit()
|
|
return BranchResponse.model_validate(branch)
|
|
|
|
|
|
@router.post("/{branch_id}/message", response_model=BranchMessageResponse)
|
|
async def send_branch_message(
|
|
session_id: UUID,
|
|
branch_id: UUID,
|
|
body: BranchMessageRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> BranchMessageResponse:
|
|
"""Send a message on a specific branch."""
|
|
session = await _get_user_session(session_id, current_user, db)
|
|
|
|
if session.status not in ("active", "paused"):
|
|
raise HTTPException(status_code=400, detail=f"Cannot message a {session.status} session")
|
|
|
|
manager = BranchManager(db)
|
|
|
|
# Switch to branch if not already active
|
|
if session.active_branch_id != branch_id:
|
|
await manager.switch_branch(session_id, branch_id)
|
|
await db.refresh(session)
|
|
|
|
# Get branch
|
|
result = await db.execute(
|
|
select(SessionBranch).where(SessionBranch.id == branch_id)
|
|
)
|
|
branch = result.scalar_one_or_none()
|
|
if not branch:
|
|
raise HTTPException(status_code=404, detail="Branch not found")
|
|
|
|
# Build cross-branch context
|
|
sibling_ctx = await manager.build_cross_branch_context(branch_id)
|
|
|
|
# Build prompt
|
|
builder = BranchAwarePromptBuilder()
|
|
session_context = f"Problem: {session.problem_summary or 'Unknown'}. Domain: {session.problem_domain or 'Unknown'}."
|
|
prompt_args = builder.build(
|
|
branch_messages=branch.conversation_messages,
|
|
sibling_summaries=sibling_ctx,
|
|
session_context=session_context,
|
|
attachments=[],
|
|
new_message=body.message,
|
|
revival_context=branch.evidence_description if branch.status == "revived" else None,
|
|
)
|
|
|
|
# Call AI
|
|
ai_content, input_tokens, output_tokens = await _call_ai(**prompt_args)
|
|
|
|
# Update branch conversation
|
|
msgs = list(branch.conversation_messages or [])
|
|
msgs.append({"role": "user", "content": body.message})
|
|
msgs.append({"role": "assistant", "content": ai_content})
|
|
branch.conversation_messages = msgs
|
|
|
|
# Update session token counts
|
|
session.total_input_tokens += input_tokens
|
|
session.total_output_tokens += output_tokens
|
|
|
|
# Resume if paused
|
|
if session.status == "paused":
|
|
session.status = "active"
|
|
|
|
await db.commit()
|
|
|
|
return BranchMessageResponse(
|
|
content=ai_content,
|
|
branch_id=branch_id,
|
|
)
|