diff --git a/backend/app/core/ai_chat_service.py b/backend/app/core/ai_chat_service.py index 09ae57ee..7c28382f 100644 --- a/backend/app/core/ai_chat_service.py +++ b/backend/app/core/ai_chat_service.py @@ -465,7 +465,9 @@ async def send_message( for msg in history ] - provider = get_ai_provider() + # Resolve model for this action type + action_model = settings.get_model_for_action(action_type) + provider = get_ai_provider(model=action_model) response_text, input_tokens, output_tokens = await provider.generate_text( system_prompt=system_prompt, messages=provider_messages, @@ -584,7 +586,7 @@ Also provide metadata as a separate JSON object after the tree: provider_messages.append({"role": "user", "content": generation_instruction}) - provider = get_ai_provider() + provider = get_ai_provider(model=settings.get_model_for_action("generate_full")) for attempt in range(2): # One try + one retry response_text, input_tokens, output_tokens = await provider.generate_text( diff --git a/backend/app/core/ai_provider.py b/backend/app/core/ai_provider.py index 993012c6..10939843 100644 --- a/backend/app/core/ai_provider.py +++ b/backend/app/core/ai_provider.py @@ -209,9 +209,13 @@ class AnthropicProvider(AIProvider): return await self.generate_json(system_prompt, messages, max_tokens) -def get_ai_provider() -> AIProvider: +def get_ai_provider(model: str | None = None) -> AIProvider: """Factory that returns the configured AI provider. + Args: + model: Optional model override (Anthropic model ID). Only applied to + AnthropicProvider; Gemini always uses settings.AI_MODEL_GEMINI. + Selection logic: 1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider 2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider @@ -230,7 +234,7 @@ def get_ai_provider() -> AIProvider: if settings.ANTHROPIC_API_KEY: return AnthropicProvider( api_key=settings.ANTHROPIC_API_KEY, - model=settings.AI_MODEL_ANTHROPIC, + model=model or settings.AI_MODEL_ANTHROPIC, timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, ) @@ -238,7 +242,7 @@ def get_ai_provider() -> AIProvider: if settings.ANTHROPIC_API_KEY: return AnthropicProvider( api_key=settings.ANTHROPIC_API_KEY, - model=settings.AI_MODEL_ANTHROPIC, + model=model or settings.AI_MODEL_ANTHROPIC, timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, ) # Fallback to Gemini