"""AI Chat Builder endpoints. Conversational flow builder: POST /ai/chat/sessions — Start session, get AI greeting POST /ai/chat/sessions/{id}/messages — Send message, get AI response GET /ai/chat/sessions/{id} — Get session state (for resume) POST /ai/chat/sessions/{id}/generate — Generate final TreeStructure POST /ai/chat/sessions/{id}/import — Create Tree from generated structure DELETE /ai/chat/sessions/{id} — Abandon session """ import logging from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Request, status from sqlalchemy.ext.asyncio import AsyncSession from app.core.rate_limit import limiter from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin from app.core.config import settings from app.core.ai_chat_service import ( start_chat_session, send_message, generate_final_tree, get_chat_session, MAX_MESSAGES_FREE, MAX_MESSAGES_PAID, ) from app.core.ai_quota_service import check_ai_quota, record_ai_usage, get_user_plan from app.models.user import User from app.models.tree import Tree from app.schemas.ai_chat import ( AIChatStartRequest, AIChatStartResponse, AIChatMessageRequest, AIChatMessageResponse, AIChatSessionResponse, AIChatGenerateResponse, AIChatImportRequest, AIChatImportResponse, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/ai/chat", tags=["ai-chat-builder"]) def _require_ai_enabled() -> None: if not settings.ai_enabled: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", ) @router.post("/sessions", response_model=AIChatStartResponse, status_code=201) @limiter.limit("10/minute") async def create_session( request: Request, data: AIChatStartRequest, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], _: None = Depends(require_engineer_or_admin), ): """Start a new AI chat builder session.""" _require_ai_enabled() allowed, quota_status = await check_ai_quota( user_id=current_user.id, account_id=current_user.account_id, db=db, billing_anchor=current_user.ai_billing_cycle_anchor_at, is_super_admin=current_user.is_super_admin, ) if not allowed: reset_key = ( "daily_reset_at" if quota_status.get("deny_reason") == "daily" else "monthly_reset_at" ) raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail={ "message": f"AI build limit exceeded ({quota_status['deny_reason']})", "reset_at": quota_status.get(reset_key), "quota": quota_status, }, ) plan = await get_user_plan(current_user.account_id, db) try: session, greeting = await start_chat_session( flow_type=data.flow_type, user_id=current_user.id, account_id=current_user.account_id, db=db, tree_id=data.tree_id, ) except Exception as e: logger.exception("AI chat session start failed: %s", e) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"AI provider error ({type(e).__name__}). Please try again.", ) await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, conversation_id=None, generation_type="chat_message", tier=plan, input_tokens=session.total_input_tokens, output_tokens=session.total_output_tokens, estimated_cost=( session.total_input_tokens * 1.0 / 1_000_000 + session.total_output_tokens * 5.0 / 1_000_000 ), succeeded=True, counts_toward_quota=False, error_code=None, extra_data={"phase": "scoping", "chat_session_id": str(session.id)}, db=db, ) await db.commit() return AIChatStartResponse( session_id=session.id, greeting=greeting, current_phase=session.current_phase, ) @router.post("/sessions/{session_id}/messages", response_model=AIChatMessageResponse) @limiter.limit("10/minute") async def post_message( request: Request, session_id: UUID, data: AIChatMessageRequest, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], _: None = Depends(require_engineer_or_admin), ): """Send a user message and get AI response.""" _require_ai_enabled() session = await get_chat_session(session_id, current_user.id, db) if session.status != "active": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Session is {session.status}, cannot send messages", ) plan = await get_user_plan(current_user.account_id, db) max_messages = MAX_MESSAGES_PAID if plan != "free" else MAX_MESSAGES_FREE if current_user.is_super_admin: max_messages = 999 if session.message_count >= max_messages: raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=f"Maximum messages per session reached ({max_messages}). Generate your tree or start a new session.", ) prev_input = session.total_input_tokens prev_output = session.total_output_tokens try: ai_content, tree_update, new_phase, metadata = await send_message( session, data.content, db, action_type=data.action_type or "open_chat", focal_node_id=data.focal_node_id, flow_context=data.flow_context, ) except Exception as e: logger.exception("AI chat message failed: %s", e) await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, conversation_id=None, generation_type="chat_message", tier=plan, input_tokens=0, output_tokens=0, estimated_cost=0, succeeded=False, counts_toward_quota=False, error_code=type(e).__name__, extra_data={"chat_session_id": str(session.id)}, db=db, ) await db.commit() raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"AI provider error ({type(e).__name__}). Please try again.", ) input_delta = session.total_input_tokens - prev_input output_delta = session.total_output_tokens - prev_output await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, conversation_id=None, generation_type="chat_message", tier=plan, input_tokens=input_delta, output_tokens=output_delta, estimated_cost=( input_delta * 1.0 / 1_000_000 + output_delta * 5.0 / 1_000_000 ), succeeded=True, counts_toward_quota=False, error_code=None, extra_data={"phase": session.current_phase, "chat_session_id": str(session.id)}, db=db, ) await db.commit() return AIChatMessageResponse( content=ai_content, current_phase=session.current_phase, working_tree=session.working_tree, tree_metadata=session.tree_metadata if session.tree_metadata else None, ) @router.get("/sessions/{session_id}", response_model=AIChatSessionResponse) async def get_session( session_id: UUID, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ): """Get full session state for resume after page reload.""" session = await get_chat_session(session_id, current_user.id, db) visible_history = [ msg for msg in session.conversation_history if not msg.get("hidden") ] return AIChatSessionResponse( session_id=session.id, status=session.status, current_phase=session.current_phase, flow_type=session.flow_type, conversation_history=visible_history, working_tree=session.working_tree, tree_metadata=session.tree_metadata if session.tree_metadata else None, message_count=session.message_count, generated_tree=session.working_tree if session.status == "completed" else None, ) @router.post("/sessions/{session_id}/generate", response_model=AIChatGenerateResponse) @limiter.limit("10/minute") async def generate_tree( request: Request, session_id: UUID, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], _: None = Depends(require_engineer_or_admin), ): """Generate final TreeStructure JSON from conversation.""" _require_ai_enabled() session = await get_chat_session(session_id, current_user.id, db) if session.status == "completed" and session.working_tree: return AIChatGenerateResponse( tree_structure=session.working_tree, tree_metadata=session.tree_metadata, status="completed", ) if session.status != "active": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Session is {session.status}, cannot generate", ) plan = await get_user_plan(current_user.account_id, db) prev_input = session.total_input_tokens prev_output = session.total_output_tokens try: tree_structure, metadata = await generate_final_tree(session, db) except ValueError as e: await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, conversation_id=None, generation_type="chat_generate", tier=plan, input_tokens=session.total_input_tokens - prev_input, output_tokens=session.total_output_tokens - prev_output, estimated_cost=0, succeeded=False, counts_toward_quota=False, error_code="invalid_output", extra_data={"error": str(e), "chat_session_id": str(session.id)}, db=db, ) await db.commit() raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Tree generation failed: {e}", ) except Exception as e: logger.exception("AI chat generate failed: %s", e) input_delta = session.total_input_tokens - prev_input output_delta = session.total_output_tokens - prev_output await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, conversation_id=None, generation_type="chat_generate", tier=plan, input_tokens=input_delta, output_tokens=output_delta, estimated_cost=0, succeeded=False, counts_toward_quota=False, error_code=type(e).__name__, extra_data={"error": str(e), "chat_session_id": str(session.id)}, db=db, ) await db.commit() error_name = type(e).__name__ if "timeout" in error_name.lower() or "Timeout" in str(e): raise HTTPException( status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Tree generation timed out. Please try again.", ) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"AI provider error ({error_name}). Please try again.", ) input_delta = session.total_input_tokens - prev_input output_delta = session.total_output_tokens - prev_output await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, conversation_id=None, generation_type="chat_generate", tier=plan, input_tokens=input_delta, output_tokens=output_delta, estimated_cost=( input_delta * 1.0 / 1_000_000 + output_delta * 5.0 / 1_000_000 ), succeeded=True, counts_toward_quota=True, error_code=None, extra_data={"chat_session_id": str(session.id)}, db=db, ) session.status = "completed" await db.commit() return AIChatGenerateResponse( tree_structure=tree_structure, tree_metadata=metadata, status="completed", ) @router.post("/sessions/{session_id}/import", response_model=AIChatImportResponse) @limiter.limit("10/minute") async def import_tree( request: Request, session_id: UUID, data: AIChatImportRequest, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], _: None = Depends(require_engineer_or_admin), ): """Create a Tree record from the generated tree structure.""" session = await get_chat_session(session_id, current_user.id, db) if session.status != "completed" or not session.working_tree: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Session must be completed with a generated tree before importing", ) # Always create a new Tree record (no duplicate check — user may # want multiple copies or re-import after edits) metadata = session.tree_metadata or {} # Extract intake form from metadata if present (procedural flows) intake_form = None if isinstance(metadata.get("intake_form"), list): intake_form = metadata.pop("intake_form") tree = Tree( name=data.name or metadata.get("name", "AI-Generated Flow"), description=data.description or metadata.get("description", ""), tree_type=session.flow_type, tree_structure=session.working_tree, intake_form=intake_form, author_id=current_user.id, account_id=current_user.account_id, category_id=data.category_id, is_public=False, ) db.add(tree) await db.flush() session.generated_tree_id = tree.id await db.commit() return AIChatImportResponse( tree_id=tree.id, tree_type=session.flow_type, ) @router.delete("/sessions/{session_id}", status_code=204) @limiter.limit("10/minute") async def abandon_session( request: Request, session_id: UUID, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ): """Abandon a chat session.""" session = await get_chat_session(session_id, current_user.id, db) session.status = "abandoned" await db.commit()