refactor: migrate AI tree generator to provider abstraction
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
"""AI-powered tree generation service using Anthropic Claude API.
|
||||
"""AI-powered tree generation service.
|
||||
|
||||
Implements the 4-stage wizard flow:
|
||||
Stage 2 (scaffold): AI suggests 4-7 top-level branches
|
||||
Stage 3 (branch_detail): AI generates detailed nodes per branch
|
||||
Stage 4 (assemble): Pure assembly logic — zero AI calls
|
||||
|
||||
System prompts are static constants to enable Anthropic prompt caching.
|
||||
Uses the provider abstraction from ai_provider.py (supports Gemini + Anthropic).
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
@@ -13,8 +13,7 @@ import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import anthropic
|
||||
|
||||
from app.core.ai_provider import get_ai_provider
|
||||
from app.core.config import settings
|
||||
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
|
||||
|
||||
@@ -121,15 +120,6 @@ def _strip_markdown_fences(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def _get_client() -> anthropic.AsyncAnthropic:
|
||||
"""Get configured async Anthropic client."""
|
||||
if not settings.ANTHROPIC_API_KEY:
|
||||
raise RuntimeError("ANTHROPIC_API_KEY not configured")
|
||||
return anthropic.AsyncAnthropic(
|
||||
api_key=settings.ANTHROPIC_API_KEY,
|
||||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
|
||||
"""Estimate USD cost from token counts."""
|
||||
@@ -146,7 +136,7 @@ async def scaffold_branches(
|
||||
Returns (branches, input_tokens, output_tokens, estimated_cost).
|
||||
Raises ValueError on invalid response.
|
||||
"""
|
||||
client = _get_client()
|
||||
provider = get_ai_provider()
|
||||
|
||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||
name = wizard_state.get("name", "")
|
||||
@@ -161,16 +151,13 @@ async def scaffold_branches(
|
||||
if tags:
|
||||
user_message += f"Environment: {', '.join(tags)}\n"
|
||||
|
||||
response = await client.messages.create(
|
||||
model=settings.AI_MODEL,
|
||||
max_tokens=1024,
|
||||
system=SCAFFOLD_SYSTEM_PROMPT,
|
||||
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt=SCAFFOLD_SYSTEM_PROMPT,
|
||||
messages=[{"role": "user", "content": user_message}],
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
raw_text = _strip_markdown_fences(response.content[0].text)
|
||||
input_tokens = response.usage.input_tokens
|
||||
output_tokens = response.usage.output_tokens
|
||||
raw_text = _strip_markdown_fences(raw_text)
|
||||
cost = _estimate_cost(input_tokens, output_tokens)
|
||||
|
||||
try:
|
||||
@@ -196,7 +183,7 @@ async def generate_branch_detail(
|
||||
On validation failure, retries once with corrective prompt.
|
||||
Raises ValueError if both attempts fail.
|
||||
"""
|
||||
client = _get_client()
|
||||
provider = get_ai_provider()
|
||||
|
||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||
name = wizard_state.get("name", "")
|
||||
@@ -217,31 +204,22 @@ async def generate_branch_detail(
|
||||
total_output = 0
|
||||
|
||||
for attempt in range(3):
|
||||
response = await client.messages.create(
|
||||
model=settings.AI_MODEL,
|
||||
max_tokens=8192,
|
||||
system=BRANCH_DETAIL_SYSTEM_PROMPT,
|
||||
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT,
|
||||
messages=messages,
|
||||
max_tokens=8192,
|
||||
)
|
||||
|
||||
total_input += response.usage.input_tokens
|
||||
total_output += response.usage.output_tokens
|
||||
total_input += input_tokens
|
||||
total_output += output_tokens
|
||||
logger.debug(
|
||||
"branch_detail attempt=%d stop_reason=%s content_blocks=%d output_tokens=%d",
|
||||
"branch_detail attempt=%d output_tokens=%d",
|
||||
attempt,
|
||||
response.stop_reason,
|
||||
len(response.content),
|
||||
response.usage.output_tokens,
|
||||
output_tokens,
|
||||
)
|
||||
if response.stop_reason == "max_tokens":
|
||||
logger.warning(
|
||||
"branch_detail attempt=%d hit max_tokens limit (%d output tokens) — response may be truncated",
|
||||
attempt,
|
||||
response.usage.output_tokens,
|
||||
)
|
||||
raw_text = _strip_markdown_fences(response.content[0].text) if response.content else ""
|
||||
raw_text = _strip_markdown_fences(raw_text) if raw_text else ""
|
||||
if not raw_text:
|
||||
logger.warning("branch_detail attempt=%d returned empty text, stop_reason=%s", attempt, response.stop_reason)
|
||||
logger.warning("branch_detail attempt=%d returned empty text", attempt)
|
||||
|
||||
try:
|
||||
branch_tree = json.loads(raw_text)
|
||||
|
||||
Reference in New Issue
Block a user