feat: AI auto-fix + Gemini Flash provider #93

Merged
chihlasm merged 14 commits from feat/ai-autofix-gemini into main 2026-02-27 07:32:24 +00:00
21 changed files with 3516 additions and 120 deletions

View File

@@ -10,7 +10,6 @@
import logging import logging
from typing import Annotated from typing import Annotated
import anthropic
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -52,10 +51,9 @@ def _require_ai_enabled() -> None:
if not settings.ai_enabled: if not settings.ai_enabled:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.", detail="AI flow builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
) )
@router.get("/quota", response_model=AIQuotaStatusResponse) @router.get("/quota", response_model=AIQuotaStatusResponse)
async def get_quota( async def get_quota(
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
@@ -174,27 +172,6 @@ async def scaffold(
branches, input_tokens, output_tokens, cost = await scaffold_branches( branches, input_tokens, output_tokens, cost = await scaffold_branches(
conversation.wizard_state, conversation.wizard_state,
) )
except anthropic.APIError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="scaffold",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e)},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="AI provider error. Please try again.",
)
except ValueError as e: except ValueError as e:
await record_ai_usage( await record_ai_usage(
user_id=current_user.id, user_id=current_user.id,
@@ -216,6 +193,28 @@ async def scaffold(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"AI returned invalid output: {e}", detail=f"AI returned invalid output: {e}",
) )
except Exception as e:
logger.exception("AI scaffold failed: %s: %s", type(e).__name__, e)
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="scaffold",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e)},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"AI provider error ({type(e).__name__}). Please try again.",
)
# Record successful usage # Record successful usage
await record_ai_usage( await record_ai_usage(
@@ -293,27 +292,6 @@ async def branch_detail(
existing_branches, existing_branches,
) )
) )
except anthropic.APIError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="branch_detail",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e), "branch_name": data.branch_name},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="AI provider error. Please try again.",
)
except ValueError as e: except ValueError as e:
await record_ai_usage( await record_ai_usage(
user_id=current_user.id, user_id=current_user.id,
@@ -335,6 +313,28 @@ async def branch_detail(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"AI returned invalid output: {e}", detail=f"AI returned invalid output: {e}",
) )
except Exception as e:
logger.exception("AI branch_detail failed: %s: %s", type(e).__name__, e)
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="branch_detail",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e), "branch_name": data.branch_name},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"AI provider error ({type(e).__name__}). Please try again.",
)
# Record successful usage # Record successful usage
await record_ai_usage( await record_ai_usage(

View File

@@ -0,0 +1,78 @@
"""AI auto-fix endpoint for tree validation errors.
POST /ai/fix-tree — accepts a tree with validation errors and returns
AI-generated fix proposals for each error.
"""
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, require_engineer_or_admin
from app.core.config import settings
from app.core.rate_limit import limiter
from app.core.ai_fix_service import generate_fixes
from app.models.user import User
from app.schemas.ai_fix import (
AIFixTreeRequest,
AIFixTreeResponse,
AIFixProposal,
AIFixTokenUsage,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ai", tags=["ai-fix"])
def _require_ai_enabled() -> None:
"""Raise 503 if AI is not configured."""
if not settings.ai_enabled:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI fix is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
)
@router.post("/fix-tree", response_model=AIFixTreeResponse)
@limiter.limit("10/minute")
async def fix_tree(
request: Request,
body: AIFixTreeRequest,
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Generate AI-powered fixes for tree validation errors."""
_require_ai_enabled()
validation_errors = [
{"node_id": e.node_id, "message": e.message}
for e in body.validation_errors
]
try:
fixes, input_tokens, output_tokens = await generate_fixes(
tree_structure=body.tree_structure,
tree_name=body.tree_name,
tree_type=body.tree_type,
validation_errors=validation_errors,
)
except RuntimeError as exc:
logger.error("AI provider not available: %s", exc)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=str(exc),
)
except Exception as exc:
logger.exception("Unexpected error in AI fix service")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred while generating fixes.",
)
return AIFixTreeResponse(
fixes=[AIFixProposal(**f) for f in fixes],
tokens_used=AIFixTokenUsage(input=input_tokens, output=output_tokens),
)

View File

@@ -6,6 +6,7 @@ from app.api.endpoints import target_lists
from app.api.endpoints import maintenance_schedules from app.api.endpoints import maintenance_schedules
from app.api.endpoints import feedback from app.api.endpoints import feedback
from app.api.endpoints import ai_builder from app.api.endpoints import ai_builder
from app.api.endpoints import ai_fix
api_router = APIRouter() api_router = APIRouter()
@@ -36,3 +37,4 @@ api_router.include_router(target_lists.router)
api_router.include_router(maintenance_schedules.router) api_router.include_router(maintenance_schedules.router)
api_router.include_router(feedback.router) api_router.include_router(feedback.router)
api_router.include_router(ai_builder.router) api_router.include_router(ai_builder.router)
api_router.include_router(ai_fix.router)

View File

@@ -0,0 +1,273 @@
"""AI-powered fix service for tree validation errors.
Given a tree structure and validation errors, generates AI-powered
proposals to fix each structural issue while preserving existing content.
"""
import copy
import json
import logging
import re
from typing import Any
from app.core.ai_provider import get_ai_provider
from app.core.ai_tree_validator import validate_generated_tree
logger = logging.getLogger(__name__)
FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers.
You will receive:
1. A full flow outline showing the tree structure
2. The specific failing node with its full JSON
3. The validation error message
Your task: Return a FIXED version of the failing node as valid JSON. Rules:
- Fix ONLY the structural issue described in the error message
- Keep ALL existing content (titles, descriptions, questions, options) unchanged
- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic
- Every new node must have a unique ID (use descriptive kebab-case IDs)
- Decision nodes must have at least 2 options and at least 2 children
- Action nodes must have a next_node_id pointing to a sibling node in the parent's children
- Solution nodes are leaf nodes (no children)
- Return ONLY the fixed node JSON, no explanation"""
# ── Pure helper functions ──
def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None:
"""Recursively find a node by its ID in the tree structure."""
if not isinstance(tree, dict):
return None
if tree.get("id") == node_id:
return tree
for child in tree.get("children", []):
result = _find_node_by_id(child, node_id)
if result is not None:
return result
return None
def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None:
"""Find the parent of a node with the given ID."""
if not isinstance(tree, dict):
return None
for child in tree.get("children", []):
if isinstance(child, dict) and child.get("id") == target_id:
return tree
result = _find_parent_node(child, target_id)
if result is not None:
return result
return None
def _serialize_tree_outline(
tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None
) -> str:
"""Serialize tree as a readable outline for AI prompt context.
Format: indented "- [type] label" with "<<< ERROR HERE" marker.
"""
if not isinstance(tree, dict):
return ""
node_type = tree.get("type", "unknown")
label = tree.get("question") or tree.get("title") or tree.get("id", "?")
prefix = " " * indent
marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else ""
line = f"{prefix}- [{node_type}] {label}{marker}"
lines = [line]
for child in tree.get("children", []):
lines.append(_serialize_tree_outline(child, indent + 1, error_node_id))
return "\n".join(lines)
def _strip_markdown_fences(text: str) -> str:
"""Strip ```json...``` fences from AI response."""
return re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.MULTILINE).rstrip(
"`"
).strip()
def _replace_node_in_tree(
tree: dict[str, Any], target_id: str, replacement: dict[str, Any]
) -> bool:
"""Replace a node in-place by ID. Returns True if found and replaced."""
if not isinstance(tree, dict):
return False
if tree.get("id") == target_id:
tree.clear()
tree.update(replacement)
return True
for child in tree.get("children", []):
if _replace_node_in_tree(child, target_id, replacement):
return True
return False
def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str:
"""Describe what changed between original and fixed node."""
changes: list[str] = []
orig_children = len(original.get("children", []))
fixed_children = len(fixed.get("children", []))
if fixed_children > orig_children:
changes.append(f"added {fixed_children - orig_children} child node(s)")
orig_options = len(original.get("options", []))
fixed_options = len(fixed.get("options", []))
if fixed_options > orig_options:
changes.append(f"added {fixed_options - orig_options} option(s)")
if fixed.get("next_node_id") and not original.get("next_node_id"):
changes.append("added next_node_id")
if not changes:
changes.append("fixed structural issue")
return "; ".join(changes).capitalize()
# ── Prompt building ──
def _build_fix_prompt(
tree: dict[str, Any],
node_id: str,
error_message: str,
tree_name: str,
tree_type: str,
) -> str:
"""Build the user message for the AI fix request."""
outline = _serialize_tree_outline(tree, error_node_id=node_id)
node = _find_node_by_id(tree, node_id)
node_json = json.dumps(node, indent=2) if node else "{}"
return (
f"Flow name: {tree_name}\n"
f"Flow type: {tree_type}\n\n"
f"## Full flow outline\n```\n{outline}\n```\n\n"
f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n"
f"## Validation error\n{error_message}\n\n"
f"Return the fixed version of this node as JSON."
)
# ── Main entry point ──
async def generate_fixes(
tree_structure: dict[str, Any],
tree_name: str,
tree_type: str,
validation_errors: list[dict[str, str]],
) -> tuple[list[dict[str, Any]], int, int]:
"""Generate AI-powered fixes for tree validation errors.
Args:
tree_structure: Full tree structure dict.
tree_name: Name of the flow.
tree_type: Type of flow (troubleshooting, procedural, maintenance).
validation_errors: List of dicts with "node_id" and "message" keys.
Returns:
Tuple of (fixes_list, total_input_tokens, total_output_tokens).
Each fix dict has: target_node_id, error_message, description,
original_node, fixed_node.
"""
provider = get_ai_provider()
fixes: list[dict[str, Any]] = []
total_input_tokens = 0
total_output_tokens = 0
for error in validation_errors:
node_id = error["node_id"]
error_message = error["message"]
original_node = _find_node_by_id(tree_structure, node_id)
if original_node is None:
logger.warning("Node %s not found in tree, skipping fix", node_id)
continue
original_snapshot = copy.deepcopy(original_node)
# Build prompt and call AI
user_message = _build_fix_prompt(
tree_structure, node_id, error_message, tree_name, tree_type
)
messages = [{"role": "user", "content": user_message}]
try:
text, in_tok, out_tok = await provider.generate_json(
system_prompt=FIX_SYSTEM_PROMPT,
messages=messages,
max_tokens=2048,
)
total_input_tokens += in_tok
total_output_tokens += out_tok
cleaned = _strip_markdown_fences(text)
fixed_node = json.loads(cleaned)
except (json.JSONDecodeError, Exception) as exc:
logger.warning("AI fix failed for node %s: %s", node_id, exc)
continue
# Validate by substituting into a tree copy
tree_copy = copy.deepcopy(tree_structure)
_replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node))
remaining_errors = validate_generated_tree(tree_copy)
# Check if the specific error is still present
still_has_error = any(node_id in e for e in remaining_errors)
if still_has_error:
# Retry once with corrective prompt
retry_message = (
f"Your previous fix still has validation errors:\n"
f"{chr(10).join(remaining_errors)}\n\n"
f"Please fix the node again. Return ONLY the corrected JSON."
)
messages.append({"role": "assistant", "content": text})
messages.append({"role": "user", "content": retry_message})
try:
text2, in_tok2, out_tok2 = await provider.generate_json(
system_prompt=FIX_SYSTEM_PROMPT,
messages=messages,
max_tokens=2048,
)
total_input_tokens += in_tok2
total_output_tokens += out_tok2
cleaned2 = _strip_markdown_fences(text2)
fixed_node = json.loads(cleaned2)
# Re-validate
tree_copy2 = copy.deepcopy(tree_structure)
_replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node))
remaining2 = validate_generated_tree(tree_copy2)
still_has_error = any(node_id in e for e in remaining2)
except (json.JSONDecodeError, Exception) as exc:
logger.warning("AI retry fix failed for node %s: %s", node_id, exc)
continue
if still_has_error:
logger.warning("AI could not fix node %s after retry", node_id)
continue
description = _describe_fix(original_snapshot, fixed_node)
fixes.append(
{
"target_node_id": node_id,
"error_message": error_message,
"description": description,
"original_node": original_snapshot,
"fixed_node": fixed_node,
}
)
return fixes, total_input_tokens, total_output_tokens

