Files
resolutionflow/backend/tests/test_ai_provider.py
2026-02-26 17:16:45 -05:00

217 lines
7.8 KiB
Python

"""Tests for the AI provider abstraction layer."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import sys
from app.core.ai_provider import (
AIProvider,
AnthropicProvider,
GeminiProvider,
get_ai_provider,
)
from app.core.config import settings
class TestGetAIProvider:
"""Tests for the get_ai_provider factory function."""
def test_returns_gemini_when_configured(self):
original_provider = settings.AI_PROVIDER
original_key = settings.GOOGLE_AI_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
provider = get_ai_provider()
assert isinstance(provider, GeminiProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_key
def test_returns_anthropic_when_configured(self):
original_provider = settings.AI_PROVIDER
original_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "anthropic"
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
provider = get_ai_provider()
assert isinstance(provider, AnthropicProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.ANTHROPIC_API_KEY = original_key
def test_fallback_to_anthropic_when_gemini_key_missing(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
provider = get_ai_provider()
assert isinstance(provider, AnthropicProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
def test_fallback_to_gemini_when_anthropic_key_missing(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "anthropic"
settings.ANTHROPIC_API_KEY = None
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
provider = get_ai_provider()
assert isinstance(provider, GeminiProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
def test_raises_when_no_provider_configured(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = None
with pytest.raises(RuntimeError, match="No AI provider configured"):
get_ai_provider()
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
class TestAnthropicProvider:
"""Tests for AnthropicProvider.generate_json."""
@pytest.mark.asyncio
async def test_generate_json(self):
provider = AnthropicProvider(
api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30
)
mock_response = MagicMock()
mock_response.content = [MagicMock(text='{"result": "ok"}')]
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
mock_client = AsyncMock()
mock_client.messages.create = AsyncMock(return_value=mock_response)
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="You are a helper.",
messages=[{"role": "user", "content": "Hello"}],
max_tokens=1024,
)
assert text == '{"result": "ok"}'
assert input_tokens == 100
assert output_tokens == 50
mock_client.messages.create.assert_called_once_with(
model="claude-haiku-4-5-20251001",
max_tokens=1024,
system="You are a helper.",
messages=[{"role": "user", "content": "Hello"}],
)
class TestGeminiProvider:
"""Tests for GeminiProvider.generate_json."""
@pytest.mark.asyncio
async def test_generate_json(self):
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
mock_usage = MagicMock()
mock_usage.prompt_token_count = 80
mock_usage.candidates_token_count = 40
mock_response = MagicMock()
mock_response.text = '{"answer": 42}'
mock_response.usage_metadata = mock_usage
mock_client = MagicMock()
mock_client.models.generate_content_async = AsyncMock(
return_value=mock_response
)
mock_genai_module = MagicMock()
mock_genai_module.Client.return_value = mock_client
mock_types = MagicMock()
mock_types.Content.side_effect = lambda **kw: kw
mock_types.Part.side_effect = lambda **kw: kw
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
mock_google = MagicMock()
mock_google.genai = mock_genai_module
mock_genai_module.types = mock_types
with patch.dict(sys.modules, {
"google": mock_google,
"google.genai": mock_genai_module,
"google.genai.types": mock_types,
}):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="Generate JSON.",
messages=[
{"role": "user", "content": "Give me data"},
{"role": "assistant", "content": "Here it is"},
{"role": "user", "content": "More please"},
],
max_tokens=2048,
)
assert text == '{"answer": 42}'
assert input_tokens == 80
assert output_tokens == 40
mock_client.models.generate_content_async.assert_called_once()
@pytest.mark.asyncio
async def test_generate_json_handles_none_usage(self):
"""Token counts default to 0 when usage_metadata attributes are None."""
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
mock_usage = MagicMock(spec=[]) # No attributes at all
mock_response = MagicMock()
mock_response.text = "{}"
mock_response.usage_metadata = mock_usage
mock_client = MagicMock()
mock_client.models.generate_content_async = AsyncMock(
return_value=mock_response
)
mock_genai_module = MagicMock()
mock_genai_module.Client.return_value = mock_client
mock_types = MagicMock()
mock_types.Content.side_effect = lambda **kw: kw
mock_types.Part.side_effect = lambda **kw: kw
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
mock_google = MagicMock()
mock_google.genai = mock_genai_module
mock_genai_module.types = mock_types
with patch.dict(sys.modules, {
"google": mock_google,
"google.genai": mock_genai_module,
"google.genai.types": mock_types,
}):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="test",
messages=[{"role": "user", "content": "test"}],
)
assert text == "{}"
assert input_tokens == 0
assert output_tokens == 0