"""Tests for the AI provider abstraction layer.""" import pytest from unittest.mock import AsyncMock, MagicMock, patch import sys from app.core.ai_provider import ( AIProvider, AnthropicProvider, GeminiProvider, get_ai_provider, ) from app.core.config import settings class TestGetAIProvider: """Tests for the get_ai_provider factory function.""" def test_returns_gemini_when_configured(self): original_provider = settings.AI_PROVIDER original_key = settings.GOOGLE_AI_API_KEY try: settings.AI_PROVIDER = "gemini" settings.GOOGLE_AI_API_KEY = "test-gemini-key" provider = get_ai_provider() assert isinstance(provider, GeminiProvider) finally: settings.AI_PROVIDER = original_provider settings.GOOGLE_AI_API_KEY = original_key def test_returns_anthropic_when_configured(self): original_provider = settings.AI_PROVIDER original_key = settings.ANTHROPIC_API_KEY try: settings.AI_PROVIDER = "anthropic" settings.ANTHROPIC_API_KEY = "test-anthropic-key" provider = get_ai_provider() assert isinstance(provider, AnthropicProvider) finally: settings.AI_PROVIDER = original_provider settings.ANTHROPIC_API_KEY = original_key def test_fallback_to_anthropic_when_gemini_key_missing(self): original_provider = settings.AI_PROVIDER original_gemini_key = settings.GOOGLE_AI_API_KEY original_anthropic_key = settings.ANTHROPIC_API_KEY try: settings.AI_PROVIDER = "gemini" settings.GOOGLE_AI_API_KEY = None settings.ANTHROPIC_API_KEY = "test-anthropic-key" provider = get_ai_provider() assert isinstance(provider, AnthropicProvider) finally: settings.AI_PROVIDER = original_provider settings.GOOGLE_AI_API_KEY = original_gemini_key settings.ANTHROPIC_API_KEY = original_anthropic_key def test_fallback_to_gemini_when_anthropic_key_missing(self): original_provider = settings.AI_PROVIDER original_gemini_key = settings.GOOGLE_AI_API_KEY original_anthropic_key = settings.ANTHROPIC_API_KEY try: settings.AI_PROVIDER = "anthropic" settings.ANTHROPIC_API_KEY = None settings.GOOGLE_AI_API_KEY = "test-gemini-key" provider = get_ai_provider() assert isinstance(provider, GeminiProvider) finally: settings.AI_PROVIDER = original_provider settings.GOOGLE_AI_API_KEY = original_gemini_key settings.ANTHROPIC_API_KEY = original_anthropic_key def test_raises_when_no_provider_configured(self): original_provider = settings.AI_PROVIDER original_gemini_key = settings.GOOGLE_AI_API_KEY original_anthropic_key = settings.ANTHROPIC_API_KEY try: settings.AI_PROVIDER = "gemini" settings.GOOGLE_AI_API_KEY = None settings.ANTHROPIC_API_KEY = None with pytest.raises(RuntimeError, match="No AI provider configured"): get_ai_provider() finally: settings.AI_PROVIDER = original_provider settings.GOOGLE_AI_API_KEY = original_gemini_key settings.ANTHROPIC_API_KEY = original_anthropic_key class TestAnthropicProvider: """Tests for AnthropicProvider.generate_json.""" @pytest.mark.asyncio async def test_generate_json(self): provider = AnthropicProvider( api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30 ) mock_response = MagicMock() mock_response.content = [MagicMock(text='{"result": "ok"}')] mock_response.usage = MagicMock(input_tokens=100, output_tokens=50) mock_client = AsyncMock() mock_client.messages.create = AsyncMock(return_value=mock_response) with patch("anthropic.AsyncAnthropic", return_value=mock_client): text, input_tokens, output_tokens = await provider.generate_json( system_prompt="You are a helper.", messages=[{"role": "user", "content": "Hello"}], max_tokens=1024, ) assert text == '{"result": "ok"}' assert input_tokens == 100 assert output_tokens == 50 mock_client.messages.create.assert_called_once_with( model="claude-haiku-4-5-20251001", max_tokens=1024, system="You are a helper.", messages=[{"role": "user", "content": "Hello"}], ) class TestGeminiProvider: """Tests for GeminiProvider.generate_json.""" @pytest.mark.asyncio async def test_generate_json(self): provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") mock_usage = MagicMock() mock_usage.prompt_token_count = 80 mock_usage.candidates_token_count = 40 mock_response = MagicMock() mock_response.text = '{"answer": 42}' mock_response.usage_metadata = mock_usage mock_client = MagicMock() mock_client.aio.models.generate_content = AsyncMock( return_value=mock_response ) mock_genai_module = MagicMock() mock_genai_module.Client.return_value = mock_client mock_types = MagicMock() mock_types.Content.side_effect = lambda **kw: kw mock_types.Part.side_effect = lambda **kw: kw mock_types.GenerateContentConfig.side_effect = lambda **kw: kw mock_google = MagicMock() mock_google.genai = mock_genai_module mock_genai_module.types = mock_types with patch.dict(sys.modules, { "google": mock_google, "google.genai": mock_genai_module, "google.genai.types": mock_types, }): text, input_tokens, output_tokens = await provider.generate_json( system_prompt="Generate JSON.", messages=[ {"role": "user", "content": "Give me data"}, {"role": "assistant", "content": "Here it is"}, {"role": "user", "content": "More please"}, ], max_tokens=2048, ) assert text == '{"answer": 42}' assert input_tokens == 80 assert output_tokens == 40 mock_client.aio.models.generate_content.assert_called_once() @pytest.mark.asyncio async def test_generate_json_handles_none_usage(self): """Token counts default to 0 when usage_metadata attributes are None.""" provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") mock_usage = MagicMock(spec=[]) # No attributes at all mock_response = MagicMock() mock_response.text = "{}" mock_response.usage_metadata = mock_usage mock_client = MagicMock() mock_client.aio.models.generate_content = AsyncMock( return_value=mock_response ) mock_genai_module = MagicMock() mock_genai_module.Client.return_value = mock_client mock_types = MagicMock() mock_types.Content.side_effect = lambda **kw: kw mock_types.Part.side_effect = lambda **kw: kw mock_types.GenerateContentConfig.side_effect = lambda **kw: kw mock_google = MagicMock() mock_google.genai = mock_genai_module mock_genai_module.types = mock_types with patch.dict(sys.modules, { "google": mock_google, "google.genai": mock_genai_module, "google.genai.types": mock_types, }): text, input_tokens, output_tokens = await provider.generate_json( system_prompt="test", messages=[{"role": "user", "content": "test"}], ) assert text == "{}" assert input_tokens == 0 assert output_tokens == 0