diff --git a/docs/plans/2026-02-26-ai-autofix-gemini-plan.md b/docs/plans/2026-02-26-ai-autofix-gemini-plan.md new file mode 100644 index 00000000..4406024a --- /dev/null +++ b/docs/plans/2026-02-26-ai-autofix-gemini-plan.md @@ -0,0 +1,1707 @@ +# AI Auto-Fix & Gemini Flash Provider Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add Gemini 2.5 Flash as primary AI provider (with Claude fallback), then build an AI-powered "Fix with AI" feature that generates structural fixes for validation errors in the tree editor. + +**Architecture:** A provider abstraction layer (`ai_provider.py`) wraps Gemini and Anthropic SDKs behind a unified `generate_json()` interface. The existing `ai_tree_generator_service.py` swaps its direct Anthropic calls for this abstraction. A new `ai_fix_service.py` builds prompts from validation errors + tree context and returns proposed node patches. The frontend adds a "Fix with AI" button to `ValidationSummary` and a review modal for applying fixes. + +**Tech Stack:** Python FastAPI, google-genai SDK, anthropic SDK, Pydantic v2, React 19, TypeScript, Zustand, Tailwind CSS + +**Design Doc:** `docs/plans/2026-02-26-ai-autofix-gemini-design.md` + +--- + +## Task 1: Install google-genai SDK + +**Files:** +- Modify: `backend/requirements.txt` + +**Step 1: Add the dependency** + +Add `google-genai` to `backend/requirements.txt`: + +``` +google-genai>=1.0.0 +``` + +**Step 2: Install it** + +Run: `cd backend && pip install google-genai` + +**Step 3: Commit** + +```bash +git add backend/requirements.txt +git commit -m "chore: add google-genai SDK dependency" +``` + +--- + +## Task 2: Add Gemini config vars to Settings + +**Files:** +- Modify: `backend/app/core/config.py:75-85` + +**Step 1: Add new config variables** + +In `backend/app/core/config.py`, after line 80 (`AI_REQUEST_TIMEOUT_SECONDS`), add: + +```python + # AI Provider selection + AI_PROVIDER: str = "gemini" # "gemini" or "anthropic" + GOOGLE_AI_API_KEY: Optional[str] = None + AI_MODEL_GEMINI: str = "gemini-2.5-flash" + AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001" +``` + +**Step 2: Update `ai_enabled` property** + +Replace the existing `ai_enabled` property (lines 82-85) with: + +```python + @property + def ai_enabled(self) -> bool: + """Check if any AI provider is configured.""" + return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None +``` + +**Step 3: Verify no tests break** + +Run: `cd backend && python -m pytest tests/test_ai_tree_validator.py -v` +Expected: All pass (config changes don't affect validator tests). + +**Step 4: Commit** + +```bash +git add backend/app/core/config.py +git commit -m "feat: add Gemini Flash config vars to Settings" +``` + +--- + +## Task 3: Build the AI provider abstraction + +**Files:** +- Create: `backend/app/core/ai_provider.py` +- Test: `backend/tests/test_ai_provider.py` + +**Step 1: Write tests for the provider abstraction** + +Create `backend/tests/test_ai_provider.py`: + +```python +"""Tests for AI provider abstraction layer.""" +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from app.core.config import settings + + +class TestGetAIProvider: + """Test provider factory function.""" + + def test_returns_gemini_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.GOOGLE_AI_API_KEY + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = "fake-gemini-key" + try: + from app.core.ai_provider import get_ai_provider, GeminiProvider + provider = get_ai_provider() + assert isinstance(provider, GeminiProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_key + + def test_returns_anthropic_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.ANTHROPIC_API_KEY + settings.AI_PROVIDER = "anthropic" + settings.ANTHROPIC_API_KEY = "fake-anthropic-key" + try: + from app.core.ai_provider import get_ai_provider, AnthropicProvider + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.ANTHROPIC_API_KEY = original_key + + def test_falls_back_to_anthropic_when_gemini_key_missing(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = "fake-anthropic-key" + try: + from app.core.ai_provider import get_ai_provider, AnthropicProvider + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + def test_raises_when_no_provider_configured(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + try: + from app.core.ai_provider import get_ai_provider + with pytest.raises(RuntimeError, match="No AI provider configured"): + get_ai_provider() + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + +class TestAnthropicProvider: + """Test Anthropic provider generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json_returns_text_and_tokens(self): + from app.core.ai_provider import AnthropicProvider + + mock_response = MagicMock() + mock_response.content = [MagicMock(text='{"key": "value"}')] + mock_response.usage = MagicMock(input_tokens=100, output_tokens=50) + + mock_client = AsyncMock() + mock_client.messages.create = AsyncMock(return_value=mock_response) + + with patch("app.core.ai_provider.anthropic.AsyncAnthropic", return_value=mock_client): + provider = AnthropicProvider(api_key="fake-key") + text, inp, out = await provider.generate_json( + system_prompt="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=1024, + ) + assert text == '{"key": "value"}' + assert inp == 100 + assert out == 50 + + +class TestGeminiProvider: + """Test Gemini provider generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json_returns_text_and_tokens(self): + from app.core.ai_provider import GeminiProvider + + mock_response = MagicMock() + mock_response.text = '{"key": "value"}' + mock_response.usage_metadata = MagicMock( + prompt_token_count=100, + candidates_token_count=50, + ) + + mock_client = MagicMock() + mock_model = MagicMock() + mock_model.generate_content_async = AsyncMock(return_value=mock_response) + mock_client.models = mock_model + + with patch("app.core.ai_provider.genai.Client", return_value=mock_client): + provider = GeminiProvider(api_key="fake-key") + text, inp, out = await provider.generate_json( + system_prompt="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=1024, + ) + assert text == '{"key": "value"}' + assert inp == 100 + assert out == 50 +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && python -m pytest tests/test_ai_provider.py -v` +Expected: ImportError — `ai_provider` module doesn't exist yet. + +**Step 3: Implement the provider abstraction** + +Create `backend/app/core/ai_provider.py`: + +```python +"""AI provider abstraction layer. + +Supports Gemini (default) and Anthropic (fallback) behind a unified interface. +""" +import logging +from abc import ABC, abstractmethod +from typing import Any + +import anthropic +from google import genai +from google.genai import types as genai_types + +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +class AIProvider(ABC): + """Base class for AI providers.""" + + @abstractmethod + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + """Generate a JSON response from the AI model. + + Args: + system_prompt: System instructions for the model. + messages: List of {"role": "user"|"assistant", "content": str} dicts. + max_tokens: Maximum output tokens. + + Returns: + (response_text, input_tokens, output_tokens) + """ + ... + + +class GeminiProvider(AIProvider): + """Google Gemini provider using google-genai SDK.""" + + def __init__(self, api_key: str, model: str | None = None): + self._api_key = api_key + self._model = model or settings.AI_MODEL_GEMINI + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + client = genai.Client(api_key=self._api_key) + + # Build contents: system instruction is separate in Gemini API + contents: list[genai_types.Content] = [] + for msg in messages: + role = "user" if msg["role"] == "user" else "model" + contents.append(genai_types.Content( + role=role, + parts=[genai_types.Part(text=msg["content"])], + )) + + config = genai_types.GenerateContentConfig( + system_instruction=system_prompt, + max_output_tokens=max_tokens, + response_mime_type="application/json", + ) + + response = await client.models.generate_content_async( + model=self._model, + contents=contents, + config=config, + ) + + text = response.text or "" + input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 + output_tokens = getattr(response.usage_metadata, "candidates_token_count", 0) or 0 + + return text, input_tokens, output_tokens + + +class AnthropicProvider(AIProvider): + """Anthropic Claude provider.""" + + def __init__(self, api_key: str, model: str | None = None): + self._api_key = api_key + self._model = model or settings.AI_MODEL_ANTHROPIC + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + client = anthropic.AsyncAnthropic( + api_key=self._api_key, + timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, + ) + + response = await client.messages.create( + model=self._model, + max_tokens=max_tokens, + system=system_prompt, + messages=messages, + ) + + text = response.content[0].text + input_tokens = response.usage.input_tokens + output_tokens = response.usage.output_tokens + + return text, input_tokens, output_tokens + + +def get_ai_provider() -> AIProvider: + """Factory: return the configured AI provider. + + Falls back to Anthropic if Gemini key is missing. + Raises RuntimeError if no provider is configured. + """ + if settings.AI_PROVIDER == "gemini" and settings.GOOGLE_AI_API_KEY: + logger.info("Using Gemini provider (%s)", settings.AI_MODEL_GEMINI) + return GeminiProvider(api_key=settings.GOOGLE_AI_API_KEY) + + if settings.AI_PROVIDER == "anthropic" and settings.ANTHROPIC_API_KEY: + logger.info("Using Anthropic provider (%s)", settings.AI_MODEL_ANTHROPIC) + return AnthropicProvider(api_key=settings.ANTHROPIC_API_KEY) + + # Fallback: if Gemini requested but key missing, try Anthropic + if settings.ANTHROPIC_API_KEY: + logger.warning("Gemini key missing, falling back to Anthropic provider") + return AnthropicProvider(api_key=settings.ANTHROPIC_API_KEY) + + raise RuntimeError( + "No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY." + ) +``` + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && python -m pytest tests/test_ai_provider.py -v` +Expected: All pass. + +**Step 5: Commit** + +```bash +git add backend/app/core/ai_provider.py backend/tests/test_ai_provider.py +git commit -m "feat: add AI provider abstraction with Gemini and Anthropic support" +``` + +--- + +## Task 4: Migrate ai_tree_generator_service to use provider abstraction + +**Files:** +- Modify: `backend/app/core/ai_tree_generator_service.py` +- Modify: `backend/app/api/endpoints/ai_builder.py` + +**Step 1: Update ai_tree_generator_service.py** + +In `backend/app/core/ai_tree_generator_service.py`: + +Replace the import at line 16: +```python +# OLD +import anthropic +# NEW +from app.core.ai_provider import get_ai_provider +``` + +Remove the `_get_client()` function (lines 124-131). + +Update `scaffold_branches()` (starting at line 141). Replace the client creation and API call: + +```python +async def scaffold_branches( + wizard_state: dict[str, Any], +) -> tuple[list[dict[str, str]], int, int, float]: + """Stage 2: AI suggests top-level branches.""" + provider = get_ai_provider() + + flow_type = wizard_state.get("flow_type", "troubleshooting") + name = wizard_state.get("name", "") + description = wizard_state.get("description", "") + tags = wizard_state.get("environment_tags", []) + + user_message = ( + f"Flow type: {flow_type}\n" + f"Name: {name}\n" + f"Description: {description}\n" + ) + if tags: + user_message += f"Environment: {', '.join(tags)}\n" + + raw_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=SCAFFOLD_SYSTEM_PROMPT, + messages=[{"role": "user", "content": user_message}], + max_tokens=1024, + ) +``` + +Then update the rest of the function to use `raw_text` instead of `response.content[0].text`, and `input_tokens`/`output_tokens` directly instead of `response.usage.*`. + +Do the same for `generate_branch_detail()` — replace `_get_client()` + `client.messages.create()` with `provider.generate_json()`. The retry loop structure stays the same; just swap the API call and response parsing. + +**Step 2: Update ai_builder.py endpoint** + +In `backend/app/api/endpoints/ai_builder.py`, line 13: + +```python +# OLD +import anthropic +# NEW (remove this import entirely — no longer needed in the endpoint file) +``` + +Update the `_require_ai_enabled()` function (line 50) to check the new config: + +```python +def _require_ai_enabled() -> None: + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI Flow Builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) +``` + +Update any `except anthropic.APIError` catch blocks in the endpoint to catch generic `Exception` or a broader error type, since the provider abstraction may raise different errors depending on backend. + +**Step 3: Run existing AI endpoint tests** + +Run: `cd backend && python -m pytest tests/test_ai_endpoints.py -v` + +These tests mock `anthropic.AsyncAnthropic` — they need to be updated to mock `app.core.ai_provider.get_ai_provider` instead. Update the mock targets in the test file: + +```python +# Replace patches like: +# @patch("app.core.ai_tree_generator_service._get_client") +# With: +# @patch("app.core.ai_tree_generator_service.get_ai_provider") +``` + +The mock provider should return an `AsyncMock` that returns `(json_text, input_tokens, output_tokens)` from `generate_json`. + +**Step 4: Verify all tests pass** + +Run: `cd backend && python -m pytest tests/test_ai_endpoints.py tests/test_ai_provider.py tests/test_ai_tree_validator.py -v` +Expected: All pass. + +**Step 5: Commit** + +```bash +git add backend/app/core/ai_tree_generator_service.py backend/app/api/endpoints/ai_builder.py backend/tests/test_ai_endpoints.py +git commit -m "refactor: migrate AI tree generator to provider abstraction" +``` + +--- + +## Task 5: Create AI fix schemas + +**Files:** +- Create: `backend/app/schemas/ai_fix.py` + +**Step 1: Create the schemas** + +Create `backend/app/schemas/ai_fix.py`: + +```python +"""Pydantic schemas for the AI auto-fix feature.""" +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class ValidationErrorInput(BaseModel): + """A single validation error to fix.""" + + node_id: str = Field(..., description="ID of the node with the error") + message: str = Field(..., description="The validation error message") + + +class AIFixTreeRequest(BaseModel): + """Request to generate AI fixes for validation errors.""" + + tree_structure: dict[str, Any] = Field(..., description="Full tree structure") + tree_name: str = Field("", max_length=255, description="Name of the flow") + tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field( + "troubleshooting", description="Type of flow" + ) + validation_errors: list[ValidationErrorInput] = Field( + ..., min_length=1, max_length=10, description="Errors to fix" + ) + + +class AIFixProposal(BaseModel): + """A single proposed fix from the AI.""" + + target_node_id: str + error_message: str + description: str + original_node: dict[str, Any] + fixed_node: dict[str, Any] + + +class AIFixTokenUsage(BaseModel): + input: int = 0 + output: int = 0 + + +class AIFixTreeResponse(BaseModel): + """Response with proposed fixes.""" + + fixes: list[AIFixProposal] + tokens_used: AIFixTokenUsage +``` + +**Step 2: Commit** + +```bash +git add backend/app/schemas/ai_fix.py +git commit -m "feat: add Pydantic schemas for AI fix-tree endpoint" +``` + +--- + +## Task 6: Build the AI fix service + +**Files:** +- Create: `backend/app/core/ai_fix_service.py` +- Test: `backend/tests/test_ai_fix_service.py` + +**Step 1: Write tests** + +Create `backend/tests/test_ai_fix_service.py`: + +```python +"""Tests for AI fix service — prompt building and node extraction.""" +import json +import pytest +from app.core.ai_fix_service import _serialize_tree_outline, _find_node_by_id, _find_parent_node + + +def _make_tree(): + return { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review event logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs Resolved", + "description": "Found issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart fix it?", + "options": [ + {"id": "opt-r-yes", "label": "Yes", "next_node_id": "restart-ok"}, + ], + "children": [ + { + "id": "restart-ok", + "type": "solution", + "title": "Restart Worked", + "description": "Server is back.", + }, + ], + }, + ], + } + + +class TestFindNodeById: + def test_finds_root(self): + tree = _make_tree() + node = _find_node_by_id(tree, "root") + assert node is not None + assert node["id"] == "root" + + def test_finds_nested_child(self): + tree = _make_tree() + node = _find_node_by_id(tree, "restart-ok") + assert node is not None + assert node["id"] == "restart-ok" + + def test_returns_none_for_missing(self): + tree = _make_tree() + assert _find_node_by_id(tree, "nonexistent") is None + + +class TestFindParentNode: + def test_root_has_no_parent(self): + tree = _make_tree() + assert _find_parent_node(tree, "root") is None + + def test_finds_parent_of_child(self): + tree = _make_tree() + parent = _find_parent_node(tree, "restart") + assert parent is not None + assert parent["id"] == "root" + + def test_finds_parent_of_deeply_nested(self): + tree = _make_tree() + parent = _find_parent_node(tree, "restart-ok") + assert parent is not None + assert parent["id"] == "restart" + + +class TestSerializeTreeOutline: + def test_produces_readable_outline(self): + tree = _make_tree() + outline = _serialize_tree_outline(tree) + assert "[decision] Is the server up?" in outline + assert "[action] Check Logs" in outline + assert "[solution] Restart Worked" in outline + + def test_marks_error_node(self): + tree = _make_tree() + outline = _serialize_tree_outline(tree, error_node_id="restart") + assert "ERROR HERE" in outline +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && python -m pytest tests/test_ai_fix_service.py -v` +Expected: ImportError — module doesn't exist. + +**Step 3: Implement the fix service** + +Create `backend/app/core/ai_fix_service.py`: + +```python +"""AI-powered fix generation for tree validation errors. + +Builds targeted prompts for each failing node, sends to the AI provider, +and validates the fix before returning it. +""" +import copy +import json +import logging +import re +from typing import Any + +from app.core.ai_provider import get_ai_provider +from app.core.ai_tree_validator import validate_generated_tree + +logger = logging.getLogger(__name__) + +# ── Helpers ── + + +def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None: + """Recursively find a node by ID in the tree.""" + if not isinstance(tree, dict): + return None + if tree.get("id") == node_id: + return tree + for child in tree.get("children", []): + result = _find_node_by_id(child, node_id) + if result is not None: + return result + return None + + +def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None: + """Find the parent of the node with target_id.""" + if not isinstance(tree, dict): + return None + for child in tree.get("children", []): + if isinstance(child, dict) and child.get("id") == target_id: + return tree + result = _find_parent_node(child, target_id) + if result is not None: + return result + return None + + +def _serialize_tree_outline( + tree: dict[str, Any], + indent: int = 0, + error_node_id: str | None = None, +) -> str: + """Serialize tree as a compact readable outline for the AI prompt.""" + if not isinstance(tree, dict): + return "" + + node_type = tree.get("type", "?") + label = tree.get("question") or tree.get("title") or tree.get("id", "?") + prefix = " " * indent + marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else "" + + line = f"{prefix}- [{node_type}] {label}{marker}" + lines = [line] + + for child in tree.get("children", []): + lines.append(_serialize_tree_outline(child, indent + 1, error_node_id)) + + return "\n".join(lines) + + +def _strip_markdown_fences(text: str) -> str: + """Strip ```json ... ``` fences from AI response.""" + match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL) + if match: + return match.group(1).strip() + return text + + +# ── Prompt ── + +FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers. + +You will receive: +1. A full flow outline showing the tree structure +2. The specific failing node with its full JSON +3. The validation error message + +Your task: Return a FIXED version of the failing node as valid JSON. Rules: +- Fix ONLY the structural issue described in the error message +- Keep ALL existing content (titles, descriptions, questions, options) unchanged +- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic +- Every new node must have a unique ID (use descriptive kebab-case IDs) +- Decision nodes must have at least 2 options and at least 2 children +- Action nodes must have a next_node_id pointing to a sibling node in the parent's children +- Solution nodes are leaf nodes (no children) +- Return ONLY the fixed node JSON, no explanation""" + + +def _build_fix_prompt( + tree: dict[str, Any], + node_id: str, + error_message: str, + tree_name: str, + tree_type: str, +) -> str: + """Build the user message for fixing a specific node.""" + outline = _serialize_tree_outline(tree, error_node_id=node_id) + failing_node = _find_node_by_id(tree, node_id) + node_json = json.dumps(failing_node, indent=2) if failing_node else "{}" + + return ( + f"Flow name: {tree_name}\n" + f"Flow type: {tree_type}\n\n" + f"FULL FLOW OUTLINE:\n{outline}\n\n" + f"ERROR: {error_message}\n\n" + f"FAILING NODE (full JSON):\n{node_json}\n\n" + f"Return the fixed version of this node as JSON." + ) + + +# ── Main Service ── + + +async def generate_fixes( + tree_structure: dict[str, Any], + tree_name: str, + tree_type: str, + validation_errors: list[dict[str, str]], +) -> tuple[list[dict[str, Any]], int, int]: + """Generate AI fixes for each validation error. + + Args: + tree_structure: Full tree JSON. + tree_name: Name of the flow. + tree_type: Type of flow (troubleshooting/procedural/maintenance). + validation_errors: List of {"node_id": str, "message": str}. + + Returns: + (fixes, total_input_tokens, total_output_tokens) + Each fix: {"target_node_id", "error_message", "description", "original_node", "fixed_node"} + """ + provider = get_ai_provider() + fixes: list[dict[str, Any]] = [] + total_input = 0 + total_output = 0 + + for error in validation_errors: + node_id = error["node_id"] + error_message = error["message"] + + original_node = _find_node_by_id(tree_structure, node_id) + if original_node is None: + logger.warning("Node %s not found in tree, skipping fix", node_id) + continue + + original_snapshot = copy.deepcopy(original_node) + user_message = _build_fix_prompt( + tree_structure, node_id, error_message, tree_name, tree_type + ) + + # Attempt fix (with 1 retry using corrective prompt) + messages: list[dict[str, str]] = [{"role": "user", "content": user_message}] + + for attempt in range(2): + try: + raw_text, inp_tokens, out_tokens = await provider.generate_json( + system_prompt=FIX_SYSTEM_PROMPT, + messages=messages, + max_tokens=4096, + ) + total_input += inp_tokens + total_output += out_tokens + + cleaned = _strip_markdown_fences(raw_text) + fixed_node = json.loads(cleaned) + + # Quick validation: check that the fix actually addresses the error + # by substituting into tree and re-validating that specific error is gone + test_tree = copy.deepcopy(tree_structure) + _replace_node_in_tree(test_tree, node_id, fixed_node) + remaining_errors = validate_generated_tree(test_tree) + still_has_error = any( + node_id in e and error_message.split(":")[0].lower() in e.lower() + for e in remaining_errors + ) + + if still_has_error and attempt == 0: + # Retry with corrective prompt + messages.append({"role": "assistant", "content": raw_text}) + messages.append({ + "role": "user", + "content": ( + f"The fix still has the same validation error. " + f"Remaining errors: {remaining_errors}. " + f"Please try again." + ), + }) + continue + + # Extract description from the fix + description = _describe_fix(original_snapshot, fixed_node) + + fixes.append({ + "target_node_id": node_id, + "error_message": error_message, + "description": description, + "original_node": original_snapshot, + "fixed_node": fixed_node, + }) + break + + except (json.JSONDecodeError, KeyError, TypeError) as exc: + logger.warning( + "Fix attempt %d for node %s failed: %s", attempt + 1, node_id, exc + ) + if attempt == 0: + messages.append({"role": "assistant", "content": raw_text if 'raw_text' in dir() else ""}) + messages.append({ + "role": "user", + "content": f"Invalid JSON response. Return ONLY valid JSON for the fixed node.", + }) + else: + logger.error("Failed to generate fix for node %s after 2 attempts", node_id) + + return fixes, total_input, total_output + + +def _replace_node_in_tree( + tree: dict[str, Any], target_id: str, replacement: dict[str, Any] +) -> bool: + """Replace a node in the tree by ID. Returns True if found and replaced.""" + if tree.get("id") == target_id: + tree.clear() + tree.update(replacement) + return True + for child in tree.get("children", []): + if isinstance(child, dict): + if child.get("id") == target_id: + child.clear() + child.update(replacement) + return True + if _replace_node_in_tree(child, target_id, replacement): + return True + return False + + +def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str: + """Generate a human-readable description of what changed.""" + orig_children = len(original.get("children", [])) + fixed_children = len(fixed.get("children", [])) + orig_options = len(original.get("options", [])) + fixed_options = len(fixed.get("options", [])) + + parts: list[str] = [] + if fixed_children > orig_children: + added = fixed_children - orig_children + parts.append(f"Added {added} child node{'s' if added > 1 else ''}") + if fixed_options > orig_options: + added = fixed_options - orig_options + parts.append(f"Added {added} option{'s' if added > 1 else ''}") + if "next_node_id" in fixed and "next_node_id" not in original: + parts.append(f"Added next_node_id '{fixed['next_node_id']}'") + + return "; ".join(parts) if parts else "Structural fix applied" +``` + +**Step 4: Run tests** + +Run: `cd backend && python -m pytest tests/test_ai_fix_service.py -v` +Expected: All pass. + +**Step 5: Commit** + +```bash +git add backend/app/core/ai_fix_service.py backend/tests/test_ai_fix_service.py +git commit -m "feat: add AI fix service with prompt building and validation" +``` + +--- + +## Task 7: Create the fix-tree endpoint + +**Files:** +- Create: `backend/app/api/endpoints/ai_fix.py` +- Modify: `backend/app/api/router.py` +- Test: `backend/tests/test_ai_fix_endpoint.py` + +**Step 1: Write the endpoint test** + +Create `backend/tests/test_ai_fix_endpoint.py`: + +```python +"""Integration tests for AI fix-tree endpoint.""" +import json +from unittest.mock import AsyncMock, patch + +import pytest + +from app.core.config import settings + + +SAMPLE_TREE = { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs OK", + "description": "Issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}], + "children": [ + {"id": "done", "type": "solution", "title": "Done", "description": "Fixed."}, + ], + }, + ], +} + + +@pytest.fixture +def enable_ai(): + original_key = settings.GOOGLE_AI_API_KEY + original_provider = settings.AI_PROVIDER + settings.GOOGLE_AI_API_KEY = "fake-key" + settings.AI_PROVIDER = "gemini" + yield + settings.GOOGLE_AI_API_KEY = original_key + settings.AI_PROVIDER = original_provider + + +@pytest.mark.asyncio +class TestFixTreeEndpoint: + + async def test_returns_401_without_auth(self, client): + response = await client.post("/api/v1/ai/fix-tree", json={ + "tree_structure": SAMPLE_TREE, + "tree_name": "Test", + "tree_type": "troubleshooting", + "validation_errors": [{"node_id": "restart", "message": "Need 2 options"}], + }) + assert response.status_code == 401 + + async def test_returns_503_when_ai_disabled(self, client, auth_headers): + original = settings.GOOGLE_AI_API_KEY + orig_anthropic = settings.ANTHROPIC_API_KEY + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + try: + response = await client.post( + "/api/v1/ai/fix-tree", + json={ + "tree_structure": SAMPLE_TREE, + "tree_name": "Test", + "tree_type": "troubleshooting", + "validation_errors": [{"node_id": "restart", "message": "test"}], + }, + headers=auth_headers, + ) + assert response.status_code == 503 + finally: + settings.GOOGLE_AI_API_KEY = original + settings.ANTHROPIC_API_KEY = orig_anthropic + + async def test_returns_fixes_on_success(self, client, auth_headers, enable_ai): + fixed_node = { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [ + {"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"}, + {"id": "opt-r-no", "label": "No", "next_node_id": "escalate"}, + ], + "children": [ + {"id": "done", "type": "solution", "title": "Done", "description": "Fixed."}, + {"id": "escalate", "type": "solution", "title": "Escalate", "description": "Escalate to vendor."}, + ], + } + + mock_provider = AsyncMock() + mock_provider.generate_json = AsyncMock( + return_value=(json.dumps(fixed_node), 500, 300) + ) + + with patch("app.core.ai_fix_service.get_ai_provider", return_value=mock_provider): + response = await client.post( + "/api/v1/ai/fix-tree", + json={ + "tree_structure": SAMPLE_TREE, + "tree_name": "Server Flow", + "tree_type": "troubleshooting", + "validation_errors": [ + {"node_id": "restart", "message": "Decision node must have at least 2 options"}, + ], + }, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["fixes"]) == 1 + assert data["fixes"][0]["target_node_id"] == "restart" + assert data["tokens_used"]["input"] == 500 + assert data["tokens_used"]["output"] == 300 +``` + +**Step 2: Run to verify it fails** + +Run: `cd backend && python -m pytest tests/test_ai_fix_endpoint.py -v` +Expected: Fail — endpoint doesn't exist. + +**Step 3: Create the endpoint** + +Create `backend/app/api/endpoints/ai_fix.py`: + +```python +"""AI auto-fix endpoint for tree validation errors.""" +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.rate_limit import limiter +from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin +from app.core.config import settings +from app.core.ai_fix_service import generate_fixes +from app.models.user import User +from app.schemas.ai_fix import AIFixTreeRequest, AIFixTreeResponse, AIFixProposal, AIFixTokenUsage + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ai", tags=["ai-fix"]) + + +def _require_ai_enabled() -> None: + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) + + +@router.post("/fix-tree", response_model=AIFixTreeResponse) +@limiter.limit("10/minute") +async def fix_tree( + request: Request, + body: AIFixTreeRequest, + user: Annotated[User, Depends(require_engineer_or_admin)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Generate AI-powered fixes for tree validation errors.""" + _require_ai_enabled() + + try: + fixes, total_input, total_output = await generate_fixes( + tree_structure=body.tree_structure, + tree_name=body.tree_name, + tree_type=body.tree_type, + validation_errors=[ + {"node_id": e.node_id, "message": e.message} + for e in body.validation_errors + ], + ) + except RuntimeError as exc: + logger.error("AI fix generation failed: %s", exc) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(exc), + ) + except Exception as exc: + logger.exception("Unexpected error during AI fix generation") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to generate fixes. Please try again.", + ) + + return AIFixTreeResponse( + fixes=[ + AIFixProposal( + target_node_id=f["target_node_id"], + error_message=f["error_message"], + description=f["description"], + original_node=f["original_node"], + fixed_node=f["fixed_node"], + ) + for f in fixes + ], + tokens_used=AIFixTokenUsage(input=total_input, output=total_output), + ) +``` + +**Step 4: Register in router** + +In `backend/app/api/router.py`, add: + +After line 8 (`from app.api.endpoints import ai_builder`): +```python +from app.api.endpoints import ai_fix +``` + +After line 38 (`api_router.include_router(ai_builder.router)`): +```python +api_router.include_router(ai_fix.router) +``` + +**Step 5: Run tests** + +Run: `cd backend && python -m pytest tests/test_ai_fix_endpoint.py -v` +Expected: All pass. + +**Step 6: Run full test suite** + +Run: `cd backend && python -m pytest --override-ini="addopts=" -v` +Expected: All pass (100+ tests). + +**Step 7: Commit** + +```bash +git add backend/app/api/endpoints/ai_fix.py backend/app/api/router.py backend/app/schemas/ai_fix.py backend/tests/test_ai_fix_endpoint.py +git commit -m "feat: add POST /ai/fix-tree endpoint for AI-powered validation fixes" +``` + +--- + +## Task 8: Add frontend API client for fix-tree + +**Files:** +- Modify: `frontend/src/api/trees.ts` +- Create: `frontend/src/types/ai-fix.ts` +- Modify: `frontend/src/types/index.ts` + +**Step 1: Create the types** + +Create `frontend/src/types/ai-fix.ts`: + +```typescript +export interface AIFixValidationError { + node_id: string + message: string +} + +export interface AIFixProposal { + target_node_id: string + error_message: string + description: string + original_node: Record + fixed_node: Record +} + +export interface AIFixTreeRequest { + tree_structure: Record + tree_name: string + tree_type: 'troubleshooting' | 'procedural' | 'maintenance' + validation_errors: AIFixValidationError[] +} + +export interface AIFixTreeResponse { + fixes: AIFixProposal[] + tokens_used: { input: number; output: number } +} +``` + +**Step 2: Export from types/index.ts** + +Add to `frontend/src/types/index.ts`: + +```typescript +export type { AIFixTreeRequest, AIFixTreeResponse, AIFixProposal, AIFixValidationError } from './ai-fix' +``` + +**Step 3: Add API method to trees.ts** + +In `frontend/src/api/trees.ts`, add a new method to the `treesApi` object: + +```typescript + async fixTree(request: AIFixTreeRequest): Promise { + const response = await apiClient.post('/ai/fix-tree', request) + return response.data + }, +``` + +Import the types at the top of the file. + +**Step 4: Commit** + +```bash +git add frontend/src/types/ai-fix.ts frontend/src/types/index.ts frontend/src/api/trees.ts +git commit -m "feat: add frontend API client and types for AI fix-tree" +``` + +--- + +## Task 9: Add "Fix with AI" button to ValidationSummary + +**Files:** +- Modify: `frontend/src/components/tree-editor/ValidationSummary.tsx` + +**Step 1: Update the component** + +Add new props and the button to `ValidationSummary`: + +```typescript +import { useState } from 'react' +import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react' +import { cn } from '@/lib/utils' +import type { ValidationError } from '@/store/treeEditorStore' + +interface ValidationSummaryProps { + errors: ValidationError[] + onSelectNode: (nodeId: string) => void + onFixWithAI?: () => void + isFixing?: boolean +} + +export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) { +``` + +In the header button area (after the expand/collapse chevron at line 62), add the "Fix with AI" button. It should appear between the error count text and the chevron icon. Restructure the header to include the button: + +After the `` closing the error/warning count (around line 60), add: + +```tsx + {/* Fix with AI button — only when there are fixable errors */} + {onFixWithAI && errorItems.some(e => e.nodeId) && ( + + )} +``` + +**Step 2: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build passes. + +**Step 3: Commit** + +```bash +git add frontend/src/components/tree-editor/ValidationSummary.tsx +git commit -m "feat: add Fix with AI button to ValidationSummary" +``` + +--- + +## Task 10: Build the AIFixReviewModal + +**Files:** +- Create: `frontend/src/components/tree-editor/AIFixReviewModal.tsx` + +**Step 1: Create the review modal** + +Create `frontend/src/components/tree-editor/AIFixReviewModal.tsx`: + +```typescript +import { useState } from 'react' +import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react' +import { cn } from '@/lib/utils' +import type { AIFixProposal } from '@/types' + +interface AIFixReviewModalProps { + fixes: AIFixProposal[] + onApply: (fix: AIFixProposal) => void + onApplyAll: () => void + onClose: () => void +} + +export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) { + const [appliedIds, setAppliedIds] = useState>(new Set()) + const [skippedIds, setSkippedIds] = useState>(new Set()) + const [expandedIds, setExpandedIds] = useState>(new Set(fixes.map(f => f.target_node_id))) + + const handleApply = (fix: AIFixProposal) => { + onApply(fix) + setAppliedIds(prev => new Set(prev).add(fix.target_node_id)) + } + + const handleSkip = (fix: AIFixProposal) => { + setSkippedIds(prev => new Set(prev).add(fix.target_node_id)) + } + + const toggleExpanded = (id: string) => { + setExpandedIds(prev => { + const next = new Set(prev) + if (next.has(id)) next.delete(id) + else next.add(id) + return next + }) + } + + const pendingFixes = fixes.filter( + f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id) + ) + const allHandled = pendingFixes.length === 0 + + return ( +
+
+ {/* Header */} +
+
+ +

