""" AI Provider abstraction layer. Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable backends for JSON generation used by the AI Flow Builder. ## Prompt caching (Anthropic only) Callers may pass `system_prompt` as either: - `str` — backward-compatible, uncached. - `list[SystemBlock]` — Anthropic structured system blocks. Each block is a dict of shape `{"type": "text", "text": str, "cache_control": {...}?}`. Caching policy (policy α, per Phase 0.1 design): - If any block in the list carries an explicit `cache_control` key, that caller-authored configuration is honored verbatim. - If no block carries `cache_control`, the provider applies `cache_control: {"type": "ephemeral"}` to the first block only. First block is the common "large static prefix" case (e.g. system prompt, reference data). Gemini ignores cache_control and concatenates list blocks into one system string — callers should not rely on Gemini for cache-hit behavior. TODO(phase0-verify): When a dev environment is available, verify cache-hit behavior by hitting any FlowPilot endpoint twice within the 5-minute ephemeral TTL. First call should emit `anthropic.cache` with `cache_creation_input_tokens > 0`; second call with `cache_read_input_tokens > 0`. If the second call returns zero reads, inspect the prefix for silent invalidators (timestamps, unsorted JSON keys, varying tool list ordering). """ import logging from abc import ABC, abstractmethod from collections.abc import AsyncIterator from typing import Any from app.core.config import settings logger = logging.getLogger(__name__) # Anthropic structured system block. See module docstring for caching policy. SystemBlock = dict[str, Any] def _normalize_system_for_anthropic( system_prompt: str | list[SystemBlock], ) -> str | list[SystemBlock]: """Return the value to pass as the `system=` parameter to the Anthropic API. - Plain strings pass through untouched (uncached path). - Lists are returned as structured system blocks. If no block in the list carries an explicit `cache_control`, `cache_control: {"type": "ephemeral"}` is applied to the FIRST block only (policy α). - Caller-authored `cache_control` is never overwritten. """ if isinstance(system_prompt, str): return system_prompt if not system_prompt: # Empty list is not a meaningful system prompt — pass empty string so # Anthropic treats this as "no system prompt" rather than erroring. return "" blocks = [dict(b) for b in system_prompt] already_cached = any("cache_control" in b for b in blocks) if not already_cached: blocks[0]["cache_control"] = {"type": "ephemeral"} return blocks def _flatten_system_for_gemini( system_prompt: str | list[SystemBlock], ) -> str: """Gemini has no structured system blocks; concatenate list entries.""" if isinstance(system_prompt, str): return system_prompt return "\n\n".join(b.get("text", "") for b in system_prompt) def build_anthropic_chat_messages( history: list[dict[str, Any]], new_message: str, images: list[dict[str, Any]] | None = None, format_reminder: str | None = None, ) -> list[dict[str, Any]]: """Construct the Anthropic `messages` payload for a cached multi-turn chat. Responsibilities: - Copy the valid history messages in order. - Apply `cache_control: ephemeral` to the LAST history message so the entire conversation prefix is cached across turns. The new user message stays uncached (it changes each turn). - Append `format_reminder` to the new user message if provided. The reminder is invisible to storage (caller's concern) but helps enforce structured output compliance at generation time. - If `images` are provided, render the new user message as a multimodal content block list (images first, then text). Otherwise, render it as a plain string. This helper is Anthropic-specific: the cache-breakpoint pattern, ephemeral cache_control, and multimodal block shape are all Anthropic conventions. Do not call it from Gemini code paths. """ messages: list[dict[str, Any]] = [] for msg in history: messages.append({"role": msg["role"], "content": msg["content"]}) # Cache breakpoint on the last existing history message so the entire # conversation prefix is cached across turns. Safe only when there IS a # history message; otherwise the new message is the only message. if messages: last = messages[-1] messages[-1] = { "role": last["role"], "content": [ { "type": "text", "text": last["content"], "cache_control": {"type": "ephemeral"}, } ], } effective_text = new_message + (format_reminder or "") if images: content_blocks: list[dict[str, Any]] = [] for img in images: content_blocks.append( { "type": "image", "source": { "type": "base64", "media_type": img["media_type"], "data": img["data"], }, } ) content_blocks.append({"type": "text", "text": effective_text}) messages.append({"role": "user", "content": content_blocks}) else: messages.append({"role": "user", "content": effective_text}) return messages def _log_anthropic_cache_usage(usage: Any, model: str) -> None: """Emit a structured log line capturing cache_read / cache_creation tokens.""" cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0 cache_creation = getattr(usage, "cache_creation_input_tokens", 0) or 0 input_tokens = getattr(usage, "input_tokens", 0) or 0 output_tokens = getattr(usage, "output_tokens", 0) or 0 if cache_read or cache_creation: logger.info( "anthropic.cache", extra={ "event": "anthropic.cache", "model": model, "cache_read_input_tokens": cache_read, "cache_creation_input_tokens": cache_creation, "input_tokens": input_tokens, "output_tokens": output_tokens, }, ) class AIProvider(ABC): """Abstract base class for AI providers.""" @abstractmethod async def generate_json( self, system_prompt: str | list[SystemBlock], messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: """Generate a JSON response from the AI model. Args: system_prompt: System-level instruction. Plain `str` is uncached (Anthropic) or used as-is (Gemini). `list[SystemBlock]` enables Anthropic prompt caching per module-docstring policy. messages: List of message dicts with "role" and "content" keys. max_tokens: Maximum output tokens. Returns: Tuple of (response_text, input_tokens, output_tokens). """ ... @abstractmethod async def generate_text( self, system_prompt: str | list[SystemBlock], messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: """Generate a text response from the AI model (no JSON constraint). See `generate_json` for argument semantics. """ ... async def generate_text_stream( self, system_prompt: str | list[SystemBlock], messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> "AsyncIterator[str]": """Stream a text response token by token. See `generate_json` for argument semantics. """ raise NotImplementedError("Streaming not supported for this provider") # Make this an async generator to satisfy type checker yield "" # pragma: no cover class GeminiProvider(AIProvider): """Google Gemini provider using the google-genai SDK.""" def __init__(self, api_key: str, model: str) -> None: self._api_key = api_key self._model = model async def generate_json( self, system_prompt: str | list[SystemBlock], messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: from google import genai from google.genai import types as genai_types client = genai.Client(api_key=self._api_key) system_text = _flatten_system_for_gemini(system_prompt) # Convert messages to Gemini Content format contents: list[genai_types.Content] = [] for msg in messages: role = "model" if msg["role"] == "assistant" else "user" contents.append( genai_types.Content( role=role, parts=[genai_types.Part(text=msg["content"])], ) ) config = genai_types.GenerateContentConfig( system_instruction=system_text, max_output_tokens=max_tokens, response_mime_type="application/json", ) response = await client.aio.models.generate_content( model=self._model, contents=contents, config=config, ) # Log finish reason to detect truncation if response.candidates: finish_reason = getattr(response.candidates[0], "finish_reason", None) logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model) if str(finish_reason) == "MAX_TOKENS": logger.warning( "Gemini output truncated (MAX_TOKENS). max_output_tokens=%d", max_tokens, ) text = response.text or "" input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 output_tokens = ( getattr(response.usage_metadata, "candidates_token_count", 0) or 0 ) return text, input_tokens, output_tokens async def generate_text( self, system_prompt: str | list[SystemBlock], messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: from google import genai from google.genai import types as genai_types client = genai.Client(api_key=self._api_key) system_text = _flatten_system_for_gemini(system_prompt) contents: list[genai_types.Content] = [] for msg in messages: role = "model" if msg["role"] == "assistant" else "user" contents.append( genai_types.Content( role=role, parts=[genai_types.Part(text=msg["content"])], ) ) config = genai_types.GenerateContentConfig( system_instruction=system_text, max_output_tokens=max_tokens, # No response_mime_type — allow free-form text ) response = await client.aio.models.generate_content( model=self._model, contents=contents, config=config, ) if response.candidates: finish_reason = getattr(response.candidates[0], "finish_reason", None) logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model) if str(finish_reason) == "MAX_TOKENS": logger.warning( "Gemini output truncated (MAX_TOKENS). max_output_tokens=%d", max_tokens, ) text = response.text or "" input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 output_tokens = ( getattr(response.usage_metadata, "candidates_token_count", 0) or 0 ) return text, input_tokens, output_tokens # Singleton client cache — avoids creating new HTTP connections per call _anthropic_clients: dict[str, "anthropic.AsyncAnthropic"] = {} def _get_anthropic_client(api_key: str, timeout: int = 45) -> "anthropic.AsyncAnthropic": """Return a cached AsyncAnthropic client, creating one if needed.""" import anthropic cache_key = f"{api_key[:8]}:{timeout}" if cache_key not in _anthropic_clients: _anthropic_clients[cache_key] = anthropic.AsyncAnthropic( api_key=api_key, timeout=timeout, max_retries=1, ) return _anthropic_clients[cache_key] class AnthropicProvider(AIProvider): """Anthropic Claude provider using the anthropic SDK.""" def __init__(self, api_key: str, model: str, timeout: int = 45) -> None: self._api_key = api_key self._model = model self._timeout = timeout async def generate_json( self, system_prompt: str | list[SystemBlock], messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: client = _get_anthropic_client(self._api_key, self._timeout) normalized_system = _normalize_system_for_anthropic(system_prompt) response = await client.messages.create( model=self._model, max_tokens=max_tokens, system=normalized_system, messages=messages, ) text = response.content[0].text input_tokens = response.usage.input_tokens output_tokens = response.usage.output_tokens _log_anthropic_cache_usage(response.usage, self._model) return text, input_tokens, output_tokens async def generate_text( self, system_prompt: str | list[SystemBlock], messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: # Anthropic doesn't differentiate between JSON and text mode return await self.generate_json(system_prompt, messages, max_tokens) async def generate_text_stream( self, system_prompt: str | list[SystemBlock], messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> AsyncIterator[str]: client = _get_anthropic_client(self._api_key, self._timeout) normalized_system = _normalize_system_for_anthropic(system_prompt) async with client.messages.stream( model=self._model, max_tokens=max_tokens, system=normalized_system, messages=messages, ) as stream: async for text in stream.text_stream: yield text # Per Anthropic SDK, get_final_message() resolves the stream's # final usage object (including cache_read/cache_creation tokens). try: final = await stream.get_final_message() _log_anthropic_cache_usage(final.usage, self._model) except Exception as exc: # best-effort telemetry, never fail the stream logger.debug("anthropic.cache streaming usage unavailable: %s", exc) def get_ai_provider(model: str | None = None) -> AIProvider: """Factory that returns the configured AI provider. Args: model: Optional model override (Anthropic model ID). Only applied to AnthropicProvider; Gemini always uses settings.AI_MODEL_GEMINI. Selection logic: 1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider 2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider 3. Fallback: if preferred provider key missing, try the other one 4. If nothing configured -> raise RuntimeError """ provider = settings.AI_PROVIDER if provider == "gemini": if settings.GOOGLE_AI_API_KEY: return GeminiProvider( api_key=settings.GOOGLE_AI_API_KEY, model=settings.AI_MODEL_GEMINI, ) # Fallback to Anthropic if settings.ANTHROPIC_API_KEY: return AnthropicProvider( api_key=settings.ANTHROPIC_API_KEY, model=model or settings.AI_MODEL_ANTHROPIC, timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, ) elif provider == "anthropic": if settings.ANTHROPIC_API_KEY: return AnthropicProvider( api_key=settings.ANTHROPIC_API_KEY, model=model or settings.AI_MODEL_ANTHROPIC, timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, ) # Fallback to Gemini if settings.GOOGLE_AI_API_KEY: return GeminiProvider( api_key=settings.GOOGLE_AI_API_KEY, model=settings.AI_MODEL_GEMINI, ) raise RuntimeError( "No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY." )