271 lines
8.7 KiB
Python
271 lines
8.7 KiB
Python
"""
|
|
AI Provider abstraction layer.
|
|
|
|
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
|
|
backends for JSON generation used by the AI Flow Builder.
|
|
"""
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AIProvider(ABC):
|
|
"""Abstract base class for AI providers."""
|
|
|
|
@abstractmethod
|
|
async def generate_json(
|
|
self,
|
|
system_prompt: str,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int = 4096,
|
|
) -> tuple[str, int, int]:
|
|
"""Generate a JSON response from the AI model.
|
|
|
|
Args:
|
|
system_prompt: System-level instruction for the model.
|
|
messages: List of message dicts with "role" and "content" keys.
|
|
max_tokens: Maximum output tokens.
|
|
|
|
Returns:
|
|
Tuple of (response_text, input_tokens, output_tokens).
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def generate_text(
|
|
self,
|
|
system_prompt: str,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int = 4096,
|
|
) -> tuple[str, int, int]:
|
|
"""Generate a text response from the AI model (no JSON constraint).
|
|
|
|
Args:
|
|
system_prompt: System-level instruction for the model.
|
|
messages: List of message dicts with "role" and "content" keys.
|
|
max_tokens: Maximum output tokens.
|
|
|
|
Returns:
|
|
Tuple of (response_text, input_tokens, output_tokens).
|
|
"""
|
|
...
|
|
|
|
|
|
class GeminiProvider(AIProvider):
|
|
"""Google Gemini provider using the google-genai SDK."""
|
|
|
|
def __init__(self, api_key: str, model: str) -> None:
|
|
self._api_key = api_key
|
|
self._model = model
|
|
|
|
async def generate_json(
|
|
self,
|
|
system_prompt: str,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int = 4096,
|
|
) -> tuple[str, int, int]:
|
|
from google import genai
|
|
from google.genai import types as genai_types
|
|
|
|
client = genai.Client(api_key=self._api_key)
|
|
|
|
# Convert messages to Gemini Content format
|
|
contents: list[genai_types.Content] = []
|
|
for msg in messages:
|
|
role = "model" if msg["role"] == "assistant" else "user"
|
|
contents.append(
|
|
genai_types.Content(
|
|
role=role,
|
|
parts=[genai_types.Part(text=msg["content"])],
|
|
)
|
|
)
|
|
|
|
config = genai_types.GenerateContentConfig(
|
|
system_instruction=system_prompt,
|
|
max_output_tokens=max_tokens,
|
|
response_mime_type="application/json",
|
|
)
|
|
|
|
response = await client.aio.models.generate_content(
|
|
model=self._model,
|
|
contents=contents,
|
|
config=config,
|
|
)
|
|
|
|
# Log finish reason to detect truncation
|
|
if response.candidates:
|
|
finish_reason = getattr(response.candidates[0], "finish_reason", None)
|
|
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
|
|
if str(finish_reason) == "MAX_TOKENS":
|
|
logger.warning(
|
|
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
|
|
max_tokens,
|
|
)
|
|
|
|
text = response.text or ""
|
|
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
|
output_tokens = (
|
|
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
|
)
|
|
|
|
return text, input_tokens, output_tokens
|
|
|
|
async def generate_text(
|
|
self,
|
|
system_prompt: str,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int = 4096,
|
|
) -> tuple[str, int, int]:
|
|
from google import genai
|
|
from google.genai import types as genai_types
|
|
|
|
client = genai.Client(api_key=self._api_key)
|
|
|
|
contents: list[genai_types.Content] = []
|
|
for msg in messages:
|
|
role = "model" if msg["role"] == "assistant" else "user"
|
|
contents.append(
|
|
genai_types.Content(
|
|
role=role,
|
|
parts=[genai_types.Part(text=msg["content"])],
|
|
)
|
|
)
|
|
|
|
config = genai_types.GenerateContentConfig(
|
|
system_instruction=system_prompt,
|
|
max_output_tokens=max_tokens,
|
|
# No response_mime_type — allow free-form text
|
|
)
|
|
|
|
response = await client.aio.models.generate_content(
|
|
model=self._model,
|
|
contents=contents,
|
|
config=config,
|
|
)
|
|
|
|
if response.candidates:
|
|
finish_reason = getattr(response.candidates[0], "finish_reason", None)
|
|
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
|
|
if str(finish_reason) == "MAX_TOKENS":
|
|
logger.warning(
|
|
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
|
|
max_tokens,
|
|
)
|
|
|
|
text = response.text or ""
|
|
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
|
output_tokens = (
|
|
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
|
)
|
|
|
|
return text, input_tokens, output_tokens
|
|
|
|
|
|
# Singleton client cache — avoids creating new HTTP connections per call
|
|
_anthropic_clients: dict[str, "anthropic.AsyncAnthropic"] = {}
|
|
|
|
|
|
def _get_anthropic_client(api_key: str, timeout: int = 45) -> "anthropic.AsyncAnthropic":
|
|
"""Return a cached AsyncAnthropic client, creating one if needed."""
|
|
import anthropic
|
|
|
|
cache_key = f"{api_key[:8]}:{timeout}"
|
|
if cache_key not in _anthropic_clients:
|
|
_anthropic_clients[cache_key] = anthropic.AsyncAnthropic(
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
max_retries=1,
|
|
)
|
|
return _anthropic_clients[cache_key]
|
|
|
|
|
|
class AnthropicProvider(AIProvider):
|
|
"""Anthropic Claude provider using the anthropic SDK."""
|
|
|
|
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
|
|
self._api_key = api_key
|
|
self._model = model
|
|
self._timeout = timeout
|
|
|
|
async def generate_json(
|
|
self,
|
|
system_prompt: str,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int = 4096,
|
|
) -> tuple[str, int, int]:
|
|
client = _get_anthropic_client(self._api_key, self._timeout)
|
|
|
|
response = await client.messages.create(
|
|
model=self._model,
|
|
max_tokens=max_tokens,
|
|
system=system_prompt,
|
|
messages=messages,
|
|
)
|
|
|
|
text = response.content[0].text
|
|
input_tokens = response.usage.input_tokens
|
|
output_tokens = response.usage.output_tokens
|
|
|
|
return text, input_tokens, output_tokens
|
|
|
|
async def generate_text(
|
|
self,
|
|
system_prompt: str,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int = 4096,
|
|
) -> tuple[str, int, int]:
|
|
# Anthropic doesn't differentiate between JSON and text mode
|
|
return await self.generate_json(system_prompt, messages, max_tokens)
|
|
|
|
|
|
def get_ai_provider(model: str | None = None) -> AIProvider:
|
|
"""Factory that returns the configured AI provider.
|
|
|
|
Args:
|
|
model: Optional model override (Anthropic model ID). Only applied to
|
|
AnthropicProvider; Gemini always uses settings.AI_MODEL_GEMINI.
|
|
|
|
Selection logic:
|
|
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
|
|
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
|
|
3. Fallback: if preferred provider key missing, try the other one
|
|
4. If nothing configured -> raise RuntimeError
|
|
"""
|
|
provider = settings.AI_PROVIDER
|
|
|
|
if provider == "gemini":
|
|
if settings.GOOGLE_AI_API_KEY:
|
|
return GeminiProvider(
|
|
api_key=settings.GOOGLE_AI_API_KEY,
|
|
model=settings.AI_MODEL_GEMINI,
|
|
)
|
|
# Fallback to Anthropic
|
|
if settings.ANTHROPIC_API_KEY:
|
|
return AnthropicProvider(
|
|
api_key=settings.ANTHROPIC_API_KEY,
|
|
model=model or settings.AI_MODEL_ANTHROPIC,
|
|
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
|
)
|
|
|
|
elif provider == "anthropic":
|
|
if settings.ANTHROPIC_API_KEY:
|
|
return AnthropicProvider(
|
|
api_key=settings.ANTHROPIC_API_KEY,
|
|
model=model or settings.AI_MODEL_ANTHROPIC,
|
|
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
|
)
|
|
# Fallback to Gemini
|
|
if settings.GOOGLE_AI_API_KEY:
|
|
return GeminiProvider(
|
|
api_key=settings.GOOGLE_AI_API_KEY,
|
|
model=settings.AI_MODEL_GEMINI,
|
|
)
|
|
|
|
raise RuntimeError(
|
|
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
|
|
)
|