diff --git a/backend/app/api/endpoints/ai_chat.py b/backend/app/api/endpoints/ai_chat.py new file mode 100644 index 00000000..326290d9 --- /dev/null +++ b/backend/app/api/endpoints/ai_chat.py @@ -0,0 +1,428 @@ +"""AI Chat Builder endpoints. + +Conversational flow builder: + POST /ai/chat/sessions — Start session, get AI greeting + POST /ai/chat/sessions/{id}/messages — Send message, get AI response + GET /ai/chat/sessions/{id} — Get session state (for resume) + POST /ai/chat/sessions/{id}/generate — Generate final TreeStructure + POST /ai/chat/sessions/{id}/import — Create Tree from generated structure + DELETE /ai/chat/sessions/{id} — Abandon session +""" +import logging +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.rate_limit import limiter +from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin +from app.core.config import settings +from app.core.ai_chat_service import ( + start_chat_session, + send_message, + generate_final_tree, + get_chat_session, + MAX_MESSAGES_FREE, + MAX_MESSAGES_PAID, +) +from app.core.ai_quota_service import check_ai_quota, record_ai_usage, get_user_plan +from app.models.user import User +from app.models.tree import Tree +from app.schemas.ai_chat import ( + AIChatStartRequest, + AIChatStartResponse, + AIChatMessageRequest, + AIChatMessageResponse, + AIChatSessionResponse, + AIChatGenerateResponse, + AIChatImportRequest, + AIChatImportResponse, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ai/chat", tags=["ai-chat-builder"]) + + +def _require_ai_enabled() -> None: + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) + + +@router.post("/sessions", response_model=AIChatStartResponse, status_code=201) +@limiter.limit("10/minute") +async def create_session( + request: Request, + data: AIChatStartRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Start a new AI chat builder session.""" + _require_ai_enabled() + + allowed, quota_status = await check_ai_quota( + user_id=current_user.id, + account_id=current_user.account_id, + db=db, + billing_anchor=current_user.ai_billing_cycle_anchor_at, + is_super_admin=current_user.is_super_admin, + ) + if not allowed: + reset_key = ( + "daily_reset_at" + if quota_status.get("deny_reason") == "daily" + else "monthly_reset_at" + ) + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail={ + "message": f"AI build limit exceeded ({quota_status['deny_reason']})", + "reset_at": quota_status.get(reset_key), + "quota": quota_status, + }, + ) + + plan = await get_user_plan(current_user.account_id, db) + + try: + session, greeting = await start_chat_session( + flow_type=data.flow_type, + user_id=current_user.id, + account_id=current_user.account_id, + db=db, + ) + except Exception as e: + logger.exception("AI chat session start failed: %s", e) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"AI provider error ({type(e).__name__}). Please try again.", + ) + + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=session.id, + generation_type="chat_message", + tier=plan, + input_tokens=session.total_input_tokens, + output_tokens=session.total_output_tokens, + estimated_cost=( + session.total_input_tokens * 1.0 / 1_000_000 + + session.total_output_tokens * 5.0 / 1_000_000 + ), + succeeded=True, + counts_toward_quota=False, + error_code=None, + extra_data={"phase": "scoping"}, + db=db, + ) + + await db.commit() + + return AIChatStartResponse( + session_id=session.id, + greeting=greeting, + current_phase=session.current_phase, + ) + + +@router.post("/sessions/{session_id}/messages", response_model=AIChatMessageResponse) +@limiter.limit("10/minute") +async def post_message( + request: Request, + session_id: UUID, + data: AIChatMessageRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Send a user message and get AI response.""" + _require_ai_enabled() + + session = await get_chat_session(session_id, current_user.id, db) + + if session.status != "active": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Session is {session.status}, cannot send messages", + ) + + plan = await get_user_plan(current_user.account_id, db) + max_messages = MAX_MESSAGES_PAID if plan != "free" else MAX_MESSAGES_FREE + if current_user.is_super_admin: + max_messages = 999 + + if session.message_count >= max_messages: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail=f"Maximum messages per session reached ({max_messages}). Generate your tree or start a new session.", + ) + + prev_input = session.total_input_tokens + prev_output = session.total_output_tokens + + try: + ai_content, tree_update, new_phase, metadata = await send_message( + session, data.content, db + ) + except Exception as e: + logger.exception("AI chat message failed: %s", e) + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=session.id, + generation_type="chat_message", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data=None, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"AI provider error ({type(e).__name__}). Please try again.", + ) + + input_delta = session.total_input_tokens - prev_input + output_delta = session.total_output_tokens - prev_output + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=session.id, + generation_type="chat_message", + tier=plan, + input_tokens=input_delta, + output_tokens=output_delta, + estimated_cost=( + input_delta * 1.0 / 1_000_000 + + output_delta * 5.0 / 1_000_000 + ), + succeeded=True, + counts_toward_quota=False, + error_code=None, + extra_data={"phase": session.current_phase}, + db=db, + ) + + await db.commit() + + return AIChatMessageResponse( + content=ai_content, + current_phase=session.current_phase, + working_tree=session.working_tree, + tree_metadata=session.tree_metadata if session.tree_metadata else None, + ) + + +@router.get("/sessions/{session_id}", response_model=AIChatSessionResponse) +async def get_session( + session_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Get full session state for resume after page reload.""" + session = await get_chat_session(session_id, current_user.id, db) + + visible_history = [ + msg for msg in session.conversation_history + if not msg.get("hidden") + ] + + return AIChatSessionResponse( + session_id=session.id, + status=session.status, + current_phase=session.current_phase, + flow_type=session.flow_type, + conversation_history=visible_history, + working_tree=session.working_tree, + tree_metadata=session.tree_metadata if session.tree_metadata else None, + message_count=session.message_count, + generated_tree=session.working_tree if session.status == "completed" else None, + ) + + +@router.post("/sessions/{session_id}/generate", response_model=AIChatGenerateResponse) +@limiter.limit("10/minute") +async def generate_tree( + request: Request, + session_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Generate final TreeStructure JSON from conversation.""" + _require_ai_enabled() + + session = await get_chat_session(session_id, current_user.id, db) + + if session.status == "completed" and session.working_tree: + return AIChatGenerateResponse( + tree_structure=session.working_tree, + tree_metadata=session.tree_metadata, + status="completed", + ) + + if session.status != "active": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Session is {session.status}, cannot generate", + ) + + plan = await get_user_plan(current_user.account_id, db) + prev_input = session.total_input_tokens + prev_output = session.total_output_tokens + + try: + tree_structure, metadata = await generate_final_tree(session, db) + except ValueError as e: + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=session.id, + generation_type="chat_generate", + tier=plan, + input_tokens=session.total_input_tokens - prev_input, + output_tokens=session.total_output_tokens - prev_output, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code="invalid_output", + extra_data={"error": str(e)}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Tree generation failed: {e}", + ) + except Exception as e: + logger.exception("AI chat generate failed: %s", e) + input_delta = session.total_input_tokens - prev_input + output_delta = session.total_output_tokens - prev_output + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=session.id, + generation_type="chat_generate", + tier=plan, + input_tokens=input_delta, + output_tokens=output_delta, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"error": str(e)}, + db=db, + ) + await db.commit() + + error_name = type(e).__name__ + if "timeout" in error_name.lower() or "Timeout" in str(e): + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Tree generation timed out. Please try again.", + ) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"AI provider error ({error_name}). Please try again.", + ) + + input_delta = session.total_input_tokens - prev_input + output_delta = session.total_output_tokens - prev_output + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=session.id, + generation_type="chat_generate", + tier=plan, + input_tokens=input_delta, + output_tokens=output_delta, + estimated_cost=( + input_delta * 1.0 / 1_000_000 + + output_delta * 5.0 / 1_000_000 + ), + succeeded=True, + counts_toward_quota=True, + error_code=None, + extra_data=None, + db=db, + ) + + session.status = "completed" + await db.commit() + + return AIChatGenerateResponse( + tree_structure=tree_structure, + tree_metadata=metadata, + status="completed", + ) + + +@router.post("/sessions/{session_id}/import", response_model=AIChatImportResponse) +@limiter.limit("10/minute") +async def import_tree( + request: Request, + session_id: UUID, + data: AIChatImportRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Create a Tree record from the generated tree structure.""" + session = await get_chat_session(session_id, current_user.id, db) + + if session.status != "completed" or not session.working_tree: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Session must be completed with a generated tree before importing", + ) + + if session.generated_tree_id: + return AIChatImportResponse( + tree_id=session.generated_tree_id, + tree_type=session.flow_type, + ) + + metadata = session.tree_metadata or {} + tree = Tree( + name=data.name or metadata.get("name", "AI-Generated Flow"), + description=data.description or metadata.get("description", ""), + tree_type=session.flow_type, + tree_structure=session.working_tree, + author_id=current_user.id, + account_id=current_user.account_id, + category_id=data.category_id, + is_public=False, + ) + db.add(tree) + await db.flush() + + session.generated_tree_id = tree.id + await db.commit() + + return AIChatImportResponse( + tree_id=tree.id, + tree_type=session.flow_type, + ) + + +@router.delete("/sessions/{session_id}", status_code=204) +async def abandon_session( + session_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Abandon a chat session.""" + session = await get_chat_session(session_id, current_user.id, db) + session.status = "abandoned" + await db.commit() diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 27963a1f..41fdb0b2 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -7,6 +7,7 @@ from app.api.endpoints import maintenance_schedules from app.api.endpoints import feedback from app.api.endpoints import ai_builder from app.api.endpoints import ai_fix +from app.api.endpoints import ai_chat api_router = APIRouter() @@ -38,3 +39,4 @@ api_router.include_router(maintenance_schedules.router) api_router.include_router(feedback.router) api_router.include_router(ai_builder.router) api_router.include_router(ai_fix.router) +api_router.include_router(ai_chat.router) diff --git a/backend/app/core/ai_quota_service.py b/backend/app/core/ai_quota_service.py index 67eed9e5..49264caa 100644 --- a/backend/app/core/ai_quota_service.py +++ b/backend/app/core/ai_quota_service.py @@ -115,7 +115,7 @@ async def check_ai_quota( select(func.count(AIUsage.id)).where( AIUsage.user_id == user_id, AIUsage.succeeded == True, # noqa: E712 - AIUsage.generation_type.in_(["scaffold", "branch_detail"]), + AIUsage.generation_type.in_(["scaffold", "branch_detail", "chat_message", "chat_generate"]), AIUsage.created_at >= day_start, ) ) or 0