Files
resolutionflow/backend/app/core/ai_provider.py
2026-03-28 23:02:07 +00:00

271 lines
8.7 KiB
Python

"""
AI Provider abstraction layer.
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
backends for JSON generation used by the AI Flow Builder.
"""
import logging
from abc import ABC, abstractmethod
from app.core.config import settings
logger = logging.getLogger(__name__)
class AIProvider(ABC):
"""Abstract base class for AI providers."""
@abstractmethod
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Generate a JSON response from the AI model.
Args:
system_prompt: System-level instruction for the model.
messages: List of message dicts with "role" and "content" keys.
max_tokens: Maximum output tokens.
Returns:
Tuple of (response_text, input_tokens, output_tokens).
"""
...
@abstractmethod
async def generate_text(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Generate a text response from the AI model (no JSON constraint).
Args:
system_prompt: System-level instruction for the model.
messages: List of message dicts with "role" and "content" keys.
max_tokens: Maximum output tokens.
Returns:
Tuple of (response_text, input_tokens, output_tokens).
"""
...
class GeminiProvider(AIProvider):
"""Google Gemini provider using the google-genai SDK."""
def __init__(self, api_key: str, model: str) -> None:
self._api_key = api_key
self._model = model
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
from google import genai
from google.genai import types as genai_types
client = genai.Client(api_key=self._api_key)
# Convert messages to Gemini Content format
contents: list[genai_types.Content] = []
for msg in messages:
role = "model" if msg["role"] == "assistant" else "user"
contents.append(
genai_types.Content(
role=role,
parts=[genai_types.Part(text=msg["content"])],
)
)
config = genai_types.GenerateContentConfig(
system_instruction=system_prompt,
max_output_tokens=max_tokens,
response_mime_type="application/json",
)
response = await client.aio.models.generate_content(
model=self._model,
contents=contents,
config=config,
)
# Log finish reason to detect truncation
if response.candidates:
finish_reason = getattr(response.candidates[0], "finish_reason", None)
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
if str(finish_reason) == "MAX_TOKENS":
logger.warning(
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
max_tokens,
)
text = response.text or ""
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
output_tokens = (
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
)
return text, input_tokens, output_tokens
async def generate_text(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
from google import genai
from google.genai import types as genai_types
client = genai.Client(api_key=self._api_key)
contents: list[genai_types.Content] = []
for msg in messages:
role = "model" if msg["role"] == "assistant" else "user"
contents.append(
genai_types.Content(
role=role,
parts=[genai_types.Part(text=msg["content"])],
)
)
config = genai_types.GenerateContentConfig(
system_instruction=system_prompt,
max_output_tokens=max_tokens,
# No response_mime_type — allow free-form text
)
response = await client.aio.models.generate_content(
model=self._model,
contents=contents,
config=config,
)
if response.candidates:
finish_reason = getattr(response.candidates[0], "finish_reason", None)
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
if str(finish_reason) == "MAX_TOKENS":
logger.warning(
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
max_tokens,
)
text = response.text or ""
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
output_tokens = (
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
)
return text, input_tokens, output_tokens
# Singleton client cache — avoids creating new HTTP connections per call
_anthropic_clients: dict[str, "anthropic.AsyncAnthropic"] = {}
def _get_anthropic_client(api_key: str, timeout: int = 45) -> "anthropic.AsyncAnthropic":
"""Return a cached AsyncAnthropic client, creating one if needed."""
import anthropic
cache_key = f"{api_key[:8]}:{timeout}"
if cache_key not in _anthropic_clients:
_anthropic_clients[cache_key] = anthropic.AsyncAnthropic(
api_key=api_key,
timeout=timeout,
max_retries=1,
)
return _anthropic_clients[cache_key]
class AnthropicProvider(AIProvider):
"""Anthropic Claude provider using the anthropic SDK."""
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
self._api_key = api_key
self._model = model
self._timeout = timeout
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
client = _get_anthropic_client(self._api_key, self._timeout)
response = await client.messages.create(
model=self._model,
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
)
text = response.content[0].text
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
return text, input_tokens, output_tokens
async def generate_text(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
# Anthropic doesn't differentiate between JSON and text mode
return await self.generate_json(system_prompt, messages, max_tokens)
def get_ai_provider(model: str | None = None) -> AIProvider:
"""Factory that returns the configured AI provider.
Args:
model: Optional model override (Anthropic model ID). Only applied to
AnthropicProvider; Gemini always uses settings.AI_MODEL_GEMINI.
Selection logic:
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
3. Fallback: if preferred provider key missing, try the other one
4. If nothing configured -> raise RuntimeError
"""
provider = settings.AI_PROVIDER
if provider == "gemini":
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
# Fallback to Anthropic
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=model or settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
elif provider == "anthropic":
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=model or settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
# Fallback to Gemini
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
raise RuntimeError(
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
)