From 55be033ecb9eb6457d9ea9664335d2af5dd44164 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:16:45 -0500 Subject: [PATCH] feat: add AI provider abstraction with Gemini and Anthropic support Co-Authored-By: Claude Opus 4.6 --- backend/app/core/ai_provider.py | 162 ++++++++++++++++++++++ backend/tests/test_ai_provider.py | 216 ++++++++++++++++++++++++++++++ 2 files changed, 378 insertions(+) create mode 100644 backend/app/core/ai_provider.py create mode 100644 backend/tests/test_ai_provider.py diff --git a/backend/app/core/ai_provider.py b/backend/app/core/ai_provider.py new file mode 100644 index 00000000..b3cf16e4 --- /dev/null +++ b/backend/app/core/ai_provider.py @@ -0,0 +1,162 @@ +""" +AI Provider abstraction layer. + +Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable +backends for JSON generation used by the AI Flow Builder. +""" + +from abc import ABC, abstractmethod + +from app.core.config import settings + + +class AIProvider(ABC): + """Abstract base class for AI providers.""" + + @abstractmethod + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + """Generate a JSON response from the AI model. + + Args: + system_prompt: System-level instruction for the model. + messages: List of message dicts with "role" and "content" keys. + max_tokens: Maximum output tokens. + + Returns: + Tuple of (response_text, input_tokens, output_tokens). + """ + ... + + +class GeminiProvider(AIProvider): + """Google Gemini provider using the google-genai SDK.""" + + def __init__(self, api_key: str, model: str) -> None: + self._api_key = api_key + self._model = model + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + from google import genai + from google.genai import types as genai_types + + client = genai.Client(api_key=self._api_key) + + # Convert messages to Gemini Content format + contents: list[genai_types.Content] = [] + for msg in messages: + role = "model" if msg["role"] == "assistant" else "user" + contents.append( + genai_types.Content( + role=role, + parts=[genai_types.Part(text=msg["content"])], + ) + ) + + config = genai_types.GenerateContentConfig( + system_instruction=system_prompt, + max_output_tokens=max_tokens, + response_mime_type="application/json", + ) + + response = await client.models.generate_content_async( + model=self._model, + contents=contents, + config=config, + ) + + text = response.text or "" + input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 + output_tokens = ( + getattr(response.usage_metadata, "candidates_token_count", 0) or 0 + ) + + return text, input_tokens, output_tokens + + +class AnthropicProvider(AIProvider): + """Anthropic Claude provider using the anthropic SDK.""" + + def __init__(self, api_key: str, model: str, timeout: int = 45) -> None: + self._api_key = api_key + self._model = model + self._timeout = timeout + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + import anthropic + + client = anthropic.AsyncAnthropic( + api_key=self._api_key, + timeout=self._timeout, + ) + + response = await client.messages.create( + model=self._model, + max_tokens=max_tokens, + system=system_prompt, + messages=messages, + ) + + text = response.content[0].text + input_tokens = response.usage.input_tokens + output_tokens = response.usage.output_tokens + + return text, input_tokens, output_tokens + + +def get_ai_provider() -> AIProvider: + """Factory that returns the configured AI provider. + + Selection logic: + 1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider + 2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider + 3. Fallback: if preferred provider key missing, try the other one + 4. If nothing configured -> raise RuntimeError + """ + provider = settings.AI_PROVIDER + + if provider == "gemini": + if settings.GOOGLE_AI_API_KEY: + return GeminiProvider( + api_key=settings.GOOGLE_AI_API_KEY, + model=settings.AI_MODEL_GEMINI, + ) + # Fallback to Anthropic + if settings.ANTHROPIC_API_KEY: + return AnthropicProvider( + api_key=settings.ANTHROPIC_API_KEY, + model=settings.AI_MODEL_ANTHROPIC, + timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, + ) + + elif provider == "anthropic": + if settings.ANTHROPIC_API_KEY: + return AnthropicProvider( + api_key=settings.ANTHROPIC_API_KEY, + model=settings.AI_MODEL_ANTHROPIC, + timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, + ) + # Fallback to Gemini + if settings.GOOGLE_AI_API_KEY: + return GeminiProvider( + api_key=settings.GOOGLE_AI_API_KEY, + model=settings.AI_MODEL_GEMINI, + ) + + raise RuntimeError( + "No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY." + ) diff --git a/backend/tests/test_ai_provider.py b/backend/tests/test_ai_provider.py new file mode 100644 index 00000000..a263d5e3 --- /dev/null +++ b/backend/tests/test_ai_provider.py @@ -0,0 +1,216 @@ +"""Tests for the AI provider abstraction layer.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import sys + +from app.core.ai_provider import ( + AIProvider, + AnthropicProvider, + GeminiProvider, + get_ai_provider, +) +from app.core.config import settings + + +class TestGetAIProvider: + """Tests for the get_ai_provider factory function.""" + + def test_returns_gemini_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.GOOGLE_AI_API_KEY + try: + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = "test-gemini-key" + provider = get_ai_provider() + assert isinstance(provider, GeminiProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_key + + def test_returns_anthropic_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "anthropic" + settings.ANTHROPIC_API_KEY = "test-anthropic-key" + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.ANTHROPIC_API_KEY = original_key + + def test_fallback_to_anthropic_when_gemini_key_missing(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = "test-anthropic-key" + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + def test_fallback_to_gemini_when_anthropic_key_missing(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "anthropic" + settings.ANTHROPIC_API_KEY = None + settings.GOOGLE_AI_API_KEY = "test-gemini-key" + provider = get_ai_provider() + assert isinstance(provider, GeminiProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + def test_raises_when_no_provider_configured(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + with pytest.raises(RuntimeError, match="No AI provider configured"): + get_ai_provider() + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + +class TestAnthropicProvider: + """Tests for AnthropicProvider.generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json(self): + provider = AnthropicProvider( + api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30 + ) + + mock_response = MagicMock() + mock_response.content = [MagicMock(text='{"result": "ok"}')] + mock_response.usage = MagicMock(input_tokens=100, output_tokens=50) + + mock_client = AsyncMock() + mock_client.messages.create = AsyncMock(return_value=mock_response) + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + text, input_tokens, output_tokens = await provider.generate_json( + system_prompt="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=1024, + ) + + assert text == '{"result": "ok"}' + assert input_tokens == 100 + assert output_tokens == 50 + + mock_client.messages.create.assert_called_once_with( + model="claude-haiku-4-5-20251001", + max_tokens=1024, + system="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + ) + + +class TestGeminiProvider: + """Tests for GeminiProvider.generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json(self): + provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") + + mock_usage = MagicMock() + mock_usage.prompt_token_count = 80 + mock_usage.candidates_token_count = 40 + + mock_response = MagicMock() + mock_response.text = '{"answer": 42}' + mock_response.usage_metadata = mock_usage + + mock_client = MagicMock() + mock_client.models.generate_content_async = AsyncMock( + return_value=mock_response + ) + + mock_genai_module = MagicMock() + mock_genai_module.Client.return_value = mock_client + + mock_types = MagicMock() + mock_types.Content.side_effect = lambda **kw: kw + mock_types.Part.side_effect = lambda **kw: kw + mock_types.GenerateContentConfig.side_effect = lambda **kw: kw + + mock_google = MagicMock() + mock_google.genai = mock_genai_module + mock_genai_module.types = mock_types + + with patch.dict(sys.modules, { + "google": mock_google, + "google.genai": mock_genai_module, + "google.genai.types": mock_types, + }): + text, input_tokens, output_tokens = await provider.generate_json( + system_prompt="Generate JSON.", + messages=[ + {"role": "user", "content": "Give me data"}, + {"role": "assistant", "content": "Here it is"}, + {"role": "user", "content": "More please"}, + ], + max_tokens=2048, + ) + + assert text == '{"answer": 42}' + assert input_tokens == 80 + assert output_tokens == 40 + + mock_client.models.generate_content_async.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_json_handles_none_usage(self): + """Token counts default to 0 when usage_metadata attributes are None.""" + provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") + + mock_usage = MagicMock(spec=[]) # No attributes at all + mock_response = MagicMock() + mock_response.text = "{}" + mock_response.usage_metadata = mock_usage + + mock_client = MagicMock() + mock_client.models.generate_content_async = AsyncMock( + return_value=mock_response + ) + + mock_genai_module = MagicMock() + mock_genai_module.Client.return_value = mock_client + + mock_types = MagicMock() + mock_types.Content.side_effect = lambda **kw: kw + mock_types.Part.side_effect = lambda **kw: kw + mock_types.GenerateContentConfig.side_effect = lambda **kw: kw + + mock_google = MagicMock() + mock_google.genai = mock_genai_module + mock_genai_module.types = mock_types + + with patch.dict(sys.modules, { + "google": mock_google, + "google.genai": mock_genai_module, + "google.genai.types": mock_types, + }): + text, input_tokens, output_tokens = await provider.generate_json( + system_prompt="test", + messages=[{"role": "user", "content": "test"}], + ) + + assert text == "{}" + assert input_tokens == 0 + assert output_tokens == 0