Harden the Anthropic provider and lay the groundwork for schema-constrained JSON, optimizing the existing claude-sonnet-4-6 / claude-haiku-4-5 usage (no model changes). ai_provider.py: - _extract_text_from_response replaces fragile response.content[0].text: skips non-text leading blocks (e.g. thinking), returns the first text block, logs an anthropic.stop_reason warning on max_tokens/refusal (truncation now observable), and raises ValueError on a no-text response. - generate_json gains an optional `schema` param. Anthropic wires it to output_config.format (structured outputs); schema=None preserves the exact prior call for every existing caller. Gemini accepts-and-ignores it. kb_conversion_service.py: - TROUBLESHOOTING_SCHEMA / PROCEDURAL_SCHEMA + _schema_for_target_type(), modelled as a strict superset of every field the prompts emit. - convert_document passes the schema only when the new AI_KB_CONVERT_STRUCTURED_OUTPUT setting is True (default False). The _try_repair_json fallback stays as belt-and-suspenders. Tests: 14 provider + 7 schema, TDD (red-green). Live constrained-decoding smoke-test still required before enabling the flag in production. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
523 lines
19 KiB
Python
523 lines
19 KiB
Python
"""
|
||
AI Provider abstraction layer.
|
||
|
||
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
|
||
backends for JSON generation used by the AI Flow Builder.
|
||
|
||
## Prompt caching (Anthropic only)
|
||
|
||
Callers may pass `system_prompt` as either:
|
||
|
||
- `str` — backward-compatible, uncached.
|
||
- `list[SystemBlock]` — Anthropic structured system blocks. Each block is a
|
||
dict of shape `{"type": "text", "text": str, "cache_control": {...}?}`.
|
||
|
||
Caching policy (policy α, per Phase 0.1 design):
|
||
- If any block in the list carries an explicit `cache_control` key, that
|
||
caller-authored configuration is honored verbatim.
|
||
- If no block carries `cache_control`, the provider applies
|
||
`cache_control: {"type": "ephemeral"}` to the first block only. First block
|
||
is the common "large static prefix" case (e.g. system prompt, reference data).
|
||
|
||
Gemini ignores cache_control and concatenates list blocks into one system
|
||
string — callers should not rely on Gemini for cache-hit behavior.
|
||
|
||
TODO(phase0-verify): When a dev environment is available, verify cache-hit
|
||
behavior by hitting any FlowPilot endpoint twice within the 5-minute
|
||
ephemeral TTL. First call should emit `anthropic.cache` with
|
||
`cache_creation_input_tokens > 0`; second call with `cache_read_input_tokens > 0`.
|
||
If the second call returns zero reads, inspect the prefix for silent
|
||
invalidators (timestamps, unsorted JSON keys, varying tool list ordering).
|
||
"""
|
||
|
||
import logging
|
||
from abc import ABC, abstractmethod
|
||
from collections.abc import AsyncIterator
|
||
from typing import Any
|
||
|
||
from app.core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Anthropic structured system block. See module docstring for caching policy.
|
||
SystemBlock = dict[str, Any]
|
||
|
||
|
||
def _normalize_system_for_anthropic(
|
||
system_prompt: str | list[SystemBlock],
|
||
) -> str | list[SystemBlock]:
|
||
"""Return the value to pass as the `system=` parameter to the Anthropic API.
|
||
|
||
- Plain strings pass through untouched (uncached path).
|
||
- Lists are returned as structured system blocks. If no block in the list
|
||
carries an explicit `cache_control`, `cache_control: {"type": "ephemeral"}`
|
||
is applied to the FIRST block only (policy α).
|
||
- Caller-authored `cache_control` is never overwritten.
|
||
"""
|
||
if isinstance(system_prompt, str):
|
||
return system_prompt
|
||
|
||
if not system_prompt:
|
||
# Empty list is not a meaningful system prompt — pass empty string so
|
||
# Anthropic treats this as "no system prompt" rather than erroring.
|
||
return ""
|
||
|
||
blocks = [dict(b) for b in system_prompt]
|
||
already_cached = any("cache_control" in b for b in blocks)
|
||
|
||
if not already_cached:
|
||
blocks[0]["cache_control"] = {"type": "ephemeral"}
|
||
|
||
return blocks
|
||
|
||
|
||
def _flatten_system_for_gemini(
|
||
system_prompt: str | list[SystemBlock],
|
||
) -> str:
|
||
"""Gemini has no structured system blocks; concatenate list entries."""
|
||
if isinstance(system_prompt, str):
|
||
return system_prompt
|
||
return "\n\n".join(b.get("text", "") for b in system_prompt)
|
||
|
||
|
||
def build_anthropic_chat_messages(
|
||
history: list[dict[str, Any]],
|
||
new_message: str,
|
||
images: list[dict[str, Any]] | None = None,
|
||
format_reminder: str | None = None,
|
||
) -> list[dict[str, Any]]:
|
||
"""Construct the Anthropic `messages` payload for a cached multi-turn chat.
|
||
|
||
Responsibilities:
|
||
- Copy the valid history messages in order.
|
||
- Apply `cache_control: ephemeral` to the LAST history message so the entire
|
||
conversation prefix is cached across turns. The new user message stays
|
||
uncached (it changes each turn).
|
||
- Append `format_reminder` to the new user message if provided. The reminder
|
||
is invisible to storage (caller's concern) but helps enforce structured
|
||
output compliance at generation time.
|
||
- If `images` are provided, render the new user message as a multimodal
|
||
content block list (images first, then text). Otherwise, render it as
|
||
a plain string.
|
||
|
||
This helper is Anthropic-specific: the cache-breakpoint pattern, ephemeral
|
||
cache_control, and multimodal block shape are all Anthropic conventions.
|
||
Do not call it from Gemini code paths.
|
||
"""
|
||
messages: list[dict[str, Any]] = []
|
||
for msg in history:
|
||
messages.append({"role": msg["role"], "content": msg["content"]})
|
||
|
||
# Cache breakpoint on the last existing history message so the entire
|
||
# conversation prefix is cached across turns. Safe only when there IS a
|
||
# history message; otherwise the new message is the only message.
|
||
if messages:
|
||
last = messages[-1]
|
||
messages[-1] = {
|
||
"role": last["role"],
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": last["content"],
|
||
"cache_control": {"type": "ephemeral"},
|
||
}
|
||
],
|
||
}
|
||
|
||
effective_text = new_message + (format_reminder or "")
|
||
|
||
if images:
|
||
content_blocks: list[dict[str, Any]] = []
|
||
for img in images:
|
||
content_blocks.append(
|
||
{
|
||
"type": "image",
|
||
"source": {
|
||
"type": "base64",
|
||
"media_type": img["media_type"],
|
||
"data": img["data"],
|
||
},
|
||
}
|
||
)
|
||
content_blocks.append({"type": "text", "text": effective_text})
|
||
messages.append({"role": "user", "content": content_blocks})
|
||
else:
|
||
messages.append({"role": "user", "content": effective_text})
|
||
|
||
return messages
|
||
|
||
|
||
def _extract_text_from_response(response: Any, model: str) -> str:
|
||
"""Return the first text block's text from an Anthropic message response.
|
||
|
||
Robustness over the naive ``response.content[0].text``:
|
||
- Skips non-text leading blocks (e.g. ``thinking``) and returns the first
|
||
block whose ``type == "text"``. Indexing ``content[0]`` blindly throws or
|
||
returns garbage the moment a non-text block leads the response.
|
||
- Surfaces truncation/refusal: when ``stop_reason`` is ``max_tokens`` or
|
||
``refusal``, emits a structured warning so silent output corruption
|
||
(truncated JSON, empty refusals) is observable rather than handed
|
||
downstream to be guessed at.
|
||
- Raises ``ValueError`` when no text block is present (e.g. a bare refusal)
|
||
instead of returning a non-text block's attributes.
|
||
"""
|
||
stop_reason = getattr(response, "stop_reason", None)
|
||
if stop_reason in ("max_tokens", "refusal"):
|
||
logger.warning(
|
||
"anthropic.stop_reason",
|
||
extra={
|
||
"event": "anthropic.stop_reason",
|
||
"model": model,
|
||
"stop_reason": stop_reason,
|
||
},
|
||
)
|
||
|
||
for block in response.content:
|
||
if getattr(block, "type", None) == "text":
|
||
return block.text
|
||
|
||
raise ValueError(
|
||
f"Anthropic response contained no text block (stop_reason={stop_reason!r})"
|
||
)
|
||
|
||
|
||
def _log_anthropic_cache_usage(usage: Any, model: str) -> None:
|
||
"""Emit a structured log line capturing cache_read / cache_creation tokens."""
|
||
cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
|
||
cache_creation = getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||
input_tokens = getattr(usage, "input_tokens", 0) or 0
|
||
output_tokens = getattr(usage, "output_tokens", 0) or 0
|
||
if cache_read or cache_creation:
|
||
logger.info(
|
||
"anthropic.cache",
|
||
extra={
|
||
"event": "anthropic.cache",
|
||
"model": model,
|
||
"cache_read_input_tokens": cache_read,
|
||
"cache_creation_input_tokens": cache_creation,
|
||
"input_tokens": input_tokens,
|
||
"output_tokens": output_tokens,
|
||
},
|
||
)
|
||
|
||
|
||
class AIProvider(ABC):
|
||
"""Abstract base class for AI providers."""
|
||
|
||
@abstractmethod
|
||
async def generate_json(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
schema: dict[str, Any] | None = None,
|
||
) -> tuple[str, int, int]:
|
||
"""Generate a JSON response from the AI model.
|
||
|
||
Args:
|
||
system_prompt: System-level instruction. Plain `str` is uncached
|
||
(Anthropic) or used as-is (Gemini). `list[SystemBlock]` enables
|
||
Anthropic prompt caching per module-docstring policy.
|
||
messages: List of message dicts with "role" and "content" keys.
|
||
max_tokens: Maximum output tokens.
|
||
schema: Optional JSON Schema constraining the response shape.
|
||
When provided, the Anthropic backend uses structured outputs
|
||
(`output_config.format`) to guarantee valid, parseable JSON —
|
||
no markdown fences, no truncated-brace repair. Must satisfy the
|
||
structured-output schema limits (every object needs
|
||
`additionalProperties: false`; no recursion; numeric/string
|
||
constraints are stripped). `None` preserves the legacy
|
||
prompt-only behavior. The Gemini backend currently ignores this
|
||
argument (it already requests `application/json`).
|
||
|
||
Returns:
|
||
Tuple of (response_text, input_tokens, output_tokens).
|
||
"""
|
||
...
|
||
|
||
@abstractmethod
|
||
async def generate_text(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
"""Generate a text response from the AI model (no JSON constraint).
|
||
|
||
See `generate_json` for argument semantics.
|
||
"""
|
||
...
|
||
|
||
async def generate_text_stream(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> "AsyncIterator[str]":
|
||
"""Stream a text response token by token.
|
||
|
||
See `generate_json` for argument semantics.
|
||
"""
|
||
raise NotImplementedError("Streaming not supported for this provider")
|
||
# Make this an async generator to satisfy type checker
|
||
yield "" # pragma: no cover
|
||
|
||
|
||
class GeminiProvider(AIProvider):
|
||
"""Google Gemini provider using the google-genai SDK."""
|
||
|
||
def __init__(self, api_key: str, model: str) -> None:
|
||
self._api_key = api_key
|
||
self._model = model
|
||
|
||
async def generate_json(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
schema: dict[str, Any] | None = None,
|
||
) -> tuple[str, int, int]:
|
||
# `schema` is accepted for interface parity but ignored: Gemini already
|
||
# constrains output via response_mime_type="application/json" below.
|
||
# Mapping JSON Schema -> Gemini response_schema is deferred.
|
||
from google import genai
|
||
from google.genai import types as genai_types
|
||
|
||
client = genai.Client(api_key=self._api_key)
|
||
system_text = _flatten_system_for_gemini(system_prompt)
|
||
|
||
# Convert messages to Gemini Content format
|
||
contents: list[genai_types.Content] = []
|
||
for msg in messages:
|
||
role = "model" if msg["role"] == "assistant" else "user"
|
||
contents.append(
|
||
genai_types.Content(
|
||
role=role,
|
||
parts=[genai_types.Part(text=msg["content"])],
|
||
)
|
||
)
|
||
|
||
config = genai_types.GenerateContentConfig(
|
||
system_instruction=system_text,
|
||
max_output_tokens=max_tokens,
|
||
response_mime_type="application/json",
|
||
)
|
||
|
||
response = await client.aio.models.generate_content(
|
||
model=self._model,
|
||
contents=contents,
|
||
config=config,
|
||
)
|
||
|
||
# Log finish reason to detect truncation
|
||
if response.candidates:
|
||
finish_reason = getattr(response.candidates[0], "finish_reason", None)
|
||
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
|
||
if str(finish_reason) == "MAX_TOKENS":
|
||
logger.warning(
|
||
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
|
||
max_tokens,
|
||
)
|
||
|
||
text = response.text or ""
|
||
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
||
output_tokens = (
|
||
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
||
)
|
||
|
||
return text, input_tokens, output_tokens
|
||
|
||
async def generate_text(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
from google import genai
|
||
from google.genai import types as genai_types
|
||
|
||
client = genai.Client(api_key=self._api_key)
|
||
system_text = _flatten_system_for_gemini(system_prompt)
|
||
|
||
contents: list[genai_types.Content] = []
|
||
for msg in messages:
|
||
role = "model" if msg["role"] == "assistant" else "user"
|
||
contents.append(
|
||
genai_types.Content(
|
||
role=role,
|
||
parts=[genai_types.Part(text=msg["content"])],
|
||
)
|
||
)
|
||
|
||
config = genai_types.GenerateContentConfig(
|
||
system_instruction=system_text,
|
||
max_output_tokens=max_tokens,
|
||
# No response_mime_type — allow free-form text
|
||
)
|
||
|
||
response = await client.aio.models.generate_content(
|
||
model=self._model,
|
||
contents=contents,
|
||
config=config,
|
||
)
|
||
|
||
if response.candidates:
|
||
finish_reason = getattr(response.candidates[0], "finish_reason", None)
|
||
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
|
||
if str(finish_reason) == "MAX_TOKENS":
|
||
logger.warning(
|
||
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
|
||
max_tokens,
|
||
)
|
||
|
||
text = response.text or ""
|
||
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
||
output_tokens = (
|
||
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
||
)
|
||
|
||
return text, input_tokens, output_tokens
|
||
|
||
|
||
# Singleton client cache — avoids creating new HTTP connections per call
|
||
_anthropic_clients: dict[str, "anthropic.AsyncAnthropic"] = {}
|
||
|
||
|
||
def _get_anthropic_client(api_key: str, timeout: int = 45) -> "anthropic.AsyncAnthropic":
|
||
"""Return a cached AsyncAnthropic client, creating one if needed."""
|
||
import anthropic
|
||
|
||
cache_key = f"{api_key[:8]}:{timeout}"
|
||
if cache_key not in _anthropic_clients:
|
||
_anthropic_clients[cache_key] = anthropic.AsyncAnthropic(
|
||
api_key=api_key,
|
||
timeout=timeout,
|
||
max_retries=1,
|
||
)
|
||
return _anthropic_clients[cache_key]
|
||
|
||
|
||
class AnthropicProvider(AIProvider):
|
||
"""Anthropic Claude provider using the anthropic SDK."""
|
||
|
||
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
|
||
self._api_key = api_key
|
||
self._model = model
|
||
self._timeout = timeout
|
||
|
||
async def generate_json(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
schema: dict[str, Any] | None = None,
|
||
) -> tuple[str, int, int]:
|
||
client = _get_anthropic_client(self._api_key, self._timeout)
|
||
normalized_system = _normalize_system_for_anthropic(system_prompt)
|
||
|
||
create_kwargs: dict[str, Any] = {
|
||
"model": self._model,
|
||
"max_tokens": max_tokens,
|
||
"system": normalized_system,
|
||
"messages": messages,
|
||
}
|
||
if schema is not None:
|
||
# Structured outputs: constrain the response to valid JSON matching
|
||
# the schema (Sonnet 4.6 / Haiku 4.5). Removes the need for
|
||
# markdown-fence stripping and truncated-JSON repair downstream.
|
||
create_kwargs["output_config"] = {
|
||
"format": {"type": "json_schema", "schema": schema}
|
||
}
|
||
|
||
response = await client.messages.create(**create_kwargs)
|
||
|
||
text = _extract_text_from_response(response, self._model)
|
||
input_tokens = response.usage.input_tokens
|
||
output_tokens = response.usage.output_tokens
|
||
|
||
_log_anthropic_cache_usage(response.usage, self._model)
|
||
|
||
return text, input_tokens, output_tokens
|
||
|
||
async def generate_text(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> tuple[str, int, int]:
|
||
# Anthropic doesn't differentiate between JSON and text mode
|
||
return await self.generate_json(system_prompt, messages, max_tokens)
|
||
|
||
async def generate_text_stream(
|
||
self,
|
||
system_prompt: str | list[SystemBlock],
|
||
messages: list[dict[str, Any]],
|
||
max_tokens: int = 4096,
|
||
) -> AsyncIterator[str]:
|
||
client = _get_anthropic_client(self._api_key, self._timeout)
|
||
normalized_system = _normalize_system_for_anthropic(system_prompt)
|
||
|
||
async with client.messages.stream(
|
||
model=self._model,
|
||
max_tokens=max_tokens,
|
||
system=normalized_system,
|
||
messages=messages,
|
||
) as stream:
|
||
async for text in stream.text_stream:
|
||
yield text
|
||
# Per Anthropic SDK, get_final_message() resolves the stream's
|
||
# final usage object (including cache_read/cache_creation tokens).
|
||
try:
|
||
final = await stream.get_final_message()
|
||
_log_anthropic_cache_usage(final.usage, self._model)
|
||
except Exception as exc: # best-effort telemetry, never fail the stream
|
||
logger.debug("anthropic.cache streaming usage unavailable: %s", exc)
|
||
|
||
|
||
def get_ai_provider(model: str | None = None) -> AIProvider:
|
||
"""Factory that returns the configured AI provider.
|
||
|
||
Args:
|
||
model: Optional model override (Anthropic model ID). Only applied to
|
||
AnthropicProvider; Gemini always uses settings.AI_MODEL_GEMINI.
|
||
|
||
Selection logic:
|
||
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
|
||
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
|
||
3. Fallback: if preferred provider key missing, try the other one
|
||
4. If nothing configured -> raise RuntimeError
|
||
"""
|
||
provider = settings.AI_PROVIDER
|
||
|
||
if provider == "gemini":
|
||
if settings.GOOGLE_AI_API_KEY:
|
||
return GeminiProvider(
|
||
api_key=settings.GOOGLE_AI_API_KEY,
|
||
model=settings.AI_MODEL_GEMINI,
|
||
)
|
||
# Fallback to Anthropic
|
||
if settings.ANTHROPIC_API_KEY:
|
||
return AnthropicProvider(
|
||
api_key=settings.ANTHROPIC_API_KEY,
|
||
model=model or settings.AI_MODEL_ANTHROPIC,
|
||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||
)
|
||
|
||
elif provider == "anthropic":
|
||
if settings.ANTHROPIC_API_KEY:
|
||
return AnthropicProvider(
|
||
api_key=settings.ANTHROPIC_API_KEY,
|
||
model=model or settings.AI_MODEL_ANTHROPIC,
|
||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||
)
|
||
# Fallback to Gemini
|
||
if settings.GOOGLE_AI_API_KEY:
|
||
return GeminiProvider(
|
||
api_key=settings.GOOGLE_AI_API_KEY,
|
||
model=settings.AI_MODEL_GEMINI,
|
||
)
|
||
|
||
raise RuntimeError(
|
||
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
|
||
)
|