Files
resolutionflow/backend/app/core/ai_provider.py
Michael Chihlas 067574ad6a feat(ai): robust response extraction + structured-output foundation
Harden the Anthropic provider and lay the groundwork for schema-constrained
JSON, optimizing the existing claude-sonnet-4-6 / claude-haiku-4-5 usage
(no model changes).

ai_provider.py:
- _extract_text_from_response replaces fragile response.content[0].text:
  skips non-text leading blocks (e.g. thinking), returns the first text
  block, logs an anthropic.stop_reason warning on max_tokens/refusal
  (truncation now observable), and raises ValueError on a no-text response.
- generate_json gains an optional `schema` param. Anthropic wires it to
  output_config.format (structured outputs); schema=None preserves the exact
  prior call for every existing caller. Gemini accepts-and-ignores it.

kb_conversion_service.py:
- TROUBLESHOOTING_SCHEMA / PROCEDURAL_SCHEMA + _schema_for_target_type(),
  modelled as a strict superset of every field the prompts emit.
- convert_document passes the schema only when the new
  AI_KB_CONVERT_STRUCTURED_OUTPUT setting is True (default False). The
  _try_repair_json fallback stays as belt-and-suspenders.

Tests: 14 provider + 7 schema, TDD (red-green). Live constrained-decoding
smoke-test still required before enabling the flag in production.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-28 21:48:49 -04:00

523 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI Provider abstraction layer.
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
backends for JSON generation used by the AI Flow Builder.
## Prompt caching (Anthropic only)
Callers may pass `system_prompt` as either:
- `str` — backward-compatible, uncached.
- `list[SystemBlock]` — Anthropic structured system blocks. Each block is a
dict of shape `{"type": "text", "text": str, "cache_control": {...}?}`.
Caching policy (policy α, per Phase 0.1 design):
- If any block in the list carries an explicit `cache_control` key, that
caller-authored configuration is honored verbatim.
- If no block carries `cache_control`, the provider applies
`cache_control: {"type": "ephemeral"}` to the first block only. First block
is the common "large static prefix" case (e.g. system prompt, reference data).
Gemini ignores cache_control and concatenates list blocks into one system
string — callers should not rely on Gemini for cache-hit behavior.
TODO(phase0-verify): When a dev environment is available, verify cache-hit
behavior by hitting any FlowPilot endpoint twice within the 5-minute
ephemeral TTL. First call should emit `anthropic.cache` with
`cache_creation_input_tokens > 0`; second call with `cache_read_input_tokens > 0`.
If the second call returns zero reads, inspect the prefix for silent
invalidators (timestamps, unsorted JSON keys, varying tool list ordering).
"""
import logging
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
from app.core.config import settings
logger = logging.getLogger(__name__)
# Anthropic structured system block. See module docstring for caching policy.
SystemBlock = dict[str, Any]
def _normalize_system_for_anthropic(
system_prompt: str | list[SystemBlock],
) -> str | list[SystemBlock]:
"""Return the value to pass as the `system=` parameter to the Anthropic API.
- Plain strings pass through untouched (uncached path).
- Lists are returned as structured system blocks. If no block in the list
carries an explicit `cache_control`, `cache_control: {"type": "ephemeral"}`
is applied to the FIRST block only (policy α).
- Caller-authored `cache_control` is never overwritten.
"""
if isinstance(system_prompt, str):
return system_prompt
if not system_prompt:
# Empty list is not a meaningful system prompt — pass empty string so
# Anthropic treats this as "no system prompt" rather than erroring.
return ""
blocks = [dict(b) for b in system_prompt]
already_cached = any("cache_control" in b for b in blocks)
if not already_cached:
blocks[0]["cache_control"] = {"type": "ephemeral"}
return blocks
def _flatten_system_for_gemini(
system_prompt: str | list[SystemBlock],
) -> str:
"""Gemini has no structured system blocks; concatenate list entries."""
if isinstance(system_prompt, str):
return system_prompt
return "\n\n".join(b.get("text", "") for b in system_prompt)
def build_anthropic_chat_messages(
history: list[dict[str, Any]],
new_message: str,
images: list[dict[str, Any]] | None = None,
format_reminder: str | None = None,
) -> list[dict[str, Any]]:
"""Construct the Anthropic `messages` payload for a cached multi-turn chat.
Responsibilities:
- Copy the valid history messages in order.
- Apply `cache_control: ephemeral` to the LAST history message so the entire
conversation prefix is cached across turns. The new user message stays
uncached (it changes each turn).
- Append `format_reminder` to the new user message if provided. The reminder
is invisible to storage (caller's concern) but helps enforce structured
output compliance at generation time.
- If `images` are provided, render the new user message as a multimodal
content block list (images first, then text). Otherwise, render it as
a plain string.
This helper is Anthropic-specific: the cache-breakpoint pattern, ephemeral
cache_control, and multimodal block shape are all Anthropic conventions.
Do not call it from Gemini code paths.
"""
messages: list[dict[str, Any]] = []
for msg in history:
messages.append({"role": msg["role"], "content": msg["content"]})
# Cache breakpoint on the last existing history message so the entire
# conversation prefix is cached across turns. Safe only when there IS a
# history message; otherwise the new message is the only message.
if messages:
last = messages[-1]
messages[-1] = {
"role": last["role"],
"content": [
{
"type": "text",
"text": last["content"],
"cache_control": {"type": "ephemeral"},
}
],
}
effective_text = new_message + (format_reminder or "")
if images:
content_blocks: list[dict[str, Any]] = []
for img in images:
content_blocks.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": img["media_type"],
"data": img["data"],
},
}
)
content_blocks.append({"type": "text", "text": effective_text})
messages.append({"role": "user", "content": content_blocks})
else:
messages.append({"role": "user", "content": effective_text})
return messages
def _extract_text_from_response(response: Any, model: str) -> str:
"""Return the first text block's text from an Anthropic message response.
Robustness over the naive ``response.content[0].text``:
- Skips non-text leading blocks (e.g. ``thinking``) and returns the first
block whose ``type == "text"``. Indexing ``content[0]`` blindly throws or
returns garbage the moment a non-text block leads the response.
- Surfaces truncation/refusal: when ``stop_reason`` is ``max_tokens`` or
``refusal``, emits a structured warning so silent output corruption
(truncated JSON, empty refusals) is observable rather than handed
downstream to be guessed at.
- Raises ``ValueError`` when no text block is present (e.g. a bare refusal)
instead of returning a non-text block's attributes.
"""
stop_reason = getattr(response, "stop_reason", None)
if stop_reason in ("max_tokens", "refusal"):
logger.warning(
"anthropic.stop_reason",
extra={
"event": "anthropic.stop_reason",
"model": model,
"stop_reason": stop_reason,
},
)
for block in response.content:
if getattr(block, "type", None) == "text":
return block.text
raise ValueError(
f"Anthropic response contained no text block (stop_reason={stop_reason!r})"
)
def _log_anthropic_cache_usage(usage: Any, model: str) -> None:
"""Emit a structured log line capturing cache_read / cache_creation tokens."""
cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
cache_creation = getattr(usage, "cache_creation_input_tokens", 0) or 0
input_tokens = getattr(usage, "input_tokens", 0) or 0
output_tokens = getattr(usage, "output_tokens", 0) or 0
if cache_read or cache_creation:
logger.info(
"anthropic.cache",
extra={
"event": "anthropic.cache",
"model": model,
"cache_read_input_tokens": cache_read,
"cache_creation_input_tokens": cache_creation,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
},
)
class AIProvider(ABC):
"""Abstract base class for AI providers."""
@abstractmethod
async def generate_json(
self,
system_prompt: str | list[SystemBlock],
messages: list[dict[str, Any]],
max_tokens: int = 4096,
schema: dict[str, Any] | None = None,
) -> tuple[str, int, int]:
"""Generate a JSON response from the AI model.
Args:
system_prompt: System-level instruction. Plain `str` is uncached
(Anthropic) or used as-is (Gemini). `list[SystemBlock]` enables
Anthropic prompt caching per module-docstring policy.
messages: List of message dicts with "role" and "content" keys.
max_tokens: Maximum output tokens.
schema: Optional JSON Schema constraining the response shape.
When provided, the Anthropic backend uses structured outputs
(`output_config.format`) to guarantee valid, parseable JSON —
no markdown fences, no truncated-brace repair. Must satisfy the
structured-output schema limits (every object needs
`additionalProperties: false`; no recursion; numeric/string
constraints are stripped). `None` preserves the legacy
prompt-only behavior. The Gemini backend currently ignores this
argument (it already requests `application/json`).
Returns:
Tuple of (response_text, input_tokens, output_tokens).
"""
...
@abstractmethod
async def generate_text(
self,
system_prompt: str | list[SystemBlock],
messages: list[dict[str, Any]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Generate a text response from the AI model (no JSON constraint).
See `generate_json` for argument semantics.
"""
...
async def generate_text_stream(
self,
system_prompt: str | list[SystemBlock],
messages: list[dict[str, Any]],
max_tokens: int = 4096,
) -> "AsyncIterator[str]":
"""Stream a text response token by token.
See `generate_json` for argument semantics.
"""
raise NotImplementedError("Streaming not supported for this provider")
# Make this an async generator to satisfy type checker
yield "" # pragma: no cover
class GeminiProvider(AIProvider):
"""Google Gemini provider using the google-genai SDK."""
def __init__(self, api_key: str, model: str) -> None:
self._api_key = api_key
self._model = model
async def generate_json(
self,
system_prompt: str | list[SystemBlock],
messages: list[dict[str, Any]],
max_tokens: int = 4096,
schema: dict[str, Any] | None = None,
) -> tuple[str, int, int]:
# `schema` is accepted for interface parity but ignored: Gemini already
# constrains output via response_mime_type="application/json" below.
# Mapping JSON Schema -> Gemini response_schema is deferred.
from google import genai
from google.genai import types as genai_types
client = genai.Client(api_key=self._api_key)
system_text = _flatten_system_for_gemini(system_prompt)
# Convert messages to Gemini Content format
contents: list[genai_types.Content] = []
for msg in messages:
role = "model" if msg["role"] == "assistant" else "user"
contents.append(
genai_types.Content(
role=role,
parts=[genai_types.Part(text=msg["content"])],
)
)
config = genai_types.GenerateContentConfig(
system_instruction=system_text,
max_output_tokens=max_tokens,
response_mime_type="application/json",
)
response = await client.aio.models.generate_content(
model=self._model,
contents=contents,
config=config,
)
# Log finish reason to detect truncation
if response.candidates:
finish_reason = getattr(response.candidates[0], "finish_reason", None)
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
if str(finish_reason) == "MAX_TOKENS":
logger.warning(
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
max_tokens,
)
text = response.text or ""
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
output_tokens = (
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
)
return text, input_tokens, output_tokens
async def generate_text(
self,
system_prompt: str | list[SystemBlock],
messages: list[dict[str, Any]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
from google import genai
from google.genai import types as genai_types
client = genai.Client(api_key=self._api_key)
system_text = _flatten_system_for_gemini(system_prompt)
contents: list[genai_types.Content] = []
for msg in messages:
role = "model" if msg["role"] == "assistant" else "user"
contents.append(
genai_types.Content(
role=role,
parts=[genai_types.Part(text=msg["content"])],
)
)
config = genai_types.GenerateContentConfig(
system_instruction=system_text,
max_output_tokens=max_tokens,
# No response_mime_type — allow free-form text
)
response = await client.aio.models.generate_content(
model=self._model,
contents=contents,
config=config,
)
if response.candidates:
finish_reason = getattr(response.candidates[0], "finish_reason", None)
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
if str(finish_reason) == "MAX_TOKENS":
logger.warning(
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
max_tokens,
)
text = response.text or ""
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
output_tokens = (
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
)
return text, input_tokens, output_tokens
# Singleton client cache — avoids creating new HTTP connections per call
_anthropic_clients: dict[str, "anthropic.AsyncAnthropic"] = {}
def _get_anthropic_client(api_key: str, timeout: int = 45) -> "anthropic.AsyncAnthropic":
"""Return a cached AsyncAnthropic client, creating one if needed."""
import anthropic
cache_key = f"{api_key[:8]}:{timeout}"
if cache_key not in _anthropic_clients:
_anthropic_clients[cache_key] = anthropic.AsyncAnthropic(
api_key=api_key,
timeout=timeout,
max_retries=1,
)
return _anthropic_clients[cache_key]
class AnthropicProvider(AIProvider):
"""Anthropic Claude provider using the anthropic SDK."""
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
self._api_key = api_key
self._model = model
self._timeout = timeout
async def generate_json(
self,
system_prompt: str | list[SystemBlock],
messages: list[dict[str, Any]],
max_tokens: int = 4096,
schema: dict[str, Any] | None = None,
) -> tuple[str, int, int]:
client = _get_anthropic_client(self._api_key, self._timeout)
normalized_system = _normalize_system_for_anthropic(system_prompt)
create_kwargs: dict[str, Any] = {
"model": self._model,
"max_tokens": max_tokens,
"system": normalized_system,
"messages": messages,
}
if schema is not None:
# Structured outputs: constrain the response to valid JSON matching
# the schema (Sonnet 4.6 / Haiku 4.5). Removes the need for
# markdown-fence stripping and truncated-JSON repair downstream.
create_kwargs["output_config"] = {
"format": {"type": "json_schema", "schema": schema}
}
response = await client.messages.create(**create_kwargs)
text = _extract_text_from_response(response, self._model)
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
_log_anthropic_cache_usage(response.usage, self._model)
return text, input_tokens, output_tokens
async def generate_text(
self,
system_prompt: str | list[SystemBlock],
messages: list[dict[str, Any]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
# Anthropic doesn't differentiate between JSON and text mode
return await self.generate_json(system_prompt, messages, max_tokens)
async def generate_text_stream(
self,
system_prompt: str | list[SystemBlock],
messages: list[dict[str, Any]],
max_tokens: int = 4096,
) -> AsyncIterator[str]:
client = _get_anthropic_client(self._api_key, self._timeout)
normalized_system = _normalize_system_for_anthropic(system_prompt)
async with client.messages.stream(
model=self._model,
max_tokens=max_tokens,
system=normalized_system,
messages=messages,
) as stream:
async for text in stream.text_stream:
yield text
# Per Anthropic SDK, get_final_message() resolves the stream's
# final usage object (including cache_read/cache_creation tokens).
try:
final = await stream.get_final_message()
_log_anthropic_cache_usage(final.usage, self._model)
except Exception as exc: # best-effort telemetry, never fail the stream
logger.debug("anthropic.cache streaming usage unavailable: %s", exc)
def get_ai_provider(model: str | None = None) -> AIProvider:
"""Factory that returns the configured AI provider.
Args:
model: Optional model override (Anthropic model ID). Only applied to
AnthropicProvider; Gemini always uses settings.AI_MODEL_GEMINI.
Selection logic:
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
3. Fallback: if preferred provider key missing, try the other one
4. If nothing configured -> raise RuntimeError
"""
provider = settings.AI_PROVIDER
if provider == "gemini":
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
# Fallback to Anthropic
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=model or settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
elif provider == "anthropic":
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=model or settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
# Fallback to Gemini
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
raise RuntimeError(
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
)