feat: add AI provider abstraction with Gemini and Anthropic support

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-02-26 17:16:45 -05:00
parent be041d0d29
commit 55be033ecb
2 changed files with 378 additions and 0 deletions

View File

@@ -0,0 +1,162 @@
"""
AI Provider abstraction layer.
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
backends for JSON generation used by the AI Flow Builder.
"""
from abc import ABC, abstractmethod
from app.core.config import settings
class AIProvider(ABC):
"""Abstract base class for AI providers."""
@abstractmethod
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Generate a JSON response from the AI model.
Args:
system_prompt: System-level instruction for the model.
messages: List of message dicts with "role" and "content" keys.
max_tokens: Maximum output tokens.
Returns:
Tuple of (response_text, input_tokens, output_tokens).
"""
...
class GeminiProvider(AIProvider):
"""Google Gemini provider using the google-genai SDK."""
def __init__(self, api_key: str, model: str) -> None:
self._api_key = api_key
self._model = model
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
from google import genai
from google.genai import types as genai_types
client = genai.Client(api_key=self._api_key)
# Convert messages to Gemini Content format
contents: list[genai_types.Content] = []
for msg in messages:
role = "model" if msg["role"] == "assistant" else "user"
contents.append(
genai_types.Content(
role=role,
parts=[genai_types.Part(text=msg["content"])],
)
)
config = genai_types.GenerateContentConfig(
system_instruction=system_prompt,
max_output_tokens=max_tokens,
response_mime_type="application/json",
)
response = await client.models.generate_content_async(
model=self._model,
contents=contents,
config=config,
)
text = response.text or ""
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
output_tokens = (
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
)
return text, input_tokens, output_tokens
class AnthropicProvider(AIProvider):
"""Anthropic Claude provider using the anthropic SDK."""
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
self._api_key = api_key
self._model = model
self._timeout = timeout
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
import anthropic
client = anthropic.AsyncAnthropic(
api_key=self._api_key,
timeout=self._timeout,
)
response = await client.messages.create(
model=self._model,
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
)
text = response.content[0].text
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
return text, input_tokens, output_tokens
def get_ai_provider() -> AIProvider:
"""Factory that returns the configured AI provider.
Selection logic:
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
3. Fallback: if preferred provider key missing, try the other one
4. If nothing configured -> raise RuntimeError
"""
provider = settings.AI_PROVIDER
if provider == "gemini":
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
# Fallback to Anthropic
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
elif provider == "anthropic":
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
# Fallback to Gemini
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
raise RuntimeError(
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
)

View File

@@ -0,0 +1,216 @@
"""Tests for the AI provider abstraction layer."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import sys
from app.core.ai_provider import (
AIProvider,
AnthropicProvider,
GeminiProvider,
get_ai_provider,
)
from app.core.config import settings
class TestGetAIProvider:
"""Tests for the get_ai_provider factory function."""
def test_returns_gemini_when_configured(self):
original_provider = settings.AI_PROVIDER
original_key = settings.GOOGLE_AI_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
provider = get_ai_provider()
assert isinstance(provider, GeminiProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_key
def test_returns_anthropic_when_configured(self):
original_provider = settings.AI_PROVIDER
original_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "anthropic"
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
provider = get_ai_provider()
assert isinstance(provider, AnthropicProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.ANTHROPIC_API_KEY = original_key
def test_fallback_to_anthropic_when_gemini_key_missing(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
provider = get_ai_provider()
assert isinstance(provider, AnthropicProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
def test_fallback_to_gemini_when_anthropic_key_missing(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "anthropic"
settings.ANTHROPIC_API_KEY = None
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
provider = get_ai_provider()
assert isinstance(provider, GeminiProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
def test_raises_when_no_provider_configured(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = None
with pytest.raises(RuntimeError, match="No AI provider configured"):
get_ai_provider()
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
class TestAnthropicProvider:
"""Tests for AnthropicProvider.generate_json."""
@pytest.mark.asyncio
async def test_generate_json(self):
provider = AnthropicProvider(
api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30
)
mock_response = MagicMock()
mock_response.content = [MagicMock(text='{"result": "ok"}')]
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
mock_client = AsyncMock()
mock_client.messages.create = AsyncMock(return_value=mock_response)
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="You are a helper.",
messages=[{"role": "user", "content": "Hello"}],
max_tokens=1024,
)
assert text == '{"result": "ok"}'
assert input_tokens == 100
assert output_tokens == 50
mock_client.messages.create.assert_called_once_with(
model="claude-haiku-4-5-20251001",
max_tokens=1024,
system="You are a helper.",
messages=[{"role": "user", "content": "Hello"}],
)
class TestGeminiProvider:
"""Tests for GeminiProvider.generate_json."""
@pytest.mark.asyncio
async def test_generate_json(self):
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
mock_usage = MagicMock()
mock_usage.prompt_token_count = 80
mock_usage.candidates_token_count = 40
mock_response = MagicMock()
mock_response.text = '{"answer": 42}'
mock_response.usage_metadata = mock_usage
mock_client = MagicMock()
mock_client.models.generate_content_async = AsyncMock(
return_value=mock_response
)
mock_genai_module = MagicMock()
mock_genai_module.Client.return_value = mock_client
mock_types = MagicMock()
mock_types.Content.side_effect = lambda **kw: kw
mock_types.Part.side_effect = lambda **kw: kw
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
mock_google = MagicMock()
mock_google.genai = mock_genai_module
mock_genai_module.types = mock_types
with patch.dict(sys.modules, {
"google": mock_google,
"google.genai": mock_genai_module,
"google.genai.types": mock_types,
}):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="Generate JSON.",
messages=[
{"role": "user", "content": "Give me data"},
{"role": "assistant", "content": "Here it is"},
{"role": "user", "content": "More please"},
],
max_tokens=2048,
)
assert text == '{"answer": 42}'
assert input_tokens == 80
assert output_tokens == 40
mock_client.models.generate_content_async.assert_called_once()
@pytest.mark.asyncio
async def test_generate_json_handles_none_usage(self):
"""Token counts default to 0 when usage_metadata attributes are None."""
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
mock_usage = MagicMock(spec=[]) # No attributes at all
mock_response = MagicMock()
mock_response.text = "{}"
mock_response.usage_metadata = mock_usage
mock_client = MagicMock()
mock_client.models.generate_content_async = AsyncMock(
return_value=mock_response
)
mock_genai_module = MagicMock()
mock_genai_module.Client.return_value = mock_client
mock_types = MagicMock()
mock_types.Content.side_effect = lambda **kw: kw
mock_types.Part.side_effect = lambda **kw: kw
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
mock_google = MagicMock()
mock_google.genai = mock_genai_module
mock_genai_module.types = mock_types
with patch.dict(sys.modules, {
"google": mock_google,
"google.genai": mock_genai_module,
"google.genai.types": mock_types,
}):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="test",
messages=[{"role": "user", "content": "test"}],
)
assert text == "{}"
assert input_tokens == 0
assert output_tokens == 0