Tests cover session create, send message with tree update, get session, abandon, 404 on missing session, and 503 when AI disabled. Fixed: ai_usage.conversation_id has FK to ai_conversations, not ai_chat_sessions. Chat builder now passes conversation_id=None and tracks session reference in extra_data.chat_session_id. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
429 lines
14 KiB
Python
429 lines
14 KiB
Python
"""AI Chat Builder endpoints.
|
|
|
|
Conversational flow builder:
|
|
POST /ai/chat/sessions — Start session, get AI greeting
|
|
POST /ai/chat/sessions/{id}/messages — Send message, get AI response
|
|
GET /ai/chat/sessions/{id} — Get session state (for resume)
|
|
POST /ai/chat/sessions/{id}/generate — Generate final TreeStructure
|
|
POST /ai/chat/sessions/{id}/import — Create Tree from generated structure
|
|
DELETE /ai/chat/sessions/{id} — Abandon session
|
|
"""
|
|
import logging
|
|
from typing import Annotated
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.rate_limit import limiter
|
|
from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin
|
|
from app.core.config import settings
|
|
from app.core.ai_chat_service import (
|
|
start_chat_session,
|
|
send_message,
|
|
generate_final_tree,
|
|
get_chat_session,
|
|
MAX_MESSAGES_FREE,
|
|
MAX_MESSAGES_PAID,
|
|
)
|
|
from app.core.ai_quota_service import check_ai_quota, record_ai_usage, get_user_plan
|
|
from app.models.user import User
|
|
from app.models.tree import Tree
|
|
from app.schemas.ai_chat import (
|
|
AIChatStartRequest,
|
|
AIChatStartResponse,
|
|
AIChatMessageRequest,
|
|
AIChatMessageResponse,
|
|
AIChatSessionResponse,
|
|
AIChatGenerateResponse,
|
|
AIChatImportRequest,
|
|
AIChatImportResponse,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/ai/chat", tags=["ai-chat-builder"])
|
|
|
|
|
|
def _require_ai_enabled() -> None:
|
|
if not settings.ai_enabled:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
|
)
|
|
|
|
|
|
@router.post("/sessions", response_model=AIChatStartResponse, status_code=201)
|
|
@limiter.limit("10/minute")
|
|
async def create_session(
|
|
request: Request,
|
|
data: AIChatStartRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
_: None = Depends(require_engineer_or_admin),
|
|
):
|
|
"""Start a new AI chat builder session."""
|
|
_require_ai_enabled()
|
|
|
|
allowed, quota_status = await check_ai_quota(
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
db=db,
|
|
billing_anchor=current_user.ai_billing_cycle_anchor_at,
|
|
is_super_admin=current_user.is_super_admin,
|
|
)
|
|
if not allowed:
|
|
reset_key = (
|
|
"daily_reset_at"
|
|
if quota_status.get("deny_reason") == "daily"
|
|
else "monthly_reset_at"
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail={
|
|
"message": f"AI build limit exceeded ({quota_status['deny_reason']})",
|
|
"reset_at": quota_status.get(reset_key),
|
|
"quota": quota_status,
|
|
},
|
|
)
|
|
|
|
plan = await get_user_plan(current_user.account_id, db)
|
|
|
|
try:
|
|
session, greeting = await start_chat_session(
|
|
flow_type=data.flow_type,
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
db=db,
|
|
)
|
|
except Exception as e:
|
|
logger.exception("AI chat session start failed: %s", e)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail=f"AI provider error ({type(e).__name__}). Please try again.",
|
|
)
|
|
|
|
await record_ai_usage(
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
conversation_id=None,
|
|
generation_type="chat_message",
|
|
tier=plan,
|
|
input_tokens=session.total_input_tokens,
|
|
output_tokens=session.total_output_tokens,
|
|
estimated_cost=(
|
|
session.total_input_tokens * 1.0 / 1_000_000
|
|
+ session.total_output_tokens * 5.0 / 1_000_000
|
|
),
|
|
succeeded=True,
|
|
counts_toward_quota=False,
|
|
error_code=None,
|
|
extra_data={"phase": "scoping", "chat_session_id": str(session.id)},
|
|
db=db,
|
|
)
|
|
|
|
await db.commit()
|
|
|
|
return AIChatStartResponse(
|
|
session_id=session.id,
|
|
greeting=greeting,
|
|
current_phase=session.current_phase,
|
|
)
|
|
|
|
|
|
@router.post("/sessions/{session_id}/messages", response_model=AIChatMessageResponse)
|
|
@limiter.limit("10/minute")
|
|
async def post_message(
|
|
request: Request,
|
|
session_id: UUID,
|
|
data: AIChatMessageRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
_: None = Depends(require_engineer_or_admin),
|
|
):
|
|
"""Send a user message and get AI response."""
|
|
_require_ai_enabled()
|
|
|
|
session = await get_chat_session(session_id, current_user.id, db)
|
|
|
|
if session.status != "active":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Session is {session.status}, cannot send messages",
|
|
)
|
|
|
|
plan = await get_user_plan(current_user.account_id, db)
|
|
max_messages = MAX_MESSAGES_PAID if plan != "free" else MAX_MESSAGES_FREE
|
|
if current_user.is_super_admin:
|
|
max_messages = 999
|
|
|
|
if session.message_count >= max_messages:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail=f"Maximum messages per session reached ({max_messages}). Generate your tree or start a new session.",
|
|
)
|
|
|
|
prev_input = session.total_input_tokens
|
|
prev_output = session.total_output_tokens
|
|
|
|
try:
|
|
ai_content, tree_update, new_phase, metadata = await send_message(
|
|
session, data.content, db
|
|
)
|
|
except Exception as e:
|
|
logger.exception("AI chat message failed: %s", e)
|
|
await record_ai_usage(
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
conversation_id=None,
|
|
generation_type="chat_message",
|
|
tier=plan,
|
|
input_tokens=0,
|
|
output_tokens=0,
|
|
estimated_cost=0,
|
|
succeeded=False,
|
|
counts_toward_quota=False,
|
|
error_code=type(e).__name__,
|
|
extra_data={"chat_session_id": str(session.id)},
|
|
db=db,
|
|
)
|
|
await db.commit()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail=f"AI provider error ({type(e).__name__}). Please try again.",
|
|
)
|
|
|
|
input_delta = session.total_input_tokens - prev_input
|
|
output_delta = session.total_output_tokens - prev_output
|
|
await record_ai_usage(
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
conversation_id=None,
|
|
generation_type="chat_message",
|
|
tier=plan,
|
|
input_tokens=input_delta,
|
|
output_tokens=output_delta,
|
|
estimated_cost=(
|
|
input_delta * 1.0 / 1_000_000
|
|
+ output_delta * 5.0 / 1_000_000
|
|
),
|
|
succeeded=True,
|
|
counts_toward_quota=False,
|
|
error_code=None,
|
|
extra_data={"phase": session.current_phase, "chat_session_id": str(session.id)},
|
|
db=db,
|
|
)
|
|
|
|
await db.commit()
|
|
|
|
return AIChatMessageResponse(
|
|
content=ai_content,
|
|
current_phase=session.current_phase,
|
|
working_tree=session.working_tree,
|
|
tree_metadata=session.tree_metadata if session.tree_metadata else None,
|
|
)
|
|
|
|
|
|
@router.get("/sessions/{session_id}", response_model=AIChatSessionResponse)
|
|
async def get_session(
|
|
session_id: UUID,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
):
|
|
"""Get full session state for resume after page reload."""
|
|
session = await get_chat_session(session_id, current_user.id, db)
|
|
|
|
visible_history = [
|
|
msg for msg in session.conversation_history
|
|
if not msg.get("hidden")
|
|
]
|
|
|
|
return AIChatSessionResponse(
|
|
session_id=session.id,
|
|
status=session.status,
|
|
current_phase=session.current_phase,
|
|
flow_type=session.flow_type,
|
|
conversation_history=visible_history,
|
|
working_tree=session.working_tree,
|
|
tree_metadata=session.tree_metadata if session.tree_metadata else None,
|
|
message_count=session.message_count,
|
|
generated_tree=session.working_tree if session.status == "completed" else None,
|
|
)
|
|
|
|
|
|
@router.post("/sessions/{session_id}/generate", response_model=AIChatGenerateResponse)
|
|
@limiter.limit("10/minute")
|
|
async def generate_tree(
|
|
request: Request,
|
|
session_id: UUID,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
_: None = Depends(require_engineer_or_admin),
|
|
):
|
|
"""Generate final TreeStructure JSON from conversation."""
|
|
_require_ai_enabled()
|
|
|
|
session = await get_chat_session(session_id, current_user.id, db)
|
|
|
|
if session.status == "completed" and session.working_tree:
|
|
return AIChatGenerateResponse(
|
|
tree_structure=session.working_tree,
|
|
tree_metadata=session.tree_metadata,
|
|
status="completed",
|
|
)
|
|
|
|
if session.status != "active":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Session is {session.status}, cannot generate",
|
|
)
|
|
|
|
plan = await get_user_plan(current_user.account_id, db)
|
|
prev_input = session.total_input_tokens
|
|
prev_output = session.total_output_tokens
|
|
|
|
try:
|
|
tree_structure, metadata = await generate_final_tree(session, db)
|
|
except ValueError as e:
|
|
await record_ai_usage(
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
conversation_id=None,
|
|
generation_type="chat_generate",
|
|
tier=plan,
|
|
input_tokens=session.total_input_tokens - prev_input,
|
|
output_tokens=session.total_output_tokens - prev_output,
|
|
estimated_cost=0,
|
|
succeeded=False,
|
|
counts_toward_quota=False,
|
|
error_code="invalid_output",
|
|
extra_data={"error": str(e), "chat_session_id": str(session.id)},
|
|
db=db,
|
|
)
|
|
await db.commit()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail=f"Tree generation failed: {e}",
|
|
)
|
|
except Exception as e:
|
|
logger.exception("AI chat generate failed: %s", e)
|
|
input_delta = session.total_input_tokens - prev_input
|
|
output_delta = session.total_output_tokens - prev_output
|
|
await record_ai_usage(
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
conversation_id=None,
|
|
generation_type="chat_generate",
|
|
tier=plan,
|
|
input_tokens=input_delta,
|
|
output_tokens=output_delta,
|
|
estimated_cost=0,
|
|
succeeded=False,
|
|
counts_toward_quota=False,
|
|
error_code=type(e).__name__,
|
|
extra_data={"error": str(e), "chat_session_id": str(session.id)},
|
|
db=db,
|
|
)
|
|
await db.commit()
|
|
|
|
error_name = type(e).__name__
|
|
if "timeout" in error_name.lower() or "Timeout" in str(e):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
|
detail="Tree generation timed out. Please try again.",
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail=f"AI provider error ({error_name}). Please try again.",
|
|
)
|
|
|
|
input_delta = session.total_input_tokens - prev_input
|
|
output_delta = session.total_output_tokens - prev_output
|
|
await record_ai_usage(
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
conversation_id=None,
|
|
generation_type="chat_generate",
|
|
tier=plan,
|
|
input_tokens=input_delta,
|
|
output_tokens=output_delta,
|
|
estimated_cost=(
|
|
input_delta * 1.0 / 1_000_000
|
|
+ output_delta * 5.0 / 1_000_000
|
|
),
|
|
succeeded=True,
|
|
counts_toward_quota=True,
|
|
error_code=None,
|
|
extra_data={"chat_session_id": str(session.id)},
|
|
db=db,
|
|
)
|
|
|
|
session.status = "completed"
|
|
await db.commit()
|
|
|
|
return AIChatGenerateResponse(
|
|
tree_structure=tree_structure,
|
|
tree_metadata=metadata,
|
|
status="completed",
|
|
)
|
|
|
|
|
|
@router.post("/sessions/{session_id}/import", response_model=AIChatImportResponse)
|
|
@limiter.limit("10/minute")
|
|
async def import_tree(
|
|
request: Request,
|
|
session_id: UUID,
|
|
data: AIChatImportRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
_: None = Depends(require_engineer_or_admin),
|
|
):
|
|
"""Create a Tree record from the generated tree structure."""
|
|
session = await get_chat_session(session_id, current_user.id, db)
|
|
|
|
if session.status != "completed" or not session.working_tree:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Session must be completed with a generated tree before importing",
|
|
)
|
|
|
|
if session.generated_tree_id:
|
|
return AIChatImportResponse(
|
|
tree_id=session.generated_tree_id,
|
|
tree_type=session.flow_type,
|
|
)
|
|
|
|
metadata = session.tree_metadata or {}
|
|
tree = Tree(
|
|
name=data.name or metadata.get("name", "AI-Generated Flow"),
|
|
description=data.description or metadata.get("description", ""),
|
|
tree_type=session.flow_type,
|
|
tree_structure=session.working_tree,
|
|
author_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
category_id=data.category_id,
|
|
is_public=False,
|
|
)
|
|
db.add(tree)
|
|
await db.flush()
|
|
|
|
session.generated_tree_id = tree.id
|
|
await db.commit()
|
|
|
|
return AIChatImportResponse(
|
|
tree_id=tree.id,
|
|
tree_type=session.flow_type,
|
|
)
|
|
|
|
|
|
@router.delete("/sessions/{session_id}", status_code=204)
|
|
async def abandon_session(
|
|
session_id: UUID,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
):
|
|
"""Abandon a chat session."""
|
|
session = await get_chat_session(session_id, current_user.id, db)
|
|
session.status = "abandoned"
|
|
await db.commit()
|