diff --git a/backend/app/api/endpoints/session_branches.py b/backend/app/api/endpoints/session_branches.py new file mode 100644 index 00000000..e946fefc --- /dev/null +++ b/backend/app/api/endpoints/session_branches.py @@ -0,0 +1,274 @@ +"""Branch management endpoints for conversational branching. + + GET /ai-sessions/{id}/branches — List all branches (tree) + POST /ai-sessions/{id}/branches/fork — Create fork with N branches + PATCH /ai-sessions/{id}/branches/{bid} — Update branch status + POST /ai-sessions/{id}/branches/{bid}/switch — Switch active branch + POST /ai-sessions/{id}/branches/{bid}/revive — Revive dead-end branch + POST /ai-sessions/{id}/branches/{bid}/message — Send message on branch +""" +import logging +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_active_user, get_db +from app.models.user import User +from app.models.ai_session import AISession +from app.models.session_branch import SessionBranch +from app.services.branch_manager import BranchManager +from app.services.branch_aware_prompt_builder import BranchAwarePromptBuilder +from app.services.assistant_chat_service import _call_ai +from app.schemas.session_branch import ( + BranchTreeResponse, + BranchResponse, + BranchUpdate, + ForkCreateRequest, + ForkPointResponse, + BranchSwitchResponse, + ReviveRequest, + BranchMessageRequest, + BranchMessageResponse, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ai-sessions/{session_id}/branches", tags=["session-branches"]) + + +async def _get_user_session( + session_id: UUID, user: User, db: AsyncSession +) -> AISession: + """Fetch session owned by user, or raise 404.""" + result = await db.execute( + select(AISession).where( + AISession.id == session_id, + AISession.user_id == user.id, + ) + ) + session = result.scalar_one_or_none() + if not session: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") + return session + + +@router.get("", response_model=BranchTreeResponse) +async def list_branches( + session_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> BranchTreeResponse: + """Get branch tree for a session.""" + session = await _get_user_session(session_id, current_user, db) + manager = BranchManager(db) + branches = await manager.get_branch_tree(session_id) + + branch_responses = [] + for b in branches: + branch_responses.append(BranchResponse( + id=b.id, + session_id=b.session_id, + parent_branch_id=b.parent_branch_id, + fork_point_step_id=b.fork_point_step_id, + branch_order=b.branch_order, + label=b.label, + status=b.status, + status_reason=b.status_reason, + status_changed_at=b.status_changed_at, + context_summary=b.context_summary, + evidence_from_branch_id=b.evidence_from_branch_id, + evidence_description=b.evidence_description, + created_at=b.created_at, + updated_at=b.updated_at, + )) + + return BranchTreeResponse( + branches=branch_responses, + active_branch_id=session.active_branch_id, + ) + + +@router.post("/fork", response_model=ForkPointResponse, status_code=status.HTTP_201_CREATED) +async def create_fork( + session_id: UUID, + body: ForkCreateRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> ForkPointResponse: + """Create a fork point with N branches.""" + session = await _get_user_session(session_id, current_user, db) + + if session.status not in ("active", "paused"): + raise HTTPException(status_code=400, detail=f"Cannot fork a {session.status} session") + + manager = BranchManager(db) + + # Ensure branching is initialized + if not session.is_branching: + await manager.create_root_branch(session_id) + await db.refresh(session) + + # Use the active branch as parent + parent_branch_id = session.active_branch_id + if not parent_branch_id: + raise HTTPException(status_code=400, detail="No active branch to fork from") + + options = [{"label": o.label, "description": o.description} for o in body.options] + + fork_point, branches = await manager.create_fork( + session_id=session_id, + parent_branch_id=parent_branch_id, + trigger_step_id=None, + fork_reason=body.fork_reason, + options=options, + ) + + await db.commit() + return ForkPointResponse.model_validate(fork_point) + + +@router.patch("/{branch_id}", response_model=BranchResponse) +async def update_branch_status( + session_id: UUID, + branch_id: UUID, + body: BranchUpdate, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> BranchResponse: + """Update a branch's status.""" + await _get_user_session(session_id, current_user, db) + manager = BranchManager(db) + + try: + branch = await manager.mark_branch_status( + branch_id=branch_id, + status=body.status, + reason=body.status_reason, + user_id=current_user.id, + ) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + await db.commit() + return BranchResponse.model_validate(branch) + + +@router.post("/{branch_id}/switch", response_model=BranchSwitchResponse) +async def switch_branch( + session_id: UUID, + branch_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> BranchSwitchResponse: + """Switch the active branch.""" + await _get_user_session(session_id, current_user, db) + manager = BranchManager(db) + + try: + branch = await manager.switch_branch(session_id, branch_id) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + await db.commit() + return BranchSwitchResponse( + active_branch_id=branch.id, + branch=BranchResponse.model_validate(branch), + conversation_messages=branch.conversation_messages, + ) + + +@router.post("/{branch_id}/revive", response_model=BranchResponse) +async def revive_branch( + session_id: UUID, + branch_id: UUID, + body: ReviveRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> BranchResponse: + """Revive a dead-end branch with new evidence.""" + await _get_user_session(session_id, current_user, db) + manager = BranchManager(db) + + try: + branch = await manager.revive_branch( + branch_id=branch_id, + evidence_from_branch_id=body.evidence_from_branch_id, + evidence_description=body.evidence_description, + ) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + await db.commit() + return BranchResponse.model_validate(branch) + + +@router.post("/{branch_id}/message", response_model=BranchMessageResponse) +async def send_branch_message( + session_id: UUID, + branch_id: UUID, + body: BranchMessageRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> BranchMessageResponse: + """Send a message on a specific branch.""" + session = await _get_user_session(session_id, current_user, db) + + if session.status not in ("active", "paused"): + raise HTTPException(status_code=400, detail=f"Cannot message a {session.status} session") + + manager = BranchManager(db) + + # Switch to branch if not already active + if session.active_branch_id != branch_id: + await manager.switch_branch(session_id, branch_id) + await db.refresh(session) + + # Get branch + result = await db.execute( + select(SessionBranch).where(SessionBranch.id == branch_id) + ) + branch = result.scalar_one_or_none() + if not branch: + raise HTTPException(status_code=404, detail="Branch not found") + + # Build cross-branch context + sibling_ctx = await manager.build_cross_branch_context(branch_id) + + # Build prompt + builder = BranchAwarePromptBuilder() + session_context = f"Problem: {session.problem_summary or 'Unknown'}. Domain: {session.problem_domain or 'Unknown'}." + prompt_args = builder.build( + branch_messages=branch.conversation_messages, + sibling_summaries=sibling_ctx, + session_context=session_context, + attachments=[], + new_message=body.message, + revival_context=branch.evidence_description if branch.status == "revived" else None, + ) + + # Call AI + ai_content, input_tokens, output_tokens = await _call_ai(**prompt_args) + + # Update branch conversation + msgs = list(branch.conversation_messages or []) + msgs.append({"role": "user", "content": body.message}) + msgs.append({"role": "assistant", "content": ai_content}) + branch.conversation_messages = msgs + + # Update session token counts + session.total_input_tokens += input_tokens + session.total_output_tokens += output_tokens + + # Resume if paused + if session.status == "paused": + session.status = "active" + + await db.commit() + + return BranchMessageResponse( + content=ai_content, + branch_id=branch_id, + ) diff --git a/backend/app/api/router.py b/backend/app/api/router.py index d98042b7..1ff7448c 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -30,6 +30,7 @@ from app.api.endpoints import admin_gallery from app.api.endpoints import uploads from app.api.endpoints import script_builder from app.api.endpoints import beta_feedback +from app.api.endpoints import session_branches api_router = APIRouter() @@ -85,3 +86,4 @@ api_router.include_router(admin_gallery.router) api_router.include_router(uploads.router) api_router.include_router(script_builder.router) api_router.include_router(beta_feedback.router) +api_router.include_router(session_branches.router) diff --git a/backend/tests/test_session_branches_api.py b/backend/tests/test_session_branches_api.py new file mode 100644 index 00000000..74920f05 --- /dev/null +++ b/backend/tests/test_session_branches_api.py @@ -0,0 +1,118 @@ +"""API endpoint tests for session branches.""" +import pytest +from httpx import AsyncClient + +from app.models.ai_session import AISession +from app.models.ai_session_step import AISessionStep + + +@pytest.mark.asyncio +async def test_list_branches_empty(client: AsyncClient, test_user, auth_headers, test_db): + """GET /ai-sessions/{id}/branches returns empty for non-branching session.""" + session = AISession( + user_id=test_user["user_data"]["id"], + account_id=test_user["user_data"]["account_id"], + session_type="guided", + intake_type="free_text", + intake_content={"text": "test"}, + status="active", + confidence_tier="discovery", + conversation_messages=[], + ) + test_db.add(session) + await test_db.commit() + + resp = await client.get( + f"/api/v1/ai-sessions/{session.id}/branches", + headers=auth_headers, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["branches"] == [] + assert data["active_branch_id"] is None + + +@pytest.mark.asyncio +async def test_create_fork(client: AsyncClient, test_user, auth_headers, test_db): + """POST /ai-sessions/{id}/branches/fork creates branches.""" + session = AISession( + user_id=test_user["user_data"]["id"], + account_id=test_user["user_data"]["account_id"], + session_type="guided", + intake_type="free_text", + intake_content={"text": "test"}, + status="active", + confidence_tier="discovery", + conversation_messages=[{"role": "user", "content": "help"}], + ) + test_db.add(session) + await test_db.flush() + + step = AISessionStep( + session_id=session.id, step_order=0, step_type="question", + content={"text": "test"}, confidence_at_step=0.5, + ) + test_db.add(step) + await test_db.commit() + + resp = await client.post( + f"/api/v1/ai-sessions/{session.id}/branches/fork", + headers=auth_headers, + json={ + "fork_reason": "Two possible causes", + "options": [ + {"label": "Network issue", "description": "Check connectivity"}, + {"label": "DNS problem", "description": "Check DNS"}, + ], + }, + ) + assert resp.status_code == 201 + data = resp.json() + assert len(data["options"]) == 2 + + +@pytest.mark.asyncio +async def test_switch_branch(client: AsyncClient, test_user, auth_headers, test_db): + """POST /ai-sessions/{id}/branches/{bid}/switch changes active branch.""" + session = AISession( + user_id=test_user["user_data"]["id"], + account_id=test_user["user_data"]["account_id"], + session_type="guided", + intake_type="free_text", + intake_content={"text": "test"}, + status="active", + confidence_tier="discovery", + conversation_messages=[{"role": "user", "content": "help"}], + ) + test_db.add(session) + await test_db.flush() + + step = AISessionStep( + session_id=session.id, step_order=0, step_type="question", + content={"text": "test"}, confidence_at_step=0.5, + ) + test_db.add(step) + await test_db.commit() + + # Create fork first + fork_resp = await client.post( + f"/api/v1/ai-sessions/{session.id}/branches/fork", + headers=auth_headers, + json={ + "fork_reason": "test", + "options": [ + {"label": "A", "description": "a"}, + {"label": "B", "description": "b"}, + ], + }, + ) + fork_data = fork_resp.json() + branch_b_id = fork_data["options"][1]["branch_id"] + + # Switch to branch B + resp = await client.post( + f"/api/v1/ai-sessions/{session.id}/branches/{branch_b_id}/switch", + headers=auth_headers, + ) + assert resp.status_code == 200 + assert resp.json()["active_branch_id"] == branch_b_id