feat: add AI provider abstraction with Gemini and Anthropic support
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
162
backend/app/core/ai_provider.py
Normal file
162
backend/app/core/ai_provider.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
AI Provider abstraction layer.
|
||||||
|
|
||||||
|
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
|
||||||
|
backends for JSON generation used by the AI Flow Builder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class AIProvider(ABC):
|
||||||
|
"""Abstract base class for AI providers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate_json(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> tuple[str, int, int]:
|
||||||
|
"""Generate a JSON response from the AI model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: System-level instruction for the model.
|
||||||
|
messages: List of message dicts with "role" and "content" keys.
|
||||||
|
max_tokens: Maximum output tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response_text, input_tokens, output_tokens).
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiProvider(AIProvider):
|
||||||
|
"""Google Gemini provider using the google-genai SDK."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str) -> None:
|
||||||
|
self._api_key = api_key
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
async def generate_json(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> tuple[str, int, int]:
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types as genai_types
|
||||||
|
|
||||||
|
client = genai.Client(api_key=self._api_key)
|
||||||
|
|
||||||
|
# Convert messages to Gemini Content format
|
||||||
|
contents: list[genai_types.Content] = []
|
||||||
|
for msg in messages:
|
||||||
|
role = "model" if msg["role"] == "assistant" else "user"
|
||||||
|
contents.append(
|
||||||
|
genai_types.Content(
|
||||||
|
role=role,
|
||||||
|
parts=[genai_types.Part(text=msg["content"])],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
config = genai_types.GenerateContentConfig(
|
||||||
|
system_instruction=system_prompt,
|
||||||
|
max_output_tokens=max_tokens,
|
||||||
|
response_mime_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.models.generate_content_async(
|
||||||
|
model=self._model,
|
||||||
|
contents=contents,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = response.text or ""
|
||||||
|
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
||||||
|
output_tokens = (
|
||||||
|
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return text, input_tokens, output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(AIProvider):
|
||||||
|
"""Anthropic Claude provider using the anthropic SDK."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
|
||||||
|
self._api_key = api_key
|
||||||
|
self._model = model
|
||||||
|
self._timeout = timeout
|
||||||
|
|
||||||
|
async def generate_json(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> tuple[str, int, int]:
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.AsyncAnthropic(
|
||||||
|
api_key=self._api_key,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.messages.create(
|
||||||
|
model=self._model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = response.content[0].text
|
||||||
|
input_tokens = response.usage.input_tokens
|
||||||
|
output_tokens = response.usage.output_tokens
|
||||||
|
|
||||||
|
return text, input_tokens, output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_ai_provider() -> AIProvider:
|
||||||
|
"""Factory that returns the configured AI provider.
|
||||||
|
|
||||||
|
Selection logic:
|
||||||
|
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
|
||||||
|
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
|
||||||
|
3. Fallback: if preferred provider key missing, try the other one
|
||||||
|
4. If nothing configured -> raise RuntimeError
|
||||||
|
"""
|
||||||
|
provider = settings.AI_PROVIDER
|
||||||
|
|
||||||
|
if provider == "gemini":
|
||||||
|
if settings.GOOGLE_AI_API_KEY:
|
||||||
|
return GeminiProvider(
|
||||||
|
api_key=settings.GOOGLE_AI_API_KEY,
|
||||||
|
model=settings.AI_MODEL_GEMINI,
|
||||||
|
)
|
||||||
|
# Fallback to Anthropic
|
||||||
|
if settings.ANTHROPIC_API_KEY:
|
||||||
|
return AnthropicProvider(
|
||||||
|
api_key=settings.ANTHROPIC_API_KEY,
|
||||||
|
model=settings.AI_MODEL_ANTHROPIC,
|
||||||
|
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif provider == "anthropic":
|
||||||
|
if settings.ANTHROPIC_API_KEY:
|
||||||
|
return AnthropicProvider(
|
||||||
|
api_key=settings.ANTHROPIC_API_KEY,
|
||||||
|
model=settings.AI_MODEL_ANTHROPIC,
|
||||||
|
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
# Fallback to Gemini
|
||||||
|
if settings.GOOGLE_AI_API_KEY:
|
||||||
|
return GeminiProvider(
|
||||||
|
api_key=settings.GOOGLE_AI_API_KEY,
|
||||||
|
model=settings.AI_MODEL_GEMINI,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
|
||||||
|
)
|
||||||
216
backend/tests/test_ai_provider.py
Normal file
216
backend/tests/test_ai_provider.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""Tests for the AI provider abstraction layer."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from app.core.ai_provider import (
|
||||||
|
AIProvider,
|
||||||
|
AnthropicProvider,
|
||||||
|
GeminiProvider,
|
||||||
|
get_ai_provider,
|
||||||
|
)
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAIProvider:
|
||||||
|
"""Tests for the get_ai_provider factory function."""
|
||||||
|
|
||||||
|
def test_returns_gemini_when_configured(self):
|
||||||
|
original_provider = settings.AI_PROVIDER
|
||||||
|
original_key = settings.GOOGLE_AI_API_KEY
|
||||||
|
try:
|
||||||
|
settings.AI_PROVIDER = "gemini"
|
||||||
|
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
|
||||||
|
provider = get_ai_provider()
|
||||||
|
assert isinstance(provider, GeminiProvider)
|
||||||
|
finally:
|
||||||
|
settings.AI_PROVIDER = original_provider
|
||||||
|
settings.GOOGLE_AI_API_KEY = original_key
|
||||||
|
|
||||||
|
def test_returns_anthropic_when_configured(self):
|
||||||
|
original_provider = settings.AI_PROVIDER
|
||||||
|
original_key = settings.ANTHROPIC_API_KEY
|
||||||
|
try:
|
||||||
|
settings.AI_PROVIDER = "anthropic"
|
||||||
|
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
|
||||||
|
provider = get_ai_provider()
|
||||||
|
assert isinstance(provider, AnthropicProvider)
|
||||||
|
finally:
|
||||||
|
settings.AI_PROVIDER = original_provider
|
||||||
|
settings.ANTHROPIC_API_KEY = original_key
|
||||||
|
|
||||||
|
def test_fallback_to_anthropic_when_gemini_key_missing(self):
|
||||||
|
original_provider = settings.AI_PROVIDER
|
||||||
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||||
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||||
|
try:
|
||||||
|
settings.AI_PROVIDER = "gemini"
|
||||||
|
settings.GOOGLE_AI_API_KEY = None
|
||||||
|
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
|
||||||
|
provider = get_ai_provider()
|
||||||
|
assert isinstance(provider, AnthropicProvider)
|
||||||
|
finally:
|
||||||
|
settings.AI_PROVIDER = original_provider
|
||||||
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||||
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||||
|
|
||||||
|
def test_fallback_to_gemini_when_anthropic_key_missing(self):
|
||||||
|
original_provider = settings.AI_PROVIDER
|
||||||
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||||
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||||
|
try:
|
||||||
|
settings.AI_PROVIDER = "anthropic"
|
||||||
|
settings.ANTHROPIC_API_KEY = None
|
||||||
|
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
|
||||||
|
provider = get_ai_provider()
|
||||||
|
assert isinstance(provider, GeminiProvider)
|
||||||
|
finally:
|
||||||
|
settings.AI_PROVIDER = original_provider
|
||||||
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||||
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||||
|
|
||||||
|
def test_raises_when_no_provider_configured(self):
|
||||||
|
original_provider = settings.AI_PROVIDER
|
||||||
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||||
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||||
|
try:
|
||||||
|
settings.AI_PROVIDER = "gemini"
|
||||||
|
settings.GOOGLE_AI_API_KEY = None
|
||||||
|
settings.ANTHROPIC_API_KEY = None
|
||||||
|
with pytest.raises(RuntimeError, match="No AI provider configured"):
|
||||||
|
get_ai_provider()
|
||||||
|
finally:
|
||||||
|
settings.AI_PROVIDER = original_provider
|
||||||
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||||
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicProvider:
|
||||||
|
"""Tests for AnthropicProvider.generate_json."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_json(self):
|
||||||
|
provider = AnthropicProvider(
|
||||||
|
api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = [MagicMock(text='{"result": "ok"}')]
|
||||||
|
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
||||||
|
text, input_tokens, output_tokens = await provider.generate_json(
|
||||||
|
system_prompt="You are a helper.",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
max_tokens=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert text == '{"result": "ok"}'
|
||||||
|
assert input_tokens == 100
|
||||||
|
assert output_tokens == 50
|
||||||
|
|
||||||
|
mock_client.messages.create.assert_called_once_with(
|
||||||
|
model="claude-haiku-4-5-20251001",
|
||||||
|
max_tokens=1024,
|
||||||
|
system="You are a helper.",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeminiProvider:
|
||||||
|
"""Tests for GeminiProvider.generate_json."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_json(self):
|
||||||
|
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
||||||
|
|
||||||
|
mock_usage = MagicMock()
|
||||||
|
mock_usage.prompt_token_count = 80
|
||||||
|
mock_usage.candidates_token_count = 40
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = '{"answer": 42}'
|
||||||
|
mock_response.usage_metadata = mock_usage
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.models.generate_content_async = AsyncMock(
|
||||||
|
return_value=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_genai_module = MagicMock()
|
||||||
|
mock_genai_module.Client.return_value = mock_client
|
||||||
|
|
||||||
|
mock_types = MagicMock()
|
||||||
|
mock_types.Content.side_effect = lambda **kw: kw
|
||||||
|
mock_types.Part.side_effect = lambda **kw: kw
|
||||||
|
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
||||||
|
|
||||||
|
mock_google = MagicMock()
|
||||||
|
mock_google.genai = mock_genai_module
|
||||||
|
mock_genai_module.types = mock_types
|
||||||
|
|
||||||
|
with patch.dict(sys.modules, {
|
||||||
|
"google": mock_google,
|
||||||
|
"google.genai": mock_genai_module,
|
||||||
|
"google.genai.types": mock_types,
|
||||||
|
}):
|
||||||
|
text, input_tokens, output_tokens = await provider.generate_json(
|
||||||
|
system_prompt="Generate JSON.",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Give me data"},
|
||||||
|
{"role": "assistant", "content": "Here it is"},
|
||||||
|
{"role": "user", "content": "More please"},
|
||||||
|
],
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert text == '{"answer": 42}'
|
||||||
|
assert input_tokens == 80
|
||||||
|
assert output_tokens == 40
|
||||||
|
|
||||||
|
mock_client.models.generate_content_async.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_json_handles_none_usage(self):
|
||||||
|
"""Token counts default to 0 when usage_metadata attributes are None."""
|
||||||
|
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
||||||
|
|
||||||
|
mock_usage = MagicMock(spec=[]) # No attributes at all
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = "{}"
|
||||||
|
mock_response.usage_metadata = mock_usage
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.models.generate_content_async = AsyncMock(
|
||||||
|
return_value=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_genai_module = MagicMock()
|
||||||
|
mock_genai_module.Client.return_value = mock_client
|
||||||
|
|
||||||
|
mock_types = MagicMock()
|
||||||
|
mock_types.Content.side_effect = lambda **kw: kw
|
||||||
|
mock_types.Part.side_effect = lambda **kw: kw
|
||||||
|
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
||||||
|
|
||||||
|
mock_google = MagicMock()
|
||||||
|
mock_google.genai = mock_genai_module
|
||||||
|
mock_genai_module.types = mock_types
|
||||||
|
|
||||||
|
with patch.dict(sys.modules, {
|
||||||
|
"google": mock_google,
|
||||||
|
"google.genai": mock_genai_module,
|
||||||
|
"google.genai.types": mock_types,
|
||||||
|
}):
|
||||||
|
text, input_tokens, output_tokens = await provider.generate_json(
|
||||||
|
system_prompt="test",
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert text == "{}"
|
||||||
|
assert input_tokens == 0
|
||||||
|
assert output_tokens == 0
|
||||||
Reference in New Issue
Block a user