+ AI Fix Proposals ({fixes.length}) +

+
+ +
+ + {/* Body */} +
+ {fixes.map((fix) => { + const isApplied = appliedIds.has(fix.target_node_id) + const isSkipped = skippedIds.has(fix.target_node_id) + const isExpanded = expandedIds.has(fix.target_node_id) + + return ( +
+ {/* Fix header */} +
+
+

{fix.error_message}

+

{fix.description}

+

+ Node: {fix.target_node_id} +

+
+ {isApplied && ( + + Applied + + )} + {isSkipped && ( + Skipped + )} +
+ + {/* Expand/collapse detail */} + {!isApplied && !isSkipped && ( + <> + + + {isExpanded && ( +
+
+

Before

+
+                            {JSON.stringify(fix.original_node, null, 2)}
+                          
+
+
+

After

+
+                            {JSON.stringify(fix.fixed_node, null, 2)}
+                          
+
+
+ )} + + {/* Action buttons */} +
+ + +
+ + )} +
+ ) + })} +
+ + {/* Footer */} +
+ + {!allHandled && ( + + )} +
+
+
+ ) +} +``` + +**Step 2: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build passes. + +**Step 3: Commit** + +```bash +git add frontend/src/components/tree-editor/AIFixReviewModal.tsx +git commit -m "feat: add AIFixReviewModal component for reviewing AI-proposed fixes" +``` + +--- + +## Task 11: Wire everything together in TreeEditorPage + +**Files:** +- Modify: `frontend/src/pages/TreeEditorPage.tsx` + +This is the integration task. The page needs to: + +1. Import the new components and API +2. Add state for fix flow (`isFixing`, `fixProposals`) +3. Handle "Fix with AI" button click — call `treesApi.fixTree()` +4. Show `AIFixReviewModal` when proposals are available +5. Handle apply/skip — call `updateNode()` on the tree editor store +6. Re-run `validate()` after applying fixes + +**Step 1: Find where ValidationSummary is rendered in TreeEditorPage** + +Search for `(null) +``` + +**Step 4: Add handler for "Fix with AI"** + +```typescript +const handleFixWithAI = async () => { + const store = useTreeEditorStore.getState() + if (!store.treeStructure) return + + // Get only fixable errors (structural errors with nodeId) + const fixableErrors = store.validationErrors + .filter(e => e.severity === 'error' && e.nodeId) + .map(e => ({ node_id: e.nodeId!, message: e.message })) + + if (fixableErrors.length === 0) return + + setIsFixing(true) + try { + const result = await treesApi.fixTree({ + tree_structure: store.treeStructure as Record, + tree_name: store.name, + tree_type: (store.treeType || 'troubleshooting') as 'troubleshooting' | 'procedural' | 'maintenance', + validation_errors: fixableErrors, + }) + if (result.fixes.length > 0) { + setFixProposals(result.fixes) + } else { + toast.info('AI could not generate fixes for these errors') + } + } catch { + toast.error('Failed to generate AI fixes. Please try again.') + } finally { + setIsFixing(false) + } +} +``` + +**Step 5: Add handlers for apply/close** + +```typescript +const handleApplyFix = (fix: AIFixProposal) => { + const store = useTreeEditorStore.getState() + store.updateNode(fix.target_node_id, fix.fixed_node as Partial) +} + +const handleApplyAllFixes = () => { + if (!fixProposals) return + for (const fix of fixProposals) { + handleApplyFix(fix) + } + setFixProposals(null) + // Re-validate after applying all fixes + setTimeout(() => { + useTreeEditorStore.getState().validate() + }, 100) +} + +const handleCloseFixModal = () => { + setFixProposals(null) + // Re-validate in case some fixes were applied + useTreeEditorStore.getState().validate() +} +``` + +**Step 6: Pass props to ValidationSummary** + +Update the `` JSX to include the new props: + +```tsx + +``` + +**Step 7: Add the review modal** + +After ``, add: + +```tsx +{fixProposals && ( + +)} +``` + +**Step 8: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build passes. + +**Step 9: Commit** + +```bash +git add frontend/src/pages/TreeEditorPage.tsx +git commit -m "feat: wire AI fix flow into TreeEditorPage" +``` + +--- + +## Task 12: Final verification + +**Step 1: Run all backend tests** + +Run: `cd backend && python -m pytest --override-ini="addopts=" -v` +Expected: All pass. + +**Step 2: Run frontend build** + +Run: `cd frontend && npm run build` +Expected: Build passes with no errors. + +**Step 3: Final commit if any adjustments needed** + +Fix any issues found during verification and commit.