"""Branch management endpoints for conversational branching. GET /ai-sessions/{id}/branches — List all branches (tree) POST /ai-sessions/{id}/branches/fork — Create fork with N branches PATCH /ai-sessions/{id}/branches/{bid} — Update branch status POST /ai-sessions/{id}/branches/{bid}/switch — Switch active branch POST /ai-sessions/{id}/branches/{bid}/revive — Revive dead-end branch POST /ai-sessions/{id}/branches/{bid}/message — Send message on branch """ import logging from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_active_user, get_db from app.models.user import User from app.models.ai_session import AISession from app.models.session_branch import SessionBranch from app.services.branch_manager import BranchManager from app.services.branch_aware_prompt_builder import BranchAwarePromptBuilder from app.services.assistant_chat_service import _call_ai from app.schemas.session_branch import ( BranchTreeResponse, BranchResponse, BranchUpdate, ForkCreateRequest, ForkPointResponse, BranchSwitchResponse, ReviveRequest, BranchMessageRequest, BranchMessageResponse, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/ai-sessions/{session_id}/branches", tags=["session-branches"]) async def _get_user_session( session_id: UUID, user: User, db: AsyncSession ) -> AISession: """Fetch session owned by user, or raise 404.""" result = await db.execute( select(AISession).where( AISession.id == session_id, AISession.user_id == user.id, ) ) session = result.scalar_one_or_none() if not session: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") return session @router.get("", response_model=BranchTreeResponse) async def list_branches( session_id: UUID, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ) -> BranchTreeResponse: """Get branch tree for a session.""" session = await _get_user_session(session_id, current_user, db) manager = BranchManager(db) branches = await manager.get_branch_tree(session_id) branch_responses = [] for b in branches: branch_responses.append(BranchResponse( id=b.id, session_id=b.session_id, parent_branch_id=b.parent_branch_id, fork_point_step_id=b.fork_point_step_id, branch_order=b.branch_order, label=b.label, status=b.status, status_reason=b.status_reason, status_changed_at=b.status_changed_at, context_summary=b.context_summary, evidence_from_branch_id=b.evidence_from_branch_id, evidence_description=b.evidence_description, created_at=b.created_at, updated_at=b.updated_at, )) return BranchTreeResponse( branches=branch_responses, active_branch_id=session.active_branch_id, ) @router.post("/fork", response_model=ForkPointResponse, status_code=status.HTTP_201_CREATED) async def create_fork( session_id: UUID, body: ForkCreateRequest, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ) -> ForkPointResponse: """Create a fork point with N branches.""" session = await _get_user_session(session_id, current_user, db) if session.status not in ("active", "paused"): raise HTTPException(status_code=400, detail=f"Cannot fork a {session.status} session") manager = BranchManager(db) # Ensure branching is initialized if not session.is_branching: await manager.create_root_branch(session_id) await db.refresh(session) # Use the active branch as parent parent_branch_id = session.active_branch_id if not parent_branch_id: raise HTTPException(status_code=400, detail="No active branch to fork from") options = [{"label": o.label, "description": o.description} for o in body.options] fork_point, branches = await manager.create_fork( session_id=session_id, parent_branch_id=parent_branch_id, trigger_step_id=None, fork_reason=body.fork_reason, options=options, ) await db.commit() return ForkPointResponse.model_validate(fork_point) @router.patch("/{branch_id}", response_model=BranchResponse) async def update_branch_status( session_id: UUID, branch_id: UUID, body: BranchUpdate, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ) -> BranchResponse: """Update a branch's status.""" await _get_user_session(session_id, current_user, db) manager = BranchManager(db) try: branch = await manager.mark_branch_status( branch_id=branch_id, status=body.status, reason=body.status_reason, user_id=current_user.id, ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) await db.commit() return BranchResponse.model_validate(branch) @router.post("/{branch_id}/switch", response_model=BranchSwitchResponse) async def switch_branch( session_id: UUID, branch_id: UUID, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ) -> BranchSwitchResponse: """Switch the active branch.""" await _get_user_session(session_id, current_user, db) manager = BranchManager(db) try: branch = await manager.switch_branch(session_id, branch_id) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) await db.commit() return BranchSwitchResponse( active_branch_id=branch.id, branch=BranchResponse.model_validate(branch), conversation_messages=branch.conversation_messages, ) @router.post("/{branch_id}/revive", response_model=BranchResponse) async def revive_branch( session_id: UUID, branch_id: UUID, body: ReviveRequest, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ) -> BranchResponse: """Revive a dead-end branch with new evidence.""" await _get_user_session(session_id, current_user, db) manager = BranchManager(db) try: branch = await manager.revive_branch( branch_id=branch_id, evidence_from_branch_id=body.evidence_from_branch_id, evidence_description=body.evidence_description, ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) await db.commit() return BranchResponse.model_validate(branch) @router.post("/{branch_id}/message", response_model=BranchMessageResponse) async def send_branch_message( session_id: UUID, branch_id: UUID, body: BranchMessageRequest, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ) -> BranchMessageResponse: """Send a message on a specific branch.""" session = await _get_user_session(session_id, current_user, db) if session.status not in ("active", "paused"): raise HTTPException(status_code=400, detail=f"Cannot message a {session.status} session") manager = BranchManager(db) # Switch to branch if not already active if session.active_branch_id != branch_id: await manager.switch_branch(session_id, branch_id) await db.refresh(session) # Get branch result = await db.execute( select(SessionBranch).where(SessionBranch.id == branch_id) ) branch = result.scalar_one_or_none() if not branch: raise HTTPException(status_code=404, detail="Branch not found") # Build cross-branch context sibling_ctx = await manager.build_cross_branch_context(branch_id) # Build prompt builder = BranchAwarePromptBuilder() session_context = f"Problem: {session.problem_summary or 'Unknown'}. Domain: {session.problem_domain or 'Unknown'}." prompt_args = builder.build( branch_messages=branch.conversation_messages, sibling_summaries=sibling_ctx, session_context=session_context, attachments=[], new_message=body.message, revival_context=branch.evidence_description if branch.status == "revived" else None, ) # Call AI ai_content, input_tokens, output_tokens = await _call_ai(**prompt_args) # Update branch conversation msgs = list(branch.conversation_messages or []) msgs.append({"role": "user", "content": body.message}) msgs.append({"role": "assistant", "content": ai_content}) branch.conversation_messages = msgs # Update session token counts session.total_input_tokens += input_tokens session.total_output_tokens += output_tokens # Resume if paused if session.status == "paused": session.status = "active" await db.commit() return BranchMessageResponse( content=ai_content, branch_id=branch_id, )