Files
resolutionflow/backend/app/api/endpoints/ai_chat.py
chihlasm 41b7cd86b8 feat: add action_type and focal_node_id to AI chat message API
- Add VALID_ACTION_TYPES literal and action_type/focal_node_id fields to
  AIChatMessageRequest schema
- Add tree_id field to AIChatStartRequest schema for editor-embedded sessions
- Update send_message() signature with action_type and focal_node_id params
- Update start_chat_session() signature with tree_id param
- Pass new params through endpoints to service functions
- All new params have defaults so existing behavior is unchanged

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 00:40:36 -05:00

437 lines
14 KiB
Python

"""AI Chat Builder endpoints.
Conversational flow builder:
POST /ai/chat/sessions — Start session, get AI greeting
POST /ai/chat/sessions/{id}/messages — Send message, get AI response
GET /ai/chat/sessions/{id} — Get session state (for resume)
POST /ai/chat/sessions/{id}/generate — Generate final TreeStructure
POST /ai/chat/sessions/{id}/import — Create Tree from generated structure
DELETE /ai/chat/sessions/{id} — Abandon session
"""
import logging
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.rate_limit import limiter
from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin
from app.core.config import settings
from app.core.ai_chat_service import (
start_chat_session,
send_message,
generate_final_tree,
get_chat_session,
MAX_MESSAGES_FREE,
MAX_MESSAGES_PAID,
)
from app.core.ai_quota_service import check_ai_quota, record_ai_usage, get_user_plan
from app.models.user import User
from app.models.tree import Tree
from app.schemas.ai_chat import (
AIChatStartRequest,
AIChatStartResponse,
AIChatMessageRequest,
AIChatMessageResponse,
AIChatSessionResponse,
AIChatGenerateResponse,
AIChatImportRequest,
AIChatImportResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ai/chat", tags=["ai-chat-builder"])
def _require_ai_enabled() -> None:
if not settings.ai_enabled:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
)
@router.post("/sessions", response_model=AIChatStartResponse, status_code=201)
@limiter.limit("10/minute")
async def create_session(
request: Request,
data: AIChatStartRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Start a new AI chat builder session."""
_require_ai_enabled()
allowed, quota_status = await check_ai_quota(
user_id=current_user.id,
account_id=current_user.account_id,
db=db,
billing_anchor=current_user.ai_billing_cycle_anchor_at,
is_super_admin=current_user.is_super_admin,
)
if not allowed:
reset_key = (
"daily_reset_at"
if quota_status.get("deny_reason") == "daily"
else "monthly_reset_at"
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"message": f"AI build limit exceeded ({quota_status['deny_reason']})",
"reset_at": quota_status.get(reset_key),
"quota": quota_status,
},
)
plan = await get_user_plan(current_user.account_id, db)
try:
session, greeting = await start_chat_session(
flow_type=data.flow_type,
user_id=current_user.id,
account_id=current_user.account_id,
db=db,
tree_id=data.tree_id,
)
except Exception as e:
logger.exception("AI chat session start failed: %s", e)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"AI provider error ({type(e).__name__}). Please try again.",
)
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=None,
generation_type="chat_message",
tier=plan,
input_tokens=session.total_input_tokens,
output_tokens=session.total_output_tokens,
estimated_cost=(
session.total_input_tokens * 1.0 / 1_000_000
+ session.total_output_tokens * 5.0 / 1_000_000
),
succeeded=True,
counts_toward_quota=False,
error_code=None,
extra_data={"phase": "scoping", "chat_session_id": str(session.id)},
db=db,
)
await db.commit()
return AIChatStartResponse(
session_id=session.id,
greeting=greeting,
current_phase=session.current_phase,
)
@router.post("/sessions/{session_id}/messages", response_model=AIChatMessageResponse)
@limiter.limit("10/minute")
async def post_message(
request: Request,
session_id: UUID,
data: AIChatMessageRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Send a user message and get AI response."""
_require_ai_enabled()
session = await get_chat_session(session_id, current_user.id, db)
if session.status != "active":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Session is {session.status}, cannot send messages",
)
plan = await get_user_plan(current_user.account_id, db)
max_messages = MAX_MESSAGES_PAID if plan != "free" else MAX_MESSAGES_FREE
if current_user.is_super_admin:
max_messages = 999
if session.message_count >= max_messages:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Maximum messages per session reached ({max_messages}). Generate your tree or start a new session.",
)
prev_input = session.total_input_tokens
prev_output = session.total_output_tokens
try:
ai_content, tree_update, new_phase, metadata = await send_message(
session, data.content, db,
action_type=data.action_type or "open_chat",
focal_node_id=data.focal_node_id,
)
except Exception as e:
logger.exception("AI chat message failed: %s", e)
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=None,
generation_type="chat_message",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"chat_session_id": str(session.id)},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"AI provider error ({type(e).__name__}). Please try again.",
)
input_delta = session.total_input_tokens - prev_input
output_delta = session.total_output_tokens - prev_output
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=None,
generation_type="chat_message",
tier=plan,
input_tokens=input_delta,
output_tokens=output_delta,
estimated_cost=(
input_delta * 1.0 / 1_000_000
+ output_delta * 5.0 / 1_000_000
),
succeeded=True,
counts_toward_quota=False,
error_code=None,
extra_data={"phase": session.current_phase, "chat_session_id": str(session.id)},
db=db,
)
await db.commit()
return AIChatMessageResponse(
content=ai_content,
current_phase=session.current_phase,
working_tree=session.working_tree,
tree_metadata=session.tree_metadata if session.tree_metadata else None,
)
@router.get("/sessions/{session_id}", response_model=AIChatSessionResponse)
async def get_session(
session_id: UUID,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Get full session state for resume after page reload."""
session = await get_chat_session(session_id, current_user.id, db)
visible_history = [
msg for msg in session.conversation_history
if not msg.get("hidden")
]
return AIChatSessionResponse(
session_id=session.id,
status=session.status,
current_phase=session.current_phase,
flow_type=session.flow_type,
conversation_history=visible_history,
working_tree=session.working_tree,
tree_metadata=session.tree_metadata if session.tree_metadata else None,
message_count=session.message_count,
generated_tree=session.working_tree if session.status == "completed" else None,
)
@router.post("/sessions/{session_id}/generate", response_model=AIChatGenerateResponse)
@limiter.limit("10/minute")
async def generate_tree(
request: Request,
session_id: UUID,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Generate final TreeStructure JSON from conversation."""
_require_ai_enabled()
session = await get_chat_session(session_id, current_user.id, db)
if session.status == "completed" and session.working_tree:
return AIChatGenerateResponse(
tree_structure=session.working_tree,
tree_metadata=session.tree_metadata,
status="completed",
)
if session.status != "active":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Session is {session.status}, cannot generate",
)
plan = await get_user_plan(current_user.account_id, db)
prev_input = session.total_input_tokens
prev_output = session.total_output_tokens
try:
tree_structure, metadata = await generate_final_tree(session, db)
except ValueError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=None,
generation_type="chat_generate",
tier=plan,
input_tokens=session.total_input_tokens - prev_input,
output_tokens=session.total_output_tokens - prev_output,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code="invalid_output",
extra_data={"error": str(e), "chat_session_id": str(session.id)},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Tree generation failed: {e}",
)
except Exception as e:
logger.exception("AI chat generate failed: %s", e)
input_delta = session.total_input_tokens - prev_input
output_delta = session.total_output_tokens - prev_output
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=None,
generation_type="chat_generate",
tier=plan,
input_tokens=input_delta,
output_tokens=output_delta,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e), "chat_session_id": str(session.id)},
db=db,
)
await db.commit()
error_name = type(e).__name__
if "timeout" in error_name.lower() or "Timeout" in str(e):
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Tree generation timed out. Please try again.",
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"AI provider error ({error_name}). Please try again.",
)
input_delta = session.total_input_tokens - prev_input
output_delta = session.total_output_tokens - prev_output
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=None,
generation_type="chat_generate",
tier=plan,
input_tokens=input_delta,
output_tokens=output_delta,
estimated_cost=(
input_delta * 1.0 / 1_000_000
+ output_delta * 5.0 / 1_000_000
),
succeeded=True,
counts_toward_quota=True,
error_code=None,
extra_data={"chat_session_id": str(session.id)},
db=db,
)
session.status = "completed"
await db.commit()
return AIChatGenerateResponse(
tree_structure=tree_structure,
tree_metadata=metadata,
status="completed",
)
@router.post("/sessions/{session_id}/import", response_model=AIChatImportResponse)
@limiter.limit("10/minute")
async def import_tree(
request: Request,
session_id: UUID,
data: AIChatImportRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Create a Tree record from the generated tree structure."""
session = await get_chat_session(session_id, current_user.id, db)
if session.status != "completed" or not session.working_tree:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Session must be completed with a generated tree before importing",
)
# Always create a new Tree record (no duplicate check — user may
# want multiple copies or re-import after edits)
metadata = session.tree_metadata or {}
# Extract intake form from metadata if present (procedural flows)
intake_form = None
if isinstance(metadata.get("intake_form"), list):
intake_form = metadata.pop("intake_form")
tree = Tree(
name=data.name or metadata.get("name", "AI-Generated Flow"),
description=data.description or metadata.get("description", ""),
tree_type=session.flow_type,
tree_structure=session.working_tree,
intake_form=intake_form,
author_id=current_user.id,
account_id=current_user.account_id,
category_id=data.category_id,
is_public=False,
)
db.add(tree)
await db.flush()
session.generated_tree_id = tree.id
await db.commit()
return AIChatImportResponse(
tree_id=tree.id,
tree_type=session.flow_type,
)
@router.delete("/sessions/{session_id}", status_code=204)
@limiter.limit("10/minute")
async def abandon_session(
request: Request,
session_id: UUID,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Abandon a chat session."""
session = await get_chat_session(session_id, current_user.id, db)
session.status = "abandoned"
await db.commit()