diff --git a/backend/app/core/ai_chat_service.py b/backend/app/core/ai_chat_service.py index 4202ac31..b87a3698 100644 --- a/backend/app/core/ai_chat_service.py +++ b/backend/app/core/ai_chat_service.py @@ -242,7 +242,7 @@ async def start_chat_session( provider_name = settings.AI_PROVIDER messages = [{"role": "user", "content": primer}] - response_text, input_tokens, output_tokens = await provider.generate_json( + response_text, input_tokens, output_tokens = await provider.generate_text( system_prompt=system_prompt, messages=messages, max_tokens=1500, @@ -291,7 +291,7 @@ async def send_message( ] provider = get_ai_provider() - response_text, input_tokens, output_tokens = await provider.generate_json( + response_text, input_tokens, output_tokens = await provider.generate_text( system_prompt=system_prompt, messages=provider_messages, max_tokens=2000, @@ -371,7 +371,7 @@ Also provide metadata as a separate JSON object after the tree: provider = get_ai_provider() for attempt in range(2): # One try + one retry - response_text, input_tokens, output_tokens = await provider.generate_json( + response_text, input_tokens, output_tokens = await provider.generate_text( system_prompt=system_prompt, messages=provider_messages, max_tokens=8000, diff --git a/backend/app/core/ai_provider.py b/backend/app/core/ai_provider.py index cb3f7178..993012c6 100644 --- a/backend/app/core/ai_provider.py +++ b/backend/app/core/ai_provider.py @@ -35,6 +35,25 @@ class AIProvider(ABC): """ ... + @abstractmethod + async def generate_text( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + """Generate a text response from the AI model (no JSON constraint). + + Args: + system_prompt: System-level instruction for the model. + messages: List of message dicts with "role" and "content" keys. + max_tokens: Maximum output tokens. + + Returns: + Tuple of (response_text, input_tokens, output_tokens). + """ + ... + class GeminiProvider(AIProvider): """Google Gemini provider using the google-genai SDK.""" @@ -95,6 +114,56 @@ class GeminiProvider(AIProvider): return text, input_tokens, output_tokens + async def generate_text( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + from google import genai + from google.genai import types as genai_types + + client = genai.Client(api_key=self._api_key) + + contents: list[genai_types.Content] = [] + for msg in messages: + role = "model" if msg["role"] == "assistant" else "user" + contents.append( + genai_types.Content( + role=role, + parts=[genai_types.Part(text=msg["content"])], + ) + ) + + config = genai_types.GenerateContentConfig( + system_instruction=system_prompt, + max_output_tokens=max_tokens, + # No response_mime_type — allow free-form text + ) + + response = await client.aio.models.generate_content( + model=self._model, + contents=contents, + config=config, + ) + + if response.candidates: + finish_reason = getattr(response.candidates[0], "finish_reason", None) + logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model) + if str(finish_reason) == "MAX_TOKENS": + logger.warning( + "Gemini output truncated (MAX_TOKENS). max_output_tokens=%d", + max_tokens, + ) + + text = response.text or "" + input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 + output_tokens = ( + getattr(response.usage_metadata, "candidates_token_count", 0) or 0 + ) + + return text, input_tokens, output_tokens + class AnthropicProvider(AIProvider): """Anthropic Claude provider using the anthropic SDK.""" @@ -130,6 +199,15 @@ class AnthropicProvider(AIProvider): return text, input_tokens, output_tokens + async def generate_text( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + # Anthropic doesn't differentiate between JSON and text mode + return await self.generate_json(system_prompt, messages, max_tokens) + def get_ai_provider() -> AIProvider: """Factory that returns the configured AI provider.