refactor: migrate AI tree generator to provider abstraction

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-02-26 17:20:48 -05:00
parent 55be033ecb
commit eb7ea7ddd9
3 changed files with 76 additions and 106 deletions

View File

@@ -10,7 +10,6 @@
import logging import logging
from typing import Annotated from typing import Annotated
import anthropic
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -52,7 +51,7 @@ def _require_ai_enabled() -> None:
if not settings.ai_enabled: if not settings.ai_enabled:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.", detail="AI flow builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
) )
@@ -174,27 +173,6 @@ async def scaffold(
branches, input_tokens, output_tokens, cost = await scaffold_branches( branches, input_tokens, output_tokens, cost = await scaffold_branches(
conversation.wizard_state, conversation.wizard_state,
) )
except anthropic.APIError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="scaffold",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e)},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="AI provider error. Please try again.",
)
except ValueError as e: except ValueError as e:
await record_ai_usage( await record_ai_usage(
user_id=current_user.id, user_id=current_user.id,
@@ -216,6 +194,27 @@ async def scaffold(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"AI returned invalid output: {e}", detail=f"AI returned invalid output: {e}",
) )
except Exception as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="scaffold",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e)},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="AI provider error. Please try again.",
)
# Record successful usage # Record successful usage
await record_ai_usage( await record_ai_usage(
@@ -293,27 +292,6 @@ async def branch_detail(
existing_branches, existing_branches,
) )
) )
except anthropic.APIError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="branch_detail",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e), "branch_name": data.branch_name},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="AI provider error. Please try again.",
)
except ValueError as e: except ValueError as e:
await record_ai_usage( await record_ai_usage(
user_id=current_user.id, user_id=current_user.id,
@@ -335,6 +313,27 @@ async def branch_detail(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"AI returned invalid output: {e}", detail=f"AI returned invalid output: {e}",
) )
except Exception as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="branch_detail",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e), "branch_name": data.branch_name},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="AI provider error. Please try again.",
)
# Record successful usage # Record successful usage
await record_ai_usage( await record_ai_usage(

View File

@@ -1,11 +1,11 @@
"""AI-powered tree generation service using Anthropic Claude API. """AI-powered tree generation service.
Implements the 4-stage wizard flow: Implements the 4-stage wizard flow:
Stage 2 (scaffold): AI suggests 4-7 top-level branches Stage 2 (scaffold): AI suggests 4-7 top-level branches
Stage 3 (branch_detail): AI generates detailed nodes per branch Stage 3 (branch_detail): AI generates detailed nodes per branch
Stage 4 (assemble): Pure assembly logic — zero AI calls Stage 4 (assemble): Pure assembly logic — zero AI calls
System prompts are static constants to enable Anthropic prompt caching. Uses the provider abstraction from ai_provider.py (supports Gemini + Anthropic).
""" """
import json import json
import logging import logging
@@ -13,8 +13,7 @@ import re
import uuid import uuid
from typing import Any from typing import Any
import anthropic from app.core.ai_provider import get_ai_provider
from app.core.config import settings from app.core.config import settings
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
@@ -121,15 +120,6 @@ def _strip_markdown_fences(text: str) -> str:
return text return text
def _get_client() -> anthropic.AsyncAnthropic:
"""Get configured async Anthropic client."""
if not settings.ANTHROPIC_API_KEY:
raise RuntimeError("ANTHROPIC_API_KEY not configured")
return anthropic.AsyncAnthropic(
api_key=settings.ANTHROPIC_API_KEY,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
def _estimate_cost(input_tokens: int, output_tokens: int) -> float: def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
"""Estimate USD cost from token counts.""" """Estimate USD cost from token counts."""
@@ -146,7 +136,7 @@ async def scaffold_branches(
Returns (branches, input_tokens, output_tokens, estimated_cost). Returns (branches, input_tokens, output_tokens, estimated_cost).
Raises ValueError on invalid response. Raises ValueError on invalid response.
""" """
client = _get_client() provider = get_ai_provider()
flow_type = wizard_state.get("flow_type", "troubleshooting") flow_type = wizard_state.get("flow_type", "troubleshooting")
name = wizard_state.get("name", "") name = wizard_state.get("name", "")
@@ -161,16 +151,13 @@ async def scaffold_branches(
if tags: if tags:
user_message += f"Environment: {', '.join(tags)}\n" user_message += f"Environment: {', '.join(tags)}\n"
response = await client.messages.create( raw_text, input_tokens, output_tokens = await provider.generate_json(
model=settings.AI_MODEL, system_prompt=SCAFFOLD_SYSTEM_PROMPT,
max_tokens=1024,
system=SCAFFOLD_SYSTEM_PROMPT,
messages=[{"role": "user", "content": user_message}], messages=[{"role": "user", "content": user_message}],
max_tokens=1024,
) )
raw_text = _strip_markdown_fences(response.content[0].text) raw_text = _strip_markdown_fences(raw_text)
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
cost = _estimate_cost(input_tokens, output_tokens) cost = _estimate_cost(input_tokens, output_tokens)
try: try:
@@ -196,7 +183,7 @@ async def generate_branch_detail(
On validation failure, retries once with corrective prompt. On validation failure, retries once with corrective prompt.
Raises ValueError if both attempts fail. Raises ValueError if both attempts fail.
""" """
client = _get_client() provider = get_ai_provider()
flow_type = wizard_state.get("flow_type", "troubleshooting") flow_type = wizard_state.get("flow_type", "troubleshooting")
name = wizard_state.get("name", "") name = wizard_state.get("name", "")
@@ -217,31 +204,22 @@ async def generate_branch_detail(
total_output = 0 total_output = 0
for attempt in range(3): for attempt in range(3):
response = await client.messages.create( raw_text, input_tokens, output_tokens = await provider.generate_json(
model=settings.AI_MODEL, system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT,
max_tokens=8192,
system=BRANCH_DETAIL_SYSTEM_PROMPT,
messages=messages, messages=messages,
max_tokens=8192,
) )
total_input += response.usage.input_tokens total_input += input_tokens
total_output += response.usage.output_tokens total_output += output_tokens
logger.debug( logger.debug(
"branch_detail attempt=%d stop_reason=%s content_blocks=%d output_tokens=%d", "branch_detail attempt=%d output_tokens=%d",
attempt, attempt,
response.stop_reason, output_tokens,
len(response.content),
response.usage.output_tokens,
) )
if response.stop_reason == "max_tokens": raw_text = _strip_markdown_fences(raw_text) if raw_text else ""
logger.warning(
"branch_detail attempt=%d hit max_tokens limit (%d output tokens) — response may be truncated",
attempt,
response.usage.output_tokens,
)
raw_text = _strip_markdown_fences(response.content[0].text) if response.content else ""
if not raw_text: if not raw_text:
logger.warning("branch_detail attempt=%d returned empty text, stop_reason=%s", attempt, response.stop_reason) logger.warning("branch_detail attempt=%d returned empty text", attempt)
try: try:
branch_tree = json.loads(raw_text) branch_tree = json.loads(raw_text)

View File

@@ -1,6 +1,6 @@
"""Integration tests for AI Flow Builder endpoints. """Integration tests for AI Flow Builder endpoints.
All Anthropic API calls are mocked — zero real API spend. All AI provider calls are mocked — zero real API spend.
""" """
import json import json
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch, MagicMock
@@ -64,12 +64,11 @@ BRANCH_DETAIL_JSON = json.dumps({
}) })
def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200): def _mock_ai_provider(text: str, input_tokens: int = 100, output_tokens: int = 200):
"""Create a mock Anthropic API response.""" """Create a mock AI provider whose generate_json returns the given text and token counts."""
response = MagicMock() provider = MagicMock()
response.content = [MagicMock(text=text)] provider.generate_json = AsyncMock(return_value=(text, input_tokens, output_tokens))
response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens) return provider
return response
@pytest.fixture @pytest.fixture
@@ -194,11 +193,9 @@ async def test_scaffold_success(client, auth_headers, enable_ai):
) )
conversation_id = start_resp.json()["conversation_id"] conversation_id = start_resp.json()["conversation_id"]
# Mock Anthropic # Mock AI provider
mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) mock_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client: with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=mock_provider):
mock_client.return_value.messages.create = AsyncMock(return_value=mock_response)
response = await client.post( response = await client.post(
"/api/v1/ai/scaffold", "/api/v1/ai/scaffold",
json={"conversation_id": conversation_id}, json={"conversation_id": conversation_id},
@@ -241,9 +238,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
) )
conversation_id = start_resp.json()["conversation_id"] conversation_id = start_resp.json()["conversation_id"]
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client: with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
await client.post( await client.post(
"/api/v1/ai/scaffold", "/api/v1/ai/scaffold",
json={"conversation_id": conversation_id}, json={"conversation_id": conversation_id},
@@ -251,10 +247,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
) )
# Now generate branch detail # Now generate branch detail
detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON) detail_provider = _mock_ai_provider(BRANCH_DETAIL_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client: with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=detail_provider):
mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock)
response = await client.post( response = await client.post(
"/api/v1/ai/branch-detail", "/api/v1/ai/branch-detail",
json={ json={
@@ -290,9 +284,8 @@ async def test_assemble_success(client, auth_headers, enable_ai):
conversation_id = start_resp.json()["conversation_id"] conversation_id = start_resp.json()["conversation_id"]
# Scaffold # Scaffold
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client: with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
await client.post( await client.post(
"/api/v1/ai/scaffold", "/api/v1/ai/scaffold",
json={"conversation_id": conversation_id}, json={"conversation_id": conversation_id},