Widens AIProvider.generate_json / generate_text / generate_text_stream
signatures to accept `system_prompt: str | list[SystemBlock]`:
- `str` (the existing call shape): passes through uncached, unchanged
behavior. Every existing caller stays on the uncached path — no silent
behavior change.
- `list[SystemBlock]`: enables Anthropic prompt caching via structured
system blocks. Caller-authored `cache_control` is honored verbatim
(policy α); if no block carries it, the provider applies
`cache_control: {"type": "ephemeral"}` to the first block only.
Gemini ignores cache_control and concatenates list entries into one
system string — the widened signature is strictly additive on that path.
Adds `anthropic.cache` structured-log telemetry: on every Anthropic
response (streaming included, via `stream.get_final_message()`), logs
`cache_read_input_tokens` and `cache_creation_input_tokens`. Telemetry
failure in streaming is swallowed so the user-facing stream never breaks.
Verification deferred: cannot run from code-server (no Python, no DB,
no dev env). TODO(phase0-verify) left inline in the module docstring.
First verification task on the new dev environment is to hit any
FlowPilot endpoint twice within 5 minutes and confirm the second call
shows cache_read_input_tokens > 0 in the `anthropic.cache` log event.
If verification fails, that's a debug task on the new env — not a
blocker for continuing Phase 0.2/0.3/0.4.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
398 lines
14 KiB
Python
398 lines
14 KiB
Python
"""
|
||
AI Provider abstraction layer.
|
||
|
||
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
|
||
backends for JSON generation used by the AI Flow Builder.
|
||
|
||
## Prompt caching (Anthropic only)
|
||
|
||
Callers may pass `system_prompt` as either:
|
||
|
||
- `str` — backward-compatible, uncached.
|
||
- `list[SystemBlock]` — Anthropic structured system blocks. Each block is a
|
||
dict of shape `{"type": "text", "text": str, "cache_control": {...}?}`.
|
||
|
||
Caching policy (policy α, per Phase 0.1 design):
|
||
- If any block in the list carries an explicit `cache_control` key, that
|
||
caller-authored configuration is honored verbatim.
|
||
- If no block carries `cache_control`, the provider applies
|
||
`cache_control: {"type": "ephemeral"}` to the first block only. First block
|
||
is the common "large static prefix" case (e.g. system prompt, reference data).
|
||
|
||
Gemini ignores cache_control and concatenates list blocks into one system
|
||
string — callers should not rely on Gemini for cache-hit behavior.
|
||
|
||
TODO(phase0-verify): When a dev environment is available, verify cache-hit
|
||
behavior by hitting any FlowPilot endpoint twice within the 5-minute
|
||
ephemeral TTL. First call should emit `anthropic.cache` with
|
||
`cache_creation_input_tokens > 0`; second call with `cache_read_input_tokens > 0`.
|
||
If the second call returns zero reads, inspect the prefix for silent
|
||
invalidators (timestamps, unsorted JSON keys, varying tool list ordering).
|
||
"""
|
||
|
||
import logging
|
||
from abc import ABC, abstractmethod
|
||
from collections.abc import AsyncIterator
|
||
from typing import Any
|
||
|
||
from app.core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Anthropic structured system block. See module docstring for caching policy.
|
||
SystemBlock = dict[str, Any]
|
||
|
||
|
||
def _normalize_system_for_anthropic(
|
||
system_prompt: str | list[SystemBlock],
|
||
) -> str | list[SystemBlock]:
|
||
"""Return the value to pass as the `system=` parameter to the Anthropic API.
|
||
|
||
- Plain strings pass through untouched (uncached path).
|
||
- Lists are returned as structured system blocks. If no block in the list
|
||
carries an explicit `cache_control`, `cache_control: {"type": "ephemeral"}`
|
||
is applied to the FIRST block only (policy α).
|
||
- Caller-authored `cache_control` is never overwritten.
|
||
"""
|
||
if isinstance(system_prompt, str):
|
||
return system_prompt
|
||
|
||
if not system_prompt:
|
||
# Empty list is not a meaningful system prompt — pass empty string so
|
||
# Anthropic treats this as "no system prompt" rather than erroring.
|
||
return ""
|
||
|
||
blocks = [dict(b) for b in system_prompt]
|
||
already_cached = any("cache_control" in b for b in blocks)
|
||
|
||
if not already_cached:
|
||
blocks[0]["cache_control"] = {"type": "ephemeral"}
|
||
|
||
return blocks
|
||
|
||
|
||
def _flatten_system_for_gemini(
|
||
system_prompt: str | list[SystemBlock],
|
||
) -> str:
|
||
"""Gemini has no structured system blocks; concatenate list entries."""
|
||
if isinstance(system_prompt, str):
|
||
return system_prompt
|
||
return "\n\n".join(b.get("text", "") for b in system_prompt)
|
||
|
||
|
||
def _log_anthropic_cache_usage(usage: Any, model: str) -> None:
|
||
"""Emit a structured log line capturing cache_read / cache_creation tokens."""
|
||
cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
|
||
cache_creation = getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||
input_tokens = getattr(usage, "input_tokens", 0) or 0
|
||
output_tokens = getattr(usage, "output_tokens", 0) or 0
|
||
if cache_read or cache_creation:
|
||
logger.info(
|
||
"anthropic.cache",
|
||
extra={
|
||
"event": "anthropic.cache",
|
||
"model": model,
|
||
"cache_read_input_tokens": cache_read,
|
||
"cache_creation_input_tokens": cache_creation,
|
||
"input_tokens": input_tokens,
|
||
"output_tokens": output_tokens,
|
||
},
|
||
)
|
||
|
||
|
||
class AIProvider(ABC):
|
||
"""Abstract base class for AI providers."""
|
||
|
||
@abstractmethod
|
||
async def generate_json(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
"""Generate a JSON response from the AI model.
|
||
|
||
Args:
|
||
system_prompt: System-level instruction. Plain `str` is uncached
|
||
(Anthropic) or used as-is (Gemini). `list[SystemBlock]` enables
|
||
Anthropic prompt caching per module-docstring policy.
|
||
messages: List of message dicts with "role" and "content" keys.
|
||
max_tokens: Maximum output tokens.
|
||
|
||
Returns:
|
||
Tuple of (response_text, input_tokens, output_tokens).
|
||
"""
|
||
...
|
||
|
||
@abstractmethod
|
||
async def generate_text(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
"""Generate a text response from the AI model (no JSON constraint).
|
||
|
||
See `generate_json` for argument semantics.
|
||
"""
|
||
...
|
||
|
||
async def generate_text_stream(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> "AsyncIterator[str]":
|
||
"""Stream a text response token by token.
|
||
|
||
See `generate_json` for argument semantics.
|
||
"""
|
||
raise NotImplementedError("Streaming not supported for this provider")
|
||
# Make this an async generator to satisfy type checker
|
||
yield "" # pragma: no cover
|
||
|
||
|
||
class GeminiProvider(AIProvider):
|
||
"""Google Gemini provider using the google-genai SDK."""
|
||
|
||
def __init__(self, api_key: str, model: str) -> None:
|
||
self._api_key = api_key
|
||
self._model = model
|
||
|
||
async def generate_json(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
from google import genai
|
||
from google.genai import types as genai_types
|
||
|
||
client = genai.Client(api_key=self._api_key)
|
||
system_text = _flatten_system_for_gemini(system_prompt)
|
||
|
||
# Convert messages to Gemini Content format
|
||
contents: list[genai_types.Content] = []
|
||
for msg in messages:
|
||
role = "model" if msg["role"] == "assistant" else "user"
|
||
contents.append(
|
||
genai_types.Content(
|
||
role=role,
|
||
parts=[genai_types.Part(text=msg["content"])],
|
||
)
|
||
)
|
||
|
||
config = genai_types.GenerateContentConfig(
|
||
system_instruction=system_text,
|
||
max_output_tokens=max_tokens,
|
||
response_mime_type="application/json",
|
||
)
|
||
|
||
response = await client.aio.models.generate_content(
|
||
model=self._model,
|
||
contents=contents,
|
||
config=config,
|
||
)
|
||
|
||
# Log finish reason to detect truncation
|
||
if response.candidates:
|
||
finish_reason = getattr(response.candidates[0], "finish_reason", None)
|
||
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
|
||
if str(finish_reason) == "MAX_TOKENS":
|
||
logger.warning(
|
||
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
|
||
max_tokens,
|
||
)
|
||
|
||
text = response.text or ""
|
||
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
||
output_tokens = (
|
||
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
||
)
|
||
|
||
return text, input_tokens, output_tokens
|
||
|
||
async def generate_text(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
from google import genai
|
||
from google.genai import types as genai_types
|
||
|
||
client = genai.Client(api_key=self._api_key)
|
||
system_text = _flatten_system_for_gemini(system_prompt)
|
||
|
||
contents: list[genai_types.Content] = []
|
||
for msg in messages:
|
||
role = "model" if msg["role"] == "assistant" else "user"
|
||
contents.append(
|
||
genai_types.Content(
|
||
role=role,
|
||
parts=[genai_types.Part(text=msg["content"])],
|
||
)
|
||
)
|
||
|
||
config = genai_types.GenerateContentConfig(
|
||
system_instruction=system_text,
|
||
max_output_tokens=max_tokens,
|
||
# No response_mime_type — allow free-form text
|
||
)
|
||
|
||
response = await client.aio.models.generate_content(
|
||
model=self._model,
|
||
contents=contents,
|
||
config=config,
|
||
)
|
||
|
||
if response.candidates:
|
||
finish_reason = getattr(response.candidates[0], "finish_reason", None)
|
||
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
|
||
if str(finish_reason) == "MAX_TOKENS":
|
||
logger.warning(
|
||
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
|
||
max_tokens,
|
||
)
|
||
|
||
text = response.text or ""
|
||
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
||
output_tokens = (
|
||
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
||
)
|
||
|
||
return text, input_tokens, output_tokens
|
||
|
||
|
||
# Singleton client cache — avoids creating new HTTP connections per call
|
||
_anthropic_clients: dict[str, "anthropic.AsyncAnthropic"] = {}
|
||
|
||
|
||
def _get_anthropic_client(api_key: str, timeout: int = 45) -> "anthropic.AsyncAnthropic":
|
||
"""Return a cached AsyncAnthropic client, creating one if needed."""
|
||
import anthropic
|
||
|
||
cache_key = f"{api_key[:8]}:{timeout}"
|
||
if cache_key not in _anthropic_clients:
|
||
_anthropic_clients[cache_key] = anthropic.AsyncAnthropic(
|
||
api_key=api_key,
|
||
timeout=timeout,
|
||
max_retries=1,
|
||
)
|
||
return _anthropic_clients[cache_key]
|
||
|
||
|
||
class AnthropicProvider(AIProvider):
|
||
"""Anthropic Claude provider using the anthropic SDK."""
|
||
|
||
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
|
||
self._api_key = api_key
|
||
self._model = model
|
||
self._timeout = timeout
|
||
|
||
async def generate_json(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
client = _get_anthropic_client(self._api_key, self._timeout)
|
||
normalized_system = _normalize_system_for_anthropic(system_prompt)
|
||
|
||
response = await client.messages.create(
|
||
model=self._model,
|
||
max_tokens=max_tokens,
|
||
system=normalized_system,
|
||
messages=messages,
|
||
)
|
||
|
||
text = response.content[0].text
|
||
input_tokens = response.usage.input_tokens
|
||
output_tokens = response.usage.output_tokens
|
||
|
||
_log_anthropic_cache_usage(response.usage, self._model)
|
||
|
||
return text, input_tokens, output_tokens
|
||
|
||
async def generate_text(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
# Anthropic doesn't differentiate between JSON and text mode
|
||
return await self.generate_json(system_prompt, messages, max_tokens)
|
||
|
||
async def generate_text_stream(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> AsyncIterator[str]:
|
||
client = _get_anthropic_client(self._api_key, self._timeout)
|
||
normalized_system = _normalize_system_for_anthropic(system_prompt)
|
||
|
||
async with client.messages.stream(
|
||
model=self._model,
|
||
max_tokens=max_tokens,
|
||
system=normalized_system,
|
||
messages=messages,
|
||
) as stream:
|
||
async for text in stream.text_stream:
|
||
yield text
|
||
# Per Anthropic SDK, get_final_message() resolves the stream's
|
||
# final usage object (including cache_read/cache_creation tokens).
|
||
try:
|
||
final = await stream.get_final_message()
|
||
_log_anthropic_cache_usage(final.usage, self._model)
|
||
except Exception as exc: # best-effort telemetry, never fail the stream
|
||
logger.debug("anthropic.cache streaming usage unavailable: %s", exc)
|
||
|
||
|
||
def get_ai_provider(model: str | None = None) -> AIProvider:
|
||
"""Factory that returns the configured AI provider.
|
||
|
||
Args:
|
||
model: Optional model override (Anthropic model ID). Only applied to
|
||
AnthropicProvider; Gemini always uses settings.AI_MODEL_GEMINI.
|
||
|
||
Selection logic:
|
||
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
|
||
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
|
||
3. Fallback: if preferred provider key missing, try the other one
|
||
4. If nothing configured -> raise RuntimeError
|
||
"""
|
||
provider = settings.AI_PROVIDER
|
||
|
||
if provider == "gemini":
|
||
if settings.GOOGLE_AI_API_KEY:
|
||
return GeminiProvider(
|
||
api_key=settings.GOOGLE_AI_API_KEY,
|
||
model=settings.AI_MODEL_GEMINI,
|
||
)
|
||
# Fallback to Anthropic
|
||
if settings.ANTHROPIC_API_KEY:
|
||
return AnthropicProvider(
|
||
api_key=settings.ANTHROPIC_API_KEY,
|
||
model=model or settings.AI_MODEL_ANTHROPIC,
|
||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||
)
|
||
|
||
elif provider == "anthropic":
|
||
if settings.ANTHROPIC_API_KEY:
|
||
return AnthropicProvider(
|
||
api_key=settings.ANTHROPIC_API_KEY,
|
||
model=model or settings.AI_MODEL_ANTHROPIC,
|
||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||
)
|
||
# Fallback to Gemini
|
||
if settings.GOOGLE_AI_API_KEY:
|
||
return GeminiProvider(
|
||
api_key=settings.GOOGLE_AI_API_KEY,
|
||
model=settings.AI_MODEL_GEMINI,
|
||
)
|
||
|
||
raise RuntimeError(
|
||
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
|
||
)
|