diff --git a/backend/app/core/config.py b/backend/app/core/config.py index a55d2e96..79afe728 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -84,6 +84,27 @@ class Settings(BaseSettings): AI_MODEL_GEMINI: str = "gemini-2.5-flash" AI_MODEL_ANTHROPIC: str = "claude-sonnet-4-6" + # Model tier routing — maps action types to model tiers + AI_MODEL_TIERS: dict[str, str] = { + "fast": "claude-haiku-4-5-20251001", + "standard": "claude-sonnet-4-6-20250514", + } + + ACTION_MODEL_MAP: dict[str, str] = { + "generate_full": "standard", + "generate_branch": "standard", + "modify_node": "fast", + "add_steps": "standard", + "quick_action": "fast", + "open_chat": "standard", + "variable_inference": "fast", + } + + def get_model_for_action(self, action_type: str) -> str: + """Resolve an action type to a concrete model name via tier routing.""" + tier = self.ACTION_MODEL_MAP.get(action_type, "standard") + return self.AI_MODEL_TIERS.get(tier, self.AI_MODEL_TIERS["standard"]) + # MCP (Model Context Protocol) integrations ENABLE_MCP_MICROSOFT_LEARN: bool = True diff --git a/backend/tests/test_config_model_tiers.py b/backend/tests/test_config_model_tiers.py new file mode 100644 index 00000000..ddcf6704 --- /dev/null +++ b/backend/tests/test_config_model_tiers.py @@ -0,0 +1,24 @@ +"""Tests for AI model tier configuration.""" +from app.core.config import settings + + +def test_ai_model_tiers_exist(): + assert "fast" in settings.AI_MODEL_TIERS + assert "standard" in settings.AI_MODEL_TIERS + + +def test_action_model_map_covers_all_actions(): + valid_tiers = set(settings.AI_MODEL_TIERS.keys()) + for action, tier in settings.ACTION_MODEL_MAP.items(): + assert tier in valid_tiers, f"Action '{action}' maps to unknown tier '{tier}'" + + +def test_get_model_for_action(): + model = settings.get_model_for_action("generate_full") + assert isinstance(model, str) + assert len(model) > 0 + + +def test_get_model_for_action_unknown_falls_back(): + model = settings.get_model_for_action("nonexistent_action") + assert model == settings.AI_MODEL_TIERS["standard"]