"""Tests for the AI provider abstraction layer.""" import pytest from unittest.mock import AsyncMock, MagicMock, patch import sys from app.core.ai_provider import ( AIProvider, AnthropicProvider, GeminiProvider, get_ai_provider, ) from app.core.config import settings class TestGetAIProvider: """Tests for the get_ai_provider factory function.""" def test_returns_gemini_when_configured(self): original_provider = settings.AI_PROVIDER original_key = settings.GOOGLE_AI_API_KEY try: settings.AI_PROVIDER = "gemini" settings.GOOGLE_AI_API_KEY = "test-gemini-key" provider = get_ai_provider() assert isinstance(provider, GeminiProvider) finally: settings.AI_PROVIDER = original_provider settings.GOOGLE_AI_API_KEY = original_key def test_returns_anthropic_when_configured(self): original_provider = settings.AI_PROVIDER original_key = settings.ANTHROPIC_API_KEY try: settings.AI_PROVIDER = "anthropic" settings.ANTHROPIC_API_KEY = "test-anthropic-key" provider = get_ai_provider() assert isinstance(provider, AnthropicProvider) finally: settings.AI_PROVIDER = original_provider settings.ANTHROPIC_API_KEY = original_key def test_fallback_to_anthropic_when_gemini_key_missing(self): original_provider = settings.AI_PROVIDER original_gemini_key = settings.GOOGLE_AI_API_KEY original_anthropic_key = settings.ANTHROPIC_API_KEY try: settings.AI_PROVIDER = "gemini" settings.GOOGLE_AI_API_KEY = None settings.ANTHROPIC_API_KEY = "test-anthropic-key" provider = get_ai_provider() assert isinstance(provider, AnthropicProvider) finally: settings.AI_PROVIDER = original_provider settings.GOOGLE_AI_API_KEY = original_gemini_key settings.ANTHROPIC_API_KEY = original_anthropic_key def test_fallback_to_gemini_when_anthropic_key_missing(self): original_provider = settings.AI_PROVIDER original_gemini_key = settings.GOOGLE_AI_API_KEY original_anthropic_key = settings.ANTHROPIC_API_KEY try: settings.AI_PROVIDER = "anthropic" settings.ANTHROPIC_API_KEY = None settings.GOOGLE_AI_API_KEY = "test-gemini-key" provider = get_ai_provider() assert isinstance(provider, GeminiProvider) finally: settings.AI_PROVIDER = original_provider settings.GOOGLE_AI_API_KEY = original_gemini_key settings.ANTHROPIC_API_KEY = original_anthropic_key def test_raises_when_no_provider_configured(self): original_provider = settings.AI_PROVIDER original_gemini_key = settings.GOOGLE_AI_API_KEY original_anthropic_key = settings.ANTHROPIC_API_KEY try: settings.AI_PROVIDER = "gemini" settings.GOOGLE_AI_API_KEY = None settings.ANTHROPIC_API_KEY = None with pytest.raises(RuntimeError, match="No AI provider configured"): get_ai_provider() finally: settings.AI_PROVIDER = original_provider settings.GOOGLE_AI_API_KEY = original_gemini_key settings.ANTHROPIC_API_KEY = original_anthropic_key class TestAnthropicProvider: """Tests for AnthropicProvider.generate_json.""" @pytest.mark.asyncio async def test_generate_json(self): provider = AnthropicProvider( api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30 ) mock_response = MagicMock() mock_response.content = [MagicMock(type="text", text='{"result": "ok"}')] mock_response.stop_reason = "end_turn" mock_response.usage = MagicMock(input_tokens=100, output_tokens=50) mock_client = AsyncMock() mock_client.messages.create = AsyncMock(return_value=mock_response) with patch("anthropic.AsyncAnthropic", return_value=mock_client): text, input_tokens, output_tokens = await provider.generate_json( system_prompt="You are a helper.", messages=[{"role": "user", "content": "Hello"}], max_tokens=1024, ) assert text == '{"result": "ok"}' assert input_tokens == 100 assert output_tokens == 50 mock_client.messages.create.assert_called_once_with( model="claude-haiku-4-5-20251001", max_tokens=1024, system="You are a helper.", messages=[{"role": "user", "content": "Hello"}], ) @pytest.mark.asyncio async def test_generate_json_skips_non_text_blocks(self): """A leading non-text block (e.g. thinking) is skipped; the first text block's text is returned instead of content[0].text.""" from app.core import ai_provider ai_provider._anthropic_clients.clear() provider = AnthropicProvider( api_key="skip-key", model="claude-sonnet-4-6", timeout=31 ) thinking_block = MagicMock(type="thinking", thinking="hmm...") text_block = MagicMock(type="text", text='{"ok": 1}') mock_response = MagicMock() mock_response.content = [thinking_block, text_block] mock_response.stop_reason = "end_turn" mock_response.usage = MagicMock(input_tokens=10, output_tokens=5) mock_client = AsyncMock() mock_client.messages.create = AsyncMock(return_value=mock_response) with patch("anthropic.AsyncAnthropic", return_value=mock_client): text, _, _ = await provider.generate_json( system_prompt="You are a helper.", messages=[{"role": "user", "content": "Hi"}], ) assert text == '{"ok": 1}' @pytest.mark.asyncio async def test_generate_json_raises_when_no_text_block(self): """A response with no text block (e.g. a bare refusal) raises a clear error instead of returning a non-text block's attributes.""" from app.core import ai_provider ai_provider._anthropic_clients.clear() provider = AnthropicProvider( api_key="empty-key", model="claude-sonnet-4-6", timeout=32 ) mock_response = MagicMock() mock_response.content = [MagicMock(type="thinking", thinking="...")] mock_response.stop_reason = "refusal" mock_response.usage = MagicMock(input_tokens=10, output_tokens=0) mock_client = AsyncMock() mock_client.messages.create = AsyncMock(return_value=mock_response) with patch("anthropic.AsyncAnthropic", return_value=mock_client): with pytest.raises(ValueError, match="no text block"): await provider.generate_json( system_prompt="You are a helper.", messages=[{"role": "user", "content": "Hi"}], ) @pytest.mark.asyncio async def test_generate_json_logs_warning_on_truncation(self, caplog): """When stop_reason is max_tokens, a warning is logged (truncation signal) and the partial text is still returned.""" import logging from app.core import ai_provider ai_provider._anthropic_clients.clear() provider = AnthropicProvider( api_key="trunc-key", model="claude-sonnet-4-6", timeout=33 ) text_block = MagicMock(type="text", text='{"partial": tr') mock_response = MagicMock() mock_response.content = [text_block] mock_response.stop_reason = "max_tokens" mock_response.usage = MagicMock(input_tokens=10, output_tokens=4096) mock_client = AsyncMock() mock_client.messages.create = AsyncMock(return_value=mock_response) with patch("anthropic.AsyncAnthropic", return_value=mock_client): with caplog.at_level(logging.WARNING, logger="app.core.ai_provider"): text, _, _ = await provider.generate_json( system_prompt="You are a helper.", messages=[{"role": "user", "content": "Hi"}], ) assert text == '{"partial": tr' truncation_records = [ r for r in caplog.records if getattr(r, "stop_reason", None) == "max_tokens" ] assert truncation_records, "expected a warning record for max_tokens truncation" @pytest.mark.asyncio async def test_generate_json_passes_output_config_when_schema_given(self): """When a JSON schema is supplied, it is forwarded as output_config.format so the API constrains the response shape.""" from app.core import ai_provider ai_provider._anthropic_clients.clear() provider = AnthropicProvider( api_key="schema-key", model="claude-sonnet-4-6", timeout=34 ) mock_response = MagicMock() mock_response.content = [MagicMock(type="text", text='{"title": "x"}')] mock_response.stop_reason = "end_turn" mock_response.usage = MagicMock(input_tokens=10, output_tokens=5) mock_client = AsyncMock() mock_client.messages.create = AsyncMock(return_value=mock_response) schema = { "type": "object", "properties": {"title": {"type": "string"}}, "required": ["title"], "additionalProperties": False, } with patch("anthropic.AsyncAnthropic", return_value=mock_client): await provider.generate_json( system_prompt="You are a helper.", messages=[{"role": "user", "content": "Hi"}], max_tokens=512, schema=schema, ) mock_client.messages.create.assert_called_once_with( model="claude-sonnet-4-6", max_tokens=512, system="You are a helper.", messages=[{"role": "user", "content": "Hi"}], output_config={"format": {"type": "json_schema", "schema": schema}}, ) @pytest.mark.asyncio async def test_generate_json_no_output_config_when_schema_none(self): """With no schema, output_config is not sent (backward compatible).""" from app.core import ai_provider ai_provider._anthropic_clients.clear() provider = AnthropicProvider( api_key="noschema-key", model="claude-sonnet-4-6", timeout=35 ) mock_response = MagicMock() mock_response.content = [MagicMock(type="text", text="{}")] mock_response.stop_reason = "end_turn" mock_response.usage = MagicMock(input_tokens=1, output_tokens=1) mock_client = AsyncMock() mock_client.messages.create = AsyncMock(return_value=mock_response) with patch("anthropic.AsyncAnthropic", return_value=mock_client): await provider.generate_json( system_prompt="You are a helper.", messages=[{"role": "user", "content": "Hi"}], ) _, call_kwargs = mock_client.messages.create.call_args assert "output_config" not in call_kwargs class TestGeminiProvider: """Tests for GeminiProvider.generate_json.""" @pytest.mark.asyncio async def test_generate_json(self): provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") mock_usage = MagicMock() mock_usage.prompt_token_count = 80 mock_usage.candidates_token_count = 40 mock_response = MagicMock() mock_response.text = '{"answer": 42}' mock_response.usage_metadata = mock_usage mock_client = MagicMock() mock_client.aio.models.generate_content = AsyncMock( return_value=mock_response ) mock_genai_module = MagicMock() mock_genai_module.Client.return_value = mock_client mock_types = MagicMock() mock_types.Content.side_effect = lambda **kw: kw mock_types.Part.side_effect = lambda **kw: kw mock_types.GenerateContentConfig.side_effect = lambda **kw: kw mock_google = MagicMock() mock_google.genai = mock_genai_module mock_genai_module.types = mock_types with patch.dict(sys.modules, { "google": mock_google, "google.genai": mock_genai_module, "google.genai.types": mock_types, }): text, input_tokens, output_tokens = await provider.generate_json( system_prompt="Generate JSON.", messages=[ {"role": "user", "content": "Give me data"}, {"role": "assistant", "content": "Here it is"}, {"role": "user", "content": "More please"}, ], max_tokens=2048, ) assert text == '{"answer": 42}' assert input_tokens == 80 assert output_tokens == 40 mock_client.aio.models.generate_content.assert_called_once() @pytest.mark.asyncio async def test_generate_json_accepts_and_ignores_schema(self): """Gemini accepts the schema kwarg (interface parity) and still returns JSON; it does not error on the param.""" provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") mock_usage = MagicMock() mock_usage.prompt_token_count = 5 mock_usage.candidates_token_count = 3 mock_response = MagicMock() mock_response.text = '{"answer": 1}' mock_response.usage_metadata = mock_usage mock_client = MagicMock() mock_client.aio.models.generate_content = AsyncMock(return_value=mock_response) mock_genai_module = MagicMock() mock_genai_module.Client.return_value = mock_client mock_types = MagicMock() mock_types.Content.side_effect = lambda **kw: kw mock_types.Part.side_effect = lambda **kw: kw mock_types.GenerateContentConfig.side_effect = lambda **kw: kw mock_google = MagicMock() mock_google.genai = mock_genai_module mock_genai_module.types = mock_types with patch.dict(sys.modules, { "google": mock_google, "google.genai": mock_genai_module, "google.genai.types": mock_types, }): text, _, _ = await provider.generate_json( system_prompt="Generate JSON.", messages=[{"role": "user", "content": "data"}], schema={"type": "object"}, ) assert text == '{"answer": 1}' @pytest.mark.asyncio async def test_generate_json_handles_none_usage(self): """Token counts default to 0 when usage_metadata attributes are None.""" provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") mock_usage = MagicMock(spec=[]) # No attributes at all mock_response = MagicMock() mock_response.text = "{}" mock_response.usage_metadata = mock_usage mock_client = MagicMock() mock_client.aio.models.generate_content = AsyncMock( return_value=mock_response ) mock_genai_module = MagicMock() mock_genai_module.Client.return_value = mock_client mock_types = MagicMock() mock_types.Content.side_effect = lambda **kw: kw mock_types.Part.side_effect = lambda **kw: kw mock_types.GenerateContentConfig.side_effect = lambda **kw: kw mock_google = MagicMock() mock_google.genai = mock_genai_module mock_genai_module.types = mock_types with patch.dict(sys.modules, { "google": mock_google, "google.genai": mock_genai_module, "google.genai.types": mock_types, }): text, input_tokens, output_tokens = await provider.generate_json( system_prompt="test", messages=[{"role": "user", "content": "test"}], ) assert text == "{}" assert input_tokens == 0 assert output_tokens == 0