diff --git a/backend/app/core/ai_provider.py b/backend/app/core/ai_provider.py index 84d4c33b..71d376ad 100644 --- a/backend/app/core/ai_provider.py +++ b/backend/app/core/ai_provider.py @@ -3,16 +3,102 @@ AI Provider abstraction layer. Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable backends for JSON generation used by the AI Flow Builder. + +## Prompt caching (Anthropic only) + +Callers may pass `system_prompt` as either: + +- `str` — backward-compatible, uncached. +- `list[SystemBlock]` — Anthropic structured system blocks. Each block is a + dict of shape `{"type": "text", "text": str, "cache_control": {...}?}`. + +Caching policy (policy α, per Phase 0.1 design): +- If any block in the list carries an explicit `cache_control` key, that + caller-authored configuration is honored verbatim. +- If no block carries `cache_control`, the provider applies + `cache_control: {"type": "ephemeral"}` to the first block only. First block + is the common "large static prefix" case (e.g. system prompt, reference data). + +Gemini ignores cache_control and concatenates list blocks into one system +string — callers should not rely on Gemini for cache-hit behavior. + +TODO(phase0-verify): When a dev environment is available, verify cache-hit +behavior by hitting any FlowPilot endpoint twice within the 5-minute +ephemeral TTL. First call should emit `anthropic.cache` with +`cache_creation_input_tokens > 0`; second call with `cache_read_input_tokens > 0`. +If the second call returns zero reads, inspect the prefix for silent +invalidators (timestamps, unsorted JSON keys, varying tool list ordering). """ import logging from abc import ABC, abstractmethod from collections.abc import AsyncIterator +from typing import Any from app.core.config import settings logger = logging.getLogger(__name__) +# Anthropic structured system block. See module docstring for caching policy. +SystemBlock = dict[str, Any] + + +def _normalize_system_for_anthropic( + system_prompt: str | list[SystemBlock], +) -> str | list[SystemBlock]: + """Return the value to pass as the `system=` parameter to the Anthropic API. + + - Plain strings pass through untouched (uncached path). + - Lists are returned as structured system blocks. If no block in the list + carries an explicit `cache_control`, `cache_control: {"type": "ephemeral"}` + is applied to the FIRST block only (policy α). + - Caller-authored `cache_control` is never overwritten. + """ + if isinstance(system_prompt, str): + return system_prompt + + if not system_prompt: + # Empty list is not a meaningful system prompt — pass empty string so + # Anthropic treats this as "no system prompt" rather than erroring. + return "" + + blocks = [dict(b) for b in system_prompt] + already_cached = any("cache_control" in b for b in blocks) + + if not already_cached: + blocks[0]["cache_control"] = {"type": "ephemeral"} + + return blocks + + +def _flatten_system_for_gemini( + system_prompt: str | list[SystemBlock], +) -> str: + """Gemini has no structured system blocks; concatenate list entries.""" + if isinstance(system_prompt, str): + return system_prompt + return "\n\n".join(b.get("text", "") for b in system_prompt) + + +def _log_anthropic_cache_usage(usage: Any, model: str) -> None: + """Emit a structured log line capturing cache_read / cache_creation tokens.""" + cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0 + cache_creation = getattr(usage, "cache_creation_input_tokens", 0) or 0 + input_tokens = getattr(usage, "input_tokens", 0) or 0 + output_tokens = getattr(usage, "output_tokens", 0) or 0 + if cache_read or cache_creation: + logger.info( + "anthropic.cache", + extra={ + "event": "anthropic.cache", + "model": model, + "cache_read_input_tokens": cache_read, + "cache_creation_input_tokens": cache_creation, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + ) + class AIProvider(ABC): """Abstract base class for AI providers.""" @@ -20,14 +106,16 @@ class AIProvider(ABC): @abstractmethod async def generate_json( self, - system_prompt: str, - messages: list[dict[str, str]], + system_prompt: str | list[SystemBlock], + messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: """Generate a JSON response from the AI model. Args: - system_prompt: System-level instruction for the model. + system_prompt: System-level instruction. Plain `str` is uncached + (Anthropic) or used as-is (Gemini). `list[SystemBlock]` enables + Anthropic prompt caching per module-docstring policy. messages: List of message dicts with "role" and "content" keys. max_tokens: Maximum output tokens. @@ -39,37 +127,25 @@ class AIProvider(ABC): @abstractmethod async def generate_text( self, - system_prompt: str, - messages: list[dict[str, str]], + system_prompt: str | list[SystemBlock], + messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: """Generate a text response from the AI model (no JSON constraint). - Args: - system_prompt: System-level instruction for the model. - messages: List of message dicts with "role" and "content" keys. - max_tokens: Maximum output tokens. - - Returns: - Tuple of (response_text, input_tokens, output_tokens). + See `generate_json` for argument semantics. """ ... async def generate_text_stream( self, - system_prompt: str, - messages: list[dict[str, str]], + system_prompt: str | list[SystemBlock], + messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> "AsyncIterator[str]": """Stream a text response token by token. - Args: - system_prompt: System-level instruction for the model. - messages: List of message dicts with "role" and "content" keys. - max_tokens: Maximum output tokens. - - Yields: - Text chunks as they are generated. + See `generate_json` for argument semantics. """ raise NotImplementedError("Streaming not supported for this provider") # Make this an async generator to satisfy type checker @@ -85,14 +161,15 @@ class GeminiProvider(AIProvider): async def generate_json( self, - system_prompt: str, - messages: list[dict[str, str]], + system_prompt: str | list[SystemBlock], + messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: from google import genai from google.genai import types as genai_types client = genai.Client(api_key=self._api_key) + system_text = _flatten_system_for_gemini(system_prompt) # Convert messages to Gemini Content format contents: list[genai_types.Content] = [] @@ -106,7 +183,7 @@ class GeminiProvider(AIProvider): ) config = genai_types.GenerateContentConfig( - system_instruction=system_prompt, + system_instruction=system_text, max_output_tokens=max_tokens, response_mime_type="application/json", ) @@ -137,14 +214,15 @@ class GeminiProvider(AIProvider): async def generate_text( self, - system_prompt: str, - messages: list[dict[str, str]], + system_prompt: str | list[SystemBlock], + messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: from google import genai from google.genai import types as genai_types client = genai.Client(api_key=self._api_key) + system_text = _flatten_system_for_gemini(system_prompt) contents: list[genai_types.Content] = [] for msg in messages: @@ -157,7 +235,7 @@ class GeminiProvider(AIProvider): ) config = genai_types.GenerateContentConfig( - system_instruction=system_prompt, + system_instruction=system_text, max_output_tokens=max_tokens, # No response_mime_type — allow free-form text ) @@ -214,16 +292,17 @@ class AnthropicProvider(AIProvider): async def generate_json( self, - system_prompt: str, - messages: list[dict[str, str]], + system_prompt: str | list[SystemBlock], + messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: client = _get_anthropic_client(self._api_key, self._timeout) + normalized_system = _normalize_system_for_anthropic(system_prompt) response = await client.messages.create( model=self._model, max_tokens=max_tokens, - system=system_prompt, + system=normalized_system, messages=messages, ) @@ -231,12 +310,14 @@ class AnthropicProvider(AIProvider): input_tokens = response.usage.input_tokens output_tokens = response.usage.output_tokens + _log_anthropic_cache_usage(response.usage, self._model) + return text, input_tokens, output_tokens async def generate_text( self, - system_prompt: str, - messages: list[dict[str, str]], + system_prompt: str | list[SystemBlock], + messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> tuple[str, int, int]: # Anthropic doesn't differentiate between JSON and text mode @@ -244,20 +325,28 @@ class AnthropicProvider(AIProvider): async def generate_text_stream( self, - system_prompt: str, - messages: list[dict[str, str]], + system_prompt: str | list[SystemBlock], + messages: list[dict[str, Any]], max_tokens: int = 4096, ) -> AsyncIterator[str]: client = _get_anthropic_client(self._api_key, self._timeout) + normalized_system = _normalize_system_for_anthropic(system_prompt) async with client.messages.stream( model=self._model, max_tokens=max_tokens, - system=system_prompt, + system=normalized_system, messages=messages, ) as stream: async for text in stream.text_stream: yield text + # Per Anthropic SDK, get_final_message() resolves the stream's + # final usage object (including cache_read/cache_creation tokens). + try: + final = await stream.get_final_message() + _log_anthropic_cache_usage(final.usage, self._model) + except Exception as exc: # best-effort telemetry, never fail the stream + logger.debug("anthropic.cache streaming usage unavailable: %s", exc) def get_ai_provider(model: str | None = None) -> AIProvider: