Harden the Anthropic provider and lay the groundwork for schema-constrained JSON, optimizing the existing claude-sonnet-4-6 / claude-haiku-4-5 usage (no model changes). ai_provider.py: - _extract_text_from_response replaces fragile response.content[0].text: skips non-text leading blocks (e.g. thinking), returns the first text block, logs an anthropic.stop_reason warning on max_tokens/refusal (truncation now observable), and raises ValueError on a no-text response. - generate_json gains an optional `schema` param. Anthropic wires it to output_config.format (structured outputs); schema=None preserves the exact prior call for every existing caller. Gemini accepts-and-ignores it. kb_conversion_service.py: - TROUBLESHOOTING_SCHEMA / PROCEDURAL_SCHEMA + _schema_for_target_type(), modelled as a strict superset of every field the prompts emit. - convert_document passes the schema only when the new AI_KB_CONVERT_STRUCTURED_OUTPUT setting is True (default False). The _try_repair_json fallback stays as belt-and-suspenders. Tests: 14 provider + 7 schema, TDD (red-green). Live constrained-decoding smoke-test still required before enabling the flag in production. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
424 lines
16 KiB
Python
424 lines
16 KiB
Python
"""Tests for the AI provider abstraction layer."""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
import sys
|
|
|
|
from app.core.ai_provider import (
|
|
AIProvider,
|
|
AnthropicProvider,
|
|
GeminiProvider,
|
|
get_ai_provider,
|
|
)
|
|
from app.core.config import settings
|
|
|
|
|
|
class TestGetAIProvider:
|
|
"""Tests for the get_ai_provider factory function."""
|
|
|
|
def test_returns_gemini_when_configured(self):
|
|
original_provider = settings.AI_PROVIDER
|
|
original_key = settings.GOOGLE_AI_API_KEY
|
|
try:
|
|
settings.AI_PROVIDER = "gemini"
|
|
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
|
|
provider = get_ai_provider()
|
|
assert isinstance(provider, GeminiProvider)
|
|
finally:
|
|
settings.AI_PROVIDER = original_provider
|
|
settings.GOOGLE_AI_API_KEY = original_key
|
|
|
|
def test_returns_anthropic_when_configured(self):
|
|
original_provider = settings.AI_PROVIDER
|
|
original_key = settings.ANTHROPIC_API_KEY
|
|
try:
|
|
settings.AI_PROVIDER = "anthropic"
|
|
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
|
|
provider = get_ai_provider()
|
|
assert isinstance(provider, AnthropicProvider)
|
|
finally:
|
|
settings.AI_PROVIDER = original_provider
|
|
settings.ANTHROPIC_API_KEY = original_key
|
|
|
|
def test_fallback_to_anthropic_when_gemini_key_missing(self):
|
|
original_provider = settings.AI_PROVIDER
|
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
|
try:
|
|
settings.AI_PROVIDER = "gemini"
|
|
settings.GOOGLE_AI_API_KEY = None
|
|
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
|
|
provider = get_ai_provider()
|
|
assert isinstance(provider, AnthropicProvider)
|
|
finally:
|
|
settings.AI_PROVIDER = original_provider
|
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
|
|
|
def test_fallback_to_gemini_when_anthropic_key_missing(self):
|
|
original_provider = settings.AI_PROVIDER
|
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
|
try:
|
|
settings.AI_PROVIDER = "anthropic"
|
|
settings.ANTHROPIC_API_KEY = None
|
|
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
|
|
provider = get_ai_provider()
|
|
assert isinstance(provider, GeminiProvider)
|
|
finally:
|
|
settings.AI_PROVIDER = original_provider
|
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
|
|
|
def test_raises_when_no_provider_configured(self):
|
|
original_provider = settings.AI_PROVIDER
|
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
|
try:
|
|
settings.AI_PROVIDER = "gemini"
|
|
settings.GOOGLE_AI_API_KEY = None
|
|
settings.ANTHROPIC_API_KEY = None
|
|
with pytest.raises(RuntimeError, match="No AI provider configured"):
|
|
get_ai_provider()
|
|
finally:
|
|
settings.AI_PROVIDER = original_provider
|
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
|
|
|
|
|
class TestAnthropicProvider:
|
|
"""Tests for AnthropicProvider.generate_json."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json(self):
|
|
provider = AnthropicProvider(
|
|
api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30
|
|
)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.content = [MagicMock(type="text", text='{"result": "ok"}')]
|
|
mock_response.stop_reason = "end_turn"
|
|
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
|
text, input_tokens, output_tokens = await provider.generate_json(
|
|
system_prompt="You are a helper.",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=1024,
|
|
)
|
|
|
|
assert text == '{"result": "ok"}'
|
|
assert input_tokens == 100
|
|
assert output_tokens == 50
|
|
|
|
mock_client.messages.create.assert_called_once_with(
|
|
model="claude-haiku-4-5-20251001",
|
|
max_tokens=1024,
|
|
system="You are a helper.",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json_skips_non_text_blocks(self):
|
|
"""A leading non-text block (e.g. thinking) is skipped; the first
|
|
text block's text is returned instead of content[0].text."""
|
|
from app.core import ai_provider
|
|
|
|
ai_provider._anthropic_clients.clear()
|
|
|
|
provider = AnthropicProvider(
|
|
api_key="skip-key", model="claude-sonnet-4-6", timeout=31
|
|
)
|
|
|
|
thinking_block = MagicMock(type="thinking", thinking="hmm...")
|
|
text_block = MagicMock(type="text", text='{"ok": 1}')
|
|
mock_response = MagicMock()
|
|
mock_response.content = [thinking_block, text_block]
|
|
mock_response.stop_reason = "end_turn"
|
|
mock_response.usage = MagicMock(input_tokens=10, output_tokens=5)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
|
text, _, _ = await provider.generate_json(
|
|
system_prompt="You are a helper.",
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
)
|
|
|
|
assert text == '{"ok": 1}'
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json_raises_when_no_text_block(self):
|
|
"""A response with no text block (e.g. a bare refusal) raises a clear
|
|
error instead of returning a non-text block's attributes."""
|
|
from app.core import ai_provider
|
|
|
|
ai_provider._anthropic_clients.clear()
|
|
|
|
provider = AnthropicProvider(
|
|
api_key="empty-key", model="claude-sonnet-4-6", timeout=32
|
|
)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.content = [MagicMock(type="thinking", thinking="...")]
|
|
mock_response.stop_reason = "refusal"
|
|
mock_response.usage = MagicMock(input_tokens=10, output_tokens=0)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
|
with pytest.raises(ValueError, match="no text block"):
|
|
await provider.generate_json(
|
|
system_prompt="You are a helper.",
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json_logs_warning_on_truncation(self, caplog):
|
|
"""When stop_reason is max_tokens, a warning is logged (truncation
|
|
signal) and the partial text is still returned."""
|
|
import logging
|
|
|
|
from app.core import ai_provider
|
|
|
|
ai_provider._anthropic_clients.clear()
|
|
|
|
provider = AnthropicProvider(
|
|
api_key="trunc-key", model="claude-sonnet-4-6", timeout=33
|
|
)
|
|
|
|
text_block = MagicMock(type="text", text='{"partial": tr')
|
|
mock_response = MagicMock()
|
|
mock_response.content = [text_block]
|
|
mock_response.stop_reason = "max_tokens"
|
|
mock_response.usage = MagicMock(input_tokens=10, output_tokens=4096)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
|
with caplog.at_level(logging.WARNING, logger="app.core.ai_provider"):
|
|
text, _, _ = await provider.generate_json(
|
|
system_prompt="You are a helper.",
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
)
|
|
|
|
assert text == '{"partial": tr'
|
|
truncation_records = [
|
|
r for r in caplog.records if getattr(r, "stop_reason", None) == "max_tokens"
|
|
]
|
|
assert truncation_records, "expected a warning record for max_tokens truncation"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json_passes_output_config_when_schema_given(self):
|
|
"""When a JSON schema is supplied, it is forwarded as
|
|
output_config.format so the API constrains the response shape."""
|
|
from app.core import ai_provider
|
|
|
|
ai_provider._anthropic_clients.clear()
|
|
|
|
provider = AnthropicProvider(
|
|
api_key="schema-key", model="claude-sonnet-4-6", timeout=34
|
|
)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.content = [MagicMock(type="text", text='{"title": "x"}')]
|
|
mock_response.stop_reason = "end_turn"
|
|
mock_response.usage = MagicMock(input_tokens=10, output_tokens=5)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
|
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {"title": {"type": "string"}},
|
|
"required": ["title"],
|
|
"additionalProperties": False,
|
|
}
|
|
|
|
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
|
await provider.generate_json(
|
|
system_prompt="You are a helper.",
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
max_tokens=512,
|
|
schema=schema,
|
|
)
|
|
|
|
mock_client.messages.create.assert_called_once_with(
|
|
model="claude-sonnet-4-6",
|
|
max_tokens=512,
|
|
system="You are a helper.",
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
output_config={"format": {"type": "json_schema", "schema": schema}},
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json_no_output_config_when_schema_none(self):
|
|
"""With no schema, output_config is not sent (backward compatible)."""
|
|
from app.core import ai_provider
|
|
|
|
ai_provider._anthropic_clients.clear()
|
|
|
|
provider = AnthropicProvider(
|
|
api_key="noschema-key", model="claude-sonnet-4-6", timeout=35
|
|
)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.content = [MagicMock(type="text", text="{}")]
|
|
mock_response.stop_reason = "end_turn"
|
|
mock_response.usage = MagicMock(input_tokens=1, output_tokens=1)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
|
await provider.generate_json(
|
|
system_prompt="You are a helper.",
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
)
|
|
|
|
_, call_kwargs = mock_client.messages.create.call_args
|
|
assert "output_config" not in call_kwargs
|
|
|
|
|
|
class TestGeminiProvider:
|
|
"""Tests for GeminiProvider.generate_json."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json(self):
|
|
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
|
|
|
mock_usage = MagicMock()
|
|
mock_usage.prompt_token_count = 80
|
|
mock_usage.candidates_token_count = 40
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.text = '{"answer": 42}'
|
|
mock_response.usage_metadata = mock_usage
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.aio.models.generate_content = AsyncMock(
|
|
return_value=mock_response
|
|
)
|
|
|
|
mock_genai_module = MagicMock()
|
|
mock_genai_module.Client.return_value = mock_client
|
|
|
|
mock_types = MagicMock()
|
|
mock_types.Content.side_effect = lambda **kw: kw
|
|
mock_types.Part.side_effect = lambda **kw: kw
|
|
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
|
|
|
mock_google = MagicMock()
|
|
mock_google.genai = mock_genai_module
|
|
mock_genai_module.types = mock_types
|
|
|
|
with patch.dict(sys.modules, {
|
|
"google": mock_google,
|
|
"google.genai": mock_genai_module,
|
|
"google.genai.types": mock_types,
|
|
}):
|
|
text, input_tokens, output_tokens = await provider.generate_json(
|
|
system_prompt="Generate JSON.",
|
|
messages=[
|
|
{"role": "user", "content": "Give me data"},
|
|
{"role": "assistant", "content": "Here it is"},
|
|
{"role": "user", "content": "More please"},
|
|
],
|
|
max_tokens=2048,
|
|
)
|
|
|
|
assert text == '{"answer": 42}'
|
|
assert input_tokens == 80
|
|
assert output_tokens == 40
|
|
|
|
mock_client.aio.models.generate_content.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json_accepts_and_ignores_schema(self):
|
|
"""Gemini accepts the schema kwarg (interface parity) and still
|
|
returns JSON; it does not error on the param."""
|
|
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
|
|
|
mock_usage = MagicMock()
|
|
mock_usage.prompt_token_count = 5
|
|
mock_usage.candidates_token_count = 3
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.text = '{"answer": 1}'
|
|
mock_response.usage_metadata = mock_usage
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.aio.models.generate_content = AsyncMock(return_value=mock_response)
|
|
|
|
mock_genai_module = MagicMock()
|
|
mock_genai_module.Client.return_value = mock_client
|
|
|
|
mock_types = MagicMock()
|
|
mock_types.Content.side_effect = lambda **kw: kw
|
|
mock_types.Part.side_effect = lambda **kw: kw
|
|
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
|
|
|
mock_google = MagicMock()
|
|
mock_google.genai = mock_genai_module
|
|
mock_genai_module.types = mock_types
|
|
|
|
with patch.dict(sys.modules, {
|
|
"google": mock_google,
|
|
"google.genai": mock_genai_module,
|
|
"google.genai.types": mock_types,
|
|
}):
|
|
text, _, _ = await provider.generate_json(
|
|
system_prompt="Generate JSON.",
|
|
messages=[{"role": "user", "content": "data"}],
|
|
schema={"type": "object"},
|
|
)
|
|
|
|
assert text == '{"answer": 1}'
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json_handles_none_usage(self):
|
|
"""Token counts default to 0 when usage_metadata attributes are None."""
|
|
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
|
|
|
mock_usage = MagicMock(spec=[]) # No attributes at all
|
|
mock_response = MagicMock()
|
|
mock_response.text = "{}"
|
|
mock_response.usage_metadata = mock_usage
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.aio.models.generate_content = AsyncMock(
|
|
return_value=mock_response
|
|
)
|
|
|
|
mock_genai_module = MagicMock()
|
|
mock_genai_module.Client.return_value = mock_client
|
|
|
|
mock_types = MagicMock()
|
|
mock_types.Content.side_effect = lambda **kw: kw
|
|
mock_types.Part.side_effect = lambda **kw: kw
|
|
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
|
|
|
mock_google = MagicMock()
|
|
mock_google.genai = mock_genai_module
|
|
mock_genai_module.types = mock_types
|
|
|
|
with patch.dict(sys.modules, {
|
|
"google": mock_google,
|
|
"google.genai": mock_genai_module,
|
|
"google.genai.types": mock_types,
|
|
}):
|
|
text, input_tokens, output_tokens = await provider.generate_json(
|
|
system_prompt="test",
|
|
messages=[{"role": "user", "content": "test"}],
|
|
)
|
|
|
|
assert text == "{}"
|
|
assert input_tokens == 0
|
|
assert output_tokens == 0
|