12-task plan covering: SDK install, config, provider abstraction, service migration, fix service, endpoint, frontend types/API, ValidationSummary button, review modal, and TreeEditorPage wiring. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
54 KiB
AI Auto-Fix & Gemini Flash Provider Implementation Plan
For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
Goal: Add Gemini 2.5 Flash as primary AI provider (with Claude fallback), then build an AI-powered "Fix with AI" feature that generates structural fixes for validation errors in the tree editor.
Architecture: A provider abstraction layer (ai_provider.py) wraps Gemini and Anthropic SDKs behind a unified generate_json() interface. The existing ai_tree_generator_service.py swaps its direct Anthropic calls for this abstraction. A new ai_fix_service.py builds prompts from validation errors + tree context and returns proposed node patches. The frontend adds a "Fix with AI" button to ValidationSummary and a review modal for applying fixes.
Tech Stack: Python FastAPI, google-genai SDK, anthropic SDK, Pydantic v2, React 19, TypeScript, Zustand, Tailwind CSS
Design Doc: docs/plans/2026-02-26-ai-autofix-gemini-design.md
Task 1: Install google-genai SDK
Files:
- Modify:
backend/requirements.txt
Step 1: Add the dependency
Add google-genai to backend/requirements.txt:
google-genai>=1.0.0
Step 2: Install it
Run: cd backend && pip install google-genai
Step 3: Commit
git add backend/requirements.txt
git commit -m "chore: add google-genai SDK dependency"
Task 2: Add Gemini config vars to Settings
Files:
- Modify:
backend/app/core/config.py:75-85
Step 1: Add new config variables
In backend/app/core/config.py, after line 80 (AI_REQUEST_TIMEOUT_SECONDS), add:
# AI Provider selection
AI_PROVIDER: str = "gemini" # "gemini" or "anthropic"
GOOGLE_AI_API_KEY: Optional[str] = None
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
Step 2: Update ai_enabled property
Replace the existing ai_enabled property (lines 82-85) with:
@property
def ai_enabled(self) -> bool:
"""Check if any AI provider is configured."""
return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None
Step 3: Verify no tests break
Run: cd backend && python -m pytest tests/test_ai_tree_validator.py -v
Expected: All pass (config changes don't affect validator tests).
Step 4: Commit
git add backend/app/core/config.py
git commit -m "feat: add Gemini Flash config vars to Settings"
Task 3: Build the AI provider abstraction
Files:
- Create:
backend/app/core/ai_provider.py - Test:
backend/tests/test_ai_provider.py
Step 1: Write tests for the provider abstraction
Create backend/tests/test_ai_provider.py:
"""Tests for AI provider abstraction layer."""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from app.core.config import settings
class TestGetAIProvider:
"""Test provider factory function."""
def test_returns_gemini_when_configured(self):
original_provider = settings.AI_PROVIDER
original_key = settings.GOOGLE_AI_API_KEY
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = "fake-gemini-key"
try:
from app.core.ai_provider import get_ai_provider, GeminiProvider
provider = get_ai_provider()
assert isinstance(provider, GeminiProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_key
def test_returns_anthropic_when_configured(self):
original_provider = settings.AI_PROVIDER
original_key = settings.ANTHROPIC_API_KEY
settings.AI_PROVIDER = "anthropic"
settings.ANTHROPIC_API_KEY = "fake-anthropic-key"
try:
from app.core.ai_provider import get_ai_provider, AnthropicProvider
provider = get_ai_provider()
assert isinstance(provider, AnthropicProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.ANTHROPIC_API_KEY = original_key
def test_falls_back_to_anthropic_when_gemini_key_missing(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = "fake-anthropic-key"
try:
from app.core.ai_provider import get_ai_provider, AnthropicProvider
provider = get_ai_provider()
assert isinstance(provider, AnthropicProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
def test_raises_when_no_provider_configured(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = None
try:
from app.core.ai_provider import get_ai_provider
with pytest.raises(RuntimeError, match="No AI provider configured"):
get_ai_provider()
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
class TestAnthropicProvider:
"""Test Anthropic provider generate_json."""
@pytest.mark.asyncio
async def test_generate_json_returns_text_and_tokens(self):
from app.core.ai_provider import AnthropicProvider
mock_response = MagicMock()
mock_response.content = [MagicMock(text='{"key": "value"}')]
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
mock_client = AsyncMock()
mock_client.messages.create = AsyncMock(return_value=mock_response)
with patch("app.core.ai_provider.anthropic.AsyncAnthropic", return_value=mock_client):
provider = AnthropicProvider(api_key="fake-key")
text, inp, out = await provider.generate_json(
system_prompt="You are a helper.",
messages=[{"role": "user", "content": "Hello"}],
max_tokens=1024,
)
assert text == '{"key": "value"}'
assert inp == 100
assert out == 50
class TestGeminiProvider:
"""Test Gemini provider generate_json."""
@pytest.mark.asyncio
async def test_generate_json_returns_text_and_tokens(self):
from app.core.ai_provider import GeminiProvider
mock_response = MagicMock()
mock_response.text = '{"key": "value"}'
mock_response.usage_metadata = MagicMock(
prompt_token_count=100,
candidates_token_count=50,
)
mock_client = MagicMock()
mock_model = MagicMock()
mock_model.generate_content_async = AsyncMock(return_value=mock_response)
mock_client.models = mock_model
with patch("app.core.ai_provider.genai.Client", return_value=mock_client):
provider = GeminiProvider(api_key="fake-key")
text, inp, out = await provider.generate_json(
system_prompt="You are a helper.",
messages=[{"role": "user", "content": "Hello"}],
max_tokens=1024,
)
assert text == '{"key": "value"}'
assert inp == 100
assert out == 50
Step 2: Run tests to verify they fail
Run: cd backend && python -m pytest tests/test_ai_provider.py -v
Expected: ImportError — ai_provider module doesn't exist yet.
Step 3: Implement the provider abstraction
Create backend/app/core/ai_provider.py:
"""AI provider abstraction layer.
Supports Gemini (default) and Anthropic (fallback) behind a unified interface.
"""
import logging
from abc import ABC, abstractmethod
from typing import Any
import anthropic
from google import genai
from google.genai import types as genai_types
from app.core.config import settings
logger = logging.getLogger(__name__)
class AIProvider(ABC):
"""Base class for AI providers."""
@abstractmethod
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Generate a JSON response from the AI model.
Args:
system_prompt: System instructions for the model.
messages: List of {"role": "user"|"assistant", "content": str} dicts.
max_tokens: Maximum output tokens.
Returns:
(response_text, input_tokens, output_tokens)
"""
...
class GeminiProvider(AIProvider):
"""Google Gemini provider using google-genai SDK."""
def __init__(self, api_key: str, model: str | None = None):
self._api_key = api_key
self._model = model or settings.AI_MODEL_GEMINI
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
client = genai.Client(api_key=self._api_key)
# Build contents: system instruction is separate in Gemini API
contents: list[genai_types.Content] = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
contents.append(genai_types.Content(
role=role,
parts=[genai_types.Part(text=msg["content"])],
))
config = genai_types.GenerateContentConfig(
system_instruction=system_prompt,
max_output_tokens=max_tokens,
response_mime_type="application/json",
)
response = await client.models.generate_content_async(
model=self._model,
contents=contents,
config=config,
)
text = response.text or ""
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
output_tokens = getattr(response.usage_metadata, "candidates_token_count", 0) or 0
return text, input_tokens, output_tokens
class AnthropicProvider(AIProvider):
"""Anthropic Claude provider."""
def __init__(self, api_key: str, model: str | None = None):
self._api_key = api_key
self._model = model or settings.AI_MODEL_ANTHROPIC
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
client = anthropic.AsyncAnthropic(
api_key=self._api_key,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
response = await client.messages.create(
model=self._model,
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
)
text = response.content[0].text
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
return text, input_tokens, output_tokens
def get_ai_provider() -> AIProvider:
"""Factory: return the configured AI provider.
Falls back to Anthropic if Gemini key is missing.
Raises RuntimeError if no provider is configured.
"""
if settings.AI_PROVIDER == "gemini" and settings.GOOGLE_AI_API_KEY:
logger.info("Using Gemini provider (%s)", settings.AI_MODEL_GEMINI)
return GeminiProvider(api_key=settings.GOOGLE_AI_API_KEY)
if settings.AI_PROVIDER == "anthropic" and settings.ANTHROPIC_API_KEY:
logger.info("Using Anthropic provider (%s)", settings.AI_MODEL_ANTHROPIC)
return AnthropicProvider(api_key=settings.ANTHROPIC_API_KEY)
# Fallback: if Gemini requested but key missing, try Anthropic
if settings.ANTHROPIC_API_KEY:
logger.warning("Gemini key missing, falling back to Anthropic provider")
return AnthropicProvider(api_key=settings.ANTHROPIC_API_KEY)
raise RuntimeError(
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
)
Step 4: Run tests to verify they pass
Run: cd backend && python -m pytest tests/test_ai_provider.py -v
Expected: All pass.
Step 5: Commit
git add backend/app/core/ai_provider.py backend/tests/test_ai_provider.py
git commit -m "feat: add AI provider abstraction with Gemini and Anthropic support"
Task 4: Migrate ai_tree_generator_service to use provider abstraction
Files:
- Modify:
backend/app/core/ai_tree_generator_service.py - Modify:
backend/app/api/endpoints/ai_builder.py
Step 1: Update ai_tree_generator_service.py
In backend/app/core/ai_tree_generator_service.py:
Replace the import at line 16:
# OLD
import anthropic
# NEW
from app.core.ai_provider import get_ai_provider
Remove the _get_client() function (lines 124-131).
Update scaffold_branches() (starting at line 141). Replace the client creation and API call:
async def scaffold_branches(
wizard_state: dict[str, Any],
) -> tuple[list[dict[str, str]], int, int, float]:
"""Stage 2: AI suggests top-level branches."""
provider = get_ai_provider()
flow_type = wizard_state.get("flow_type", "troubleshooting")
name = wizard_state.get("name", "")
description = wizard_state.get("description", "")
tags = wizard_state.get("environment_tags", [])
user_message = (
f"Flow type: {flow_type}\n"
f"Name: {name}\n"
f"Description: {description}\n"
)
if tags:
user_message += f"Environment: {', '.join(tags)}\n"
raw_text, input_tokens, output_tokens = await provider.generate_json(
system_prompt=SCAFFOLD_SYSTEM_PROMPT,
messages=[{"role": "user", "content": user_message}],
max_tokens=1024,
)
Then update the rest of the function to use raw_text instead of response.content[0].text, and input_tokens/output_tokens directly instead of response.usage.*.
Do the same for generate_branch_detail() — replace _get_client() + client.messages.create() with provider.generate_json(). The retry loop structure stays the same; just swap the API call and response parsing.
Step 2: Update ai_builder.py endpoint
In backend/app/api/endpoints/ai_builder.py, line 13:
# OLD
import anthropic
# NEW (remove this import entirely — no longer needed in the endpoint file)
Update the _require_ai_enabled() function (line 50) to check the new config:
def _require_ai_enabled() -> None:
if not settings.ai_enabled:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI Flow Builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
)
Update any except anthropic.APIError catch blocks in the endpoint to catch generic Exception or a broader error type, since the provider abstraction may raise different errors depending on backend.
Step 3: Run existing AI endpoint tests
Run: cd backend && python -m pytest tests/test_ai_endpoints.py -v
These tests mock anthropic.AsyncAnthropic — they need to be updated to mock app.core.ai_provider.get_ai_provider instead. Update the mock targets in the test file:
# Replace patches like:
# @patch("app.core.ai_tree_generator_service._get_client")
# With:
# @patch("app.core.ai_tree_generator_service.get_ai_provider")
The mock provider should return an AsyncMock that returns (json_text, input_tokens, output_tokens) from generate_json.
Step 4: Verify all tests pass
Run: cd backend && python -m pytest tests/test_ai_endpoints.py tests/test_ai_provider.py tests/test_ai_tree_validator.py -v
Expected: All pass.
Step 5: Commit
git add backend/app/core/ai_tree_generator_service.py backend/app/api/endpoints/ai_builder.py backend/tests/test_ai_endpoints.py
git commit -m "refactor: migrate AI tree generator to provider abstraction"
Task 5: Create AI fix schemas
Files:
- Create:
backend/app/schemas/ai_fix.py
Step 1: Create the schemas
Create backend/app/schemas/ai_fix.py:
"""Pydantic schemas for the AI auto-fix feature."""
from typing import Any, Literal
from pydantic import BaseModel, Field
class ValidationErrorInput(BaseModel):
"""A single validation error to fix."""
node_id: str = Field(..., description="ID of the node with the error")
message: str = Field(..., description="The validation error message")
class AIFixTreeRequest(BaseModel):
"""Request to generate AI fixes for validation errors."""
tree_structure: dict[str, Any] = Field(..., description="Full tree structure")
tree_name: str = Field("", max_length=255, description="Name of the flow")
tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field(
"troubleshooting", description="Type of flow"
)
validation_errors: list[ValidationErrorInput] = Field(
..., min_length=1, max_length=10, description="Errors to fix"
)
class AIFixProposal(BaseModel):
"""A single proposed fix from the AI."""
target_node_id: str
error_message: str
description: str
original_node: dict[str, Any]
fixed_node: dict[str, Any]
class AIFixTokenUsage(BaseModel):
input: int = 0
output: int = 0
class AIFixTreeResponse(BaseModel):
"""Response with proposed fixes."""
fixes: list[AIFixProposal]
tokens_used: AIFixTokenUsage
Step 2: Commit
git add backend/app/schemas/ai_fix.py
git commit -m "feat: add Pydantic schemas for AI fix-tree endpoint"
Task 6: Build the AI fix service
Files:
- Create:
backend/app/core/ai_fix_service.py - Test:
backend/tests/test_ai_fix_service.py
Step 1: Write tests
Create backend/tests/test_ai_fix_service.py:
"""Tests for AI fix service — prompt building and node extraction."""
import json
import pytest
from app.core.ai_fix_service import _serialize_tree_outline, _find_node_by_id, _find_parent_node
def _make_tree():
return {
"id": "root",
"type": "decision",
"question": "Is the server up?",
"options": [
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
],
"children": [
{
"id": "check-logs",
"type": "action",
"title": "Check Logs",
"description": "Review event logs.",
"next_node_id": "logs-ok",
},
{
"id": "logs-ok",
"type": "solution",
"title": "Logs Resolved",
"description": "Found issue in logs.",
},
{
"id": "restart",
"type": "decision",
"question": "Did restart fix it?",
"options": [
{"id": "opt-r-yes", "label": "Yes", "next_node_id": "restart-ok"},
],
"children": [
{
"id": "restart-ok",
"type": "solution",
"title": "Restart Worked",
"description": "Server is back.",
},
],
},
],
}
class TestFindNodeById:
def test_finds_root(self):
tree = _make_tree()
node = _find_node_by_id(tree, "root")
assert node is not None
assert node["id"] == "root"
def test_finds_nested_child(self):
tree = _make_tree()
node = _find_node_by_id(tree, "restart-ok")
assert node is not None
assert node["id"] == "restart-ok"
def test_returns_none_for_missing(self):
tree = _make_tree()
assert _find_node_by_id(tree, "nonexistent") is None
class TestFindParentNode:
def test_root_has_no_parent(self):
tree = _make_tree()
assert _find_parent_node(tree, "root") is None
def test_finds_parent_of_child(self):
tree = _make_tree()
parent = _find_parent_node(tree, "restart")
assert parent is not None
assert parent["id"] == "root"
def test_finds_parent_of_deeply_nested(self):
tree = _make_tree()
parent = _find_parent_node(tree, "restart-ok")
assert parent is not None
assert parent["id"] == "restart"
class TestSerializeTreeOutline:
def test_produces_readable_outline(self):
tree = _make_tree()
outline = _serialize_tree_outline(tree)
assert "[decision] Is the server up?" in outline
assert "[action] Check Logs" in outline
assert "[solution] Restart Worked" in outline
def test_marks_error_node(self):
tree = _make_tree()
outline = _serialize_tree_outline(tree, error_node_id="restart")
assert "ERROR HERE" in outline
Step 2: Run tests to verify they fail
Run: cd backend && python -m pytest tests/test_ai_fix_service.py -v
Expected: ImportError — module doesn't exist.
Step 3: Implement the fix service
Create backend/app/core/ai_fix_service.py:
"""AI-powered fix generation for tree validation errors.
Builds targeted prompts for each failing node, sends to the AI provider,
and validates the fix before returning it.
"""
import copy
import json
import logging
import re
from typing import Any
from app.core.ai_provider import get_ai_provider
from app.core.ai_tree_validator import validate_generated_tree
logger = logging.getLogger(__name__)
# ── Helpers ──
def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None:
"""Recursively find a node by ID in the tree."""
if not isinstance(tree, dict):
return None
if tree.get("id") == node_id:
return tree
for child in tree.get("children", []):
result = _find_node_by_id(child, node_id)
if result is not None:
return result
return None
def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None:
"""Find the parent of the node with target_id."""
if not isinstance(tree, dict):
return None
for child in tree.get("children", []):
if isinstance(child, dict) and child.get("id") == target_id:
return tree
result = _find_parent_node(child, target_id)
if result is not None:
return result
return None
def _serialize_tree_outline(
tree: dict[str, Any],
indent: int = 0,
error_node_id: str | None = None,
) -> str:
"""Serialize tree as a compact readable outline for the AI prompt."""
if not isinstance(tree, dict):
return ""
node_type = tree.get("type", "?")
label = tree.get("question") or tree.get("title") or tree.get("id", "?")
prefix = " " * indent
marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else ""
line = f"{prefix}- [{node_type}] {label}{marker}"
lines = [line]
for child in tree.get("children", []):
lines.append(_serialize_tree_outline(child, indent + 1, error_node_id))
return "\n".join(lines)
def _strip_markdown_fences(text: str) -> str:
"""Strip ```json ... ``` fences from AI response."""
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL)
if match:
return match.group(1).strip()
return text
# ── Prompt ──
FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers.
You will receive:
1. A full flow outline showing the tree structure
2. The specific failing node with its full JSON
3. The validation error message
Your task: Return a FIXED version of the failing node as valid JSON. Rules:
- Fix ONLY the structural issue described in the error message
- Keep ALL existing content (titles, descriptions, questions, options) unchanged
- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic
- Every new node must have a unique ID (use descriptive kebab-case IDs)
- Decision nodes must have at least 2 options and at least 2 children
- Action nodes must have a next_node_id pointing to a sibling node in the parent's children
- Solution nodes are leaf nodes (no children)
- Return ONLY the fixed node JSON, no explanation"""
def _build_fix_prompt(
tree: dict[str, Any],
node_id: str,
error_message: str,
tree_name: str,
tree_type: str,
) -> str:
"""Build the user message for fixing a specific node."""
outline = _serialize_tree_outline(tree, error_node_id=node_id)
failing_node = _find_node_by_id(tree, node_id)
node_json = json.dumps(failing_node, indent=2) if failing_node else "{}"
return (
f"Flow name: {tree_name}\n"
f"Flow type: {tree_type}\n\n"
f"FULL FLOW OUTLINE:\n{outline}\n\n"
f"ERROR: {error_message}\n\n"
f"FAILING NODE (full JSON):\n{node_json}\n\n"
f"Return the fixed version of this node as JSON."
)
# ── Main Service ──
async def generate_fixes(
tree_structure: dict[str, Any],
tree_name: str,
tree_type: str,
validation_errors: list[dict[str, str]],
) -> tuple[list[dict[str, Any]], int, int]:
"""Generate AI fixes for each validation error.
Args:
tree_structure: Full tree JSON.
tree_name: Name of the flow.
tree_type: Type of flow (troubleshooting/procedural/maintenance).
validation_errors: List of {"node_id": str, "message": str}.
Returns:
(fixes, total_input_tokens, total_output_tokens)
Each fix: {"target_node_id", "error_message", "description", "original_node", "fixed_node"}
"""
provider = get_ai_provider()
fixes: list[dict[str, Any]] = []
total_input = 0
total_output = 0
for error in validation_errors:
node_id = error["node_id"]
error_message = error["message"]
original_node = _find_node_by_id(tree_structure, node_id)
if original_node is None:
logger.warning("Node %s not found in tree, skipping fix", node_id)
continue
original_snapshot = copy.deepcopy(original_node)
user_message = _build_fix_prompt(
tree_structure, node_id, error_message, tree_name, tree_type
)
# Attempt fix (with 1 retry using corrective prompt)
messages: list[dict[str, str]] = [{"role": "user", "content": user_message}]
for attempt in range(2):
try:
raw_text, inp_tokens, out_tokens = await provider.generate_json(
system_prompt=FIX_SYSTEM_PROMPT,
messages=messages,
max_tokens=4096,
)
total_input += inp_tokens
total_output += out_tokens
cleaned = _strip_markdown_fences(raw_text)
fixed_node = json.loads(cleaned)
# Quick validation: check that the fix actually addresses the error
# by substituting into tree and re-validating that specific error is gone
test_tree = copy.deepcopy(tree_structure)
_replace_node_in_tree(test_tree, node_id, fixed_node)
remaining_errors = validate_generated_tree(test_tree)
still_has_error = any(
node_id in e and error_message.split(":")[0].lower() in e.lower()
for e in remaining_errors
)
if still_has_error and attempt == 0:
# Retry with corrective prompt
messages.append({"role": "assistant", "content": raw_text})
messages.append({
"role": "user",
"content": (
f"The fix still has the same validation error. "
f"Remaining errors: {remaining_errors}. "
f"Please try again."
),
})
continue
# Extract description from the fix
description = _describe_fix(original_snapshot, fixed_node)
fixes.append({
"target_node_id": node_id,
"error_message": error_message,
"description": description,
"original_node": original_snapshot,
"fixed_node": fixed_node,
})
break
except (json.JSONDecodeError, KeyError, TypeError) as exc:
logger.warning(
"Fix attempt %d for node %s failed: %s", attempt + 1, node_id, exc
)
if attempt == 0:
messages.append({"role": "assistant", "content": raw_text if 'raw_text' in dir() else ""})
messages.append({
"role": "user",
"content": f"Invalid JSON response. Return ONLY valid JSON for the fixed node.",
})
else:
logger.error("Failed to generate fix for node %s after 2 attempts", node_id)
return fixes, total_input, total_output
def _replace_node_in_tree(
tree: dict[str, Any], target_id: str, replacement: dict[str, Any]
) -> bool:
"""Replace a node in the tree by ID. Returns True if found and replaced."""
if tree.get("id") == target_id:
tree.clear()
tree.update(replacement)
return True
for child in tree.get("children", []):
if isinstance(child, dict):
if child.get("id") == target_id:
child.clear()
child.update(replacement)
return True
if _replace_node_in_tree(child, target_id, replacement):
return True
return False
def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str:
"""Generate a human-readable description of what changed."""
orig_children = len(original.get("children", []))
fixed_children = len(fixed.get("children", []))
orig_options = len(original.get("options", []))
fixed_options = len(fixed.get("options", []))
parts: list[str] = []
if fixed_children > orig_children:
added = fixed_children - orig_children
parts.append(f"Added {added} child node{'s' if added > 1 else ''}")
if fixed_options > orig_options:
added = fixed_options - orig_options
parts.append(f"Added {added} option{'s' if added > 1 else ''}")
if "next_node_id" in fixed and "next_node_id" not in original:
parts.append(f"Added next_node_id '{fixed['next_node_id']}'")
return "; ".join(parts) if parts else "Structural fix applied"
Step 4: Run tests
Run: cd backend && python -m pytest tests/test_ai_fix_service.py -v
Expected: All pass.
Step 5: Commit
git add backend/app/core/ai_fix_service.py backend/tests/test_ai_fix_service.py
git commit -m "feat: add AI fix service with prompt building and validation"
Task 7: Create the fix-tree endpoint
Files:
- Create:
backend/app/api/endpoints/ai_fix.py - Modify:
backend/app/api/router.py - Test:
backend/tests/test_ai_fix_endpoint.py
Step 1: Write the endpoint test
Create backend/tests/test_ai_fix_endpoint.py:
"""Integration tests for AI fix-tree endpoint."""
import json
from unittest.mock import AsyncMock, patch
import pytest
from app.core.config import settings
SAMPLE_TREE = {
"id": "root",
"type": "decision",
"question": "Is the server up?",
"options": [
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
],
"children": [
{
"id": "check-logs",
"type": "action",
"title": "Check Logs",
"description": "Review logs.",
"next_node_id": "logs-ok",
},
{
"id": "logs-ok",
"type": "solution",
"title": "Logs OK",
"description": "Issue in logs.",
},
{
"id": "restart",
"type": "decision",
"question": "Did restart work?",
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
"children": [
{"id": "done", "type": "solution", "title": "Done", "description": "Fixed."},
],
},
],
}
@pytest.fixture
def enable_ai():
original_key = settings.GOOGLE_AI_API_KEY
original_provider = settings.AI_PROVIDER
settings.GOOGLE_AI_API_KEY = "fake-key"
settings.AI_PROVIDER = "gemini"
yield
settings.GOOGLE_AI_API_KEY = original_key
settings.AI_PROVIDER = original_provider
@pytest.mark.asyncio
class TestFixTreeEndpoint:
async def test_returns_401_without_auth(self, client):
response = await client.post("/api/v1/ai/fix-tree", json={
"tree_structure": SAMPLE_TREE,
"tree_name": "Test",
"tree_type": "troubleshooting",
"validation_errors": [{"node_id": "restart", "message": "Need 2 options"}],
})
assert response.status_code == 401
async def test_returns_503_when_ai_disabled(self, client, auth_headers):
original = settings.GOOGLE_AI_API_KEY
orig_anthropic = settings.ANTHROPIC_API_KEY
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = None
try:
response = await client.post(
"/api/v1/ai/fix-tree",
json={
"tree_structure": SAMPLE_TREE,
"tree_name": "Test",
"tree_type": "troubleshooting",
"validation_errors": [{"node_id": "restart", "message": "test"}],
},
headers=auth_headers,
)
assert response.status_code == 503
finally:
settings.GOOGLE_AI_API_KEY = original
settings.ANTHROPIC_API_KEY = orig_anthropic
async def test_returns_fixes_on_success(self, client, auth_headers, enable_ai):
fixed_node = {
"id": "restart",
"type": "decision",
"question": "Did restart work?",
"options": [
{"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"},
{"id": "opt-r-no", "label": "No", "next_node_id": "escalate"},
],
"children": [
{"id": "done", "type": "solution", "title": "Done", "description": "Fixed."},
{"id": "escalate", "type": "solution", "title": "Escalate", "description": "Escalate to vendor."},
],
}
mock_provider = AsyncMock()
mock_provider.generate_json = AsyncMock(
return_value=(json.dumps(fixed_node), 500, 300)
)
with patch("app.core.ai_fix_service.get_ai_provider", return_value=mock_provider):
response = await client.post(
"/api/v1/ai/fix-tree",
json={
"tree_structure": SAMPLE_TREE,
"tree_name": "Server Flow",
"tree_type": "troubleshooting",
"validation_errors": [
{"node_id": "restart", "message": "Decision node must have at least 2 options"},
],
},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert len(data["fixes"]) == 1
assert data["fixes"][0]["target_node_id"] == "restart"
assert data["tokens_used"]["input"] == 500
assert data["tokens_used"]["output"] == 300
Step 2: Run to verify it fails
Run: cd backend && python -m pytest tests/test_ai_fix_endpoint.py -v
Expected: Fail — endpoint doesn't exist.
Step 3: Create the endpoint
Create backend/app/api/endpoints/ai_fix.py:
"""AI auto-fix endpoint for tree validation errors."""
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.rate_limit import limiter
from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin
from app.core.config import settings
from app.core.ai_fix_service import generate_fixes
from app.models.user import User
from app.schemas.ai_fix import AIFixTreeRequest, AIFixTreeResponse, AIFixProposal, AIFixTokenUsage
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ai", tags=["ai-fix"])
def _require_ai_enabled() -> None:
if not settings.ai_enabled:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
)
@router.post("/fix-tree", response_model=AIFixTreeResponse)
@limiter.limit("10/minute")
async def fix_tree(
request: Request,
body: AIFixTreeRequest,
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Generate AI-powered fixes for tree validation errors."""
_require_ai_enabled()
try:
fixes, total_input, total_output = await generate_fixes(
tree_structure=body.tree_structure,
tree_name=body.tree_name,
tree_type=body.tree_type,
validation_errors=[
{"node_id": e.node_id, "message": e.message}
for e in body.validation_errors
],
)
except RuntimeError as exc:
logger.error("AI fix generation failed: %s", exc)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=str(exc),
)
except Exception as exc:
logger.exception("Unexpected error during AI fix generation")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate fixes. Please try again.",
)
return AIFixTreeResponse(
fixes=[
AIFixProposal(
target_node_id=f["target_node_id"],
error_message=f["error_message"],
description=f["description"],
original_node=f["original_node"],
fixed_node=f["fixed_node"],
)
for f in fixes
],
tokens_used=AIFixTokenUsage(input=total_input, output=total_output),
)
Step 4: Register in router
In backend/app/api/router.py, add:
After line 8 (from app.api.endpoints import ai_builder):
from app.api.endpoints import ai_fix
After line 38 (api_router.include_router(ai_builder.router)):
api_router.include_router(ai_fix.router)
Step 5: Run tests
Run: cd backend && python -m pytest tests/test_ai_fix_endpoint.py -v
Expected: All pass.
Step 6: Run full test suite
Run: cd backend && python -m pytest --override-ini="addopts=" -v
Expected: All pass (100+ tests).
Step 7: Commit
git add backend/app/api/endpoints/ai_fix.py backend/app/api/router.py backend/app/schemas/ai_fix.py backend/tests/test_ai_fix_endpoint.py
git commit -m "feat: add POST /ai/fix-tree endpoint for AI-powered validation fixes"
Task 8: Add frontend API client for fix-tree
Files:
- Modify:
frontend/src/api/trees.ts - Create:
frontend/src/types/ai-fix.ts - Modify:
frontend/src/types/index.ts
Step 1: Create the types
Create frontend/src/types/ai-fix.ts:
export interface AIFixValidationError {
node_id: string
message: string
}
export interface AIFixProposal {
target_node_id: string
error_message: string
description: string
original_node: Record<string, unknown>
fixed_node: Record<string, unknown>
}
export interface AIFixTreeRequest {
tree_structure: Record<string, unknown>
tree_name: string
tree_type: 'troubleshooting' | 'procedural' | 'maintenance'
validation_errors: AIFixValidationError[]
}
export interface AIFixTreeResponse {
fixes: AIFixProposal[]
tokens_used: { input: number; output: number }
}
Step 2: Export from types/index.ts
Add to frontend/src/types/index.ts:
export type { AIFixTreeRequest, AIFixTreeResponse, AIFixProposal, AIFixValidationError } from './ai-fix'
Step 3: Add API method to trees.ts
In frontend/src/api/trees.ts, add a new method to the treesApi object:
async fixTree(request: AIFixTreeRequest): Promise<AIFixTreeResponse> {
const response = await apiClient.post<AIFixTreeResponse>('/ai/fix-tree', request)
return response.data
},
Import the types at the top of the file.
Step 4: Commit
git add frontend/src/types/ai-fix.ts frontend/src/types/index.ts frontend/src/api/trees.ts
git commit -m "feat: add frontend API client and types for AI fix-tree"
Task 9: Add "Fix with AI" button to ValidationSummary
Files:
- Modify:
frontend/src/components/tree-editor/ValidationSummary.tsx
Step 1: Update the component
Add new props and the button to ValidationSummary:
import { useState } from 'react'
import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react'
import { cn } from '@/lib/utils'
import type { ValidationError } from '@/store/treeEditorStore'
interface ValidationSummaryProps {
errors: ValidationError[]
onSelectNode: (nodeId: string) => void
onFixWithAI?: () => void
isFixing?: boolean
}
export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) {
In the header button area (after the expand/collapse chevron at line 62), add the "Fix with AI" button. It should appear between the error count text and the chevron icon. Restructure the header to include the button:
After the </span> closing the error/warning count (around line 60), add:
{/* Fix with AI button — only when there are fixable errors */}
{onFixWithAI && errorItems.some(e => e.nodeId) && (
<button
onClick={(e) => {
e.stopPropagation()
onFixWithAI()
}}
disabled={isFixing}
className={cn(
'ml-auto mr-2 flex items-center gap-1.5 rounded-md px-3 py-1 text-xs font-medium transition-colors',
isFixing
? 'bg-primary/10 text-primary cursor-wait'
: 'bg-gradient-brand text-white shadow-sm shadow-primary/20 hover:opacity-90'
)}
>
{isFixing ? (
<>
<Loader2 className="h-3 w-3 animate-spin" />
Generating fixes...
</>
) : (
<>
<Sparkles className="h-3 w-3" />
Fix with AI
</>
)}
</button>
)}
Step 2: Build to verify
Run: cd frontend && npm run build
Expected: Build passes.
Step 3: Commit
git add frontend/src/components/tree-editor/ValidationSummary.tsx
git commit -m "feat: add Fix with AI button to ValidationSummary"
Task 10: Build the AIFixReviewModal
Files:
- Create:
frontend/src/components/tree-editor/AIFixReviewModal.tsx
Step 1: Create the review modal
Create frontend/src/components/tree-editor/AIFixReviewModal.tsx:
import { useState } from 'react'
import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react'
import { cn } from '@/lib/utils'
import type { AIFixProposal } from '@/types'
interface AIFixReviewModalProps {
fixes: AIFixProposal[]
onApply: (fix: AIFixProposal) => void
onApplyAll: () => void
onClose: () => void
}
export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) {
const [appliedIds, setAppliedIds] = useState<Set<string>>(new Set())
const [skippedIds, setSkippedIds] = useState<Set<string>>(new Set())
const [expandedIds, setExpandedIds] = useState<Set<string>>(new Set(fixes.map(f => f.target_node_id)))
const handleApply = (fix: AIFixProposal) => {
onApply(fix)
setAppliedIds(prev => new Set(prev).add(fix.target_node_id))
}
const handleSkip = (fix: AIFixProposal) => {
setSkippedIds(prev => new Set(prev).add(fix.target_node_id))
}
const toggleExpanded = (id: string) => {
setExpandedIds(prev => {
const next = new Set(prev)
if (next.has(id)) next.delete(id)
else next.add(id)
return next
})
}
const pendingFixes = fixes.filter(
f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id)
)
const allHandled = pendingFixes.length === 0
return (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/80 backdrop-blur-sm p-4">
<div className="relative flex h-[80vh] w-full max-w-2xl flex-col bg-card border border-border rounded-2xl shadow-lg">
{/* Header */}
<div className="flex items-center justify-between border-b border-border px-6 py-4">
<div className="flex items-center gap-2">
<Sparkles className="h-5 w-5 text-primary" />
<h2 className="text-lg font-semibold text-foreground">
AI Fix Proposals ({fixes.length})
</h2>
</div>
<button
onClick={onClose}
className="rounded-md p-1 text-muted-foreground hover:bg-accent hover:text-foreground"
>
<X className="h-5 w-5" />
</button>
</div>
{/* Body */}
<div className="flex-1 overflow-y-auto p-4 space-y-3">
{fixes.map((fix) => {
const isApplied = appliedIds.has(fix.target_node_id)
const isSkipped = skippedIds.has(fix.target_node_id)
const isExpanded = expandedIds.has(fix.target_node_id)
return (
<div
key={fix.target_node_id}
className={cn(
'rounded-lg border p-4',
isApplied
? 'border-emerald-400/30 bg-emerald-400/5'
: isSkipped
? 'border-border bg-accent/30 opacity-60'
: 'border-border bg-card'
)}
>
{/* Fix header */}
<div className="flex items-start justify-between gap-3">
<div className="flex-1">
<p className="text-sm text-red-400 mb-1">{fix.error_message}</p>
<p className="text-sm text-foreground">{fix.description}</p>
<p className="text-xs text-muted-foreground mt-1">
Node: {fix.target_node_id}
</p>
</div>
{isApplied && (
<span className="flex items-center gap-1 rounded-full bg-emerald-400/10 px-2 py-1 text-xs text-emerald-400">
<Check className="h-3 w-3" /> Applied
</span>
)}
{isSkipped && (
<span className="text-xs text-muted-foreground">Skipped</span>
)}
</div>
{/* Expand/collapse detail */}
{!isApplied && !isSkipped && (
<>
<button
onClick={() => toggleExpanded(fix.target_node_id)}
className="mt-2 flex items-center gap-1 text-xs text-muted-foreground hover:text-foreground"
>
{isExpanded ? <ChevronUp className="h-3 w-3" /> : <ChevronDown className="h-3 w-3" />}
{isExpanded ? 'Hide' : 'Show'} details
</button>
{isExpanded && (
<div className="mt-3 grid grid-cols-2 gap-3">
<div>
<p className="text-xs font-medium text-muted-foreground mb-1">Before</p>
<pre className="overflow-x-auto rounded bg-accent/50 p-2 text-xs text-muted-foreground max-h-48 overflow-y-auto">
{JSON.stringify(fix.original_node, null, 2)}
</pre>
</div>
<div>
<p className="text-xs font-medium text-emerald-400 mb-1">After</p>
<pre className="overflow-x-auto rounded bg-emerald-400/5 p-2 text-xs text-foreground max-h-48 overflow-y-auto">
{JSON.stringify(fix.fixed_node, null, 2)}
</pre>
</div>
</div>
)}
{/* Action buttons */}
<div className="mt-3 flex gap-2">
<button
onClick={() => handleApply(fix)}
className="flex items-center gap-1 rounded-md bg-gradient-brand px-3 py-1.5 text-xs font-medium text-white shadow-sm shadow-primary/20 hover:opacity-90"
>
<Check className="h-3 w-3" />
Apply
</button>
<button
onClick={() => handleSkip(fix)}
className="flex items-center gap-1 rounded-md border border-border px-3 py-1.5 text-xs font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
>
<SkipForward className="h-3 w-3" />
Skip
</button>
</div>
</>
)}
</div>
)
})}
</div>
{/* Footer */}
<div className="flex items-center justify-between border-t border-border px-6 py-4">
<button
onClick={onClose}
className="rounded-md border border-border px-4 py-2 text-sm font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
>
{allHandled ? 'Done' : 'Cancel'}
</button>
{!allHandled && (
<button
onClick={onApplyAll}
className="rounded-md bg-gradient-brand px-4 py-2 text-sm font-medium text-white shadow-lg shadow-primary/20 hover:opacity-90"
>
Apply All ({pendingFixes.length})
</button>
)}
</div>
</div>
</div>
)
}
Step 2: Build to verify
Run: cd frontend && npm run build
Expected: Build passes.
Step 3: Commit
git add frontend/src/components/tree-editor/AIFixReviewModal.tsx
git commit -m "feat: add AIFixReviewModal component for reviewing AI-proposed fixes"
Task 11: Wire everything together in TreeEditorPage
Files:
- Modify:
frontend/src/pages/TreeEditorPage.tsx
This is the integration task. The page needs to:
- Import the new components and API
- Add state for fix flow (
isFixing,fixProposals) - Handle "Fix with AI" button click — call
treesApi.fixTree() - Show
AIFixReviewModalwhen proposals are available - Handle apply/skip — call
updateNode()on the tree editor store - Re-run
validate()after applying fixes
Step 1: Find where ValidationSummary is rendered in TreeEditorPage
Search for <ValidationSummary in TreeEditorPage.tsx and note the exact location.
Step 2: Add imports
Add to the imports in TreeEditorPage.tsx:
import { AIFixReviewModal } from '@/components/tree-editor/AIFixReviewModal'
import { treesApi } from '@/api/trees'
import type { AIFixProposal } from '@/types'
Step 3: Add state
Inside the TreeEditorPage component, add:
const [isFixing, setIsFixing] = useState(false)
const [fixProposals, setFixProposals] = useState<AIFixProposal[] | null>(null)
Step 4: Add handler for "Fix with AI"
const handleFixWithAI = async () => {
const store = useTreeEditorStore.getState()
if (!store.treeStructure) return
// Get only fixable errors (structural errors with nodeId)
const fixableErrors = store.validationErrors
.filter(e => e.severity === 'error' && e.nodeId)
.map(e => ({ node_id: e.nodeId!, message: e.message }))
if (fixableErrors.length === 0) return
setIsFixing(true)
try {
const result = await treesApi.fixTree({
tree_structure: store.treeStructure as Record<string, unknown>,
tree_name: store.name,
tree_type: (store.treeType || 'troubleshooting') as 'troubleshooting' | 'procedural' | 'maintenance',
validation_errors: fixableErrors,
})
if (result.fixes.length > 0) {
setFixProposals(result.fixes)
} else {
toast.info('AI could not generate fixes for these errors')
}
} catch {
toast.error('Failed to generate AI fixes. Please try again.')
} finally {
setIsFixing(false)
}
}
Step 5: Add handlers for apply/close
const handleApplyFix = (fix: AIFixProposal) => {
const store = useTreeEditorStore.getState()
store.updateNode(fix.target_node_id, fix.fixed_node as Partial<TreeStructure>)
}
const handleApplyAllFixes = () => {
if (!fixProposals) return
for (const fix of fixProposals) {
handleApplyFix(fix)
}
setFixProposals(null)
// Re-validate after applying all fixes
setTimeout(() => {
useTreeEditorStore.getState().validate()
}, 100)
}
const handleCloseFixModal = () => {
setFixProposals(null)
// Re-validate in case some fixes were applied
useTreeEditorStore.getState().validate()
}
Step 6: Pass props to ValidationSummary
Update the <ValidationSummary> JSX to include the new props:
<ValidationSummary
errors={validationErrors}
onSelectNode={handleSelectNode}
onFixWithAI={handleFixWithAI}
isFixing={isFixing}
/>
Step 7: Add the review modal
After <ValidationSummary>, add:
{fixProposals && (
<AIFixReviewModal
fixes={fixProposals}
onApply={handleApplyFix}
onApplyAll={handleApplyAllFixes}
onClose={handleCloseFixModal}
/>
)}
Step 8: Build to verify
Run: cd frontend && npm run build
Expected: Build passes.
Step 9: Commit
git add frontend/src/pages/TreeEditorPage.tsx
git commit -m "feat: wire AI fix flow into TreeEditorPage"
Task 12: Final verification
Step 1: Run all backend tests
Run: cd backend && python -m pytest --override-ini="addopts=" -v
Expected: All pass.
Step 2: Run frontend build
Run: cd frontend && npm run build
Expected: Build passes with no errors.
Step 3: Final commit if any adjustments needed
Fix any issues found during verification and commit.