View File

@@ -0,0 +1,175 @@
"""
AI Provider abstraction layer.
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
backends for JSON generation used by the AI Flow Builder.
"""
import logging
from abc import ABC, abstractmethod
from app.core.config import settings
logger = logging.getLogger(__name__)
class AIProvider(ABC):
"""Abstract base class for AI providers."""
@abstractmethod
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Generate a JSON response from the AI model.
Args:
system_prompt: System-level instruction for the model.
messages: List of message dicts with "role" and "content" keys.
max_tokens: Maximum output tokens.
Returns:
Tuple of (response_text, input_tokens, output_tokens).
"""
...
class GeminiProvider(AIProvider):
"""Google Gemini provider using the google-genai SDK."""
def __init__(self, api_key: str, model: str) -> None:
self._api_key = api_key
self._model = model
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
from google import genai
from google.genai import types as genai_types
client = genai.Client(api_key=self._api_key)
# Convert messages to Gemini Content format
contents: list[genai_types.Content] = []
for msg in messages:
role = "model" if msg["role"] == "assistant" else "user"
contents.append(
genai_types.Content(
role=role,
parts=[genai_types.Part(text=msg["content"])],
)
)
config = genai_types.GenerateContentConfig(
system_instruction=system_prompt,
max_output_tokens=max_tokens,
response_mime_type="application/json",
)
response = await client.aio.models.generate_content(
model=self._model,
contents=contents,
config=config,
)
# Log finish reason to detect truncation
if response.candidates:
finish_reason = getattr(response.candidates[0], "finish_reason", None)
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
if str(finish_reason) == "MAX_TOKENS":
logger.warning(
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
max_tokens,
)
text = response.text or ""
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
output_tokens = (
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
)
return text, input_tokens, output_tokens
class AnthropicProvider(AIProvider):
"""Anthropic Claude provider using the anthropic SDK."""
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
self._api_key = api_key
self._model = model
self._timeout = timeout
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
import anthropic
client = anthropic.AsyncAnthropic(
api_key=self._api_key,
timeout=self._timeout,
)
response = await client.messages.create(
model=self._model,
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
)
text = response.content[0].text
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
return text, input_tokens, output_tokens
def get_ai_provider() -> AIProvider:
"""Factory that returns the configured AI provider.
Selection logic:
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
3. Fallback: if preferred provider key missing, try the other one
4. If nothing configured -> raise RuntimeError
"""
provider = settings.AI_PROVIDER
if provider == "gemini":
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
# Fallback to Anthropic
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
elif provider == "anthropic":
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
# Fallback to Gemini
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
raise RuntimeError(
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
)

View File

@@ -1,11 +1,11 @@
"""AI-powered tree generation service using Anthropic Claude API. """AI-powered tree generation service.
Implements the 4-stage wizard flow: Implements the 4-stage wizard flow:
Stage 2 (scaffold): AI suggests 4-7 top-level branches Stage 2 (scaffold): AI suggests 4-7 top-level branches
Stage 3 (branch_detail): AI generates detailed nodes per branch Stage 3 (branch_detail): AI generates detailed nodes per branch
Stage 4 (assemble): Pure assembly logic — zero AI calls Stage 4 (assemble): Pure assembly logic — zero AI calls
System prompts are static constants to enable Anthropic prompt caching. Uses the provider abstraction from ai_provider.py (supports Gemini + Anthropic).
""" """
import json import json
import logging import logging
@@ -13,8 +13,7 @@ import re
import uuid import uuid
from typing import Any from typing import Any
import anthropic from app.core.ai_provider import get_ai_provider
from app.core.config import settings from app.core.config import settings
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
@@ -121,15 +120,6 @@ def _strip_markdown_fences(text: str) -> str:
return text return text
def _get_client() -> anthropic.AsyncAnthropic:
"""Get configured async Anthropic client."""
if not settings.ANTHROPIC_API_KEY:
raise RuntimeError("ANTHROPIC_API_KEY not configured")
return anthropic.AsyncAnthropic(
api_key=settings.ANTHROPIC_API_KEY,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
def _estimate_cost(input_tokens: int, output_tokens: int) -> float: def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
"""Estimate USD cost from token counts.""" """Estimate USD cost from token counts."""
@@ -146,7 +136,7 @@ async def scaffold_branches(
Returns (branches, input_tokens, output_tokens, estimated_cost). Returns (branches, input_tokens, output_tokens, estimated_cost).
Raises ValueError on invalid response. Raises ValueError on invalid response.
""" """
client = _get_client() provider = get_ai_provider()
flow_type = wizard_state.get("flow_type", "troubleshooting") flow_type = wizard_state.get("flow_type", "troubleshooting")
name = wizard_state.get("name", "") name = wizard_state.get("name", "")
@@ -161,21 +151,26 @@ async def scaffold_branches(
if tags: if tags:
user_message += f"Environment: {', '.join(tags)}\n" user_message += f"Environment: {', '.join(tags)}\n"
response = await client.messages.create( raw_text, input_tokens, output_tokens = await provider.generate_json(
model=settings.AI_MODEL, system_prompt=SCAFFOLD_SYSTEM_PROMPT,
max_tokens=1024,
system=SCAFFOLD_SYSTEM_PROMPT,
messages=[{"role": "user", "content": user_message}], messages=[{"role": "user", "content": user_message}],
max_tokens=2048,
) )
raw_text = _strip_markdown_fences(response.content[0].text) logger.info(
input_tokens = response.usage.input_tokens "scaffold raw response (tokens in=%d out=%d, len=%d): %s",
output_tokens = response.usage.output_tokens input_tokens,
output_tokens,
len(raw_text),
raw_text[:500],
)
raw_text = _strip_markdown_fences(raw_text)
cost = _estimate_cost(input_tokens, output_tokens) cost = _estimate_cost(input_tokens, output_tokens)
try: try:
data = json.loads(raw_text) data = json.loads(raw_text)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error("scaffold JSON parse failed. Full text (%d chars): %s", len(raw_text), raw_text)
raise ValueError(f"AI returned invalid JSON: {e}") raise ValueError(f"AI returned invalid JSON: {e}")
branches = data.get("branches", []) branches = data.get("branches", [])
@@ -196,7 +191,7 @@ async def generate_branch_detail(
On validation failure, retries once with corrective prompt. On validation failure, retries once with corrective prompt.
Raises ValueError if both attempts fail. Raises ValueError if both attempts fail.
""" """
client = _get_client() provider = get_ai_provider()
flow_type = wizard_state.get("flow_type", "troubleshooting") flow_type = wizard_state.get("flow_type", "troubleshooting")
name = wizard_state.get("name", "") name = wizard_state.get("name", "")
@@ -217,35 +212,30 @@ async def generate_branch_detail(
total_output = 0 total_output = 0
for attempt in range(3): for attempt in range(3):
response = await client.messages.create( raw_text, input_tokens, output_tokens = await provider.generate_json(
model=settings.AI_MODEL, system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT,
max_tokens=8192,
system=BRANCH_DETAIL_SYSTEM_PROMPT,
messages=messages, messages=messages,
max_tokens=8192,
) )
total_input += response.usage.input_tokens total_input += input_tokens
total_output += response.usage.output_tokens total_output += output_tokens
logger.debug( logger.debug(
"branch_detail attempt=%d stop_reason=%s content_blocks=%d output_tokens=%d", "branch_detail attempt=%d output_tokens=%d",
attempt, attempt,
response.stop_reason, output_tokens,
len(response.content),
response.usage.output_tokens,
) )
if response.stop_reason == "max_tokens": raw_text = _strip_markdown_fences(raw_text) if raw_text else ""
logger.warning(
"branch_detail attempt=%d hit max_tokens limit (%d output tokens) — response may be truncated",
attempt,
response.usage.output_tokens,
)
raw_text = _strip_markdown_fences(response.content[0].text) if response.content else ""
if not raw_text: if not raw_text:
logger.warning("branch_detail attempt=%d returned empty text, stop_reason=%s", attempt, response.stop_reason) logger.warning("branch_detail attempt=%d returned empty text", attempt)
try: try:
branch_tree = json.loads(raw_text) branch_tree = json.loads(raw_text)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(
"branch_detail attempt=%d JSON parse failed (%d chars): %s",
attempt, len(raw_text), raw_text[:500],
)
if attempt < 2: if attempt < 2:
messages.append({"role": "assistant", "content": raw_text}) messages.append({"role": "assistant", "content": raw_text})
messages.append({ messages.append({

View File

@@ -78,11 +78,16 @@ class Settings(BaseSettings):
AI_CONVERSATION_TTL_HOURS: int = 24 AI_CONVERSATION_TTL_HOURS: int = 24
AI_MAX_CALLS_PER_FLOW: int = 10 AI_MAX_CALLS_PER_FLOW: int = 10
AI_REQUEST_TIMEOUT_SECONDS: int = 45 AI_REQUEST_TIMEOUT_SECONDS: int = 45
# AI Provider selection
AI_PROVIDER: str = "gemini" # "gemini" or "anthropic"
GOOGLE_AI_API_KEY: Optional[str] = None
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
@property @property
def ai_enabled(self) -> bool: def ai_enabled(self) -> bool:
"""Check if AI Flow Builder is configured.""" """Check if any AI provider is configured."""
return self.ANTHROPIC_API_KEY is not None return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None
# Deployment auto-seed test data on PR environments # Deployment auto-seed test data on PR environments
SEED_ON_DEPLOY: bool = False SEED_ON_DEPLOY: bool = False

View File

@@ -0,0 +1,52 @@
"""Pydantic schemas for the AI auto-fix feature."""
from typing import Any, Literal
from pydantic import BaseModel, Field
# ── Requests ──
class ValidationErrorInput(BaseModel):
"""A single validation error to fix."""
node_id: str = Field(..., description="ID of the node with the error")
message: str = Field(..., description="The validation error message")
class AIFixTreeRequest(BaseModel):
"""Request to generate AI fixes for validation errors."""
tree_structure: dict[str, Any] = Field(..., description="Full tree structure")
tree_name: str = Field("", max_length=255, description="Name of the flow")
tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field(
"troubleshooting", description="Type of flow"
)
validation_errors: list[ValidationErrorInput] = Field(
..., min_length=1, max_length=10, description="Errors to fix"
)
# ── Responses ──
class AIFixProposal(BaseModel):
"""A single proposed fix from the AI."""
target_node_id: str
error_message: str
description: str
original_node: dict[str, Any]
fixed_node: dict[str, Any]
class AIFixTokenUsage(BaseModel):
input: int = 0
output: int = 0
class AIFixTreeResponse(BaseModel):
"""Response with proposed fixes."""
fixes: list[AIFixProposal]
tokens_used: AIFixTokenUsage

View File

@@ -33,6 +33,7 @@ httpx>=0.27.0
# AI Flow Builder # AI Flow Builder
anthropic>=0.40.0 anthropic>=0.40.0
google-genai>=1.0.0
# Utilities # Utilities
python-dotenv==1.0.1 python-dotenv==1.0.1

View File

@@ -1,6 +1,6 @@
"""Integration tests for AI Flow Builder endpoints. """Integration tests for AI Flow Builder endpoints.
All Anthropic API calls are mocked — zero real API spend. All AI provider calls are mocked — zero real API spend.
""" """
import json import json
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch, MagicMock
@@ -64,12 +64,11 @@ BRANCH_DETAIL_JSON = json.dumps({
}) })
def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200): def _mock_ai_provider(text: str, input_tokens: int = 100, output_tokens: int = 200):
"""Create a mock Anthropic API response.""" """Create a mock AI provider whose generate_json returns the given text and token counts."""
response = MagicMock() provider = MagicMock()
response.content = [MagicMock(text=text)] provider.generate_json = AsyncMock(return_value=(text, input_tokens, output_tokens))
response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens) return provider
return response
@pytest.fixture @pytest.fixture
@@ -194,11 +193,9 @@ async def test_scaffold_success(client, auth_headers, enable_ai):
) )
conversation_id = start_resp.json()["conversation_id"] conversation_id = start_resp.json()["conversation_id"]
# Mock Anthropic # Mock AI provider
mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) mock_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client: with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=mock_provider):
mock_client.return_value.messages.create = AsyncMock(return_value=mock_response)
response = await client.post( response = await client.post(
"/api/v1/ai/scaffold", "/api/v1/ai/scaffold",
json={"conversation_id": conversation_id}, json={"conversation_id": conversation_id},
@@ -241,9 +238,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
) )
conversation_id = start_resp.json()["conversation_id"] conversation_id = start_resp.json()["conversation_id"]
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client: with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
await client.post( await client.post(
"/api/v1/ai/scaffold", "/api/v1/ai/scaffold",
json={"conversation_id": conversation_id}, json={"conversation_id": conversation_id},
@@ -251,10 +247,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
) )
# Now generate branch detail # Now generate branch detail
detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON) detail_provider = _mock_ai_provider(BRANCH_DETAIL_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client: with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=detail_provider):
mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock)
response = await client.post( response = await client.post(
"/api/v1/ai/branch-detail", "/api/v1/ai/branch-detail",
json={ json={
@@ -290,9 +284,8 @@ async def test_assemble_success(client, auth_headers, enable_ai):
conversation_id = start_resp.json()["conversation_id"] conversation_id = start_resp.json()["conversation_id"]
# Scaffold # Scaffold
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client: with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
await client.post( await client.post(
"/api/v1/ai/scaffold", "/api/v1/ai/scaffold",
json={"conversation_id": conversation_id}, json={"conversation_id": conversation_id},

View File

@@ -0,0 +1,169 @@
"""Integration tests for the POST /ai/fix-tree endpoint."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.core.config import settings
# ── Sample tree (has a decision node with only 1 option + 1 child) ──
SAMPLE_TREE = {
"id": "root",
"type": "decision",
"question": "Is the server up?",
"options": [
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
],
"children": [
{
"id": "check-logs",
"type": "action",
"title": "Check Logs",
"description": "Review logs.",
"next_node_id": "logs-ok",
},
{
"id": "logs-ok",
"type": "solution",
"title": "Logs OK",
"description": "Issue in logs.",
},
{
"id": "restart",
"type": "decision",
"question": "Did restart work?",
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
"children": [
{
"id": "done",
"type": "solution",
"title": "Done",
"description": "Fixed.",
}
],
},
],
}
# Fixed version of the "restart" node — 2 options, 2 children
FIXED_RESTART_NODE = {
"id": "restart",
"type": "decision",
"question": "Did restart work?",
"options": [
{"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"},
{"id": "opt-r-no", "label": "No", "next_node_id": "escalate"},
],
"children": [
{
"id": "done",
"type": "solution",
"title": "Done",
"description": "Fixed.",
},
{
"id": "escalate",
"type": "solution",
"title": "Escalate",
"description": "Escalate to senior engineer.",
},
],
}
FIX_REQUEST_BODY = {
"tree_structure": SAMPLE_TREE,
"tree_name": "Server Troubleshooting",
"tree_type": "troubleshooting",
"validation_errors": [
{
"node_id": "restart",
"message": "Decision node 'restart' must have at least 2 options",
}
],
}
def _mock_ai_provider(response_text: str, input_tokens: int = 50, output_tokens: int = 100):
"""Create a mock provider whose generate_json returns given text."""
provider = MagicMock()
provider.generate_json = AsyncMock(return_value=(response_text, input_tokens, output_tokens))
return provider
@pytest.fixture
def enable_ai():
"""Temporarily enable AI by setting a fake API key."""
original = settings.GOOGLE_AI_API_KEY
settings.GOOGLE_AI_API_KEY = "test-key-fake"
yield
settings.GOOGLE_AI_API_KEY = original
@pytest.fixture
def disable_ai():
"""Ensure AI is disabled."""
orig_google = settings.GOOGLE_AI_API_KEY
orig_anthropic = settings.ANTHROPIC_API_KEY
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = None
yield
settings.GOOGLE_AI_API_KEY = orig_google
settings.ANTHROPIC_API_KEY = orig_anthropic
# ── Tests ──
@pytest.mark.asyncio
async def test_returns_401_without_auth(client):
"""POST /ai/fix-tree without auth token returns 401."""
response = await client.post("/api/v1/ai/fix-tree", json=FIX_REQUEST_BODY)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_returns_503_when_ai_disabled(client, auth_headers, disable_ai):
"""POST /ai/fix-tree returns 503 when no AI keys are configured."""
response = await client.post(
"/api/v1/ai/fix-tree",
json=FIX_REQUEST_BODY,
headers=auth_headers,
)
assert response.status_code == 503
@pytest.mark.asyncio
async def test_returns_fixes_on_success(client, auth_headers, enable_ai):
"""POST /ai/fix-tree returns fix proposals when AI succeeds."""
mock_provider = _mock_ai_provider(json.dumps(FIXED_RESTART_NODE))
with patch(
"app.core.ai_fix_service.get_ai_provider",
return_value=mock_provider,
):
response = await client.post(
"/api/v1/ai/fix-tree",
json=FIX_REQUEST_BODY,
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "fixes" in data
assert "tokens_used" in data
assert len(data["fixes"]) == 1
fix = data["fixes"][0]
assert fix["target_node_id"] == "restart"
assert fix["error_message"] == "Decision node 'restart' must have at least 2 options"
assert fix["original_node"]["id"] == "restart"
assert fix["fixed_node"]["id"] == "restart"
assert len(fix["fixed_node"]["options"]) == 2
assert len(fix["fixed_node"]["children"]) == 2
assert data["tokens_used"]["input"] == 50
assert data["tokens_used"]["output"] == 100

View File

@@ -0,0 +1,224 @@
"""Unit tests for AI fix service helper functions.
Tests pure Python helpers only — no AI mocking needed.
"""
import pytest
from app.core.ai_fix_service import (
_find_node_by_id,
_find_parent_node,
_serialize_tree_outline,
_strip_markdown_fences,
_replace_node_in_tree,
_describe_fix,
)
# ── Sample tree ──
SAMPLE_TREE = {
"id": "root",
"type": "decision",
"question": "Is the server up?",
"options": [
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
],
"children": [
{
"id": "check-logs",
"type": "action",
"title": "Check Logs",
"description": "Review logs.",
"next_node_id": "logs-ok",
},
{
"id": "logs-ok",
"type": "solution",
"title": "Logs OK",
"description": "Issue in logs.",
},
{
"id": "restart",
"type": "decision",
"question": "Did restart work?",
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
"children": [
{
"id": "done",
"type": "solution",
"title": "Done",
"description": "Fixed.",
}
],
},
],
}
# ── _find_node_by_id ──
class TestFindNodeById:
def test_finds_root(self):
node = _find_node_by_id(SAMPLE_TREE, "root")
assert node is not None
assert node["id"] == "root"
assert node["type"] == "decision"
def test_finds_nested_child(self):
node = _find_node_by_id(SAMPLE_TREE, "done")
assert node is not None
assert node["id"] == "done"
assert node["type"] == "solution"
def test_finds_direct_child(self):
node = _find_node_by_id(SAMPLE_TREE, "check-logs")
assert node is not None
assert node["title"] == "Check Logs"
def test_returns_none_for_missing(self):
node = _find_node_by_id(SAMPLE_TREE, "nonexistent")
assert node is None
def test_returns_none_for_non_dict(self):
assert _find_node_by_id("not a dict", "root") is None
# ── _find_parent_node ──
class TestFindParentNode:
def test_root_has_no_parent(self):
parent = _find_parent_node(SAMPLE_TREE, "root")
assert parent is None
def test_finds_parent_of_direct_child(self):
parent = _find_parent_node(SAMPLE_TREE, "check-logs")
assert parent is not None
assert parent["id"] == "root"
def test_finds_parent_of_deeply_nested(self):
parent = _find_parent_node(SAMPLE_TREE, "done")
assert parent is not None
assert parent["id"] == "restart"
def test_returns_none_for_missing(self):
parent = _find_parent_node(SAMPLE_TREE, "nonexistent")
assert parent is None
def test_returns_none_for_non_dict(self):
assert _find_parent_node("not a dict", "root") is None
# ── _serialize_tree_outline ──
class TestSerializeTreeOutline:
def test_produces_readable_outline(self):
outline = _serialize_tree_outline(SAMPLE_TREE)
assert "- [decision] Is the server up?" in outline
assert " - [action] Check Logs" in outline
assert " - [solution] Logs OK" in outline
assert " - [solution] Done" in outline
def test_marks_error_node(self):
outline = _serialize_tree_outline(SAMPLE_TREE, error_node_id="restart")
assert "<<< ERROR HERE" in outline
# Only the restart node should be marked
lines = outline.split("\n")
error_lines = [l for l in lines if "ERROR HERE" in l]
assert len(error_lines) == 1
assert "Did restart work?" in error_lines[0]
def test_no_error_marker_when_none(self):
outline = _serialize_tree_outline(SAMPLE_TREE)
assert "ERROR HERE" not in outline
def test_handles_non_dict(self):
assert _serialize_tree_outline("not a dict") == ""
def test_indentation_increases_with_depth(self):
outline = _serialize_tree_outline(SAMPLE_TREE)
lines = outline.split("\n")
# Root has no indentation
assert lines[0].startswith("- [decision]")
# Children have 2-space indent
child_lines = [l for l in lines if "Check Logs" in l]
assert child_lines[0].startswith(" - ")
# ── _strip_markdown_fences ──
class TestStripMarkdownFences:
def test_strips_json_fences(self):
text = '```json\n{"key": "value"}\n```'
assert _strip_markdown_fences(text) == '{"key": "value"}'
def test_strips_plain_fences(self):
text = '```\n{"key": "value"}\n```'
assert _strip_markdown_fences(text) == '{"key": "value"}'
def test_passes_through_plain_json(self):
text = '{"key": "value"}'
assert _strip_markdown_fences(text) == '{"key": "value"}'
# ── _replace_node_in_tree ──
class TestReplaceNodeInTree:
def test_replaces_root(self):
import copy
tree = copy.deepcopy(SAMPLE_TREE)
replacement = {"id": "root", "type": "decision", "question": "New question"}
assert _replace_node_in_tree(tree, "root", replacement) is True
assert tree["question"] == "New question"
assert "children" not in tree # cleared and replaced
def test_replaces_nested_node(self):
import copy
tree = copy.deepcopy(SAMPLE_TREE)
replacement = {"id": "done", "type": "solution", "title": "All Done", "description": "Complete."}
assert _replace_node_in_tree(tree, "done", replacement) is True
found = _find_node_by_id(tree, "done")
assert found["title"] == "All Done"
def test_returns_false_for_missing(self):
import copy
tree = copy.deepcopy(SAMPLE_TREE)
assert _replace_node_in_tree(tree, "nonexistent", {"id": "x"}) is False
# ── _describe_fix ──
class TestDescribeFix:
def test_describes_added_children(self):
original = {"id": "n1", "children": [{"id": "c1"}]}
fixed = {"id": "n1", "children": [{"id": "c1"}, {"id": "c2"}]}
desc = _describe_fix(original, fixed)
assert "1 child node" in desc
def test_describes_added_options(self):
original = {"id": "n1", "options": [{"id": "o1"}]}
fixed = {"id": "n1", "options": [{"id": "o1"}, {"id": "o2"}]}
desc = _describe_fix(original, fixed)
assert "1 option" in desc
def test_describes_added_next_node_id(self):
original = {"id": "n1", "type": "action"}
fixed = {"id": "n1", "type": "action", "next_node_id": "n2"}
desc = _describe_fix(original, fixed)
assert "next_node_id" in desc
def test_fallback_description(self):
original = {"id": "n1", "type": "solution"}
fixed = {"id": "n1", "type": "solution"}
desc = _describe_fix(original, fixed)
assert "fixed structural issue" in desc.lower()

View File

@@ -0,0 +1,216 @@
"""Tests for the AI provider abstraction layer."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import sys
from app.core.ai_provider import (
AIProvider,
AnthropicProvider,
GeminiProvider,
get_ai_provider,
)
from app.core.config import settings
class TestGetAIProvider:
"""Tests for the get_ai_provider factory function."""
def test_returns_gemini_when_configured(self):
original_provider = settings.AI_PROVIDER
original_key = settings.GOOGLE_AI_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
provider = get_ai_provider()
assert isinstance(provider, GeminiProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_key
def test_returns_anthropic_when_configured(self):
original_provider = settings.AI_PROVIDER
original_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "anthropic"
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
provider = get_ai_provider()
assert isinstance(provider, AnthropicProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.ANTHROPIC_API_KEY = original_key
def test_fallback_to_anthropic_when_gemini_key_missing(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
provider = get_ai_provider()
assert isinstance(provider, AnthropicProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
def test_fallback_to_gemini_when_anthropic_key_missing(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "anthropic"
settings.ANTHROPIC_API_KEY = None
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
provider = get_ai_provider()
assert isinstance(provider, GeminiProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
def test_raises_when_no_provider_configured(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = None
with pytest.raises(RuntimeError, match="No AI provider configured"):
get_ai_provider()
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
class TestAnthropicProvider:
"""Tests for AnthropicProvider.generate_json."""
@pytest.mark.asyncio
async def test_generate_json(self):
provider = AnthropicProvider(
api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30
)
mock_response = MagicMock()
mock_response.content = [MagicMock(text='{"result": "ok"}')]
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
mock_client = AsyncMock()
mock_client.messages.create = AsyncMock(return_value=mock_response)
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="You are a helper.",
messages=[{"role": "user", "content": "Hello"}],
max_tokens=1024,
)
assert text == '{"result": "ok"}'
assert input_tokens == 100
assert output_tokens == 50
mock_client.messages.create.assert_called_once_with(
model="claude-haiku-4-5-20251001",
max_tokens=1024,
system="You are a helper.",
messages=[{"role": "user", "content": "Hello"}],
)
class TestGeminiProvider:
"""Tests for GeminiProvider.generate_json."""
@pytest.mark.asyncio
async def test_generate_json(self):
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
mock_usage = MagicMock()
mock_usage.prompt_token_count = 80
mock_usage.candidates_token_count = 40
mock_response = MagicMock()
mock_response.text = '{"answer": 42}'
mock_response.usage_metadata = mock_usage
mock_client = MagicMock()
mock_client.aio.models.generate_content = AsyncMock(
return_value=mock_response
)
mock_genai_module = MagicMock()
mock_genai_module.Client.return_value = mock_client
mock_types = MagicMock()
mock_types.Content.side_effect = lambda **kw: kw
mock_types.Part.side_effect = lambda **kw: kw
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
mock_google = MagicMock()
mock_google.genai = mock_genai_module
mock_genai_module.types = mock_types
with patch.dict(sys.modules, {
"google": mock_google,
"google.genai": mock_genai_module,
"google.genai.types": mock_types,
}):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="Generate JSON.",
messages=[
{"role": "user", "content": "Give me data"},
{"role": "assistant", "content": "Here it is"},
{"role": "user", "content": "More please"},
],
max_tokens=2048,
)
assert text == '{"answer": 42}'
assert input_tokens == 80
assert output_tokens == 40
mock_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_generate_json_handles_none_usage(self):
"""Token counts default to 0 when usage_metadata attributes are None."""
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
mock_usage = MagicMock(spec=[]) # No attributes at all
mock_response = MagicMock()
mock_response.text = "{}"
mock_response.usage_metadata = mock_usage
mock_client = MagicMock()
mock_client.aio.models.generate_content = AsyncMock(
return_value=mock_response
)
mock_genai_module = MagicMock()
mock_genai_module.Client.return_value = mock_client
mock_types = MagicMock()
mock_types.Content.side_effect = lambda **kw: kw
mock_types.Part.side_effect = lambda **kw: kw
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
mock_google = MagicMock()
mock_google.genai = mock_genai_module
mock_genai_module.types = mock_types
with patch.dict(sys.modules, {
"google": mock_google,
"google.genai": mock_genai_module,
"google.genai.types": mock_types,
}):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="test",
messages=[{"role": "user", "content": "test"}],
)
assert text == "{}"
assert input_tokens == 0
assert output_tokens == 0

View File

@@ -0,0 +1,209 @@
# AI Auto-Fix & Gemini Flash Provider Design
> **Date:** 2026-02-26
> **Status:** Approved
---
## Overview
Two combined features:
1. **AI Provider Abstraction** — Add Gemini 2.5 Flash as the default AI provider with Claude as fallback, behind a unified interface.
2. **AI Auto-Fix for Validation Errors** — When a flow fails validation, offer an AI-powered "Fix with AI" button that generates structural fixes for review.
---
## Section 1: AI Provider Abstraction
### Design
New `backend/app/core/ai_provider.py` with a unified interface:
```python
class AIProvider(ABC):
async def generate_json(
self,
system_prompt: str,
messages: list[dict],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Returns (text, input_tokens, output_tokens)"""
```
Two implementations:
| Provider | Model | SDK | Role |
|----------|-------|-----|------|
| `GeminiProvider` | `gemini-2.5-flash` | `google-genai` | Default |
| `AnthropicProvider` | `claude-haiku-4-5-20251001` | `anthropic` | Fallback |
### Provider Selection
- `get_ai_provider()` factory reads `AI_PROVIDER` env var (default: `"gemini"`)
- Falls back to Anthropic if Gemini key is missing
- Existing `ai_tree_generator_service.py` swaps direct Anthropic calls for `get_ai_provider()`
### New Environment Variables
| Variable | Default | Purpose |
|----------|---------|---------|
| `AI_PROVIDER` | `"gemini"` | Which provider to use (`gemini` or `anthropic`) |
| `GOOGLE_AI_API_KEY` | — | Gemini API key |
Existing `ANTHROPIC_API_KEY` remains for fallback.
### Config Changes (`core/config.py`)
```python
AI_PROVIDER: str = "gemini"
GOOGLE_AI_API_KEY: str | None = None
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
```
---
## Section 2: AI Auto-Fix Feature
### Backend Endpoint
**`POST /api/v1/ai/fix-tree`**
Request:
```json
{
"tree_structure": { /* full tree */ },
"tree_name": "Router Troubleshooting",
"tree_type": "troubleshooting",
"validation_errors": [
{
"node_id": "node_abc",
"message": "Decision node must have at least 2 children (branches)"
}
]
}
```
Response:
```json
{
"fixes": [
{
"target_node_id": "node_abc",
"error_message": "Decision node must have at least 2 children (branches)",
"description": "Added second branch 'Check firmware version' with solution node",
"original_node": { /* snapshot before fix */ },
"fixed_node": { /* replacement node with corrected subtree */ }
}
],
"tokens_used": { "input": 1200, "output": 800 }
}
```
### How It Works
1. For each validation error tied to a `node_id`, extract that node + its parent + siblings from the tree.
2. Build a prompt with:
- The **full tree structure** serialized as a simplified outline (node titles + types + structure) for context
- The **specific failing node** highlighted with full JSON detail
- The **validation error message**
- Instructions: "Fix ONLY this node's structural issue. Keep all existing content. Generate domain-relevant additions that fit the flow's topic."
3. AI returns a corrected version of that node (with children/options adjusted).
4. Backend re-validates the fixed node before returning it.
5. If re-validation fails, retry once with the error fed back (corrective prompt pattern).
### Prompt Strategy
The prompt gives the AI the full tree as a compact outline, then zooms into the failing node:
```
You are fixing a validation error in a troubleshooting flow called "Router Troubleshooting".
FULL FLOW OUTLINE:
- [decision] Is the router powered on?
- [action] Check power cable → [solution] Power restored
- [decision] Are lights blinking? ← ERROR HERE
- [solution] Contact ISP
ERROR: Decision node "Are lights blinking?" must have at least 2 children (branches).
FAILING NODE (full detail):
{...json...}
Fix this node by adding the minimum structure needed to resolve the error.
Return ONLY the fixed node as JSON.
```
### Frontend UX
1. **Trigger**: "Fix with AI" button in `ValidationSummary` — appears when there are fixable errors (structural errors with a `node_id`).
2. **Loading state**: Button shows spinner + "Generating fixes..." — disabled during request.
3. **Review modal** (`AIFixReviewModal`): Shows each proposed fix as a card:
- Error message at top
- Before/after view of the node change
- "Apply" / "Skip" buttons per fix
- "Apply All" button in footer
4. **Apply**: Each accepted fix calls `updateNode(targetNodeId, fixedNode)` in the tree editor store.
5. **Re-validate**: After applying fixes, auto-run `validate()` to confirm resolution.
---
## Section 3: Scope & Constraints
### Fixable Errors (Auto-Fix Scope)
Only structural validation errors with a `node_id`:
- Decision node missing children/branches
- Decision node missing options
- Action node missing `next_node_id`
- Dead-end decision nodes (no children)
### NOT Fixable
- Global checks (tree too small/large, not enough solutions) — require rethinking the whole tree
- Content quality issues — out of scope
- Errors without a `node_id` (root-level issues)
Non-fixable errors still show in ValidationSummary but without the "Fix with AI" option.
### Token Budget
- Tree outline: ~50-100 tokens for a typical 15-node tree
- Failing node detail: ~100-200 tokens
- System prompt + instructions: ~300 tokens
- **Total input per fix: ~500-600 tokens**
- One API call per failing node (not batched)
### Error Handling
- Provider failure (rate limit, network): toast error, user can retry
- Fix fails re-validation: "AI couldn't generate a valid fix" with retry option
- Max 1 retry with corrective prompt per attempt
- Both provider and fallback fail: surface error to user
### Auth
- Requires `engineer` role or above (`require_engineer_or_admin`)
---
## New Files
| File | Purpose |
|------|---------|
| `backend/app/core/ai_provider.py` | Provider abstraction + Gemini/Anthropic implementations |
| `backend/app/core/ai_fix_service.py` | Fix generation logic + prompt building |
| `backend/app/api/endpoints/ai.py` | `POST /ai/fix-tree` endpoint |
| `backend/app/schemas/ai.py` | Request/response schemas for AI endpoints |
| `frontend/src/components/tree-editor/AIFixReviewModal.tsx` | Review modal for proposed fixes |
## Modified Files
| File | Change |
|------|--------|
| `backend/app/core/config.py` | Add Gemini config vars |
| `backend/app/core/ai_tree_generator_service.py` | Swap Anthropic calls for provider abstraction |
| `backend/app/api/router.py` | Register `/ai` routes |
| `frontend/src/api/trees.ts` | Add `fixTree()` API call |
| `frontend/src/components/tree-editor/ValidationSummary.tsx` | Add "Fix with AI" button |

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
import apiClient from './client' import apiClient from './client'
import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse } from '@/types' import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse, AIFixTreeRequest, AIFixTreeResponse } from '@/types'
export const treesApi = { export const treesApi = {
async list(params?: TreeFilters): Promise<TreeListItem[]> { async list(params?: TreeFilters): Promise<TreeListItem[]> {
@@ -65,6 +65,12 @@ export const treesApi = {
const response = await apiClient.post<TreeValidationResponse>(`/trees/${id}/can-publish`) const response = await apiClient.post<TreeValidationResponse>(`/trees/${id}/can-publish`)
return response.data return response.data
}, },
// AI auto-fix
async fixTree(request: AIFixTreeRequest): Promise<AIFixTreeResponse> {
const response = await apiClient.post<AIFixTreeResponse>('/ai/fix-tree', request)
return response.data
},
} }
export default treesApi export default treesApi

View File

@@ -0,0 +1,170 @@
import { useState } from 'react'
import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react'
import { cn } from '@/lib/utils'
import type { AIFixProposal } from '@/types'
interface AIFixReviewModalProps {
fixes: AIFixProposal[]
onApply: (fix: AIFixProposal) => void
onApplyAll: () => void
onClose: () => void
}
export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) {
const [appliedIds, setAppliedIds] = useState<Set<string>>(new Set())
const [skippedIds, setSkippedIds] = useState<Set<string>>(new Set())
const [expandedIds, setExpandedIds] = useState<Set<string>>(new Set(fixes.map(f => f.target_node_id)))
const handleApply = (fix: AIFixProposal) => {
onApply(fix)
setAppliedIds(prev => new Set(prev).add(fix.target_node_id))
}
const handleSkip = (fix: AIFixProposal) => {
setSkippedIds(prev => new Set(prev).add(fix.target_node_id))
}
const toggleExpanded = (id: string) => {
setExpandedIds(prev => {
const next = new Set(prev)
if (next.has(id)) next.delete(id)
else next.add(id)
return next
})
}
const pendingFixes = fixes.filter(
f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id)
)
const allHandled = pendingFixes.length === 0
return (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/80 backdrop-blur-sm p-4">
<div className="relative flex h-[80vh] w-full max-w-2xl flex-col bg-card border border-border rounded-2xl shadow-lg">
{/* Header */}
<div className="flex items-center justify-between border-b border-border px-6 py-4">
<div className="flex items-center gap-2">
<Sparkles className="h-5 w-5 text-primary" />
<h2 className="text-lg font-semibold text-foreground">
AI Fix Proposals ({fixes.length})
</h2>
</div>
<button
onClick={onClose}
className="rounded-md p-1 text-muted-foreground hover:bg-accent hover:text-foreground"
>
<X className="h-5 w-5" />
</button>
</div>
{/* Body */}
<div className="flex-1 overflow-y-auto p-4 space-y-3">
{fixes.map((fix) => {
const isApplied = appliedIds.has(fix.target_node_id)
const isSkipped = skippedIds.has(fix.target_node_id)
const isExpanded = expandedIds.has(fix.target_node_id)
return (
<div
key={fix.target_node_id}
className={cn(
'rounded-lg border p-4',
isApplied
? 'border-emerald-400/30 bg-emerald-400/5'
: isSkipped
? 'border-border bg-accent/30 opacity-60'
: 'border-border bg-card'
)}
>
{/* Fix header */}
<div className="flex items-start justify-between gap-3">
<div className="flex-1">
<p className="text-sm text-red-400 mb-1">{fix.error_message}</p>
<p className="text-sm text-foreground">{fix.description}</p>
<p className="text-xs text-muted-foreground mt-1">
Node: {fix.target_node_id}
</p>
</div>
{isApplied && (
<span className="flex items-center gap-1 rounded-full bg-emerald-400/10 px-2 py-1 text-xs text-emerald-400">
<Check className="h-3 w-3" /> Applied
</span>
)}
{isSkipped && (
<span className="text-xs text-muted-foreground">Skipped</span>
)}
</div>
{/* Expand/collapse detail */}
{!isApplied && !isSkipped && (
<>
<button
onClick={() => toggleExpanded(fix.target_node_id)}
className="mt-2 flex items-center gap-1 text-xs text-muted-foreground hover:text-foreground"
>
{isExpanded ? <ChevronUp className="h-3 w-3" /> : <ChevronDown className="h-3 w-3" />}
{isExpanded ? 'Hide' : 'Show'} details
</button>
{isExpanded && (
<div className="mt-3 grid grid-cols-2 gap-3">
<div>
<p className="text-xs font-medium text-muted-foreground mb-1">Before</p>
<pre className="overflow-x-auto rounded bg-accent/50 p-2 text-xs text-muted-foreground max-h-48 overflow-y-auto">
{JSON.stringify(fix.original_node, null, 2)}
</pre>
</div>
<div>
<p className="text-xs font-medium text-emerald-400 mb-1">After</p>
<pre className="overflow-x-auto rounded bg-emerald-400/5 p-2 text-xs text-foreground max-h-48 overflow-y-auto">
{JSON.stringify(fix.fixed_node, null, 2)}
</pre>
</div>
</div>
)}
{/* Action buttons */}
<div className="mt-3 flex gap-2">
<button
onClick={() => handleApply(fix)}
className="flex items-center gap-1 rounded-md bg-gradient-brand px-3 py-1.5 text-xs font-medium text-white shadow-sm shadow-primary/20 hover:opacity-90"
>
<Check className="h-3 w-3" />
Apply
</button>
<button
onClick={() => handleSkip(fix)}
className="flex items-center gap-1 rounded-md border border-border px-3 py-1.5 text-xs font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
>
<SkipForward className="h-3 w-3" />
Skip
</button>
</div>
</>
)}
</div>
)
})}
</div>
{/* Footer */}
<div className="flex items-center justify-between border-t border-border px-6 py-4">
<button
onClick={onClose}
className="rounded-md border border-border px-4 py-2 text-sm font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
>
{allHandled ? 'Done' : 'Cancel'}
</button>
{!allHandled && (
<button
onClick={onApplyAll}
className="rounded-md bg-gradient-brand px-4 py-2 text-sm font-medium text-white shadow-lg shadow-primary/20 hover:opacity-90"
>
Apply All ({pendingFixes.length})
</button>
)}
</div>
</div>
</div>
)
}

View File

@@ -1,14 +1,16 @@
import { useState } from 'react' import { useState } from 'react'
import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp } from 'lucide-react' import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import type { ValidationError } from '@/store/treeEditorStore' import type { ValidationError } from '@/store/treeEditorStore'
interface ValidationSummaryProps { interface ValidationSummaryProps {
errors: ValidationError[] errors: ValidationError[]
onSelectNode: (nodeId: string) => void onSelectNode: (nodeId: string) => void
onFixWithAI?: () => void
isFixing?: boolean
} }
export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryProps) { export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) {
const [isExpanded, setIsExpanded] = useState(true) const [isExpanded, setIsExpanded] = useState(true)
const errorItems = errors.filter(e => e.severity === 'error') const errorItems = errors.filter(e => e.severity === 'error')
@@ -22,6 +24,8 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
} }
} }
const hasFixableErrors = errorItems.some(e => e.nodeId)
return ( return (
<div <div
className={cn( className={cn(
@@ -32,14 +36,16 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
)} )}
> >
{/* Header */} {/* Header */}
<button <div
onClick={() => setIsExpanded(!isExpanded)}
className={cn( className={cn(
'flex w-full items-center justify-between p-3 text-left transition-colors hover:bg-accent', 'flex w-full items-center justify-between p-3 transition-colors',
errorItems.length > 0 ? 'text-red-400' : 'text-yellow-400' errorItems.length > 0 ? 'text-red-400' : 'text-yellow-400'
)} )}
> >
<div className="flex items-center gap-2"> <button
onClick={() => setIsExpanded(!isExpanded)}
className="flex items-center gap-2 text-left hover:opacity-80"
>
{errorItems.length > 0 ? ( {errorItems.length > 0 ? (
<AlertCircle className="h-5 w-5" /> <AlertCircle className="h-5 w-5" />
) : ( ) : (
@@ -58,9 +64,35 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
</> </>
)} )}
</span> </span>
</div> {isExpanded ? <ChevronUp className="h-4 w-4" /> : <ChevronDown className="h-4 w-4" />}
{isExpanded ? <ChevronUp className="h-4 w-4" /> : <ChevronDown className="h-4 w-4" />} </button>
</button>
{/* Fix with AI button */}
{onFixWithAI && hasFixableErrors && (
<button
onClick={onFixWithAI}
disabled={isFixing}
className={cn(
'flex items-center gap-1.5 rounded-md px-3 py-1 text-xs font-medium transition-colors',
isFixing
? 'bg-primary/10 text-primary cursor-wait'
: 'bg-gradient-brand text-white shadow-sm shadow-primary/20 hover:opacity-90'
)}
>
{isFixing ? (
<>
<Loader2 className="h-3 w-3 animate-spin" />
Generating fixes...
</>
) : (
<>
<Sparkles className="h-3 w-3" />
Fix with AI
</>
)}
</button>
)}
</div>
{/* Error/Warning List */} {/* Error/Warning List */}
{isExpanded && ( {isExpanded && (

View File

@@ -5,10 +5,11 @@ import { Undo2, Redo2, Save, CheckCircle2, Monitor, FileText, Code2, LayoutList,
import { getMonacoEditor } from '@/components/tree-editor/code-mode' import { getMonacoEditor } from '@/components/tree-editor/code-mode'
import { treesApi } from '@/api/trees' import { treesApi } from '@/api/trees'
import { treeMarkdownApi } from '@/api/treeMarkdown' import { treeMarkdownApi } from '@/api/treeMarkdown'
import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure } from '@/types' import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure, AIFixProposal } from '@/types'
import { useTreeEditorStore, useTreeEditorTemporal } from '@/store/treeEditorStore' import { useTreeEditorStore, useTreeEditorTemporal } from '@/store/treeEditorStore'
import { TreeEditorLayout } from '@/components/tree-editor/TreeEditorLayout' import { TreeEditorLayout } from '@/components/tree-editor/TreeEditorLayout'
import { ValidationSummary } from '@/components/tree-editor/ValidationSummary' import { ValidationSummary } from '@/components/tree-editor/ValidationSummary'
import { AIFixReviewModal } from '@/components/tree-editor/AIFixReviewModal'
import { useKeyboardShortcuts } from '@/hooks/useKeyboardShortcuts' import { useKeyboardShortcuts } from '@/hooks/useKeyboardShortcuts'
import { usePermissions } from '@/hooks/usePermissions' import { usePermissions } from '@/hooks/usePermissions'
import { Spinner } from '@/components/common/Spinner' import { Spinner } from '@/components/common/Spinner'
@@ -58,6 +59,8 @@ export function TreeEditorPage() {
const [showAnalytics, setShowAnalytics] = useState(false) const [showAnalytics, setShowAnalytics] = useState(false)
const [isMetadataOpen, setIsMetadataOpen] = useState(false) const [isMetadataOpen, setIsMetadataOpen] = useState(false)
const [editingNodeId, setEditingNodeId] = useState<string | null>(null) const [editingNodeId, setEditingNodeId] = useState<string | null>(null)
const [isFixing, setIsFixing] = useState(false)
const [fixProposals, setFixProposals] = useState<AIFixProposal[] | null>(null)
// Mobile detection // Mobile detection
const [isMobile, setIsMobile] = useState(false) const [isMobile, setIsMobile] = useState(false)
@@ -217,6 +220,54 @@ export function TreeEditorPage() {
selectNode(nodeId) selectNode(nodeId)
} }
const handleFixWithAI = async () => {
const store = useTreeEditorStore.getState()
if (!store.treeStructure) return
const fixableErrors = store.validationErrors
.filter(e => e.severity === 'error' && e.nodeId)
.map(e => ({ node_id: e.nodeId!, message: e.message }))
if (fixableErrors.length === 0) return
setIsFixing(true)
try {
const result = await treesApi.fixTree({
tree_structure: store.treeStructure as unknown as Record<string, unknown>,
tree_name: store.name,
tree_type: 'troubleshooting',
validation_errors: fixableErrors,
})
if (result.fixes.length > 0) {
setFixProposals(result.fixes)
} else {
toast.info('AI could not generate fixes for these errors')
}
} catch {
toast.error('Failed to generate AI fixes. Please try again.')
} finally {
setIsFixing(false)
}
}
const handleApplyFix = (fix: AIFixProposal) => {
updateNode(fix.target_node_id, fix.fixed_node as Partial<TreeStructure>)
}
const handleApplyAllFixes = () => {
if (!fixProposals) return
for (const fix of fixProposals) {
handleApplyFix(fix)
}
setFixProposals(null)
setTimeout(() => { validate() }, 100)
}
const handleCloseFixModal = () => {
setFixProposals(null)
validate()
}
const handleNodeSelect = useCallback((nodeId: string | null) => { const handleNodeSelect = useCallback((nodeId: string | null) => {
if (nodeId) { if (nodeId) {
setIsMetadataOpen(false) // close metadata when opening node editor setIsMetadataOpen(false) // close metadata when opening node editor
@@ -685,6 +736,8 @@ export function TreeEditorPage() {
<ValidationSummary <ValidationSummary
errors={validationErrors} errors={validationErrors}
onSelectNode={handleSelectNode} onSelectNode={handleSelectNode}
onFixWithAI={handleFixWithAI}
isFixing={isFixing}
/> />
</div> </div>
)} )}
@@ -705,6 +758,16 @@ export function TreeEditorPage() {
<FlowAnalyticsPanel treeId={id} /> <FlowAnalyticsPanel treeId={id} />
</div> </div>
)} )}
{/* AI Fix Review Modal */}
{fixProposals && (
<AIFixReviewModal
fixes={fixProposals}
onApply={handleApplyFix}
onApplyAll={handleApplyAllFixes}
onClose={handleCloseFixModal}
/>
)}
</div> </div>
) )
} }

View File

@@ -0,0 +1,24 @@
export interface AIFixValidationError {
node_id: string
message: string
}
export interface AIFixProposal {
target_node_id: string
error_message: string
description: string
original_node: Record<string, unknown>
fixed_node: Record<string, unknown>
}
export interface AIFixTreeRequest {
tree_structure: Record<string, unknown>
tree_name: string
tree_type: 'troubleshooting' | 'procedural' | 'maintenance'
validation_errors: AIFixValidationError[]
}
export interface AIFixTreeResponse {
fixes: AIFixProposal[]
tokens_used: { input: number; output: number }
}

View File

@@ -45,3 +45,10 @@ export type {
AIAssembleResponse, AIAssembleResponse,
AIWizardPhase, AIWizardPhase,
} from './ai' } from './ai'
export type {
AIFixTreeRequest,
AIFixTreeResponse,
AIFixProposal,
AIFixValidationError,
} from './ai-fix'