From eb7ea7ddd9008cfea3694fca0abe07bdc1ac22a0 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:20:48 -0500 Subject: [PATCH] refactor: migrate AI tree generator to provider abstraction Co-Authored-By: Claude Opus 4.6 --- backend/app/api/endpoints/ai_builder.py | 87 +++++++++---------- backend/app/core/ai_tree_generator_service.py | 58 ++++--------- backend/tests/test_ai_endpoints.py | 37 ++++---- 3 files changed, 76 insertions(+), 106 deletions(-) diff --git a/backend/app/api/endpoints/ai_builder.py b/backend/app/api/endpoints/ai_builder.py index 5ec8d55a..dcb0a966 100644 --- a/backend/app/api/endpoints/ai_builder.py +++ b/backend/app/api/endpoints/ai_builder.py @@ -10,7 +10,6 @@ import logging from typing import Annotated -import anthropic from fastapi import APIRouter, Depends, HTTPException, Request, status from sqlalchemy.ext.asyncio import AsyncSession @@ -52,7 +51,7 @@ def _require_ai_enabled() -> None: if not settings.ai_enabled: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.", + detail="AI flow builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", ) @@ -174,27 +173,6 @@ async def scaffold( branches, input_tokens, output_tokens, cost = await scaffold_branches( conversation.wizard_state, ) - except anthropic.APIError as e: - await record_ai_usage( - user_id=current_user.id, - account_id=current_user.account_id, - conversation_id=conversation.id, - generation_type="scaffold", - tier=plan, - input_tokens=0, - output_tokens=0, - estimated_cost=0, - succeeded=False, - counts_toward_quota=False, - error_code=type(e).__name__, - extra_data={"error": str(e)}, - db=db, - ) - await db.commit() - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail="AI provider error. Please try again.", - ) except ValueError as e: await record_ai_usage( user_id=current_user.id, @@ -216,6 +194,27 @@ async def scaffold( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"AI returned invalid output: {e}", ) + except Exception as e: + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="scaffold", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"error": str(e)}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="AI provider error. Please try again.", + ) # Record successful usage await record_ai_usage( @@ -293,27 +292,6 @@ async def branch_detail( existing_branches, ) ) - except anthropic.APIError as e: - await record_ai_usage( - user_id=current_user.id, - account_id=current_user.account_id, - conversation_id=conversation.id, - generation_type="branch_detail", - tier=plan, - input_tokens=0, - output_tokens=0, - estimated_cost=0, - succeeded=False, - counts_toward_quota=False, - error_code=type(e).__name__, - extra_data={"error": str(e), "branch_name": data.branch_name}, - db=db, - ) - await db.commit() - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail="AI provider error. Please try again.", - ) except ValueError as e: await record_ai_usage( user_id=current_user.id, @@ -335,6 +313,27 @@ async def branch_detail( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"AI returned invalid output: {e}", ) + except Exception as e: + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="branch_detail", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"error": str(e), "branch_name": data.branch_name}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="AI provider error. Please try again.", + ) # Record successful usage await record_ai_usage( diff --git a/backend/app/core/ai_tree_generator_service.py b/backend/app/core/ai_tree_generator_service.py index 4d40e257..7a562d1c 100644 --- a/backend/app/core/ai_tree_generator_service.py +++ b/backend/app/core/ai_tree_generator_service.py @@ -1,11 +1,11 @@ -"""AI-powered tree generation service using Anthropic Claude API. +"""AI-powered tree generation service. Implements the 4-stage wizard flow: Stage 2 (scaffold): AI suggests 4-7 top-level branches Stage 3 (branch_detail): AI generates detailed nodes per branch Stage 4 (assemble): Pure assembly logic — zero AI calls -System prompts are static constants to enable Anthropic prompt caching. +Uses the provider abstraction from ai_provider.py (supports Gemini + Anthropic). """ import json import logging @@ -13,8 +13,7 @@ import re import uuid from typing import Any -import anthropic - +from app.core.ai_provider import get_ai_provider from app.core.config import settings from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats @@ -121,15 +120,6 @@ def _strip_markdown_fences(text: str) -> str: return text -def _get_client() -> anthropic.AsyncAnthropic: - """Get configured async Anthropic client.""" - if not settings.ANTHROPIC_API_KEY: - raise RuntimeError("ANTHROPIC_API_KEY not configured") - return anthropic.AsyncAnthropic( - api_key=settings.ANTHROPIC_API_KEY, - timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, - ) - def _estimate_cost(input_tokens: int, output_tokens: int) -> float: """Estimate USD cost from token counts.""" @@ -146,7 +136,7 @@ async def scaffold_branches( Returns (branches, input_tokens, output_tokens, estimated_cost). Raises ValueError on invalid response. """ - client = _get_client() + provider = get_ai_provider() flow_type = wizard_state.get("flow_type", "troubleshooting") name = wizard_state.get("name", "") @@ -161,16 +151,13 @@ async def scaffold_branches( if tags: user_message += f"Environment: {', '.join(tags)}\n" - response = await client.messages.create( - model=settings.AI_MODEL, - max_tokens=1024, - system=SCAFFOLD_SYSTEM_PROMPT, + raw_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=SCAFFOLD_SYSTEM_PROMPT, messages=[{"role": "user", "content": user_message}], + max_tokens=1024, ) - raw_text = _strip_markdown_fences(response.content[0].text) - input_tokens = response.usage.input_tokens - output_tokens = response.usage.output_tokens + raw_text = _strip_markdown_fences(raw_text) cost = _estimate_cost(input_tokens, output_tokens) try: @@ -196,7 +183,7 @@ async def generate_branch_detail( On validation failure, retries once with corrective prompt. Raises ValueError if both attempts fail. """ - client = _get_client() + provider = get_ai_provider() flow_type = wizard_state.get("flow_type", "troubleshooting") name = wizard_state.get("name", "") @@ -217,31 +204,22 @@ async def generate_branch_detail( total_output = 0 for attempt in range(3): - response = await client.messages.create( - model=settings.AI_MODEL, - max_tokens=8192, - system=BRANCH_DETAIL_SYSTEM_PROMPT, + raw_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT, messages=messages, + max_tokens=8192, ) - total_input += response.usage.input_tokens - total_output += response.usage.output_tokens + total_input += input_tokens + total_output += output_tokens logger.debug( - "branch_detail attempt=%d stop_reason=%s content_blocks=%d output_tokens=%d", + "branch_detail attempt=%d output_tokens=%d", attempt, - response.stop_reason, - len(response.content), - response.usage.output_tokens, + output_tokens, ) - if response.stop_reason == "max_tokens": - logger.warning( - "branch_detail attempt=%d hit max_tokens limit (%d output tokens) — response may be truncated", - attempt, - response.usage.output_tokens, - ) - raw_text = _strip_markdown_fences(response.content[0].text) if response.content else "" + raw_text = _strip_markdown_fences(raw_text) if raw_text else "" if not raw_text: - logger.warning("branch_detail attempt=%d returned empty text, stop_reason=%s", attempt, response.stop_reason) + logger.warning("branch_detail attempt=%d returned empty text", attempt) try: branch_tree = json.loads(raw_text) diff --git a/backend/tests/test_ai_endpoints.py b/backend/tests/test_ai_endpoints.py index 339448dd..1f91514e 100644 --- a/backend/tests/test_ai_endpoints.py +++ b/backend/tests/test_ai_endpoints.py @@ -1,6 +1,6 @@ """Integration tests for AI Flow Builder endpoints. -All Anthropic API calls are mocked — zero real API spend. +All AI provider calls are mocked — zero real API spend. """ import json from unittest.mock import AsyncMock, patch, MagicMock @@ -64,12 +64,11 @@ BRANCH_DETAIL_JSON = json.dumps({ }) -def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200): - """Create a mock Anthropic API response.""" - response = MagicMock() - response.content = [MagicMock(text=text)] - response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens) - return response +def _mock_ai_provider(text: str, input_tokens: int = 100, output_tokens: int = 200): + """Create a mock AI provider whose generate_json returns the given text and token counts.""" + provider = MagicMock() + provider.generate_json = AsyncMock(return_value=(text, input_tokens, output_tokens)) + return provider @pytest.fixture @@ -194,11 +193,9 @@ async def test_scaffold_success(client, auth_headers, enable_ai): ) conversation_id = start_resp.json()["conversation_id"] - # Mock Anthropic - mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=mock_response) - + # Mock AI provider + mock_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=mock_provider): response = await client.post( "/api/v1/ai/scaffold", json={"conversation_id": conversation_id}, @@ -241,9 +238,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai): ) conversation_id = start_resp.json()["conversation_id"] - scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock) + scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider): await client.post( "/api/v1/ai/scaffold", json={"conversation_id": conversation_id}, @@ -251,10 +247,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai): ) # Now generate branch detail - detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock) - + detail_provider = _mock_ai_provider(BRANCH_DETAIL_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=detail_provider): response = await client.post( "/api/v1/ai/branch-detail", json={ @@ -290,9 +284,8 @@ async def test_assemble_success(client, auth_headers, enable_ai): conversation_id = start_resp.json()["conversation_id"] # Scaffold - scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock) + scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider): await client.post( "/api/v1/ai/scaffold", json={"conversation_id": conversation_id},