""" AI Provider abstraction layer. Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable backends for JSON generation used by the AI Flow Builder. """ import logging from abc import ABC, abstractmethod from collections.abc import AsyncIterator from app.core.config import settings logger = logging.getLogger(__name__) class AIProvider(ABC): """Abstract base class for AI providers.""" @abstractmethod async def generate_json( self, system_prompt: str, messages: list[dict[str, str]], max_tokens: int = 4096, ) -> tuple[str, int, int]: """Generate a JSON response from the AI model. Args: system_prompt: System-level instruction for the model. messages: List of message dicts with "role" and "content" keys. max_tokens: Maximum output tokens. Returns: Tuple of (response_text, input_tokens, output_tokens). """ ... @abstractmethod async def generate_text( self, system_prompt: str, messages: list[dict[str, str]], max_tokens: int = 4096, ) -> tuple[str, int, int]: """Generate a text response from the AI model (no JSON constraint). Args: system_prompt: System-level instruction for the model. messages: List of message dicts with "role" and "content" keys. max_tokens: Maximum output tokens. Returns: Tuple of (response_text, input_tokens, output_tokens). """ ... async def generate_text_stream( self, system_prompt: str, messages: list[dict[str, str]], max_tokens: int = 4096, ) -> "AsyncIterator[str]": """Stream a text response token by token. Args: system_prompt: System-level instruction for the model. messages: List of message dicts with "role" and "content" keys. max_tokens: Maximum output tokens. Yields: Text chunks as they are generated. """ raise NotImplementedError("Streaming not supported for this provider") # Make this an async generator to satisfy type checker yield "" # pragma: no cover class GeminiProvider(AIProvider): """Google Gemini provider using the google-genai SDK.""" def __init__(self, api_key: str, model: str) -> None: self._api_key = api_key self._model = model async def generate_json( self, system_prompt: str, messages: list[dict[str, str]], max_tokens: int = 4096, ) -> tuple[str, int, int]: from google import genai from google.genai import types as genai_types client = genai.Client(api_key=self._api_key) # Convert messages to Gemini Content format contents: list[genai_types.Content] = [] for msg in messages: role = "model" if msg["role"] == "assistant" else "user" contents.append( genai_types.Content( role=role, parts=[genai_types.Part(text=msg["content"])], ) ) config = genai_types.GenerateContentConfig( system_instruction=system_prompt, max_output_tokens=max_tokens, response_mime_type="application/json", ) response = await client.aio.models.generate_content( model=self._model, contents=contents, config=config, ) # Log finish reason to detect truncation if response.candidates: finish_reason = getattr(response.candidates[0], "finish_reason", None) logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model) if str(finish_reason) == "MAX_TOKENS": logger.warning( "Gemini output truncated (MAX_TOKENS). max_output_tokens=%d", max_tokens, ) text = response.text or "" input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 output_tokens = ( getattr(response.usage_metadata, "candidates_token_count", 0) or 0 ) return text, input_tokens, output_tokens async def generate_text( self, system_prompt: str, messages: list[dict[str, str]], max_tokens: int = 4096, ) -> tuple[str, int, int]: from google import genai from google.genai import types as genai_types client = genai.Client(api_key=self._api_key) contents: list[genai_types.Content] = [] for msg in messages: role = "model" if msg["role"] == "assistant" else "user" contents.append( genai_types.Content( role=role, parts=[genai_types.Part(text=msg["content"])], ) ) config = genai_types.GenerateContentConfig( system_instruction=system_prompt, max_output_tokens=max_tokens, # No response_mime_type — allow free-form text ) response = await client.aio.models.generate_content( model=self._model, contents=contents, config=config, ) if response.candidates: finish_reason = getattr(response.candidates[0], "finish_reason", None) logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model) if str(finish_reason) == "MAX_TOKENS": logger.warning( "Gemini output truncated (MAX_TOKENS). max_output_tokens=%d", max_tokens, ) text = response.text or "" input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 output_tokens = ( getattr(response.usage_metadata, "candidates_token_count", 0) or 0 ) return text, input_tokens, output_tokens # Singleton client cache — avoids creating new HTTP connections per call _anthropic_clients: dict[str, "anthropic.AsyncAnthropic"] = {} def _get_anthropic_client(api_key: str, timeout: int = 45) -> "anthropic.AsyncAnthropic": """Return a cached AsyncAnthropic client, creating one if needed.""" import anthropic cache_key = f"{api_key[:8]}:{timeout}" if cache_key not in _anthropic_clients: _anthropic_clients[cache_key] = anthropic.AsyncAnthropic( api_key=api_key, timeout=timeout, max_retries=1, ) return _anthropic_clients[cache_key] class AnthropicProvider(AIProvider): """Anthropic Claude provider using the anthropic SDK.""" def __init__(self, api_key: str, model: str, timeout: int = 45) -> None: self._api_key = api_key self._model = model self._timeout = timeout async def generate_json( self, system_prompt: str, messages: list[dict[str, str]], max_tokens: int = 4096, ) -> tuple[str, int, int]: client = _get_anthropic_client(self._api_key, self._timeout) response = await client.messages.create( model=self._model, max_tokens=max_tokens, system=system_prompt, messages=messages, ) text = response.content[0].text input_tokens = response.usage.input_tokens output_tokens = response.usage.output_tokens return text, input_tokens, output_tokens async def generate_text( self, system_prompt: str, messages: list[dict[str, str]], max_tokens: int = 4096, ) -> tuple[str, int, int]: # Anthropic doesn't differentiate between JSON and text mode return await self.generate_json(system_prompt, messages, max_tokens) async def generate_text_stream( self, system_prompt: str, messages: list[dict[str, str]], max_tokens: int = 4096, ) -> AsyncIterator[str]: client = _get_anthropic_client(self._api_key, self._timeout) async with client.messages.stream( model=self._model, max_tokens=max_tokens, system=system_prompt, messages=messages, ) as stream: async for text in stream.text_stream: yield text def get_ai_provider(model: str | None = None) -> AIProvider: """Factory that returns the configured AI provider. Args: model: Optional model override (Anthropic model ID). Only applied to AnthropicProvider; Gemini always uses settings.AI_MODEL_GEMINI. Selection logic: 1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider 2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider 3. Fallback: if preferred provider key missing, try the other one 4. If nothing configured -> raise RuntimeError """ provider = settings.AI_PROVIDER if provider == "gemini": if settings.GOOGLE_AI_API_KEY: return GeminiProvider( api_key=settings.GOOGLE_AI_API_KEY, model=settings.AI_MODEL_GEMINI, ) # Fallback to Anthropic if settings.ANTHROPIC_API_KEY: return AnthropicProvider( api_key=settings.ANTHROPIC_API_KEY, model=model or settings.AI_MODEL_ANTHROPIC, timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, ) elif provider == "anthropic": if settings.ANTHROPIC_API_KEY: return AnthropicProvider( api_key=settings.ANTHROPIC_API_KEY, model=model or settings.AI_MODEL_ANTHROPIC, timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, ) # Fallback to Gemini if settings.GOOGLE_AI_API_KEY: return GeminiProvider( api_key=settings.GOOGLE_AI_API_KEY, model=settings.AI_MODEL_GEMINI, ) raise RuntimeError( "No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY." )