Renames the chat caller to a name that signals its actual purpose, and factors the reusable cached-system-block + cached-history + cache-usage-log primitives out to app.core.ai_provider so they can be shared with the provider-generic path without pulling MCP/beta/images into the abstract interface. Helpers added to ai_provider.py: - `build_anthropic_chat_messages(history, new_message, images, format_reminder)` — owns: copy history, apply cache_control to last history message, append format reminder to new message, render images as multimodal blocks. Anthropic-shaped by design; do not call from Gemini paths. chat_call_cached keeps exactly the concerns that are unique to the one MCP/beta/multimodal chat caller: - Anthropic beta endpoint invocation - Microsoft Learn MCP server wiring (ENABLE_MCP_MICROSOFT_LEARN) - Retry-without-MCP fallback - Format-reminder content string (declared as module constant) - Phase 0.5 telemetry (mcp.turn, mcp.fallback) Documents in the module docstring AND at the function site that this is the ONE MCP/beta chat caller and should not become the general provider path. MCP/beta/images are features of exactly one optional Anthropic beta endpoint; routing them through AnthropicProvider would leak a provider- specific concern into the abstract interface that also serves Gemini. Behavior change: chat_call_cached now reuses the singleton AnthropicProvider HTTP client via `_get_anthropic_client(...)` instead of instantiating a new `anthropic.AsyncAnthropic(...)` per call. Matches the provider's own pattern and avoids burning connections per-turn. No user-visible difference. No runtime verification from code-server. TODO(phase0-verify) in ai_provider.py tracks the cache-hit verification owed on the new dev env. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
465 lines
16 KiB
Python
465 lines
16 KiB
Python
"""
|
||
AI Provider abstraction layer.
|
||
|
||
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
|
||
backends for JSON generation used by the AI Flow Builder.
|
||
|
||
## Prompt caching (Anthropic only)
|
||
|
||
Callers may pass `system_prompt` as either:
|
||
|
||
- `str` — backward-compatible, uncached.
|
||
- `list[SystemBlock]` — Anthropic structured system blocks. Each block is a
|
||
dict of shape `{"type": "text", "text": str, "cache_control": {...}?}`.
|
||
|
||
Caching policy (policy α, per Phase 0.1 design):
|
||
- If any block in the list carries an explicit `cache_control` key, that
|
||
caller-authored configuration is honored verbatim.
|
||
- If no block carries `cache_control`, the provider applies
|
||
`cache_control: {"type": "ephemeral"}` to the first block only. First block
|
||
is the common "large static prefix" case (e.g. system prompt, reference data).
|
||
|
||
Gemini ignores cache_control and concatenates list blocks into one system
|
||
string — callers should not rely on Gemini for cache-hit behavior.
|
||
|
||
TODO(phase0-verify): When a dev environment is available, verify cache-hit
|
||
behavior by hitting any FlowPilot endpoint twice within the 5-minute
|
||
ephemeral TTL. First call should emit `anthropic.cache` with
|
||
`cache_creation_input_tokens > 0`; second call with `cache_read_input_tokens > 0`.
|
||
If the second call returns zero reads, inspect the prefix for silent
|
||
invalidators (timestamps, unsorted JSON keys, varying tool list ordering).
|
||
"""
|
||
|
||
import logging
|
||
from abc import ABC, abstractmethod
|
||
from collections.abc import AsyncIterator
|
||
from typing import Any
|
||
|
||
from app.core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Anthropic structured system block. See module docstring for caching policy.
|
||
SystemBlock = dict[str, Any]
|
||
|
||
|
||
def _normalize_system_for_anthropic(
|
||
system_prompt: str | list[SystemBlock],
|
||
) -> str | list[SystemBlock]:
|
||
"""Return the value to pass as the `system=` parameter to the Anthropic API.
|
||
|
||
- Plain strings pass through untouched (uncached path).
|
||
- Lists are returned as structured system blocks. If no block in the list
|
||
carries an explicit `cache_control`, `cache_control: {"type": "ephemeral"}`
|
||
is applied to the FIRST block only (policy α).
|
||
- Caller-authored `cache_control` is never overwritten.
|
||
"""
|
||
if isinstance(system_prompt, str):
|
||
return system_prompt
|
||
|
||
if not system_prompt:
|
||
# Empty list is not a meaningful system prompt — pass empty string so
|
||
# Anthropic treats this as "no system prompt" rather than erroring.
|
||
return ""
|
||
|
||
blocks = [dict(b) for b in system_prompt]
|
||
already_cached = any("cache_control" in b for b in blocks)
|
||
|
||
if not already_cached:
|
||
blocks[0]["cache_control"] = {"type": "ephemeral"}
|
||
|
||
return blocks
|
||
|
||
|
||
def _flatten_system_for_gemini(
|
||
system_prompt: str | list[SystemBlock],
|
||
) -> str:
|
||
"""Gemini has no structured system blocks; concatenate list entries."""
|
||
if isinstance(system_prompt, str):
|
||
return system_prompt
|
||
return "\n\n".join(b.get("text", "") for b in system_prompt)
|
||
|
||
|
||
def build_anthropic_chat_messages(
|
||
history: list[dict[str, Any]],
|
||
new_message: str,
|
||
images: list[dict[str, Any]] | None = None,
|
||
format_reminder: str | None = None,
|
||
) -> list[dict[str, Any]]:
|
||
"""Construct the Anthropic `messages` payload for a cached multi-turn chat.
|
||
|
||
Responsibilities:
|
||
- Copy the valid history messages in order.
|
||
- Apply `cache_control: ephemeral` to the LAST history message so the entire
|
||
conversation prefix is cached across turns. The new user message stays
|
||
uncached (it changes each turn).
|
||
- Append `format_reminder` to the new user message if provided. The reminder
|
||
is invisible to storage (caller's concern) but helps enforce structured
|
||
output compliance at generation time.
|
||
- If `images` are provided, render the new user message as a multimodal
|
||
content block list (images first, then text). Otherwise, render it as
|
||
a plain string.
|
||
|
||
This helper is Anthropic-specific: the cache-breakpoint pattern, ephemeral
|
||
cache_control, and multimodal block shape are all Anthropic conventions.
|
||
Do not call it from Gemini code paths.
|
||
"""
|
||
messages: list[dict[str, Any]] = []
|
||
for msg in history:
|
||
messages.append({"role": msg["role"], "content": msg["content"]})
|
||
|
||
# Cache breakpoint on the last existing history message so the entire
|
||
# conversation prefix is cached across turns. Safe only when there IS a
|
||
# history message; otherwise the new message is the only message.
|
||
if messages:
|
||
last = messages[-1]
|
||
messages[-1] = {
|
||
"role": last["role"],
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": last["content"],
|
||
"cache_control": {"type": "ephemeral"},
|
||
}
|
||
],
|
||
}
|
||
|
||
effective_text = new_message + (format_reminder or "")
|
||
|
||
if images:
|
||
content_blocks: list[dict[str, Any]] = []
|
||
for img in images:
|
||
content_blocks.append(
|
||
{
|
||
"type": "image",
|
||
"source": {
|
||
"type": "base64",
|
||
"media_type": img["media_type"],
|
||
"data": img["data"],
|
||
},
|
||
}
|
||
)
|
||
content_blocks.append({"type": "text", "text": effective_text})
|
||
messages.append({"role": "user", "content": content_blocks})
|
||
else:
|
||
messages.append({"role": "user", "content": effective_text})
|
||
|
||
return messages
|
||
|
||
|
||
def _log_anthropic_cache_usage(usage: Any, model: str) -> None:
|
||
"""Emit a structured log line capturing cache_read / cache_creation tokens."""
|
||
cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
|
||
cache_creation = getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||
input_tokens = getattr(usage, "input_tokens", 0) or 0
|
||
output_tokens = getattr(usage, "output_tokens", 0) or 0
|
||
if cache_read or cache_creation:
|
||
logger.info(
|
||
"anthropic.cache",
|
||
extra={
|
||
"event": "anthropic.cache",
|
||
"model": model,
|
||
"cache_read_input_tokens": cache_read,
|
||
"cache_creation_input_tokens": cache_creation,
|
||
"input_tokens": input_tokens,
|
||
"output_tokens": output_tokens,
|
||
},
|
||
)
|
||
|
||
|
||
class AIProvider(ABC):
|
||
"""Abstract base class for AI providers."""
|
||
|
||
@abstractmethod
|
||
async def generate_json(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
"""Generate a JSON response from the AI model.
|
||
|
||
Args:
|
||
system_prompt: System-level instruction. Plain `str` is uncached
|
||
(Anthropic) or used as-is (Gemini). `list[SystemBlock]` enables
|
||
Anthropic prompt caching per module-docstring policy.
|
||
messages: List of message dicts with "role" and "content" keys.
|
||
max_tokens: Maximum output tokens.
|
||
|
||
Returns:
|
||
Tuple of (response_text, input_tokens, output_tokens).
|
||
"""
|
||
...
|
||
|
||
@abstractmethod
|
||
async def generate_text(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
"""Generate a text response from the AI model (no JSON constraint).
|
||
|
||
See `generate_json` for argument semantics.
|
||
"""
|
||
...
|
||
|
||
async def generate_text_stream(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> "AsyncIterator[str]":
|
||
"""Stream a text response token by token.
|
||
|
||
See `generate_json` for argument semantics.
|
||
"""
|
||
raise NotImplementedError("Streaming not supported for this provider")
|
||
# Make this an async generator to satisfy type checker
|
||
yield "" # pragma: no cover
|
||
|
||
|
||
class GeminiProvider(AIProvider):
|
||
"""Google Gemini provider using the google-genai SDK."""
|
||
|
||
def __init__(self, api_key: str, model: str) -> None:
|
||
self._api_key = api_key
|
||
self._model = model
|
||
|
||
async def generate_json(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
from google import genai
|
||
from google.genai import types as genai_types
|
||
|
||
client = genai.Client(api_key=self._api_key)
|
||
system_text = _flatten_system_for_gemini(system_prompt)
|
||
|
||
# Convert messages to Gemini Content format
|
||
contents: list[genai_types.Content] = []
|
||
for msg in messages:
|
||
role = "model" if msg["role"] == "assistant" else "user"
|
||
contents.append(
|
||
genai_types.Content(
|
||
role=role,
|
||
parts=[genai_types.Part(text=msg["content"])],
|
||
)
|
||
)
|
||
|
||
config = genai_types.GenerateContentConfig(
|
||
system_instruction=system_text,
|
||
max_output_tokens=max_tokens,
|
||
response_mime_type="application/json",
|
||
)
|
||
|
||
response = await client.aio.models.generate_content(
|
||
model=self._model,
|
||
contents=contents,
|
||
config=config,
|
||
)
|
||
|
||
# Log finish reason to detect truncation
|
||
if response.candidates:
|
||
finish_reason = getattr(response.candidates[0], "finish_reason", None)
|
||
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
|
||
if str(finish_reason) == "MAX_TOKENS":
|
||
logger.warning(
|
||
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
|
||
max_tokens,
|
||
)
|
||
|
||
text = response.text or ""
|
||
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
||
output_tokens = (
|
||
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
||
)
|
||
|
||
return text, input_tokens, output_tokens
|
||
|
||
async def generate_text(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
from google import genai
|
||
from google.genai import types as genai_types
|
||
|
||
client = genai.Client(api_key=self._api_key)
|
||
system_text = _flatten_system_for_gemini(system_prompt)
|
||
|
||
contents: list[genai_types.Content] = []
|
||
for msg in messages:
|
||
role = "model" if msg["role"] == "assistant" else "user"
|
||
contents.append(
|
||
genai_types.Content(
|
||
role=role,
|
||
parts=[genai_types.Part(text=msg["content"])],
|
||
)
|
||
)
|
||
|
||
config = genai_types.GenerateContentConfig(
|
||
system_instruction=system_text,
|
||
max_output_tokens=max_tokens,
|
||
# No response_mime_type — allow free-form text
|
||
)
|
||
|
||
response = await client.aio.models.generate_content(
|
||
model=self._model,
|
||
contents=contents,
|
||
config=config,
|
||
)
|
||
|
||
if response.candidates:
|
||
finish_reason = getattr(response.candidates[0], "finish_reason", None)
|
||
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
|
||
if str(finish_reason) == "MAX_TOKENS":
|
||
logger.warning(
|
||
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
|
||
max_tokens,
|
||
)
|
||
|
||
text = response.text or ""
|
||
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
||
output_tokens = (
|
||
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
||
)
|
||
|
||
return text, input_tokens, output_tokens
|
||
|
||
|
||
# Singleton client cache — avoids creating new HTTP connections per call
|
||
_anthropic_clients: dict[str, "anthropic.AsyncAnthropic"] = {}
|
||
|
||
|
||
def _get_anthropic_client(api_key: str, timeout: int = 45) -> "anthropic.AsyncAnthropic":
|
||
"""Return a cached AsyncAnthropic client, creating one if needed."""
|
||
import anthropic
|
||
|
||
cache_key = f"{api_key[:8]}:{timeout}"
|
||
if cache_key not in _anthropic_clients:
|
||
_anthropic_clients[cache_key] = anthropic.AsyncAnthropic(
|
||
api_key=api_key,
|
||
timeout=timeout,
|
||
max_retries=1,
|
||
)
|
||
return _anthropic_clients[cache_key]
|
||
|
||
|
||
class AnthropicProvider(AIProvider):
|
||
"""Anthropic Claude provider using the anthropic SDK."""
|
||
|
||
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
|
||
self._api_key = api_key
|
||
self._model = model
|
||
self._timeout = timeout
|
||
|
||
async def generate_json(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
client = _get_anthropic_client(self._api_key, self._timeout)
|
||
normalized_system = _normalize_system_for_anthropic(system_prompt)
|
||
|
||
response = await client.messages.create(
|
||
model=self._model,
|
||
max_tokens=max_tokens,
|
||
system=normalized_system,
|
||
messages=messages,
|
||
)
|
||
|
||
text = response.content[0].text
|
||
input_tokens = response.usage.input_tokens
|
||
output_tokens = response.usage.output_tokens
|
||
|
||
_log_anthropic_cache_usage(response.usage, self._model)
|
||
|
||
return text, input_tokens, output_tokens
|
||
|
||
async def generate_text(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
# Anthropic doesn't differentiate between JSON and text mode
|
||
return await self.generate_json(system_prompt, messages, max_tokens)
|
||
|
||
async def generate_text_stream(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> AsyncIterator[str]:
|
||
client = _get_anthropic_client(self._api_key, self._timeout)
|
||
normalized_system = _normalize_system_for_anthropic(system_prompt)
|
||
|
||
async with client.messages.stream(
|
||
model=self._model,
|
||
max_tokens=max_tokens,
|
||
system=normalized_system,
|
||
messages=messages,
|
||
) as stream:
|
||
async for text in stream.text_stream:
|
||
yield text
|
||
# Per Anthropic SDK, get_final_message() resolves the stream's
|
||
# final usage object (including cache_read/cache_creation tokens).
|
||
try:
|
||
final = await stream.get_final_message()
|
||
_log_anthropic_cache_usage(final.usage, self._model)
|
||
except Exception as exc: # best-effort telemetry, never fail the stream
|
||
logger.debug("anthropic.cache streaming usage unavailable: %s", exc)
|
||
|
||
|
||
def get_ai_provider(model: str | None = None) -> AIProvider:
|
||
"""Factory that returns the configured AI provider.
|
||
|
||
Args:
|
||
model: Optional model override (Anthropic model ID). Only applied to
|
||
AnthropicProvider; Gemini always uses settings.AI_MODEL_GEMINI.
|
||
|
||
Selection logic:
|
||
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
|
||
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
|
||
3. Fallback: if preferred provider key missing, try the other one
|
||
4. If nothing configured -> raise RuntimeError
|
||
"""
|
||
provider = settings.AI_PROVIDER
|
||
|
||
if provider == "gemini":
|
||
if settings.GOOGLE_AI_API_KEY:
|
||
return GeminiProvider(
|
||
api_key=settings.GOOGLE_AI_API_KEY,
|
||
model=settings.AI_MODEL_GEMINI,
|
||
)
|
||
# Fallback to Anthropic
|
||
if settings.ANTHROPIC_API_KEY:
|
||
return AnthropicProvider(
|
||
api_key=settings.ANTHROPIC_API_KEY,
|
||
model=model or settings.AI_MODEL_ANTHROPIC,
|
||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||
)
|
||
|
||
elif provider == "anthropic":
|
||
if settings.ANTHROPIC_API_KEY:
|
||
return AnthropicProvider(
|
||
api_key=settings.ANTHROPIC_API_KEY,
|
||
model=model or settings.AI_MODEL_ANTHROPIC,
|
||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||
)
|
||
# Fallback to Gemini
|
||
if settings.GOOGLE_AI_API_KEY:
|
||
return GeminiProvider(
|
||
api_key=settings.GOOGLE_AI_API_KEY,
|
||
model=settings.AI_MODEL_GEMINI,
|
||
)
|
||
|
||
raise RuntimeError(
|
||
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
|
||
)
|