Move completed plan docs to docs/plans/archive/. Add survey migration 046 and reference HTML/plan files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1708 lines
54 KiB
Markdown
1708 lines
54 KiB
Markdown
# AI Auto-Fix & Gemini Flash Provider Implementation Plan
|
|
|
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
|
|
|
**Goal:** Add Gemini 2.5 Flash as primary AI provider (with Claude fallback), then build an AI-powered "Fix with AI" feature that generates structural fixes for validation errors in the tree editor.
|
|
|
|
**Architecture:** A provider abstraction layer (`ai_provider.py`) wraps Gemini and Anthropic SDKs behind a unified `generate_json()` interface. The existing `ai_tree_generator_service.py` swaps its direct Anthropic calls for this abstraction. A new `ai_fix_service.py` builds prompts from validation errors + tree context and returns proposed node patches. The frontend adds a "Fix with AI" button to `ValidationSummary` and a review modal for applying fixes.
|
|
|
|
**Tech Stack:** Python FastAPI, google-genai SDK, anthropic SDK, Pydantic v2, React 19, TypeScript, Zustand, Tailwind CSS
|
|
|
|
**Design Doc:** `docs/plans/2026-02-26-ai-autofix-gemini-design.md`
|
|
|
|
---
|
|
|
|
## Task 1: Install google-genai SDK
|
|
|
|
**Files:**
|
|
- Modify: `backend/requirements.txt`
|
|
|
|
**Step 1: Add the dependency**
|
|
|
|
Add `google-genai` to `backend/requirements.txt`:
|
|
|
|
```
|
|
google-genai>=1.0.0
|
|
```
|
|
|
|
**Step 2: Install it**
|
|
|
|
Run: `cd backend && pip install google-genai`
|
|
|
|
**Step 3: Commit**
|
|
|
|
```bash
|
|
git add backend/requirements.txt
|
|
git commit -m "chore: add google-genai SDK dependency"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 2: Add Gemini config vars to Settings
|
|
|
|
**Files:**
|
|
- Modify: `backend/app/core/config.py:75-85`
|
|
|
|
**Step 1: Add new config variables**
|
|
|
|
In `backend/app/core/config.py`, after line 80 (`AI_REQUEST_TIMEOUT_SECONDS`), add:
|
|
|
|
```python
|
|
# AI Provider selection
|
|
AI_PROVIDER: str = "gemini" # "gemini" or "anthropic"
|
|
GOOGLE_AI_API_KEY: Optional[str] = None
|
|
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
|
|
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
|
|
```
|
|
|
|
**Step 2: Update `ai_enabled` property**
|
|
|
|
Replace the existing `ai_enabled` property (lines 82-85) with:
|
|
|
|
```python
|
|
@property
|
|
def ai_enabled(self) -> bool:
|
|
"""Check if any AI provider is configured."""
|
|
return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None
|
|
```
|
|
|
|
**Step 3: Verify no tests break**
|
|
|
|
Run: `cd backend && python -m pytest tests/test_ai_tree_validator.py -v`
|
|
Expected: All pass (config changes don't affect validator tests).
|
|
|
|
**Step 4: Commit**
|
|
|
|
```bash
|
|
git add backend/app/core/config.py
|
|
git commit -m "feat: add Gemini Flash config vars to Settings"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 3: Build the AI provider abstraction
|
|
|
|
**Files:**
|
|
- Create: `backend/app/core/ai_provider.py`
|
|
- Test: `backend/tests/test_ai_provider.py`
|
|
|
|
**Step 1: Write tests for the provider abstraction**
|
|
|
|
Create `backend/tests/test_ai_provider.py`:
|
|
|
|
```python
|
|
"""Tests for AI provider abstraction layer."""
|
|
import pytest
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
class TestGetAIProvider:
|
|
"""Test provider factory function."""
|
|
|
|
def test_returns_gemini_when_configured(self):
|
|
original_provider = settings.AI_PROVIDER
|
|
original_key = settings.GOOGLE_AI_API_KEY
|
|
settings.AI_PROVIDER = "gemini"
|
|
settings.GOOGLE_AI_API_KEY = "fake-gemini-key"
|
|
try:
|
|
from app.core.ai_provider import get_ai_provider, GeminiProvider
|
|
provider = get_ai_provider()
|
|
assert isinstance(provider, GeminiProvider)
|
|
finally:
|
|
settings.AI_PROVIDER = original_provider
|
|
settings.GOOGLE_AI_API_KEY = original_key
|
|
|
|
def test_returns_anthropic_when_configured(self):
|
|
original_provider = settings.AI_PROVIDER
|
|
original_key = settings.ANTHROPIC_API_KEY
|
|
settings.AI_PROVIDER = "anthropic"
|
|
settings.ANTHROPIC_API_KEY = "fake-anthropic-key"
|
|
try:
|
|
from app.core.ai_provider import get_ai_provider, AnthropicProvider
|
|
provider = get_ai_provider()
|
|
assert isinstance(provider, AnthropicProvider)
|
|
finally:
|
|
settings.AI_PROVIDER = original_provider
|
|
settings.ANTHROPIC_API_KEY = original_key
|
|
|
|
def test_falls_back_to_anthropic_when_gemini_key_missing(self):
|
|
original_provider = settings.AI_PROVIDER
|
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
|
settings.AI_PROVIDER = "gemini"
|
|
settings.GOOGLE_AI_API_KEY = None
|
|
settings.ANTHROPIC_API_KEY = "fake-anthropic-key"
|
|
try:
|
|
from app.core.ai_provider import get_ai_provider, AnthropicProvider
|
|
provider = get_ai_provider()
|
|
assert isinstance(provider, AnthropicProvider)
|
|
finally:
|
|
settings.AI_PROVIDER = original_provider
|
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
|
|
|
def test_raises_when_no_provider_configured(self):
|
|
original_provider = settings.AI_PROVIDER
|
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
|
settings.AI_PROVIDER = "gemini"
|
|
settings.GOOGLE_AI_API_KEY = None
|
|
settings.ANTHROPIC_API_KEY = None
|
|
try:
|
|
from app.core.ai_provider import get_ai_provider
|
|
with pytest.raises(RuntimeError, match="No AI provider configured"):
|
|
get_ai_provider()
|
|
finally:
|
|
settings.AI_PROVIDER = original_provider
|
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
|
|
|
|
|
class TestAnthropicProvider:
|
|
"""Test Anthropic provider generate_json."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json_returns_text_and_tokens(self):
|
|
from app.core.ai_provider import AnthropicProvider
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.content = [MagicMock(text='{"key": "value"}')]
|
|
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("app.core.ai_provider.anthropic.AsyncAnthropic", return_value=mock_client):
|
|
provider = AnthropicProvider(api_key="fake-key")
|
|
text, inp, out = await provider.generate_json(
|
|
system_prompt="You are a helper.",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=1024,
|
|
)
|
|
assert text == '{"key": "value"}'
|
|
assert inp == 100
|
|
assert out == 50
|
|
|
|
|
|
class TestGeminiProvider:
|
|
"""Test Gemini provider generate_json."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_json_returns_text_and_tokens(self):
|
|
from app.core.ai_provider import GeminiProvider
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.text = '{"key": "value"}'
|
|
mock_response.usage_metadata = MagicMock(
|
|
prompt_token_count=100,
|
|
candidates_token_count=50,
|
|
)
|
|
|
|
mock_client = MagicMock()
|
|
mock_model = MagicMock()
|
|
mock_model.generate_content_async = AsyncMock(return_value=mock_response)
|
|
mock_client.models = mock_model
|
|
|
|
with patch("app.core.ai_provider.genai.Client", return_value=mock_client):
|
|
provider = GeminiProvider(api_key="fake-key")
|
|
text, inp, out = await provider.generate_json(
|
|
system_prompt="You are a helper.",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=1024,
|
|
)
|
|
assert text == '{"key": "value"}'
|
|
assert inp == 100
|
|
assert out == 50
|
|
```
|
|
|
|
**Step 2: Run tests to verify they fail**
|
|
|
|
Run: `cd backend && python -m pytest tests/test_ai_provider.py -v`
|
|
Expected: ImportError — `ai_provider` module doesn't exist yet.
|
|
|
|
**Step 3: Implement the provider abstraction**
|
|
|
|
Create `backend/app/core/ai_provider.py`:
|
|
|
|
```python
|
|
"""AI provider abstraction layer.
|
|
|
|
Supports Gemini (default) and Anthropic (fallback) behind a unified interface.
|
|
"""
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any
|
|
|
|
import anthropic
|
|
from google import genai
|
|
from google.genai import types as genai_types
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AIProvider(ABC):
|
|
"""Base class for AI providers."""
|
|
|
|
@abstractmethod
|
|
async def generate_json(
|
|
self,
|
|
system_prompt: str,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int = 4096,
|
|
) -> tuple[str, int, int]:
|
|
"""Generate a JSON response from the AI model.
|
|
|
|
Args:
|
|
system_prompt: System instructions for the model.
|
|
messages: List of {"role": "user"|"assistant", "content": str} dicts.
|
|
max_tokens: Maximum output tokens.
|
|
|
|
Returns:
|
|
(response_text, input_tokens, output_tokens)
|
|
"""
|
|
...
|
|
|
|
|
|
class GeminiProvider(AIProvider):
|
|
"""Google Gemini provider using google-genai SDK."""
|
|
|
|
def __init__(self, api_key: str, model: str | None = None):
|
|
self._api_key = api_key
|
|
self._model = model or settings.AI_MODEL_GEMINI
|
|
|
|
async def generate_json(
|
|
self,
|
|
system_prompt: str,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int = 4096,
|
|
) -> tuple[str, int, int]:
|
|
client = genai.Client(api_key=self._api_key)
|
|
|
|
# Build contents: system instruction is separate in Gemini API
|
|
contents: list[genai_types.Content] = []
|
|
for msg in messages:
|
|
role = "user" if msg["role"] == "user" else "model"
|
|
contents.append(genai_types.Content(
|
|
role=role,
|
|
parts=[genai_types.Part(text=msg["content"])],
|
|
))
|
|
|
|
config = genai_types.GenerateContentConfig(
|
|
system_instruction=system_prompt,
|
|
max_output_tokens=max_tokens,
|
|
response_mime_type="application/json",
|
|
)
|
|
|
|
response = await client.models.generate_content_async(
|
|
model=self._model,
|
|
contents=contents,
|
|
config=config,
|
|
)
|
|
|
|
text = response.text or ""
|
|
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
|
output_tokens = getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
|
|
|
return text, input_tokens, output_tokens
|
|
|
|
|
|
class AnthropicProvider(AIProvider):
|
|
"""Anthropic Claude provider."""
|
|
|
|
def __init__(self, api_key: str, model: str | None = None):
|
|
self._api_key = api_key
|
|
self._model = model or settings.AI_MODEL_ANTHROPIC
|
|
|
|
async def generate_json(
|
|
self,
|
|
system_prompt: str,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int = 4096,
|
|
) -> tuple[str, int, int]:
|
|
client = anthropic.AsyncAnthropic(
|
|
api_key=self._api_key,
|
|
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
|
)
|
|
|
|
response = await client.messages.create(
|
|
model=self._model,
|
|
max_tokens=max_tokens,
|
|
system=system_prompt,
|
|
messages=messages,
|
|
)
|
|
|
|
text = response.content[0].text
|
|
input_tokens = response.usage.input_tokens
|
|
output_tokens = response.usage.output_tokens
|
|
|
|
return text, input_tokens, output_tokens
|
|
|
|
|
|
def get_ai_provider() -> AIProvider:
|
|
"""Factory: return the configured AI provider.
|
|
|
|
Falls back to Anthropic if Gemini key is missing.
|
|
Raises RuntimeError if no provider is configured.
|
|
"""
|
|
if settings.AI_PROVIDER == "gemini" and settings.GOOGLE_AI_API_KEY:
|
|
logger.info("Using Gemini provider (%s)", settings.AI_MODEL_GEMINI)
|
|
return GeminiProvider(api_key=settings.GOOGLE_AI_API_KEY)
|
|
|
|
if settings.AI_PROVIDER == "anthropic" and settings.ANTHROPIC_API_KEY:
|
|
logger.info("Using Anthropic provider (%s)", settings.AI_MODEL_ANTHROPIC)
|
|
return AnthropicProvider(api_key=settings.ANTHROPIC_API_KEY)
|
|
|
|
# Fallback: if Gemini requested but key missing, try Anthropic
|
|
if settings.ANTHROPIC_API_KEY:
|
|
logger.warning("Gemini key missing, falling back to Anthropic provider")
|
|
return AnthropicProvider(api_key=settings.ANTHROPIC_API_KEY)
|
|
|
|
raise RuntimeError(
|
|
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
|
|
)
|
|
```
|
|
|
|
**Step 4: Run tests to verify they pass**
|
|
|
|
Run: `cd backend && python -m pytest tests/test_ai_provider.py -v`
|
|
Expected: All pass.
|
|
|
|
**Step 5: Commit**
|
|
|
|
```bash
|
|
git add backend/app/core/ai_provider.py backend/tests/test_ai_provider.py
|
|
git commit -m "feat: add AI provider abstraction with Gemini and Anthropic support"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 4: Migrate ai_tree_generator_service to use provider abstraction
|
|
|
|
**Files:**
|
|
- Modify: `backend/app/core/ai_tree_generator_service.py`
|
|
- Modify: `backend/app/api/endpoints/ai_builder.py`
|
|
|
|
**Step 1: Update ai_tree_generator_service.py**
|
|
|
|
In `backend/app/core/ai_tree_generator_service.py`:
|
|
|
|
Replace the import at line 16:
|
|
```python
|
|
# OLD
|
|
import anthropic
|
|
# NEW
|
|
from app.core.ai_provider import get_ai_provider
|
|
```
|
|
|
|
Remove the `_get_client()` function (lines 124-131).
|
|
|
|
Update `scaffold_branches()` (starting at line 141). Replace the client creation and API call:
|
|
|
|
```python
|
|
async def scaffold_branches(
|
|
wizard_state: dict[str, Any],
|
|
) -> tuple[list[dict[str, str]], int, int, float]:
|
|
"""Stage 2: AI suggests top-level branches."""
|
|
provider = get_ai_provider()
|
|
|
|
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
|
name = wizard_state.get("name", "")
|
|
description = wizard_state.get("description", "")
|
|
tags = wizard_state.get("environment_tags", [])
|
|
|
|
user_message = (
|
|
f"Flow type: {flow_type}\n"
|
|
f"Name: {name}\n"
|
|
f"Description: {description}\n"
|
|
)
|
|
if tags:
|
|
user_message += f"Environment: {', '.join(tags)}\n"
|
|
|
|
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
|
system_prompt=SCAFFOLD_SYSTEM_PROMPT,
|
|
messages=[{"role": "user", "content": user_message}],
|
|
max_tokens=1024,
|
|
)
|
|
```
|
|
|
|
Then update the rest of the function to use `raw_text` instead of `response.content[0].text`, and `input_tokens`/`output_tokens` directly instead of `response.usage.*`.
|
|
|
|
Do the same for `generate_branch_detail()` — replace `_get_client()` + `client.messages.create()` with `provider.generate_json()`. The retry loop structure stays the same; just swap the API call and response parsing.
|
|
|
|
**Step 2: Update ai_builder.py endpoint**
|
|
|
|
In `backend/app/api/endpoints/ai_builder.py`, line 13:
|
|
|
|
```python
|
|
# OLD
|
|
import anthropic
|
|
# NEW (remove this import entirely — no longer needed in the endpoint file)
|
|
```
|
|
|
|
Update the `_require_ai_enabled()` function (line 50) to check the new config:
|
|
|
|
```python
|
|
def _require_ai_enabled() -> None:
|
|
if not settings.ai_enabled:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="AI Flow Builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
|
)
|
|
```
|
|
|
|
Update any `except anthropic.APIError` catch blocks in the endpoint to catch generic `Exception` or a broader error type, since the provider abstraction may raise different errors depending on backend.
|
|
|
|
**Step 3: Run existing AI endpoint tests**
|
|
|
|
Run: `cd backend && python -m pytest tests/test_ai_endpoints.py -v`
|
|
|
|
These tests mock `anthropic.AsyncAnthropic` — they need to be updated to mock `app.core.ai_provider.get_ai_provider` instead. Update the mock targets in the test file:
|
|
|
|
```python
|
|
# Replace patches like:
|
|
# @patch("app.core.ai_tree_generator_service._get_client")
|
|
# With:
|
|
# @patch("app.core.ai_tree_generator_service.get_ai_provider")
|
|
```
|
|
|
|
The mock provider should return an `AsyncMock` that returns `(json_text, input_tokens, output_tokens)` from `generate_json`.
|
|
|
|
**Step 4: Verify all tests pass**
|
|
|
|
Run: `cd backend && python -m pytest tests/test_ai_endpoints.py tests/test_ai_provider.py tests/test_ai_tree_validator.py -v`
|
|
Expected: All pass.
|
|
|
|
**Step 5: Commit**
|
|
|
|
```bash
|
|
git add backend/app/core/ai_tree_generator_service.py backend/app/api/endpoints/ai_builder.py backend/tests/test_ai_endpoints.py
|
|
git commit -m "refactor: migrate AI tree generator to provider abstraction"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 5: Create AI fix schemas
|
|
|
|
**Files:**
|
|
- Create: `backend/app/schemas/ai_fix.py`
|
|
|
|
**Step 1: Create the schemas**
|
|
|
|
Create `backend/app/schemas/ai_fix.py`:
|
|
|
|
```python
|
|
"""Pydantic schemas for the AI auto-fix feature."""
|
|
from typing import Any, Literal
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class ValidationErrorInput(BaseModel):
|
|
"""A single validation error to fix."""
|
|
|
|
node_id: str = Field(..., description="ID of the node with the error")
|
|
message: str = Field(..., description="The validation error message")
|
|
|
|
|
|
class AIFixTreeRequest(BaseModel):
|
|
"""Request to generate AI fixes for validation errors."""
|
|
|
|
tree_structure: dict[str, Any] = Field(..., description="Full tree structure")
|
|
tree_name: str = Field("", max_length=255, description="Name of the flow")
|
|
tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field(
|
|
"troubleshooting", description="Type of flow"
|
|
)
|
|
validation_errors: list[ValidationErrorInput] = Field(
|
|
..., min_length=1, max_length=10, description="Errors to fix"
|
|
)
|
|
|
|
|
|
class AIFixProposal(BaseModel):
|
|
"""A single proposed fix from the AI."""
|
|
|
|
target_node_id: str
|
|
error_message: str
|
|
description: str
|
|
original_node: dict[str, Any]
|
|
fixed_node: dict[str, Any]
|
|
|
|
|
|
class AIFixTokenUsage(BaseModel):
|
|
input: int = 0
|
|
output: int = 0
|
|
|
|
|
|
class AIFixTreeResponse(BaseModel):
|
|
"""Response with proposed fixes."""
|
|
|
|
fixes: list[AIFixProposal]
|
|
tokens_used: AIFixTokenUsage
|
|
```
|
|
|
|
**Step 2: Commit**
|
|
|
|
```bash
|
|
git add backend/app/schemas/ai_fix.py
|
|
git commit -m "feat: add Pydantic schemas for AI fix-tree endpoint"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 6: Build the AI fix service
|
|
|
|
**Files:**
|
|
- Create: `backend/app/core/ai_fix_service.py`
|
|
- Test: `backend/tests/test_ai_fix_service.py`
|
|
|
|
**Step 1: Write tests**
|
|
|
|
Create `backend/tests/test_ai_fix_service.py`:
|
|
|
|
```python
|
|
"""Tests for AI fix service — prompt building and node extraction."""
|
|
import json
|
|
import pytest
|
|
from app.core.ai_fix_service import _serialize_tree_outline, _find_node_by_id, _find_parent_node
|
|
|
|
|
|
def _make_tree():
|
|
return {
|
|
"id": "root",
|
|
"type": "decision",
|
|
"question": "Is the server up?",
|
|
"options": [
|
|
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
|
|
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
|
|
],
|
|
"children": [
|
|
{
|
|
"id": "check-logs",
|
|
"type": "action",
|
|
"title": "Check Logs",
|
|
"description": "Review event logs.",
|
|
"next_node_id": "logs-ok",
|
|
},
|
|
{
|
|
"id": "logs-ok",
|
|
"type": "solution",
|
|
"title": "Logs Resolved",
|
|
"description": "Found issue in logs.",
|
|
},
|
|
{
|
|
"id": "restart",
|
|
"type": "decision",
|
|
"question": "Did restart fix it?",
|
|
"options": [
|
|
{"id": "opt-r-yes", "label": "Yes", "next_node_id": "restart-ok"},
|
|
],
|
|
"children": [
|
|
{
|
|
"id": "restart-ok",
|
|
"type": "solution",
|
|
"title": "Restart Worked",
|
|
"description": "Server is back.",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
}
|
|
|
|
|
|
class TestFindNodeById:
|
|
def test_finds_root(self):
|
|
tree = _make_tree()
|
|
node = _find_node_by_id(tree, "root")
|
|
assert node is not None
|
|
assert node["id"] == "root"
|
|
|
|
def test_finds_nested_child(self):
|
|
tree = _make_tree()
|
|
node = _find_node_by_id(tree, "restart-ok")
|
|
assert node is not None
|
|
assert node["id"] == "restart-ok"
|
|
|
|
def test_returns_none_for_missing(self):
|
|
tree = _make_tree()
|
|
assert _find_node_by_id(tree, "nonexistent") is None
|
|
|
|
|
|
class TestFindParentNode:
|
|
def test_root_has_no_parent(self):
|
|
tree = _make_tree()
|
|
assert _find_parent_node(tree, "root") is None
|
|
|
|
def test_finds_parent_of_child(self):
|
|
tree = _make_tree()
|
|
parent = _find_parent_node(tree, "restart")
|
|
assert parent is not None
|
|
assert parent["id"] == "root"
|
|
|
|
def test_finds_parent_of_deeply_nested(self):
|
|
tree = _make_tree()
|
|
parent = _find_parent_node(tree, "restart-ok")
|
|
assert parent is not None
|
|
assert parent["id"] == "restart"
|
|
|
|
|
|
class TestSerializeTreeOutline:
|
|
def test_produces_readable_outline(self):
|
|
tree = _make_tree()
|
|
outline = _serialize_tree_outline(tree)
|
|
assert "[decision] Is the server up?" in outline
|
|
assert "[action] Check Logs" in outline
|
|
assert "[solution] Restart Worked" in outline
|
|
|
|
def test_marks_error_node(self):
|
|
tree = _make_tree()
|
|
outline = _serialize_tree_outline(tree, error_node_id="restart")
|
|
assert "ERROR HERE" in outline
|
|
```
|
|
|
|
**Step 2: Run tests to verify they fail**
|
|
|
|
Run: `cd backend && python -m pytest tests/test_ai_fix_service.py -v`
|
|
Expected: ImportError — module doesn't exist.
|
|
|
|
**Step 3: Implement the fix service**
|
|
|
|
Create `backend/app/core/ai_fix_service.py`:
|
|
|
|
```python
|
|
"""AI-powered fix generation for tree validation errors.
|
|
|
|
Builds targeted prompts for each failing node, sends to the AI provider,
|
|
and validates the fix before returning it.
|
|
"""
|
|
import copy
|
|
import json
|
|
import logging
|
|
import re
|
|
from typing import Any
|
|
|
|
from app.core.ai_provider import get_ai_provider
|
|
from app.core.ai_tree_validator import validate_generated_tree
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ── Helpers ──
|
|
|
|
|
|
def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None:
|
|
"""Recursively find a node by ID in the tree."""
|
|
if not isinstance(tree, dict):
|
|
return None
|
|
if tree.get("id") == node_id:
|
|
return tree
|
|
for child in tree.get("children", []):
|
|
result = _find_node_by_id(child, node_id)
|
|
if result is not None:
|
|
return result
|
|
return None
|
|
|
|
|
|
def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None:
|
|
"""Find the parent of the node with target_id."""
|
|
if not isinstance(tree, dict):
|
|
return None
|
|
for child in tree.get("children", []):
|
|
if isinstance(child, dict) and child.get("id") == target_id:
|
|
return tree
|
|
result = _find_parent_node(child, target_id)
|
|
if result is not None:
|
|
return result
|
|
return None
|
|
|
|
|
|
def _serialize_tree_outline(
|
|
tree: dict[str, Any],
|
|
indent: int = 0,
|
|
error_node_id: str | None = None,
|
|
) -> str:
|
|
"""Serialize tree as a compact readable outline for the AI prompt."""
|
|
if not isinstance(tree, dict):
|
|
return ""
|
|
|
|
node_type = tree.get("type", "?")
|
|
label = tree.get("question") or tree.get("title") or tree.get("id", "?")
|
|
prefix = " " * indent
|
|
marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else ""
|
|
|
|
line = f"{prefix}- [{node_type}] {label}{marker}"
|
|
lines = [line]
|
|
|
|
for child in tree.get("children", []):
|
|
lines.append(_serialize_tree_outline(child, indent + 1, error_node_id))
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _strip_markdown_fences(text: str) -> str:
|
|
"""Strip ```json ... ``` fences from AI response."""
|
|
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return text
|
|
|
|
|
|
# ── Prompt ──
|
|
|
|
FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers.
|
|
|
|
You will receive:
|
|
1. A full flow outline showing the tree structure
|
|
2. The specific failing node with its full JSON
|
|
3. The validation error message
|
|
|
|
Your task: Return a FIXED version of the failing node as valid JSON. Rules:
|
|
- Fix ONLY the structural issue described in the error message
|
|
- Keep ALL existing content (titles, descriptions, questions, options) unchanged
|
|
- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic
|
|
- Every new node must have a unique ID (use descriptive kebab-case IDs)
|
|
- Decision nodes must have at least 2 options and at least 2 children
|
|
- Action nodes must have a next_node_id pointing to a sibling node in the parent's children
|
|
- Solution nodes are leaf nodes (no children)
|
|
- Return ONLY the fixed node JSON, no explanation"""
|
|
|
|
|
|
def _build_fix_prompt(
|
|
tree: dict[str, Any],
|
|
node_id: str,
|
|
error_message: str,
|
|
tree_name: str,
|
|
tree_type: str,
|
|
) -> str:
|
|
"""Build the user message for fixing a specific node."""
|
|
outline = _serialize_tree_outline(tree, error_node_id=node_id)
|
|
failing_node = _find_node_by_id(tree, node_id)
|
|
node_json = json.dumps(failing_node, indent=2) if failing_node else "{}"
|
|
|
|
return (
|
|
f"Flow name: {tree_name}\n"
|
|
f"Flow type: {tree_type}\n\n"
|
|
f"FULL FLOW OUTLINE:\n{outline}\n\n"
|
|
f"ERROR: {error_message}\n\n"
|
|
f"FAILING NODE (full JSON):\n{node_json}\n\n"
|
|
f"Return the fixed version of this node as JSON."
|
|
)
|
|
|
|
|
|
# ── Main Service ──
|
|
|
|
|
|
async def generate_fixes(
|
|
tree_structure: dict[str, Any],
|
|
tree_name: str,
|
|
tree_type: str,
|
|
validation_errors: list[dict[str, str]],
|
|
) -> tuple[list[dict[str, Any]], int, int]:
|
|
"""Generate AI fixes for each validation error.
|
|
|
|
Args:
|
|
tree_structure: Full tree JSON.
|
|
tree_name: Name of the flow.
|
|
tree_type: Type of flow (troubleshooting/procedural/maintenance).
|
|
validation_errors: List of {"node_id": str, "message": str}.
|
|
|
|
Returns:
|
|
(fixes, total_input_tokens, total_output_tokens)
|
|
Each fix: {"target_node_id", "error_message", "description", "original_node", "fixed_node"}
|
|
"""
|
|
provider = get_ai_provider()
|
|
fixes: list[dict[str, Any]] = []
|
|
total_input = 0
|
|
total_output = 0
|
|
|
|
for error in validation_errors:
|
|
node_id = error["node_id"]
|
|
error_message = error["message"]
|
|
|
|
original_node = _find_node_by_id(tree_structure, node_id)
|
|
if original_node is None:
|
|
logger.warning("Node %s not found in tree, skipping fix", node_id)
|
|
continue
|
|
|
|
original_snapshot = copy.deepcopy(original_node)
|
|
user_message = _build_fix_prompt(
|
|
tree_structure, node_id, error_message, tree_name, tree_type
|
|
)
|
|
|
|
# Attempt fix (with 1 retry using corrective prompt)
|
|
messages: list[dict[str, str]] = [{"role": "user", "content": user_message}]
|
|
|
|
for attempt in range(2):
|
|
try:
|
|
raw_text, inp_tokens, out_tokens = await provider.generate_json(
|
|
system_prompt=FIX_SYSTEM_PROMPT,
|
|
messages=messages,
|
|
max_tokens=4096,
|
|
)
|
|
total_input += inp_tokens
|
|
total_output += out_tokens
|
|
|
|
cleaned = _strip_markdown_fences(raw_text)
|
|
fixed_node = json.loads(cleaned)
|
|
|
|
# Quick validation: check that the fix actually addresses the error
|
|
# by substituting into tree and re-validating that specific error is gone
|
|
test_tree = copy.deepcopy(tree_structure)
|
|
_replace_node_in_tree(test_tree, node_id, fixed_node)
|
|
remaining_errors = validate_generated_tree(test_tree)
|
|
still_has_error = any(
|
|
node_id in e and error_message.split(":")[0].lower() in e.lower()
|
|
for e in remaining_errors
|
|
)
|
|
|
|
if still_has_error and attempt == 0:
|
|
# Retry with corrective prompt
|
|
messages.append({"role": "assistant", "content": raw_text})
|
|
messages.append({
|
|
"role": "user",
|
|
"content": (
|
|
f"The fix still has the same validation error. "
|
|
f"Remaining errors: {remaining_errors}. "
|
|
f"Please try again."
|
|
),
|
|
})
|
|
continue
|
|
|
|
# Extract description from the fix
|
|
description = _describe_fix(original_snapshot, fixed_node)
|
|
|
|
fixes.append({
|
|
"target_node_id": node_id,
|
|
"error_message": error_message,
|
|
"description": description,
|
|
"original_node": original_snapshot,
|
|
"fixed_node": fixed_node,
|
|
})
|
|
break
|
|
|
|
except (json.JSONDecodeError, KeyError, TypeError) as exc:
|
|
logger.warning(
|
|
"Fix attempt %d for node %s failed: %s", attempt + 1, node_id, exc
|
|
)
|
|
if attempt == 0:
|
|
messages.append({"role": "assistant", "content": raw_text if 'raw_text' in dir() else ""})
|
|
messages.append({
|
|
"role": "user",
|
|
"content": f"Invalid JSON response. Return ONLY valid JSON for the fixed node.",
|
|
})
|
|
else:
|
|
logger.error("Failed to generate fix for node %s after 2 attempts", node_id)
|
|
|
|
return fixes, total_input, total_output
|
|
|
|
|
|
def _replace_node_in_tree(
|
|
tree: dict[str, Any], target_id: str, replacement: dict[str, Any]
|
|
) -> bool:
|
|
"""Replace a node in the tree by ID. Returns True if found and replaced."""
|
|
if tree.get("id") == target_id:
|
|
tree.clear()
|
|
tree.update(replacement)
|
|
return True
|
|
for child in tree.get("children", []):
|
|
if isinstance(child, dict):
|
|
if child.get("id") == target_id:
|
|
child.clear()
|
|
child.update(replacement)
|
|
return True
|
|
if _replace_node_in_tree(child, target_id, replacement):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str:
|
|
"""Generate a human-readable description of what changed."""
|
|
orig_children = len(original.get("children", []))
|
|
fixed_children = len(fixed.get("children", []))
|
|
orig_options = len(original.get("options", []))
|
|
fixed_options = len(fixed.get("options", []))
|
|
|
|
parts: list[str] = []
|
|
if fixed_children > orig_children:
|
|
added = fixed_children - orig_children
|
|
parts.append(f"Added {added} child node{'s' if added > 1 else ''}")
|
|
if fixed_options > orig_options:
|
|
added = fixed_options - orig_options
|
|
parts.append(f"Added {added} option{'s' if added > 1 else ''}")
|
|
if "next_node_id" in fixed and "next_node_id" not in original:
|
|
parts.append(f"Added next_node_id '{fixed['next_node_id']}'")
|
|
|
|
return "; ".join(parts) if parts else "Structural fix applied"
|
|
```
|
|
|
|
**Step 4: Run tests**
|
|
|
|
Run: `cd backend && python -m pytest tests/test_ai_fix_service.py -v`
|
|
Expected: All pass.
|
|
|
|
**Step 5: Commit**
|
|
|
|
```bash
|
|
git add backend/app/core/ai_fix_service.py backend/tests/test_ai_fix_service.py
|
|
git commit -m "feat: add AI fix service with prompt building and validation"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 7: Create the fix-tree endpoint
|
|
|
|
**Files:**
|
|
- Create: `backend/app/api/endpoints/ai_fix.py`
|
|
- Modify: `backend/app/api/router.py`
|
|
- Test: `backend/tests/test_ai_fix_endpoint.py`
|
|
|
|
**Step 1: Write the endpoint test**
|
|
|
|
Create `backend/tests/test_ai_fix_endpoint.py`:
|
|
|
|
```python
|
|
"""Integration tests for AI fix-tree endpoint."""
|
|
import json
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
SAMPLE_TREE = {
|
|
"id": "root",
|
|
"type": "decision",
|
|
"question": "Is the server up?",
|
|
"options": [
|
|
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
|
|
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
|
|
],
|
|
"children": [
|
|
{
|
|
"id": "check-logs",
|
|
"type": "action",
|
|
"title": "Check Logs",
|
|
"description": "Review logs.",
|
|
"next_node_id": "logs-ok",
|
|
},
|
|
{
|
|
"id": "logs-ok",
|
|
"type": "solution",
|
|
"title": "Logs OK",
|
|
"description": "Issue in logs.",
|
|
},
|
|
{
|
|
"id": "restart",
|
|
"type": "decision",
|
|
"question": "Did restart work?",
|
|
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
|
|
"children": [
|
|
{"id": "done", "type": "solution", "title": "Done", "description": "Fixed."},
|
|
],
|
|
},
|
|
],
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def enable_ai():
|
|
original_key = settings.GOOGLE_AI_API_KEY
|
|
original_provider = settings.AI_PROVIDER
|
|
settings.GOOGLE_AI_API_KEY = "fake-key"
|
|
settings.AI_PROVIDER = "gemini"
|
|
yield
|
|
settings.GOOGLE_AI_API_KEY = original_key
|
|
settings.AI_PROVIDER = original_provider
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestFixTreeEndpoint:
|
|
|
|
async def test_returns_401_without_auth(self, client):
|
|
response = await client.post("/api/v1/ai/fix-tree", json={
|
|
"tree_structure": SAMPLE_TREE,
|
|
"tree_name": "Test",
|
|
"tree_type": "troubleshooting",
|
|
"validation_errors": [{"node_id": "restart", "message": "Need 2 options"}],
|
|
})
|
|
assert response.status_code == 401
|
|
|
|
async def test_returns_503_when_ai_disabled(self, client, auth_headers):
|
|
original = settings.GOOGLE_AI_API_KEY
|
|
orig_anthropic = settings.ANTHROPIC_API_KEY
|
|
settings.GOOGLE_AI_API_KEY = None
|
|
settings.ANTHROPIC_API_KEY = None
|
|
try:
|
|
response = await client.post(
|
|
"/api/v1/ai/fix-tree",
|
|
json={
|
|
"tree_structure": SAMPLE_TREE,
|
|
"tree_name": "Test",
|
|
"tree_type": "troubleshooting",
|
|
"validation_errors": [{"node_id": "restart", "message": "test"}],
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 503
|
|
finally:
|
|
settings.GOOGLE_AI_API_KEY = original
|
|
settings.ANTHROPIC_API_KEY = orig_anthropic
|
|
|
|
async def test_returns_fixes_on_success(self, client, auth_headers, enable_ai):
|
|
fixed_node = {
|
|
"id": "restart",
|
|
"type": "decision",
|
|
"question": "Did restart work?",
|
|
"options": [
|
|
{"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"},
|
|
{"id": "opt-r-no", "label": "No", "next_node_id": "escalate"},
|
|
],
|
|
"children": [
|
|
{"id": "done", "type": "solution", "title": "Done", "description": "Fixed."},
|
|
{"id": "escalate", "type": "solution", "title": "Escalate", "description": "Escalate to vendor."},
|
|
],
|
|
}
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.generate_json = AsyncMock(
|
|
return_value=(json.dumps(fixed_node), 500, 300)
|
|
)
|
|
|
|
with patch("app.core.ai_fix_service.get_ai_provider", return_value=mock_provider):
|
|
response = await client.post(
|
|
"/api/v1/ai/fix-tree",
|
|
json={
|
|
"tree_structure": SAMPLE_TREE,
|
|
"tree_name": "Server Flow",
|
|
"tree_type": "troubleshooting",
|
|
"validation_errors": [
|
|
{"node_id": "restart", "message": "Decision node must have at least 2 options"},
|
|
],
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["fixes"]) == 1
|
|
assert data["fixes"][0]["target_node_id"] == "restart"
|
|
assert data["tokens_used"]["input"] == 500
|
|
assert data["tokens_used"]["output"] == 300
|
|
```
|
|
|
|
**Step 2: Run to verify it fails**
|
|
|
|
Run: `cd backend && python -m pytest tests/test_ai_fix_endpoint.py -v`
|
|
Expected: Fail — endpoint doesn't exist.
|
|
|
|
**Step 3: Create the endpoint**
|
|
|
|
Create `backend/app/api/endpoints/ai_fix.py`:
|
|
|
|
```python
|
|
"""AI auto-fix endpoint for tree validation errors."""
|
|
import logging
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.rate_limit import limiter
|
|
from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin
|
|
from app.core.config import settings
|
|
from app.core.ai_fix_service import generate_fixes
|
|
from app.models.user import User
|
|
from app.schemas.ai_fix import AIFixTreeRequest, AIFixTreeResponse, AIFixProposal, AIFixTokenUsage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/ai", tags=["ai-fix"])
|
|
|
|
|
|
def _require_ai_enabled() -> None:
|
|
if not settings.ai_enabled:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
|
)
|
|
|
|
|
|
@router.post("/fix-tree", response_model=AIFixTreeResponse)
|
|
@limiter.limit("10/minute")
|
|
async def fix_tree(
|
|
request: Request,
|
|
body: AIFixTreeRequest,
|
|
user: Annotated[User, Depends(require_engineer_or_admin)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
):
|
|
"""Generate AI-powered fixes for tree validation errors."""
|
|
_require_ai_enabled()
|
|
|
|
try:
|
|
fixes, total_input, total_output = await generate_fixes(
|
|
tree_structure=body.tree_structure,
|
|
tree_name=body.tree_name,
|
|
tree_type=body.tree_type,
|
|
validation_errors=[
|
|
{"node_id": e.node_id, "message": e.message}
|
|
for e in body.validation_errors
|
|
],
|
|
)
|
|
except RuntimeError as exc:
|
|
logger.error("AI fix generation failed: %s", exc)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail=str(exc),
|
|
)
|
|
except Exception as exc:
|
|
logger.exception("Unexpected error during AI fix generation")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to generate fixes. Please try again.",
|
|
)
|
|
|
|
return AIFixTreeResponse(
|
|
fixes=[
|
|
AIFixProposal(
|
|
target_node_id=f["target_node_id"],
|
|
error_message=f["error_message"],
|
|
description=f["description"],
|
|
original_node=f["original_node"],
|
|
fixed_node=f["fixed_node"],
|
|
)
|
|
for f in fixes
|
|
],
|
|
tokens_used=AIFixTokenUsage(input=total_input, output=total_output),
|
|
)
|
|
```
|
|
|
|
**Step 4: Register in router**
|
|
|
|
In `backend/app/api/router.py`, add:
|
|
|
|
After line 8 (`from app.api.endpoints import ai_builder`):
|
|
```python
|
|
from app.api.endpoints import ai_fix
|
|
```
|
|
|
|
After line 38 (`api_router.include_router(ai_builder.router)`):
|
|
```python
|
|
api_router.include_router(ai_fix.router)
|
|
```
|
|
|
|
**Step 5: Run tests**
|
|
|
|
Run: `cd backend && python -m pytest tests/test_ai_fix_endpoint.py -v`
|
|
Expected: All pass.
|
|
|
|
**Step 6: Run full test suite**
|
|
|
|
Run: `cd backend && python -m pytest --override-ini="addopts=" -v`
|
|
Expected: All pass (100+ tests).
|
|
|
|
**Step 7: Commit**
|
|
|
|
```bash
|
|
git add backend/app/api/endpoints/ai_fix.py backend/app/api/router.py backend/app/schemas/ai_fix.py backend/tests/test_ai_fix_endpoint.py
|
|
git commit -m "feat: add POST /ai/fix-tree endpoint for AI-powered validation fixes"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 8: Add frontend API client for fix-tree
|
|
|
|
**Files:**
|
|
- Modify: `frontend/src/api/trees.ts`
|
|
- Create: `frontend/src/types/ai-fix.ts`
|
|
- Modify: `frontend/src/types/index.ts`
|
|
|
|
**Step 1: Create the types**
|
|
|
|
Create `frontend/src/types/ai-fix.ts`:
|
|
|
|
```typescript
|
|
export interface AIFixValidationError {
|
|
node_id: string
|
|
message: string
|
|
}
|
|
|
|
export interface AIFixProposal {
|
|
target_node_id: string
|
|
error_message: string
|
|
description: string
|
|
original_node: Record<string, unknown>
|
|
fixed_node: Record<string, unknown>
|
|
}
|
|
|
|
export interface AIFixTreeRequest {
|
|
tree_structure: Record<string, unknown>
|
|
tree_name: string
|
|
tree_type: 'troubleshooting' | 'procedural' | 'maintenance'
|
|
validation_errors: AIFixValidationError[]
|
|
}
|
|
|
|
export interface AIFixTreeResponse {
|
|
fixes: AIFixProposal[]
|
|
tokens_used: { input: number; output: number }
|
|
}
|
|
```
|
|
|
|
**Step 2: Export from types/index.ts**
|
|
|
|
Add to `frontend/src/types/index.ts`:
|
|
|
|
```typescript
|
|
export type { AIFixTreeRequest, AIFixTreeResponse, AIFixProposal, AIFixValidationError } from './ai-fix'
|
|
```
|
|
|
|
**Step 3: Add API method to trees.ts**
|
|
|
|
In `frontend/src/api/trees.ts`, add a new method to the `treesApi` object:
|
|
|
|
```typescript
|
|
async fixTree(request: AIFixTreeRequest): Promise<AIFixTreeResponse> {
|
|
const response = await apiClient.post<AIFixTreeResponse>('/ai/fix-tree', request)
|
|
return response.data
|
|
},
|
|
```
|
|
|
|
Import the types at the top of the file.
|
|
|
|
**Step 4: Commit**
|
|
|
|
```bash
|
|
git add frontend/src/types/ai-fix.ts frontend/src/types/index.ts frontend/src/api/trees.ts
|
|
git commit -m "feat: add frontend API client and types for AI fix-tree"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 9: Add "Fix with AI" button to ValidationSummary
|
|
|
|
**Files:**
|
|
- Modify: `frontend/src/components/tree-editor/ValidationSummary.tsx`
|
|
|
|
**Step 1: Update the component**
|
|
|
|
Add new props and the button to `ValidationSummary`:
|
|
|
|
```typescript
|
|
import { useState } from 'react'
|
|
import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react'
|
|
import { cn } from '@/lib/utils'
|
|
import type { ValidationError } from '@/store/treeEditorStore'
|
|
|
|
interface ValidationSummaryProps {
|
|
errors: ValidationError[]
|
|
onSelectNode: (nodeId: string) => void
|
|
onFixWithAI?: () => void
|
|
isFixing?: boolean
|
|
}
|
|
|
|
export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) {
|
|
```
|
|
|
|
In the header button area (after the expand/collapse chevron at line 62), add the "Fix with AI" button. It should appear between the error count text and the chevron icon. Restructure the header to include the button:
|
|
|
|
After the `</span>` closing the error/warning count (around line 60), add:
|
|
|
|
```tsx
|
|
{/* Fix with AI button — only when there are fixable errors */}
|
|
{onFixWithAI && errorItems.some(e => e.nodeId) && (
|
|
<button
|
|
onClick={(e) => {
|
|
e.stopPropagation()
|
|
onFixWithAI()
|
|
}}
|
|
disabled={isFixing}
|
|
className={cn(
|
|
'ml-auto mr-2 flex items-center gap-1.5 rounded-md px-3 py-1 text-xs font-medium transition-colors',
|
|
isFixing
|
|
? 'bg-primary/10 text-primary cursor-wait'
|
|
: 'bg-gradient-brand text-white shadow-sm shadow-primary/20 hover:opacity-90'
|
|
)}
|
|
>
|
|
{isFixing ? (
|
|
<>
|
|
<Loader2 className="h-3 w-3 animate-spin" />
|
|
Generating fixes...
|
|
</>
|
|
) : (
|
|
<>
|
|
<Sparkles className="h-3 w-3" />
|
|
Fix with AI
|
|
</>
|
|
)}
|
|
</button>
|
|
)}
|
|
```
|
|
|
|
**Step 2: Build to verify**
|
|
|
|
Run: `cd frontend && npm run build`
|
|
Expected: Build passes.
|
|
|
|
**Step 3: Commit**
|
|
|
|
```bash
|
|
git add frontend/src/components/tree-editor/ValidationSummary.tsx
|
|
git commit -m "feat: add Fix with AI button to ValidationSummary"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 10: Build the AIFixReviewModal
|
|
|
|
**Files:**
|
|
- Create: `frontend/src/components/tree-editor/AIFixReviewModal.tsx`
|
|
|
|
**Step 1: Create the review modal**
|
|
|
|
Create `frontend/src/components/tree-editor/AIFixReviewModal.tsx`:
|
|
|
|
```typescript
|
|
import { useState } from 'react'
|
|
import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react'
|
|
import { cn } from '@/lib/utils'
|
|
import type { AIFixProposal } from '@/types'
|
|
|
|
interface AIFixReviewModalProps {
|
|
fixes: AIFixProposal[]
|
|
onApply: (fix: AIFixProposal) => void
|
|
onApplyAll: () => void
|
|
onClose: () => void
|
|
}
|
|
|
|
export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) {
|
|
const [appliedIds, setAppliedIds] = useState<Set<string>>(new Set())
|
|
const [skippedIds, setSkippedIds] = useState<Set<string>>(new Set())
|
|
const [expandedIds, setExpandedIds] = useState<Set<string>>(new Set(fixes.map(f => f.target_node_id)))
|
|
|
|
const handleApply = (fix: AIFixProposal) => {
|
|
onApply(fix)
|
|
setAppliedIds(prev => new Set(prev).add(fix.target_node_id))
|
|
}
|
|
|
|
const handleSkip = (fix: AIFixProposal) => {
|
|
setSkippedIds(prev => new Set(prev).add(fix.target_node_id))
|
|
}
|
|
|
|
const toggleExpanded = (id: string) => {
|
|
setExpandedIds(prev => {
|
|
const next = new Set(prev)
|
|
if (next.has(id)) next.delete(id)
|
|
else next.add(id)
|
|
return next
|
|
})
|
|
}
|
|
|
|
const pendingFixes = fixes.filter(
|
|
f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id)
|
|
)
|
|
const allHandled = pendingFixes.length === 0
|
|
|
|
return (
|
|
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/80 backdrop-blur-sm p-4">
|
|
<div className="relative flex h-[80vh] w-full max-w-2xl flex-col bg-card border border-border rounded-2xl shadow-lg">
|
|
{/* Header */}
|
|
<div className="flex items-center justify-between border-b border-border px-6 py-4">
|
|
<div className="flex items-center gap-2">
|
|
<Sparkles className="h-5 w-5 text-primary" />
|
|
<h2 className="text-lg font-semibold text-foreground">
|
|
AI Fix Proposals ({fixes.length})
|
|
</h2>
|
|
</div>
|
|
<button
|
|
onClick={onClose}
|
|
className="rounded-md p-1 text-muted-foreground hover:bg-accent hover:text-foreground"
|
|
>
|
|
<X className="h-5 w-5" />
|
|
</button>
|
|
</div>
|
|
|
|
{/* Body */}
|
|
<div className="flex-1 overflow-y-auto p-4 space-y-3">
|
|
{fixes.map((fix) => {
|
|
const isApplied = appliedIds.has(fix.target_node_id)
|
|
const isSkipped = skippedIds.has(fix.target_node_id)
|
|
const isExpanded = expandedIds.has(fix.target_node_id)
|
|
|
|
return (
|
|
<div
|
|
key={fix.target_node_id}
|
|
className={cn(
|
|
'rounded-lg border p-4',
|
|
isApplied
|
|
? 'border-emerald-400/30 bg-emerald-400/5'
|
|
: isSkipped
|
|
? 'border-border bg-accent/30 opacity-60'
|
|
: 'border-border bg-card'
|
|
)}
|
|
>
|
|
{/* Fix header */}
|
|
<div className="flex items-start justify-between gap-3">
|
|
<div className="flex-1">
|
|
<p className="text-sm text-red-400 mb-1">{fix.error_message}</p>
|
|
<p className="text-sm text-foreground">{fix.description}</p>
|
|
<p className="text-xs text-muted-foreground mt-1">
|
|
Node: {fix.target_node_id}
|
|
</p>
|
|
</div>
|
|
{isApplied && (
|
|
<span className="flex items-center gap-1 rounded-full bg-emerald-400/10 px-2 py-1 text-xs text-emerald-400">
|
|
<Check className="h-3 w-3" /> Applied
|
|
</span>
|
|
)}
|
|
{isSkipped && (
|
|
<span className="text-xs text-muted-foreground">Skipped</span>
|
|
)}
|
|
</div>
|
|
|
|
{/* Expand/collapse detail */}
|
|
{!isApplied && !isSkipped && (
|
|
<>
|
|
<button
|
|
onClick={() => toggleExpanded(fix.target_node_id)}
|
|
className="mt-2 flex items-center gap-1 text-xs text-muted-foreground hover:text-foreground"
|
|
>
|
|
{isExpanded ? <ChevronUp className="h-3 w-3" /> : <ChevronDown className="h-3 w-3" />}
|
|
{isExpanded ? 'Hide' : 'Show'} details
|
|
</button>
|
|
|
|
{isExpanded && (
|
|
<div className="mt-3 grid grid-cols-2 gap-3">
|
|
<div>
|
|
<p className="text-xs font-medium text-muted-foreground mb-1">Before</p>
|
|
<pre className="overflow-x-auto rounded bg-accent/50 p-2 text-xs text-muted-foreground max-h-48 overflow-y-auto">
|
|
{JSON.stringify(fix.original_node, null, 2)}
|
|
</pre>
|
|
</div>
|
|
<div>
|
|
<p className="text-xs font-medium text-emerald-400 mb-1">After</p>
|
|
<pre className="overflow-x-auto rounded bg-emerald-400/5 p-2 text-xs text-foreground max-h-48 overflow-y-auto">
|
|
{JSON.stringify(fix.fixed_node, null, 2)}
|
|
</pre>
|
|
</div>
|
|
</div>
|
|
)}
|
|
|
|
{/* Action buttons */}
|
|
<div className="mt-3 flex gap-2">
|
|
<button
|
|
onClick={() => handleApply(fix)}
|
|
className="flex items-center gap-1 rounded-md bg-gradient-brand px-3 py-1.5 text-xs font-medium text-white shadow-sm shadow-primary/20 hover:opacity-90"
|
|
>
|
|
<Check className="h-3 w-3" />
|
|
Apply
|
|
</button>
|
|
<button
|
|
onClick={() => handleSkip(fix)}
|
|
className="flex items-center gap-1 rounded-md border border-border px-3 py-1.5 text-xs font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
|
|
>
|
|
<SkipForward className="h-3 w-3" />
|
|
Skip
|
|
</button>
|
|
</div>
|
|
</>
|
|
)}
|
|
</div>
|
|
)
|
|
})}
|
|
</div>
|
|
|
|
{/* Footer */}
|
|
<div className="flex items-center justify-between border-t border-border px-6 py-4">
|
|
<button
|
|
onClick={onClose}
|
|
className="rounded-md border border-border px-4 py-2 text-sm font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
|
|
>
|
|
{allHandled ? 'Done' : 'Cancel'}
|
|
</button>
|
|
{!allHandled && (
|
|
<button
|
|
onClick={onApplyAll}
|
|
className="rounded-md bg-gradient-brand px-4 py-2 text-sm font-medium text-white shadow-lg shadow-primary/20 hover:opacity-90"
|
|
>
|
|
Apply All ({pendingFixes.length})
|
|
</button>
|
|
)}
|
|
</div>
|
|
</div>
|
|
</div>
|
|
)
|
|
}
|
|
```
|
|
|
|
**Step 2: Build to verify**
|
|
|
|
Run: `cd frontend && npm run build`
|
|
Expected: Build passes.
|
|
|
|
**Step 3: Commit**
|
|
|
|
```bash
|
|
git add frontend/src/components/tree-editor/AIFixReviewModal.tsx
|
|
git commit -m "feat: add AIFixReviewModal component for reviewing AI-proposed fixes"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 11: Wire everything together in TreeEditorPage
|
|
|
|
**Files:**
|
|
- Modify: `frontend/src/pages/TreeEditorPage.tsx`
|
|
|
|
This is the integration task. The page needs to:
|
|
|
|
1. Import the new components and API
|
|
2. Add state for fix flow (`isFixing`, `fixProposals`)
|
|
3. Handle "Fix with AI" button click — call `treesApi.fixTree()`
|
|
4. Show `AIFixReviewModal` when proposals are available
|
|
5. Handle apply/skip — call `updateNode()` on the tree editor store
|
|
6. Re-run `validate()` after applying fixes
|
|
|
|
**Step 1: Find where ValidationSummary is rendered in TreeEditorPage**
|
|
|
|
Search for `<ValidationSummary` in `TreeEditorPage.tsx` and note the exact location.
|
|
|
|
**Step 2: Add imports**
|
|
|
|
Add to the imports in `TreeEditorPage.tsx`:
|
|
|
|
```typescript
|
|
import { AIFixReviewModal } from '@/components/tree-editor/AIFixReviewModal'
|
|
import { treesApi } from '@/api/trees'
|
|
import type { AIFixProposal } from '@/types'
|
|
```
|
|
|
|
**Step 3: Add state**
|
|
|
|
Inside the `TreeEditorPage` component, add:
|
|
|
|
```typescript
|
|
const [isFixing, setIsFixing] = useState(false)
|
|
const [fixProposals, setFixProposals] = useState<AIFixProposal[] | null>(null)
|
|
```
|
|
|
|
**Step 4: Add handler for "Fix with AI"**
|
|
|
|
```typescript
|
|
const handleFixWithAI = async () => {
|
|
const store = useTreeEditorStore.getState()
|
|
if (!store.treeStructure) return
|
|
|
|
// Get only fixable errors (structural errors with nodeId)
|
|
const fixableErrors = store.validationErrors
|
|
.filter(e => e.severity === 'error' && e.nodeId)
|
|
.map(e => ({ node_id: e.nodeId!, message: e.message }))
|
|
|
|
if (fixableErrors.length === 0) return
|
|
|
|
setIsFixing(true)
|
|
try {
|
|
const result = await treesApi.fixTree({
|
|
tree_structure: store.treeStructure as Record<string, unknown>,
|
|
tree_name: store.name,
|
|
tree_type: (store.treeType || 'troubleshooting') as 'troubleshooting' | 'procedural' | 'maintenance',
|
|
validation_errors: fixableErrors,
|
|
})
|
|
if (result.fixes.length > 0) {
|
|
setFixProposals(result.fixes)
|
|
} else {
|
|
toast.info('AI could not generate fixes for these errors')
|
|
}
|
|
} catch {
|
|
toast.error('Failed to generate AI fixes. Please try again.')
|
|
} finally {
|
|
setIsFixing(false)
|
|
}
|
|
}
|
|
```
|
|
|
|
**Step 5: Add handlers for apply/close**
|
|
|
|
```typescript
|
|
const handleApplyFix = (fix: AIFixProposal) => {
|
|
const store = useTreeEditorStore.getState()
|
|
store.updateNode(fix.target_node_id, fix.fixed_node as Partial<TreeStructure>)
|
|
}
|
|
|
|
const handleApplyAllFixes = () => {
|
|
if (!fixProposals) return
|
|
for (const fix of fixProposals) {
|
|
handleApplyFix(fix)
|
|
}
|
|
setFixProposals(null)
|
|
// Re-validate after applying all fixes
|
|
setTimeout(() => {
|
|
useTreeEditorStore.getState().validate()
|
|
}, 100)
|
|
}
|
|
|
|
const handleCloseFixModal = () => {
|
|
setFixProposals(null)
|
|
// Re-validate in case some fixes were applied
|
|
useTreeEditorStore.getState().validate()
|
|
}
|
|
```
|
|
|
|
**Step 6: Pass props to ValidationSummary**
|
|
|
|
Update the `<ValidationSummary>` JSX to include the new props:
|
|
|
|
```tsx
|
|
<ValidationSummary
|
|
errors={validationErrors}
|
|
onSelectNode={handleSelectNode}
|
|
onFixWithAI={handleFixWithAI}
|
|
isFixing={isFixing}
|
|
/>
|
|
```
|
|
|
|
**Step 7: Add the review modal**
|
|
|
|
After `<ValidationSummary>`, add:
|
|
|
|
```tsx
|
|
{fixProposals && (
|
|
<AIFixReviewModal
|
|
fixes={fixProposals}
|
|
onApply={handleApplyFix}
|
|
onApplyAll={handleApplyAllFixes}
|
|
onClose={handleCloseFixModal}
|
|
/>
|
|
)}
|
|
```
|
|
|
|
**Step 8: Build to verify**
|
|
|
|
Run: `cd frontend && npm run build`
|
|
Expected: Build passes.
|
|
|
|
**Step 9: Commit**
|
|
|
|
```bash
|
|
git add frontend/src/pages/TreeEditorPage.tsx
|
|
git commit -m "feat: wire AI fix flow into TreeEditorPage"
|
|
```
|
|
|
|
---
|
|
|
|
## Task 12: Final verification
|
|
|
|
**Step 1: Run all backend tests**
|
|
|
|
Run: `cd backend && python -m pytest --override-ini="addopts=" -v`
|
|
Expected: All pass.
|
|
|
|
**Step 2: Run frontend build**
|
|
|
|
Run: `cd frontend && npm run build`
|
|
Expected: Build passes with no errors.
|
|
|
|
**Step 3: Final commit if any adjustments needed**
|
|
|
|
Fix any issues found during verification and commit.
|