feat: add AI provider abstraction with Gemini and Anthropic support
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
162
backend/app/core/ai_provider.py
Normal file
162
backend/app/core/ai_provider.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
AI Provider abstraction layer.
|
||||
|
||||
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
|
||||
backends for JSON generation used by the AI Flow Builder.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_json(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
"""Generate a JSON response from the AI model.
|
||||
|
||||
Args:
|
||||
system_prompt: System-level instruction for the model.
|
||||
messages: List of message dicts with "role" and "content" keys.
|
||||
max_tokens: Maximum output tokens.
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, input_tokens, output_tokens).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class GeminiProvider(AIProvider):
|
||||
"""Google Gemini provider using the google-genai SDK."""
|
||||
|
||||
def __init__(self, api_key: str, model: str) -> None:
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
|
||||
async def generate_json(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
from google import genai
|
||||
from google.genai import types as genai_types
|
||||
|
||||
client = genai.Client(api_key=self._api_key)
|
||||
|
||||
# Convert messages to Gemini Content format
|
||||
contents: list[genai_types.Content] = []
|
||||
for msg in messages:
|
||||
role = "model" if msg["role"] == "assistant" else "user"
|
||||
contents.append(
|
||||
genai_types.Content(
|
||||
role=role,
|
||||
parts=[genai_types.Part(text=msg["content"])],
|
||||
)
|
||||
)
|
||||
|
||||
config = genai_types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
max_output_tokens=max_tokens,
|
||||
response_mime_type="application/json",
|
||||
)
|
||||
|
||||
response = await client.models.generate_content_async(
|
||||
model=self._model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
)
|
||||
|
||||
text = response.text or ""
|
||||
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
||||
output_tokens = (
|
||||
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
||||
)
|
||||
|
||||
return text, input_tokens, output_tokens
|
||||
|
||||
|
||||
class AnthropicProvider(AIProvider):
|
||||
"""Anthropic Claude provider using the anthropic SDK."""
|
||||
|
||||
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
self._timeout = timeout
|
||||
|
||||
async def generate_json(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
import anthropic
|
||||
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=self._api_key,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
|
||||
response = await client.messages.create(
|
||||
model=self._model,
|
||||
max_tokens=max_tokens,
|
||||
system=system_prompt,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
text = response.content[0].text
|
||||
input_tokens = response.usage.input_tokens
|
||||
output_tokens = response.usage.output_tokens
|
||||
|
||||
return text, input_tokens, output_tokens
|
||||
|
||||
|
||||
def get_ai_provider() -> AIProvider:
|
||||
"""Factory that returns the configured AI provider.
|
||||
|
||||
Selection logic:
|
||||
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
|
||||
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
|
||||
3. Fallback: if preferred provider key missing, try the other one
|
||||
4. If nothing configured -> raise RuntimeError
|
||||
"""
|
||||
provider = settings.AI_PROVIDER
|
||||
|
||||
if provider == "gemini":
|
||||
if settings.GOOGLE_AI_API_KEY:
|
||||
return GeminiProvider(
|
||||
api_key=settings.GOOGLE_AI_API_KEY,
|
||||
model=settings.AI_MODEL_GEMINI,
|
||||
)
|
||||
# Fallback to Anthropic
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
return AnthropicProvider(
|
||||
api_key=settings.ANTHROPIC_API_KEY,
|
||||
model=settings.AI_MODEL_ANTHROPIC,
|
||||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
elif provider == "anthropic":
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
return AnthropicProvider(
|
||||
api_key=settings.ANTHROPIC_API_KEY,
|
||||
model=settings.AI_MODEL_ANTHROPIC,
|
||||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
# Fallback to Gemini
|
||||
if settings.GOOGLE_AI_API_KEY:
|
||||
return GeminiProvider(
|
||||
api_key=settings.GOOGLE_AI_API_KEY,
|
||||
model=settings.AI_MODEL_GEMINI,
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
|
||||
)
|
||||
216
backend/tests/test_ai_provider.py
Normal file
216
backend/tests/test_ai_provider.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Tests for the AI provider abstraction layer."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import sys
|
||||
|
||||
from app.core.ai_provider import (
|
||||
AIProvider,
|
||||
AnthropicProvider,
|
||||
GeminiProvider,
|
||||
get_ai_provider,
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestGetAIProvider:
|
||||
"""Tests for the get_ai_provider factory function."""
|
||||
|
||||
def test_returns_gemini_when_configured(self):
|
||||
original_provider = settings.AI_PROVIDER
|
||||
original_key = settings.GOOGLE_AI_API_KEY
|
||||
try:
|
||||
settings.AI_PROVIDER = "gemini"
|
||||
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
|
||||
provider = get_ai_provider()
|
||||
assert isinstance(provider, GeminiProvider)
|
||||
finally:
|
||||
settings.AI_PROVIDER = original_provider
|
||||
settings.GOOGLE_AI_API_KEY = original_key
|
||||
|
||||
def test_returns_anthropic_when_configured(self):
|
||||
original_provider = settings.AI_PROVIDER
|
||||
original_key = settings.ANTHROPIC_API_KEY
|
||||
try:
|
||||
settings.AI_PROVIDER = "anthropic"
|
||||
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
|
||||
provider = get_ai_provider()
|
||||
assert isinstance(provider, AnthropicProvider)
|
||||
finally:
|
||||
settings.AI_PROVIDER = original_provider
|
||||
settings.ANTHROPIC_API_KEY = original_key
|
||||
|
||||
def test_fallback_to_anthropic_when_gemini_key_missing(self):
|
||||
original_provider = settings.AI_PROVIDER
|
||||
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||
try:
|
||||
settings.AI_PROVIDER = "gemini"
|
||||
settings.GOOGLE_AI_API_KEY = None
|
||||
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
|
||||
provider = get_ai_provider()
|
||||
assert isinstance(provider, AnthropicProvider)
|
||||
finally:
|
||||
settings.AI_PROVIDER = original_provider
|
||||
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||
|
||||
def test_fallback_to_gemini_when_anthropic_key_missing(self):
|
||||
original_provider = settings.AI_PROVIDER
|
||||
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||
try:
|
||||
settings.AI_PROVIDER = "anthropic"
|
||||
settings.ANTHROPIC_API_KEY = None
|
||||
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
|
||||
provider = get_ai_provider()
|
||||
assert isinstance(provider, GeminiProvider)
|
||||
finally:
|
||||
settings.AI_PROVIDER = original_provider
|
||||
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||
|
||||
def test_raises_when_no_provider_configured(self):
|
||||
original_provider = settings.AI_PROVIDER
|
||||
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||
try:
|
||||
settings.AI_PROVIDER = "gemini"
|
||||
settings.GOOGLE_AI_API_KEY = None
|
||||
settings.ANTHROPIC_API_KEY = None
|
||||
with pytest.raises(RuntimeError, match="No AI provider configured"):
|
||||
get_ai_provider()
|
||||
finally:
|
||||
settings.AI_PROVIDER = original_provider
|
||||
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||
|
||||
|
||||
class TestAnthropicProvider:
|
||||
"""Tests for AnthropicProvider.generate_json."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_json(self):
|
||||
provider = AnthropicProvider(
|
||||
api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30
|
||||
)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text='{"result": "ok"}')]
|
||||
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
||||
text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt="You are a helper.",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
assert text == '{"result": "ok"}'
|
||||
assert input_tokens == 100
|
||||
assert output_tokens == 50
|
||||
|
||||
mock_client.messages.create.assert_called_once_with(
|
||||
model="claude-haiku-4-5-20251001",
|
||||
max_tokens=1024,
|
||||
system="You are a helper.",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
|
||||
class TestGeminiProvider:
|
||||
"""Tests for GeminiProvider.generate_json."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_json(self):
|
||||
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
||||
|
||||
mock_usage = MagicMock()
|
||||
mock_usage.prompt_token_count = 80
|
||||
mock_usage.candidates_token_count = 40
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"answer": 42}'
|
||||
mock_response.usage_metadata = mock_usage
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content_async = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
mock_genai_module = MagicMock()
|
||||
mock_genai_module.Client.return_value = mock_client
|
||||
|
||||
mock_types = MagicMock()
|
||||
mock_types.Content.side_effect = lambda **kw: kw
|
||||
mock_types.Part.side_effect = lambda **kw: kw
|
||||
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
||||
|
||||
mock_google = MagicMock()
|
||||
mock_google.genai = mock_genai_module
|
||||
mock_genai_module.types = mock_types
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"google": mock_google,
|
||||
"google.genai": mock_genai_module,
|
||||
"google.genai.types": mock_types,
|
||||
}):
|
||||
text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt="Generate JSON.",
|
||||
messages=[
|
||||
{"role": "user", "content": "Give me data"},
|
||||
{"role": "assistant", "content": "Here it is"},
|
||||
{"role": "user", "content": "More please"},
|
||||
],
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
assert text == '{"answer": 42}'
|
||||
assert input_tokens == 80
|
||||
assert output_tokens == 40
|
||||
|
||||
mock_client.models.generate_content_async.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_json_handles_none_usage(self):
|
||||
"""Token counts default to 0 when usage_metadata attributes are None."""
|
||||
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
||||
|
||||
mock_usage = MagicMock(spec=[]) # No attributes at all
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "{}"
|
||||
mock_response.usage_metadata = mock_usage
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content_async = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
mock_genai_module = MagicMock()
|
||||
mock_genai_module.Client.return_value = mock_client
|
||||
|
||||
mock_types = MagicMock()
|
||||
mock_types.Content.side_effect = lambda **kw: kw
|
||||
mock_types.Part.side_effect = lambda **kw: kw
|
||||
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
||||
|
||||
mock_google = MagicMock()
|
||||
mock_google.genai = mock_genai_module
|
||||
mock_genai_module.types = mock_types
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"google": mock_google,
|
||||
"google.genai": mock_genai_module,
|
||||
"google.genai.types": mock_types,
|
||||
}):
|
||||
text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt="test",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
)
|
||||
|
||||
assert text == "{}"
|
||||
assert input_tokens == 0
|
||||
assert output_tokens == 0
|
||||
Reference in New Issue
Block a user