Files
resolutionflow/backend/app/core/ai_provider.py
2026-03-28 23:02:35 +00:00

309 lines
9.9 KiB
Python

"""
AI Provider abstraction layer.
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
backends for JSON generation used by the AI Flow Builder.
"""
import logging
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from app.core.config import settings
logger = logging.getLogger(__name__)
class AIProvider(ABC):
"""Abstract base class for AI providers."""
@abstractmethod
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Generate a JSON response from the AI model.
Args:
system_prompt: System-level instruction for the model.
messages: List of message dicts with "role" and "content" keys.
max_tokens: Maximum output tokens.
Returns:
Tuple of (response_text, input_tokens, output_tokens).
"""
...
@abstractmethod
async def generate_text(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Generate a text response from the AI model (no JSON constraint).
Args:
system_prompt: System-level instruction for the model.
messages: List of message dicts with "role" and "content" keys.
max_tokens: Maximum output tokens.
Returns:
Tuple of (response_text, input_tokens, output_tokens).
"""
...
async def generate_text_stream(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> "AsyncIterator[str]":
"""Stream a text response token by token.
Args:
system_prompt: System-level instruction for the model.
messages: List of message dicts with "role" and "content" keys.
max_tokens: Maximum output tokens.
Yields:
Text chunks as they are generated.
"""
raise NotImplementedError("Streaming not supported for this provider")
# Make this an async generator to satisfy type checker
yield "" # pragma: no cover
class GeminiProvider(AIProvider):
"""Google Gemini provider using the google-genai SDK."""
def __init__(self, api_key: str, model: str) -> None:
self._api_key = api_key
self._model = model
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
from google import genai
from google.genai import types as genai_types
client = genai.Client(api_key=self._api_key)
# Convert messages to Gemini Content format
contents: list[genai_types.Content] = []
for msg in messages:
role = "model" if msg["role"] == "assistant" else "user"
contents.append(
genai_types.Content(
role=role,
parts=[genai_types.Part(text=msg["content"])],
)
)
config = genai_types.GenerateContentConfig(
system_instruction=system_prompt,
max_output_tokens=max_tokens,
response_mime_type="application/json",
)
response = await client.aio.models.generate_content(
model=self._model,
contents=contents,
config=config,
)
# Log finish reason to detect truncation
if response.candidates:
finish_reason = getattr(response.candidates[0], "finish_reason", None)
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
if str(finish_reason) == "MAX_TOKENS":
logger.warning(
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
max_tokens,
)
text = response.text or ""
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
output_tokens = (
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
)
return text, input_tokens, output_tokens
async def generate_text(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
from google import genai
from google.genai import types as genai_types
client = genai.Client(api_key=self._api_key)
contents: list[genai_types.Content] = []
for msg in messages:
role = "model" if msg["role"] == "assistant" else "user"
contents.append(
genai_types.Content(
role=role,
parts=[genai_types.Part(text=msg["content"])],
)
)
config = genai_types.GenerateContentConfig(
system_instruction=system_prompt,
max_output_tokens=max_tokens,
# No response_mime_type — allow free-form text
)
response = await client.aio.models.generate_content(
model=self._model,
contents=contents,
config=config,
)
if response.candidates:
finish_reason = getattr(response.candidates[0], "finish_reason", None)
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
if str(finish_reason) == "MAX_TOKENS":
logger.warning(
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
max_tokens,
)
text = response.text or ""
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
output_tokens = (
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
)
return text, input_tokens, output_tokens
# Singleton client cache — avoids creating new HTTP connections per call
_anthropic_clients: dict[str, "anthropic.AsyncAnthropic"] = {}
def _get_anthropic_client(api_key: str, timeout: int = 45) -> "anthropic.AsyncAnthropic":
"""Return a cached AsyncAnthropic client, creating one if needed."""
import anthropic
cache_key = f"{api_key[:8]}:{timeout}"
if cache_key not in _anthropic_clients:
_anthropic_clients[cache_key] = anthropic.AsyncAnthropic(
api_key=api_key,
timeout=timeout,
max_retries=1,
)
return _anthropic_clients[cache_key]
class AnthropicProvider(AIProvider):
"""Anthropic Claude provider using the anthropic SDK."""
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
self._api_key = api_key
self._model = model
self._timeout = timeout
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
client = _get_anthropic_client(self._api_key, self._timeout)
response = await client.messages.create(
model=self._model,
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
)
text = response.content[0].text
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
return text, input_tokens, output_tokens
async def generate_text(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
# Anthropic doesn't differentiate between JSON and text mode
return await self.generate_json(system_prompt, messages, max_tokens)
async def generate_text_stream(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> AsyncIterator[str]:
client = _get_anthropic_client(self._api_key, self._timeout)
async with client.messages.stream(
model=self._model,
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
) as stream:
async for text in stream.text_stream:
yield text
def get_ai_provider(model: str | None = None) -> AIProvider:
"""Factory that returns the configured AI provider.
Args:
model: Optional model override (Anthropic model ID). Only applied to
AnthropicProvider; Gemini always uses settings.AI_MODEL_GEMINI.
Selection logic:
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
3. Fallback: if preferred provider key missing, try the other one
4. If nothing configured -> raise RuntimeError
"""
provider = settings.AI_PROVIDER
if provider == "gemini":
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
# Fallback to Anthropic
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=model or settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
elif provider == "anthropic":
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=model or settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
# Fallback to Gemini
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
raise RuntimeError(
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
)