refactor: migrate AI tree generator to provider abstraction
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,7 +10,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import anthropic
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -52,7 +51,7 @@ def _require_ai_enabled() -> None:
|
|||||||
if not settings.ai_enabled:
|
if not settings.ai_enabled:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.",
|
detail="AI flow builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -174,27 +173,6 @@ async def scaffold(
|
|||||||
branches, input_tokens, output_tokens, cost = await scaffold_branches(
|
branches, input_tokens, output_tokens, cost = await scaffold_branches(
|
||||||
conversation.wizard_state,
|
conversation.wizard_state,
|
||||||
)
|
)
|
||||||
except anthropic.APIError as e:
|
|
||||||
await record_ai_usage(
|
|
||||||
user_id=current_user.id,
|
|
||||||
account_id=current_user.account_id,
|
|
||||||
conversation_id=conversation.id,
|
|
||||||
generation_type="scaffold",
|
|
||||||
tier=plan,
|
|
||||||
input_tokens=0,
|
|
||||||
output_tokens=0,
|
|
||||||
estimated_cost=0,
|
|
||||||
succeeded=False,
|
|
||||||
counts_toward_quota=False,
|
|
||||||
error_code=type(e).__name__,
|
|
||||||
extra_data={"error": str(e)},
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
||||||
detail="AI provider error. Please try again.",
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -216,6 +194,27 @@ async def scaffold(
|
|||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
detail=f"AI returned invalid output: {e}",
|
detail=f"AI returned invalid output: {e}",
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await record_ai_usage(
|
||||||
|
user_id=current_user.id,
|
||||||
|
account_id=current_user.account_id,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
generation_type="scaffold",
|
||||||
|
tier=plan,
|
||||||
|
input_tokens=0,
|
||||||
|
output_tokens=0,
|
||||||
|
estimated_cost=0,
|
||||||
|
succeeded=False,
|
||||||
|
counts_toward_quota=False,
|
||||||
|
error_code=type(e).__name__,
|
||||||
|
extra_data={"error": str(e)},
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail="AI provider error. Please try again.",
|
||||||
|
)
|
||||||
|
|
||||||
# Record successful usage
|
# Record successful usage
|
||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
@@ -293,27 +292,6 @@ async def branch_detail(
|
|||||||
existing_branches,
|
existing_branches,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except anthropic.APIError as e:
|
|
||||||
await record_ai_usage(
|
|
||||||
user_id=current_user.id,
|
|
||||||
account_id=current_user.account_id,
|
|
||||||
conversation_id=conversation.id,
|
|
||||||
generation_type="branch_detail",
|
|
||||||
tier=plan,
|
|
||||||
input_tokens=0,
|
|
||||||
output_tokens=0,
|
|
||||||
estimated_cost=0,
|
|
||||||
succeeded=False,
|
|
||||||
counts_toward_quota=False,
|
|
||||||
error_code=type(e).__name__,
|
|
||||||
extra_data={"error": str(e), "branch_name": data.branch_name},
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
||||||
detail="AI provider error. Please try again.",
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -335,6 +313,27 @@ async def branch_detail(
|
|||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
detail=f"AI returned invalid output: {e}",
|
detail=f"AI returned invalid output: {e}",
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await record_ai_usage(
|
||||||
|
user_id=current_user.id,
|
||||||
|
account_id=current_user.account_id,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
generation_type="branch_detail",
|
||||||
|
tier=plan,
|
||||||
|
input_tokens=0,
|
||||||
|
output_tokens=0,
|
||||||
|
estimated_cost=0,
|
||||||
|
succeeded=False,
|
||||||
|
counts_toward_quota=False,
|
||||||
|
error_code=type(e).__name__,
|
||||||
|
extra_data={"error": str(e), "branch_name": data.branch_name},
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail="AI provider error. Please try again.",
|
||||||
|
)
|
||||||
|
|
||||||
# Record successful usage
|
# Record successful usage
|
||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""AI-powered tree generation service using Anthropic Claude API.
|
"""AI-powered tree generation service.
|
||||||
|
|
||||||
Implements the 4-stage wizard flow:
|
Implements the 4-stage wizard flow:
|
||||||
Stage 2 (scaffold): AI suggests 4-7 top-level branches
|
Stage 2 (scaffold): AI suggests 4-7 top-level branches
|
||||||
Stage 3 (branch_detail): AI generates detailed nodes per branch
|
Stage 3 (branch_detail): AI generates detailed nodes per branch
|
||||||
Stage 4 (assemble): Pure assembly logic — zero AI calls
|
Stage 4 (assemble): Pure assembly logic — zero AI calls
|
||||||
|
|
||||||
System prompts are static constants to enable Anthropic prompt caching.
|
Uses the provider abstraction from ai_provider.py (supports Gemini + Anthropic).
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -13,8 +13,7 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import anthropic
|
from app.core.ai_provider import get_ai_provider
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
|
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
|
||||||
|
|
||||||
@@ -121,15 +120,6 @@ def _strip_markdown_fences(text: str) -> str:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def _get_client() -> anthropic.AsyncAnthropic:
|
|
||||||
"""Get configured async Anthropic client."""
|
|
||||||
if not settings.ANTHROPIC_API_KEY:
|
|
||||||
raise RuntimeError("ANTHROPIC_API_KEY not configured")
|
|
||||||
return anthropic.AsyncAnthropic(
|
|
||||||
api_key=settings.ANTHROPIC_API_KEY,
|
|
||||||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
|
def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
|
||||||
"""Estimate USD cost from token counts."""
|
"""Estimate USD cost from token counts."""
|
||||||
@@ -146,7 +136,7 @@ async def scaffold_branches(
|
|||||||
Returns (branches, input_tokens, output_tokens, estimated_cost).
|
Returns (branches, input_tokens, output_tokens, estimated_cost).
|
||||||
Raises ValueError on invalid response.
|
Raises ValueError on invalid response.
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
provider = get_ai_provider()
|
||||||
|
|
||||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||||
name = wizard_state.get("name", "")
|
name = wizard_state.get("name", "")
|
||||||
@@ -161,16 +151,13 @@ async def scaffold_branches(
|
|||||||
if tags:
|
if tags:
|
||||||
user_message += f"Environment: {', '.join(tags)}\n"
|
user_message += f"Environment: {', '.join(tags)}\n"
|
||||||
|
|
||||||
response = await client.messages.create(
|
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||||
model=settings.AI_MODEL,
|
system_prompt=SCAFFOLD_SYSTEM_PROMPT,
|
||||||
max_tokens=1024,
|
|
||||||
system=SCAFFOLD_SYSTEM_PROMPT,
|
|
||||||
messages=[{"role": "user", "content": user_message}],
|
messages=[{"role": "user", "content": user_message}],
|
||||||
|
max_tokens=1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_text = _strip_markdown_fences(response.content[0].text)
|
raw_text = _strip_markdown_fences(raw_text)
|
||||||
input_tokens = response.usage.input_tokens
|
|
||||||
output_tokens = response.usage.output_tokens
|
|
||||||
cost = _estimate_cost(input_tokens, output_tokens)
|
cost = _estimate_cost(input_tokens, output_tokens)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -196,7 +183,7 @@ async def generate_branch_detail(
|
|||||||
On validation failure, retries once with corrective prompt.
|
On validation failure, retries once with corrective prompt.
|
||||||
Raises ValueError if both attempts fail.
|
Raises ValueError if both attempts fail.
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
provider = get_ai_provider()
|
||||||
|
|
||||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||||
name = wizard_state.get("name", "")
|
name = wizard_state.get("name", "")
|
||||||
@@ -217,31 +204,22 @@ async def generate_branch_detail(
|
|||||||
total_output = 0
|
total_output = 0
|
||||||
|
|
||||||
for attempt in range(3):
|
for attempt in range(3):
|
||||||
response = await client.messages.create(
|
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||||
model=settings.AI_MODEL,
|
system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT,
|
||||||
max_tokens=8192,
|
|
||||||
system=BRANCH_DETAIL_SYSTEM_PROMPT,
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
max_tokens=8192,
|
||||||
)
|
)
|
||||||
|
|
||||||
total_input += response.usage.input_tokens
|
total_input += input_tokens
|
||||||
total_output += response.usage.output_tokens
|
total_output += output_tokens
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"branch_detail attempt=%d stop_reason=%s content_blocks=%d output_tokens=%d",
|
"branch_detail attempt=%d output_tokens=%d",
|
||||||
attempt,
|
attempt,
|
||||||
response.stop_reason,
|
output_tokens,
|
||||||
len(response.content),
|
|
||||||
response.usage.output_tokens,
|
|
||||||
)
|
)
|
||||||
if response.stop_reason == "max_tokens":
|
raw_text = _strip_markdown_fences(raw_text) if raw_text else ""
|
||||||
logger.warning(
|
|
||||||
"branch_detail attempt=%d hit max_tokens limit (%d output tokens) — response may be truncated",
|
|
||||||
attempt,
|
|
||||||
response.usage.output_tokens,
|
|
||||||
)
|
|
||||||
raw_text = _strip_markdown_fences(response.content[0].text) if response.content else ""
|
|
||||||
if not raw_text:
|
if not raw_text:
|
||||||
logger.warning("branch_detail attempt=%d returned empty text, stop_reason=%s", attempt, response.stop_reason)
|
logger.warning("branch_detail attempt=%d returned empty text", attempt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
branch_tree = json.loads(raw_text)
|
branch_tree = json.loads(raw_text)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Integration tests for AI Flow Builder endpoints.
|
"""Integration tests for AI Flow Builder endpoints.
|
||||||
|
|
||||||
All Anthropic API calls are mocked — zero real API spend.
|
All AI provider calls are mocked — zero real API spend.
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
@@ -64,12 +64,11 @@ BRANCH_DETAIL_JSON = json.dumps({
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200):
|
def _mock_ai_provider(text: str, input_tokens: int = 100, output_tokens: int = 200):
|
||||||
"""Create a mock Anthropic API response."""
|
"""Create a mock AI provider whose generate_json returns the given text and token counts."""
|
||||||
response = MagicMock()
|
provider = MagicMock()
|
||||||
response.content = [MagicMock(text=text)]
|
provider.generate_json = AsyncMock(return_value=(text, input_tokens, output_tokens))
|
||||||
response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens)
|
return provider
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -194,11 +193,9 @@ async def test_scaffold_success(client, auth_headers, enable_ai):
|
|||||||
)
|
)
|
||||||
conversation_id = start_resp.json()["conversation_id"]
|
conversation_id = start_resp.json()["conversation_id"]
|
||||||
|
|
||||||
# Mock Anthropic
|
# Mock AI provider
|
||||||
mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
mock_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
||||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=mock_provider):
|
||||||
mock_client.return_value.messages.create = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/v1/ai/scaffold",
|
"/api/v1/ai/scaffold",
|
||||||
json={"conversation_id": conversation_id},
|
json={"conversation_id": conversation_id},
|
||||||
@@ -241,9 +238,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
|
|||||||
)
|
)
|
||||||
conversation_id = start_resp.json()["conversation_id"]
|
conversation_id = start_resp.json()["conversation_id"]
|
||||||
|
|
||||||
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
||||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
|
||||||
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
|
|
||||||
await client.post(
|
await client.post(
|
||||||
"/api/v1/ai/scaffold",
|
"/api/v1/ai/scaffold",
|
||||||
json={"conversation_id": conversation_id},
|
json={"conversation_id": conversation_id},
|
||||||
@@ -251,10 +247,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Now generate branch detail
|
# Now generate branch detail
|
||||||
detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON)
|
detail_provider = _mock_ai_provider(BRANCH_DETAIL_JSON)
|
||||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=detail_provider):
|
||||||
mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock)
|
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/v1/ai/branch-detail",
|
"/api/v1/ai/branch-detail",
|
||||||
json={
|
json={
|
||||||
@@ -290,9 +284,8 @@ async def test_assemble_success(client, auth_headers, enable_ai):
|
|||||||
conversation_id = start_resp.json()["conversation_id"]
|
conversation_id = start_resp.json()["conversation_id"]
|
||||||
|
|
||||||
# Scaffold
|
# Scaffold
|
||||||
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
||||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
|
||||||
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
|
|
||||||
await client.post(
|
await client.post(
|
||||||
"/api/v1/ai/scaffold",
|
"/api/v1/ai/scaffold",
|
||||||
json={"conversation_id": conversation_id},
|
json={"conversation_id": conversation_id},
|
||||||
|
|||||||
Reference in New Issue
Block a user