feat: AI auto-fix + Gemini Flash provider #93

Merged
chihlasm merged 14 commits from feat/ai-autofix-gemini into main 2026-02-27 07:32:24 +00:00
21 changed files with 3516 additions and 120 deletions

View File

@@ -10,7 +10,6 @@
import logging
from typing import Annotated
import anthropic
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.ext.asyncio import AsyncSession
@@ -52,10 +51,9 @@ def _require_ai_enabled() -> None:
if not settings.ai_enabled:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.",
detail="AI flow builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
)
@router.get("/quota", response_model=AIQuotaStatusResponse)
async def get_quota(
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -174,27 +172,6 @@ async def scaffold(
branches, input_tokens, output_tokens, cost = await scaffold_branches(
conversation.wizard_state,
)
except anthropic.APIError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="scaffold",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e)},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="AI provider error. Please try again.",
)
except ValueError as e:
await record_ai_usage(
user_id=current_user.id,
@@ -216,6 +193,28 @@ async def scaffold(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"AI returned invalid output: {e}",
)
except Exception as e:
logger.exception("AI scaffold failed: %s: %s", type(e).__name__, e)
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="scaffold",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e)},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"AI provider error ({type(e).__name__}). Please try again.",
)
# Record successful usage
await record_ai_usage(
@@ -293,27 +292,6 @@ async def branch_detail(
existing_branches,
)
)
except anthropic.APIError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="branch_detail",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e), "branch_name": data.branch_name},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="AI provider error. Please try again.",
)
except ValueError as e:
await record_ai_usage(
user_id=current_user.id,
@@ -335,6 +313,28 @@ async def branch_detail(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"AI returned invalid output: {e}",
)
except Exception as e:
logger.exception("AI branch_detail failed: %s: %s", type(e).__name__, e)
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="branch_detail",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e), "branch_name": data.branch_name},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"AI provider error ({type(e).__name__}). Please try again.",
)
# Record successful usage
await record_ai_usage(

View File

@@ -0,0 +1,78 @@
"""AI auto-fix endpoint for tree validation errors.
POST /ai/fix-tree — accepts a tree with validation errors and returns
AI-generated fix proposals for each error.
"""
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, require_engineer_or_admin
from app.core.config import settings
from app.core.rate_limit import limiter
from app.core.ai_fix_service import generate_fixes
from app.models.user import User
from app.schemas.ai_fix import (
AIFixTreeRequest,
AIFixTreeResponse,
AIFixProposal,
AIFixTokenUsage,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ai", tags=["ai-fix"])
def _require_ai_enabled() -> None:
"""Raise 503 if AI is not configured."""
if not settings.ai_enabled:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI fix is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
)
@router.post("/fix-tree", response_model=AIFixTreeResponse)
@limiter.limit("10/minute")
async def fix_tree(
request: Request,
body: AIFixTreeRequest,
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Generate AI-powered fixes for tree validation errors."""
_require_ai_enabled()
validation_errors = [
{"node_id": e.node_id, "message": e.message}
for e in body.validation_errors
]
try:
fixes, input_tokens, output_tokens = await generate_fixes(
tree_structure=body.tree_structure,
tree_name=body.tree_name,
tree_type=body.tree_type,
validation_errors=validation_errors,
)
except RuntimeError as exc:
logger.error("AI provider not available: %s", exc)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=str(exc),
)
except Exception as exc:
logger.exception("Unexpected error in AI fix service")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred while generating fixes.",
)
return AIFixTreeResponse(
fixes=[AIFixProposal(**f) for f in fixes],
tokens_used=AIFixTokenUsage(input=input_tokens, output=output_tokens),
)

View File

@@ -6,6 +6,7 @@ from app.api.endpoints import target_lists
from app.api.endpoints import maintenance_schedules
from app.api.endpoints import feedback
from app.api.endpoints import ai_builder
from app.api.endpoints import ai_fix
api_router = APIRouter()
@@ -36,3 +37,4 @@ api_router.include_router(target_lists.router)
api_router.include_router(maintenance_schedules.router)
api_router.include_router(feedback.router)
api_router.include_router(ai_builder.router)
api_router.include_router(ai_fix.router)

View File

@@ -0,0 +1,273 @@
"""AI-powered fix service for tree validation errors.
Given a tree structure and validation errors, generates AI-powered
proposals to fix each structural issue while preserving existing content.
"""
import copy
import json
import logging
import re
from typing import Any
from app.core.ai_provider import get_ai_provider
from app.core.ai_tree_validator import validate_generated_tree
logger = logging.getLogger(__name__)
FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers.
You will receive:
1. A full flow outline showing the tree structure
2. The specific failing node with its full JSON
3. The validation error message
Your task: Return a FIXED version of the failing node as valid JSON. Rules:
- Fix ONLY the structural issue described in the error message
- Keep ALL existing content (titles, descriptions, questions, options) unchanged
- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic
- Every new node must have a unique ID (use descriptive kebab-case IDs)
- Decision nodes must have at least 2 options and at least 2 children
- Action nodes must have a next_node_id pointing to a sibling node in the parent's children
- Solution nodes are leaf nodes (no children)
- Return ONLY the fixed node JSON, no explanation"""
# ── Pure helper functions ──
def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None:
"""Recursively find a node by its ID in the tree structure."""
if not isinstance(tree, dict):
return None
if tree.get("id") == node_id:
return tree
for child in tree.get("children", []):
result = _find_node_by_id(child, node_id)
if result is not None:
return result
return None
def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None:
"""Find the parent of a node with the given ID."""
if not isinstance(tree, dict):
return None
for child in tree.get("children", []):
if isinstance(child, dict) and child.get("id") == target_id:
return tree
result = _find_parent_node(child, target_id)
if result is not None:
return result
return None
def _serialize_tree_outline(
tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None
) -> str:
"""Serialize tree as a readable outline for AI prompt context.
Format: indented "- [type] label" with "<<< ERROR HERE" marker.
"""
if not isinstance(tree, dict):
return ""
node_type = tree.get("type", "unknown")
label = tree.get("question") or tree.get("title") or tree.get("id", "?")
prefix = " " * indent
marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else ""
line = f"{prefix}- [{node_type}] {label}{marker}"
lines = [line]
for child in tree.get("children", []):
lines.append(_serialize_tree_outline(child, indent + 1, error_node_id))
return "\n".join(lines)
def _strip_markdown_fences(text: str) -> str:
"""Strip ```json...``` fences from AI response."""
return re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.MULTILINE).rstrip(
"`"
).strip()
def _replace_node_in_tree(
tree: dict[str, Any], target_id: str, replacement: dict[str, Any]
) -> bool:
"""Replace a node in-place by ID. Returns True if found and replaced."""
if not isinstance(tree, dict):
return False
if tree.get("id") == target_id:
tree.clear()
tree.update(replacement)
return True
for child in tree.get("children", []):
if _replace_node_in_tree(child, target_id, replacement):
return True
return False
def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str:
"""Describe what changed between original and fixed node."""
changes: list[str] = []
orig_children = len(original.get("children", []))
fixed_children = len(fixed.get("children", []))
if fixed_children > orig_children:
changes.append(f"added {fixed_children - orig_children} child node(s)")
orig_options = len(original.get("options", []))
fixed_options = len(fixed.get("options", []))
if fixed_options > orig_options:
changes.append(f"added {fixed_options - orig_options} option(s)")
if fixed.get("next_node_id") and not original.get("next_node_id"):
changes.append("added next_node_id")
if not changes:
changes.append("fixed structural issue")
return "; ".join(changes).capitalize()
# ── Prompt building ──
def _build_fix_prompt(
tree: dict[str, Any],
node_id: str,
error_message: str,
tree_name: str,
tree_type: str,
) -> str:
"""Build the user message for the AI fix request."""
outline = _serialize_tree_outline(tree, error_node_id=node_id)
node = _find_node_by_id(tree, node_id)
node_json = json.dumps(node, indent=2) if node else "{}"
return (
f"Flow name: {tree_name}\n"
f"Flow type: {tree_type}\n\n"
f"## Full flow outline\n```\n{outline}\n```\n\n"
f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n"
f"## Validation error\n{error_message}\n\n"
f"Return the fixed version of this node as JSON."
)
# ── Main entry point ──
async def generate_fixes(
tree_structure: dict[str, Any],
tree_name: str,
tree_type: str,
validation_errors: list[dict[str, str]],
) -> tuple[list[dict[str, Any]], int, int]:
"""Generate AI-powered fixes for tree validation errors.
Args:
tree_structure: Full tree structure dict.
tree_name: Name of the flow.
tree_type: Type of flow (troubleshooting, procedural, maintenance).
validation_errors: List of dicts with "node_id" and "message" keys.
Returns:
Tuple of (fixes_list, total_input_tokens, total_output_tokens).
Each fix dict has: target_node_id, error_message, description,
original_node, fixed_node.
"""
provider = get_ai_provider()
fixes: list[dict[str, Any]] = []
total_input_tokens = 0
total_output_tokens = 0
for error in validation_errors:
node_id = error["node_id"]
error_message = error["message"]
original_node = _find_node_by_id(tree_structure, node_id)
if original_node is None:
logger.warning("Node %s not found in tree, skipping fix", node_id)
continue
original_snapshot = copy.deepcopy(original_node)
# Build prompt and call AI
user_message = _build_fix_prompt(
tree_structure, node_id, error_message, tree_name, tree_type
)
messages = [{"role": "user", "content": user_message}]
try:
text, in_tok, out_tok = await provider.generate_json(
system_prompt=FIX_SYSTEM_PROMPT,
messages=messages,
max_tokens=2048,
)
total_input_tokens += in_tok
total_output_tokens += out_tok
cleaned = _strip_markdown_fences(text)
fixed_node = json.loads(cleaned)
except (json.JSONDecodeError, Exception) as exc:
logger.warning("AI fix failed for node %s: %s", node_id, exc)
continue
# Validate by substituting into a tree copy
tree_copy = copy.deepcopy(tree_structure)
_replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node))
remaining_errors = validate_generated_tree(tree_copy)
# Check if the specific error is still present
still_has_error = any(node_id in e for e in remaining_errors)
if still_has_error:
# Retry once with corrective prompt
retry_message = (
f"Your previous fix still has validation errors:\n"
f"{chr(10).join(remaining_errors)}\n\n"
f"Please fix the node again. Return ONLY the corrected JSON."
)
messages.append({"role": "assistant", "content": text})
messages.append({"role": "user", "content": retry_message})
try:
text2, in_tok2, out_tok2 = await provider.generate_json(
system_prompt=FIX_SYSTEM_PROMPT,
messages=messages,
max_tokens=2048,
)
total_input_tokens += in_tok2
total_output_tokens += out_tok2
cleaned2 = _strip_markdown_fences(text2)
fixed_node = json.loads(cleaned2)
# Re-validate
tree_copy2 = copy.deepcopy(tree_structure)
_replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node))
remaining2 = validate_generated_tree(tree_copy2)
still_has_error = any(node_id in e for e in remaining2)
except (json.JSONDecodeError, Exception) as exc:
logger.warning("AI retry fix failed for node %s: %s", node_id, exc)
continue
if still_has_error:
logger.warning("AI could not fix node %s after retry", node_id)
continue
description = _describe_fix(original_snapshot, fixed_node)
fixes.append(
{
"target_node_id": node_id,
"error_message": error_message,
"description": description,
"original_node": original_snapshot,
"fixed_node": fixed_node,
}
)
return fixes, total_input_tokens, total_output_tokens

View File

@@ -0,0 +1,175 @@
"""
AI Provider abstraction layer.
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
backends for JSON generation used by the AI Flow Builder.
"""
import logging
from abc import ABC, abstractmethod
from app.core.config import settings
logger = logging.getLogger(__name__)
class AIProvider(ABC):
"""Abstract base class for AI providers."""
@abstractmethod
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Generate a JSON response from the AI model.
Args:
system_prompt: System-level instruction for the model.
messages: List of message dicts with "role" and "content" keys.
max_tokens: Maximum output tokens.
Returns:
Tuple of (response_text, input_tokens, output_tokens).
"""
...
class GeminiProvider(AIProvider):
"""Google Gemini provider using the google-genai SDK."""
def __init__(self, api_key: str, model: str) -> None:
self._api_key = api_key
self._model = model
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
from google import genai
from google.genai import types as genai_types
client = genai.Client(api_key=self._api_key)
# Convert messages to Gemini Content format
contents: list[genai_types.Content] = []
for msg in messages:
role = "model" if msg["role"] == "assistant" else "user"
contents.append(
genai_types.Content(
role=role,
parts=[genai_types.Part(text=msg["content"])],
)
)
config = genai_types.GenerateContentConfig(
system_instruction=system_prompt,
max_output_tokens=max_tokens,
response_mime_type="application/json",
)
response = await client.aio.models.generate_content(
model=self._model,
contents=contents,
config=config,
)
# Log finish reason to detect truncation
if response.candidates:
finish_reason = getattr(response.candidates[0], "finish_reason", None)
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
if str(finish_reason) == "MAX_TOKENS":
logger.warning(
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
max_tokens,
)
text = response.text or ""
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
output_tokens = (
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
)
return text, input_tokens, output_tokens
class AnthropicProvider(AIProvider):
"""Anthropic Claude provider using the anthropic SDK."""
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
self._api_key = api_key
self._model = model
self._timeout = timeout
async def generate_json(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
import anthropic
client = anthropic.AsyncAnthropic(
api_key=self._api_key,
timeout=self._timeout,
)
response = await client.messages.create(
model=self._model,
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
)
text = response.content[0].text
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
return text, input_tokens, output_tokens
def get_ai_provider() -> AIProvider:
"""Factory that returns the configured AI provider.
Selection logic:
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
3. Fallback: if preferred provider key missing, try the other one
4. If nothing configured -> raise RuntimeError
"""
provider = settings.AI_PROVIDER
if provider == "gemini":
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
# Fallback to Anthropic
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
elif provider == "anthropic":
if settings.ANTHROPIC_API_KEY:
return AnthropicProvider(
api_key=settings.ANTHROPIC_API_KEY,
model=settings.AI_MODEL_ANTHROPIC,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
# Fallback to Gemini
if settings.GOOGLE_AI_API_KEY:
return GeminiProvider(
api_key=settings.GOOGLE_AI_API_KEY,
model=settings.AI_MODEL_GEMINI,
)
raise RuntimeError(
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
)

View File

@@ -1,11 +1,11 @@
"""AI-powered tree generation service using Anthropic Claude API.
"""AI-powered tree generation service.
Implements the 4-stage wizard flow:
Stage 2 (scaffold): AI suggests 4-7 top-level branches
Stage 3 (branch_detail): AI generates detailed nodes per branch
Stage 4 (assemble): Pure assembly logic — zero AI calls
System prompts are static constants to enable Anthropic prompt caching.
Uses the provider abstraction from ai_provider.py (supports Gemini + Anthropic).
"""
import json
import logging
@@ -13,8 +13,7 @@ import re
import uuid
from typing import Any
import anthropic
from app.core.ai_provider import get_ai_provider
from app.core.config import settings
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
@@ -121,15 +120,6 @@ def _strip_markdown_fences(text: str) -> str:
return text
def _get_client() -> anthropic.AsyncAnthropic:
"""Get configured async Anthropic client."""
if not settings.ANTHROPIC_API_KEY:
raise RuntimeError("ANTHROPIC_API_KEY not configured")
return anthropic.AsyncAnthropic(
api_key=settings.ANTHROPIC_API_KEY,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
"""Estimate USD cost from token counts."""
@@ -146,7 +136,7 @@ async def scaffold_branches(
Returns (branches, input_tokens, output_tokens, estimated_cost).
Raises ValueError on invalid response.
"""
client = _get_client()
provider = get_ai_provider()
flow_type = wizard_state.get("flow_type", "troubleshooting")
name = wizard_state.get("name", "")
@@ -161,21 +151,26 @@ async def scaffold_branches(
if tags:
user_message += f"Environment: {', '.join(tags)}\n"
response = await client.messages.create(
model=settings.AI_MODEL,
max_tokens=1024,
system=SCAFFOLD_SYSTEM_PROMPT,
raw_text, input_tokens, output_tokens = await provider.generate_json(
system_prompt=SCAFFOLD_SYSTEM_PROMPT,
messages=[{"role": "user", "content": user_message}],
max_tokens=2048,
)
raw_text = _strip_markdown_fences(response.content[0].text)
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
logger.info(
"scaffold raw response (tokens in=%d out=%d, len=%d): %s",
input_tokens,
output_tokens,
len(raw_text),
raw_text[:500],
)
raw_text = _strip_markdown_fences(raw_text)
cost = _estimate_cost(input_tokens, output_tokens)
try:
data = json.loads(raw_text)
except json.JSONDecodeError as e:
logger.error("scaffold JSON parse failed. Full text (%d chars): %s", len(raw_text), raw_text)
raise ValueError(f"AI returned invalid JSON: {e}")
branches = data.get("branches", [])
@@ -196,7 +191,7 @@ async def generate_branch_detail(
On validation failure, retries once with corrective prompt.
Raises ValueError if both attempts fail.
"""
client = _get_client()
provider = get_ai_provider()
flow_type = wizard_state.get("flow_type", "troubleshooting")
name = wizard_state.get("name", "")
@@ -217,35 +212,30 @@ async def generate_branch_detail(
total_output = 0
for attempt in range(3):
response = await client.messages.create(
model=settings.AI_MODEL,
max_tokens=8192,
system=BRANCH_DETAIL_SYSTEM_PROMPT,
raw_text, input_tokens, output_tokens = await provider.generate_json(
system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT,
messages=messages,
max_tokens=8192,
)
total_input += response.usage.input_tokens
total_output += response.usage.output_tokens
total_input += input_tokens
total_output += output_tokens
logger.debug(
"branch_detail attempt=%d stop_reason=%s content_blocks=%d output_tokens=%d",
"branch_detail attempt=%d output_tokens=%d",
attempt,
response.stop_reason,
len(response.content),
response.usage.output_tokens,
output_tokens,
)
if response.stop_reason == "max_tokens":
logger.warning(
"branch_detail attempt=%d hit max_tokens limit (%d output tokens) — response may be truncated",
attempt,
response.usage.output_tokens,
)
raw_text = _strip_markdown_fences(response.content[0].text) if response.content else ""
raw_text = _strip_markdown_fences(raw_text) if raw_text else ""
if not raw_text:
logger.warning("branch_detail attempt=%d returned empty text, stop_reason=%s", attempt, response.stop_reason)
logger.warning("branch_detail attempt=%d returned empty text", attempt)
try:
branch_tree = json.loads(raw_text)
except json.JSONDecodeError as e:
logger.error(
"branch_detail attempt=%d JSON parse failed (%d chars): %s",
attempt, len(raw_text), raw_text[:500],
)
if attempt < 2:
messages.append({"role": "assistant", "content": raw_text})
messages.append({

View File

@@ -78,11 +78,16 @@ class Settings(BaseSettings):
AI_CONVERSATION_TTL_HOURS: int = 24
AI_MAX_CALLS_PER_FLOW: int = 10
AI_REQUEST_TIMEOUT_SECONDS: int = 45
# AI Provider selection
AI_PROVIDER: str = "gemini" # "gemini" or "anthropic"
GOOGLE_AI_API_KEY: Optional[str] = None
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
@property
def ai_enabled(self) -> bool:
"""Check if AI Flow Builder is configured."""
return self.ANTHROPIC_API_KEY is not None
"""Check if any AI provider is configured."""
return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None
# Deployment auto-seed test data on PR environments
SEED_ON_DEPLOY: bool = False

View File

@@ -0,0 +1,52 @@
"""Pydantic schemas for the AI auto-fix feature."""
from typing import Any, Literal
from pydantic import BaseModel, Field
# ── Requests ──
class ValidationErrorInput(BaseModel):
"""A single validation error to fix."""
node_id: str = Field(..., description="ID of the node with the error")
message: str = Field(..., description="The validation error message")
class AIFixTreeRequest(BaseModel):
"""Request to generate AI fixes for validation errors."""
tree_structure: dict[str, Any] = Field(..., description="Full tree structure")
tree_name: str = Field("", max_length=255, description="Name of the flow")
tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field(
"troubleshooting", description="Type of flow"
)
validation_errors: list[ValidationErrorInput] = Field(
..., min_length=1, max_length=10, description="Errors to fix"
)
# ── Responses ──
class AIFixProposal(BaseModel):
"""A single proposed fix from the AI."""
target_node_id: str
error_message: str
description: str
original_node: dict[str, Any]
fixed_node: dict[str, Any]
class AIFixTokenUsage(BaseModel):
input: int = 0
output: int = 0
class AIFixTreeResponse(BaseModel):
"""Response with proposed fixes."""
fixes: list[AIFixProposal]
tokens_used: AIFixTokenUsage

View File

@@ -33,6 +33,7 @@ httpx>=0.27.0
# AI Flow Builder
anthropic>=0.40.0
google-genai>=1.0.0
# Utilities
python-dotenv==1.0.1

View File

@@ -1,6 +1,6 @@
"""Integration tests for AI Flow Builder endpoints.
All Anthropic API calls are mocked — zero real API spend.
All AI provider calls are mocked — zero real API spend.
"""
import json
from unittest.mock import AsyncMock, patch, MagicMock
@@ -64,12 +64,11 @@ BRANCH_DETAIL_JSON = json.dumps({
})
def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200):
"""Create a mock Anthropic API response."""
response = MagicMock()
response.content = [MagicMock(text=text)]
response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens)
return response
def _mock_ai_provider(text: str, input_tokens: int = 100, output_tokens: int = 200):
"""Create a mock AI provider whose generate_json returns the given text and token counts."""
provider = MagicMock()
provider.generate_json = AsyncMock(return_value=(text, input_tokens, output_tokens))
return provider
@pytest.fixture
@@ -194,11 +193,9 @@ async def test_scaffold_success(client, auth_headers, enable_ai):
)
conversation_id = start_resp.json()["conversation_id"]
# Mock Anthropic
mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(return_value=mock_response)
# Mock AI provider
mock_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=mock_provider):
response = await client.post(
"/api/v1/ai/scaffold",
json={"conversation_id": conversation_id},
@@ -241,9 +238,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
)
conversation_id = start_resp.json()["conversation_id"]
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
await client.post(
"/api/v1/ai/scaffold",
json={"conversation_id": conversation_id},
@@ -251,10 +247,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
)
# Now generate branch detail
detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock)
detail_provider = _mock_ai_provider(BRANCH_DETAIL_JSON)
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=detail_provider):
response = await client.post(
"/api/v1/ai/branch-detail",
json={
@@ -290,9 +284,8 @@ async def test_assemble_success(client, auth_headers, enable_ai):
conversation_id = start_resp.json()["conversation_id"]
# Scaffold
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
await client.post(
"/api/v1/ai/scaffold",
json={"conversation_id": conversation_id},

View File

@@ -0,0 +1,169 @@
"""Integration tests for the POST /ai/fix-tree endpoint."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.core.config import settings
# ── Sample tree (has a decision node with only 1 option + 1 child) ──
SAMPLE_TREE = {
"id": "root",
"type": "decision",
"question": "Is the server up?",
"options": [
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
],
"children": [
{
"id": "check-logs",
"type": "action",
"title": "Check Logs",
"description": "Review logs.",
"next_node_id": "logs-ok",
},
{
"id": "logs-ok",
"type": "solution",
"title": "Logs OK",
"description": "Issue in logs.",
},
{
"id": "restart",
"type": "decision",
"question": "Did restart work?",
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
"children": [
{
"id": "done",
"type": "solution",
"title": "Done",
"description": "Fixed.",
}
],
},
],
}
# Fixed version of the "restart" node — 2 options, 2 children
FIXED_RESTART_NODE = {
"id": "restart",
"type": "decision",
"question": "Did restart work?",
"options": [
{"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"},
{"id": "opt-r-no", "label": "No", "next_node_id": "escalate"},
],
"children": [
{
"id": "done",
"type": "solution",
"title": "Done",
"description": "Fixed.",
},
{
"id": "escalate",
"type": "solution",
"title": "Escalate",
"description": "Escalate to senior engineer.",
},
],
}
FIX_REQUEST_BODY = {
"tree_structure": SAMPLE_TREE,
"tree_name": "Server Troubleshooting",
"tree_type": "troubleshooting",
"validation_errors": [
{
"node_id": "restart",
"message": "Decision node 'restart' must have at least 2 options",
}
],
}
def _mock_ai_provider(response_text: str, input_tokens: int = 50, output_tokens: int = 100):
"""Create a mock provider whose generate_json returns given text."""
provider = MagicMock()
provider.generate_json = AsyncMock(return_value=(response_text, input_tokens, output_tokens))
return provider
@pytest.fixture
def enable_ai():
"""Temporarily enable AI by setting a fake API key."""
original = settings.GOOGLE_AI_API_KEY
settings.GOOGLE_AI_API_KEY = "test-key-fake"
yield
settings.GOOGLE_AI_API_KEY = original
@pytest.fixture
def disable_ai():
"""Ensure AI is disabled."""
orig_google = settings.GOOGLE_AI_API_KEY
orig_anthropic = settings.ANTHROPIC_API_KEY
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = None
yield
settings.GOOGLE_AI_API_KEY = orig_google
settings.ANTHROPIC_API_KEY = orig_anthropic
# ── Tests ──
@pytest.mark.asyncio
async def test_returns_401_without_auth(client):
"""POST /ai/fix-tree without auth token returns 401."""
response = await client.post("/api/v1/ai/fix-tree", json=FIX_REQUEST_BODY)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_returns_503_when_ai_disabled(client, auth_headers, disable_ai):
"""POST /ai/fix-tree returns 503 when no AI keys are configured."""
response = await client.post(
"/api/v1/ai/fix-tree",
json=FIX_REQUEST_BODY,
headers=auth_headers,
)
assert response.status_code == 503
@pytest.mark.asyncio
async def test_returns_fixes_on_success(client, auth_headers, enable_ai):
"""POST /ai/fix-tree returns fix proposals when AI succeeds."""
mock_provider = _mock_ai_provider(json.dumps(FIXED_RESTART_NODE))
with patch(
"app.core.ai_fix_service.get_ai_provider",
return_value=mock_provider,
):
response = await client.post(
"/api/v1/ai/fix-tree",
json=FIX_REQUEST_BODY,
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "fixes" in data
assert "tokens_used" in data
assert len(data["fixes"]) == 1
fix = data["fixes"][0]
assert fix["target_node_id"] == "restart"
assert fix["error_message"] == "Decision node 'restart' must have at least 2 options"
assert fix["original_node"]["id"] == "restart"
assert fix["fixed_node"]["id"] == "restart"
assert len(fix["fixed_node"]["options"]) == 2
assert len(fix["fixed_node"]["children"]) == 2
assert data["tokens_used"]["input"] == 50
assert data["tokens_used"]["output"] == 100

View File

@@ -0,0 +1,224 @@
"""Unit tests for AI fix service helper functions.
Tests pure Python helpers only — no AI mocking needed.
"""
import pytest
from app.core.ai_fix_service import (
_find_node_by_id,
_find_parent_node,
_serialize_tree_outline,
_strip_markdown_fences,
_replace_node_in_tree,
_describe_fix,
)
# ── Sample tree ──
SAMPLE_TREE = {
"id": "root",
"type": "decision",
"question": "Is the server up?",
"options": [
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
],
"children": [
{
"id": "check-logs",
"type": "action",
"title": "Check Logs",
"description": "Review logs.",
"next_node_id": "logs-ok",
},
{
"id": "logs-ok",
"type": "solution",
"title": "Logs OK",
"description": "Issue in logs.",
},
{
"id": "restart",
"type": "decision",
"question": "Did restart work?",
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
"children": [
{
"id": "done",
"type": "solution",
"title": "Done",
"description": "Fixed.",
}
],
},
],
}
# ── _find_node_by_id ──
class TestFindNodeById:
def test_finds_root(self):
node = _find_node_by_id(SAMPLE_TREE, "root")
assert node is not None
assert node["id"] == "root"
assert node["type"] == "decision"
def test_finds_nested_child(self):
node = _find_node_by_id(SAMPLE_TREE, "done")
assert node is not None
assert node["id"] == "done"
assert node["type"] == "solution"
def test_finds_direct_child(self):
node = _find_node_by_id(SAMPLE_TREE, "check-logs")
assert node is not None
assert node["title"] == "Check Logs"
def test_returns_none_for_missing(self):
node = _find_node_by_id(SAMPLE_TREE, "nonexistent")
assert node is None
def test_returns_none_for_non_dict(self):
assert _find_node_by_id("not a dict", "root") is None
# ── _find_parent_node ──
class TestFindParentNode:
def test_root_has_no_parent(self):
parent = _find_parent_node(SAMPLE_TREE, "root")
assert parent is None
def test_finds_parent_of_direct_child(self):
parent = _find_parent_node(SAMPLE_TREE, "check-logs")
assert parent is not None
assert parent["id"] == "root"
def test_finds_parent_of_deeply_nested(self):
parent = _find_parent_node(SAMPLE_TREE, "done")
assert parent is not None
assert parent["id"] == "restart"
def test_returns_none_for_missing(self):
parent = _find_parent_node(SAMPLE_TREE, "nonexistent")
assert parent is None
def test_returns_none_for_non_dict(self):
assert _find_parent_node("not a dict", "root") is None
# ── _serialize_tree_outline ──
class TestSerializeTreeOutline:
def test_produces_readable_outline(self):
outline = _serialize_tree_outline(SAMPLE_TREE)
assert "- [decision] Is the server up?" in outline
assert " - [action] Check Logs" in outline
assert " - [solution] Logs OK" in outline
assert " - [solution] Done" in outline
def test_marks_error_node(self):
outline = _serialize_tree_outline(SAMPLE_TREE, error_node_id="restart")
assert "<<< ERROR HERE" in outline
# Only the restart node should be marked
lines = outline.split("\n")
error_lines = [l for l in lines if "ERROR HERE" in l]
assert len(error_lines) == 1
assert "Did restart work?" in error_lines[0]
def test_no_error_marker_when_none(self):
outline = _serialize_tree_outline(SAMPLE_TREE)
assert "ERROR HERE" not in outline
def test_handles_non_dict(self):
assert _serialize_tree_outline("not a dict") == ""
def test_indentation_increases_with_depth(self):
outline = _serialize_tree_outline(SAMPLE_TREE)
lines = outline.split("\n")
# Root has no indentation
assert lines[0].startswith("- [decision]")
# Children have 2-space indent
child_lines = [l for l in lines if "Check Logs" in l]
assert child_lines[0].startswith(" - ")
# ── _strip_markdown_fences ──
class TestStripMarkdownFences:
def test_strips_json_fences(self):
text = '```json\n{"key": "value"}\n```'
assert _strip_markdown_fences(text) == '{"key": "value"}'
def test_strips_plain_fences(self):
text = '```\n{"key": "value"}\n```'
assert _strip_markdown_fences(text) == '{"key": "value"}'
def test_passes_through_plain_json(self):
text = '{"key": "value"}'
assert _strip_markdown_fences(text) == '{"key": "value"}'
# ── _replace_node_in_tree ──
class TestReplaceNodeInTree:
def test_replaces_root(self):
import copy
tree = copy.deepcopy(SAMPLE_TREE)
replacement = {"id": "root", "type": "decision", "question": "New question"}
assert _replace_node_in_tree(tree, "root", replacement) is True
assert tree["question"] == "New question"
assert "children" not in tree # cleared and replaced
def test_replaces_nested_node(self):
import copy
tree = copy.deepcopy(SAMPLE_TREE)
replacement = {"id": "done", "type": "solution", "title": "All Done", "description": "Complete."}
assert _replace_node_in_tree(tree, "done", replacement) is True
found = _find_node_by_id(tree, "done")
assert found["title"] == "All Done"
def test_returns_false_for_missing(self):
import copy
tree = copy.deepcopy(SAMPLE_TREE)
assert _replace_node_in_tree(tree, "nonexistent", {"id": "x"}) is False
# ── _describe_fix ──
class TestDescribeFix:
def test_describes_added_children(self):
original = {"id": "n1", "children": [{"id": "c1"}]}
fixed = {"id": "n1", "children": [{"id": "c1"}, {"id": "c2"}]}
desc = _describe_fix(original, fixed)
assert "1 child node" in desc
def test_describes_added_options(self):
original = {"id": "n1", "options": [{"id": "o1"}]}
fixed = {"id": "n1", "options": [{"id": "o1"}, {"id": "o2"}]}
desc = _describe_fix(original, fixed)
assert "1 option" in desc
def test_describes_added_next_node_id(self):
original = {"id": "n1", "type": "action"}
fixed = {"id": "n1", "type": "action", "next_node_id": "n2"}
desc = _describe_fix(original, fixed)
assert "next_node_id" in desc
def test_fallback_description(self):
original = {"id": "n1", "type": "solution"}
fixed = {"id": "n1", "type": "solution"}
desc = _describe_fix(original, fixed)
assert "fixed structural issue" in desc.lower()

View File

@@ -0,0 +1,216 @@
"""Tests for the AI provider abstraction layer."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import sys
from app.core.ai_provider import (
AIProvider,
AnthropicProvider,
GeminiProvider,
get_ai_provider,
)
from app.core.config import settings
class TestGetAIProvider:
"""Tests for the get_ai_provider factory function."""
def test_returns_gemini_when_configured(self):
original_provider = settings.AI_PROVIDER
original_key = settings.GOOGLE_AI_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
provider = get_ai_provider()
assert isinstance(provider, GeminiProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_key
def test_returns_anthropic_when_configured(self):
original_provider = settings.AI_PROVIDER
original_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "anthropic"
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
provider = get_ai_provider()
assert isinstance(provider, AnthropicProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.ANTHROPIC_API_KEY = original_key
def test_fallback_to_anthropic_when_gemini_key_missing(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
provider = get_ai_provider()
assert isinstance(provider, AnthropicProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
def test_fallback_to_gemini_when_anthropic_key_missing(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "anthropic"
settings.ANTHROPIC_API_KEY = None
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
provider = get_ai_provider()
assert isinstance(provider, GeminiProvider)
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
def test_raises_when_no_provider_configured(self):
original_provider = settings.AI_PROVIDER
original_gemini_key = settings.GOOGLE_AI_API_KEY
original_anthropic_key = settings.ANTHROPIC_API_KEY
try:
settings.AI_PROVIDER = "gemini"
settings.GOOGLE_AI_API_KEY = None
settings.ANTHROPIC_API_KEY = None
with pytest.raises(RuntimeError, match="No AI provider configured"):
get_ai_provider()
finally:
settings.AI_PROVIDER = original_provider
settings.GOOGLE_AI_API_KEY = original_gemini_key
settings.ANTHROPIC_API_KEY = original_anthropic_key
class TestAnthropicProvider:
"""Tests for AnthropicProvider.generate_json."""
@pytest.mark.asyncio
async def test_generate_json(self):
provider = AnthropicProvider(
api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30
)
mock_response = MagicMock()
mock_response.content = [MagicMock(text='{"result": "ok"}')]
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
mock_client = AsyncMock()
mock_client.messages.create = AsyncMock(return_value=mock_response)
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="You are a helper.",
messages=[{"role": "user", "content": "Hello"}],
max_tokens=1024,
)
assert text == '{"result": "ok"}'
assert input_tokens == 100
assert output_tokens == 50
mock_client.messages.create.assert_called_once_with(
model="claude-haiku-4-5-20251001",
max_tokens=1024,
system="You are a helper.",
messages=[{"role": "user", "content": "Hello"}],
)
class TestGeminiProvider:
"""Tests for GeminiProvider.generate_json."""
@pytest.mark.asyncio
async def test_generate_json(self):
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
mock_usage = MagicMock()
mock_usage.prompt_token_count = 80
mock_usage.candidates_token_count = 40
mock_response = MagicMock()
mock_response.text = '{"answer": 42}'
mock_response.usage_metadata = mock_usage
mock_client = MagicMock()
mock_client.aio.models.generate_content = AsyncMock(
return_value=mock_response
)
mock_genai_module = MagicMock()
mock_genai_module.Client.return_value = mock_client
mock_types = MagicMock()
mock_types.Content.side_effect = lambda **kw: kw
mock_types.Part.side_effect = lambda **kw: kw
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
mock_google = MagicMock()
mock_google.genai = mock_genai_module
mock_genai_module.types = mock_types
with patch.dict(sys.modules, {
"google": mock_google,
"google.genai": mock_genai_module,
"google.genai.types": mock_types,
}):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="Generate JSON.",
messages=[
{"role": "user", "content": "Give me data"},
{"role": "assistant", "content": "Here it is"},
{"role": "user", "content": "More please"},
],
max_tokens=2048,
)
assert text == '{"answer": 42}'
assert input_tokens == 80
assert output_tokens == 40
mock_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_generate_json_handles_none_usage(self):
"""Token counts default to 0 when usage_metadata attributes are None."""
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
mock_usage = MagicMock(spec=[]) # No attributes at all
mock_response = MagicMock()
mock_response.text = "{}"
mock_response.usage_metadata = mock_usage
mock_client = MagicMock()
mock_client.aio.models.generate_content = AsyncMock(
return_value=mock_response
)
mock_genai_module = MagicMock()
mock_genai_module.Client.return_value = mock_client
mock_types = MagicMock()
mock_types.Content.side_effect = lambda **kw: kw
mock_types.Part.side_effect = lambda **kw: kw
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
mock_google = MagicMock()
mock_google.genai = mock_genai_module
mock_genai_module.types = mock_types
with patch.dict(sys.modules, {
"google": mock_google,
"google.genai": mock_genai_module,
"google.genai.types": mock_types,
}):
text, input_tokens, output_tokens = await provider.generate_json(
system_prompt="test",
messages=[{"role": "user", "content": "test"}],
)
assert text == "{}"
assert input_tokens == 0
assert output_tokens == 0

View File

@@ -0,0 +1,209 @@
# AI Auto-Fix & Gemini Flash Provider Design
> **Date:** 2026-02-26
> **Status:** Approved
---
## Overview
Two combined features:
1. **AI Provider Abstraction** — Add Gemini 2.5 Flash as the default AI provider with Claude as fallback, behind a unified interface.
2. **AI Auto-Fix for Validation Errors** — When a flow fails validation, offer an AI-powered "Fix with AI" button that generates structural fixes for review.
---
## Section 1: AI Provider Abstraction
### Design
New `backend/app/core/ai_provider.py` with a unified interface:
```python
class AIProvider(ABC):
async def generate_json(
self,
system_prompt: str,
messages: list[dict],
max_tokens: int = 4096,
) -> tuple[str, int, int]:
"""Returns (text, input_tokens, output_tokens)"""
```
Two implementations:
| Provider | Model | SDK | Role |
|----------|-------|-----|------|
| `GeminiProvider` | `gemini-2.5-flash` | `google-genai` | Default |
| `AnthropicProvider` | `claude-haiku-4-5-20251001` | `anthropic` | Fallback |
### Provider Selection
- `get_ai_provider()` factory reads `AI_PROVIDER` env var (default: `"gemini"`)
- Falls back to Anthropic if Gemini key is missing
- Existing `ai_tree_generator_service.py` swaps direct Anthropic calls for `get_ai_provider()`
### New Environment Variables
| Variable | Default | Purpose |
|----------|---------|---------|
| `AI_PROVIDER` | `"gemini"` | Which provider to use (`gemini` or `anthropic`) |
| `GOOGLE_AI_API_KEY` | — | Gemini API key |
Existing `ANTHROPIC_API_KEY` remains for fallback.
### Config Changes (`core/config.py`)
```python
AI_PROVIDER: str = "gemini"
GOOGLE_AI_API_KEY: str | None = None
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
```
---
## Section 2: AI Auto-Fix Feature
### Backend Endpoint
**`POST /api/v1/ai/fix-tree`**
Request:
```json
{
"tree_structure": { /* full tree */ },
"tree_name": "Router Troubleshooting",
"tree_type": "troubleshooting",
"validation_errors": [
{
"node_id": "node_abc",
"message": "Decision node must have at least 2 children (branches)"
}
]
}
```
Response:
```json
{
"fixes": [
{
"target_node_id": "node_abc",
"error_message": "Decision node must have at least 2 children (branches)",
"description": "Added second branch 'Check firmware version' with solution node",
"original_node": { /* snapshot before fix */ },
"fixed_node": { /* replacement node with corrected subtree */ }
}
],
"tokens_used": { "input": 1200, "output": 800 }
}
```
### How It Works
1. For each validation error tied to a `node_id`, extract that node + its parent + siblings from the tree.
2. Build a prompt with:
- The **full tree structure** serialized as a simplified outline (node titles + types + structure) for context
- The **specific failing node** highlighted with full JSON detail
- The **validation error message**
- Instructions: "Fix ONLY this node's structural issue. Keep all existing content. Generate domain-relevant additions that fit the flow's topic."
3. AI returns a corrected version of that node (with children/options adjusted).
4. Backend re-validates the fixed node before returning it.
5. If re-validation fails, retry once with the error fed back (corrective prompt pattern).
### Prompt Strategy
The prompt gives the AI the full tree as a compact outline, then zooms into the failing node:
```
You are fixing a validation error in a troubleshooting flow called "Router Troubleshooting".
FULL FLOW OUTLINE:
- [decision] Is the router powered on?
- [action] Check power cable → [solution] Power restored
- [decision] Are lights blinking? ← ERROR HERE
- [solution] Contact ISP
ERROR: Decision node "Are lights blinking?" must have at least 2 children (branches).
FAILING NODE (full detail):
{...json...}
Fix this node by adding the minimum structure needed to resolve the error.
Return ONLY the fixed node as JSON.
```
### Frontend UX
1. **Trigger**: "Fix with AI" button in `ValidationSummary` — appears when there are fixable errors (structural errors with a `node_id`).
2. **Loading state**: Button shows spinner + "Generating fixes..." — disabled during request.
3. **Review modal** (`AIFixReviewModal`): Shows each proposed fix as a card:
- Error message at top
- Before/after view of the node change
- "Apply" / "Skip" buttons per fix
- "Apply All" button in footer
4. **Apply**: Each accepted fix calls `updateNode(targetNodeId, fixedNode)` in the tree editor store.
5. **Re-validate**: After applying fixes, auto-run `validate()` to confirm resolution.
---
## Section 3: Scope & Constraints
### Fixable Errors (Auto-Fix Scope)
Only structural validation errors with a `node_id`:
- Decision node missing children/branches
- Decision node missing options
- Action node missing `next_node_id`
- Dead-end decision nodes (no children)
### NOT Fixable
- Global checks (tree too small/large, not enough solutions) — require rethinking the whole tree
- Content quality issues — out of scope
- Errors without a `node_id` (root-level issues)
Non-fixable errors still show in ValidationSummary but without the "Fix with AI" option.
### Token Budget
- Tree outline: ~50-100 tokens for a typical 15-node tree
- Failing node detail: ~100-200 tokens
- System prompt + instructions: ~300 tokens
- **Total input per fix: ~500-600 tokens**
- One API call per failing node (not batched)
### Error Handling
- Provider failure (rate limit, network): toast error, user can retry
- Fix fails re-validation: "AI couldn't generate a valid fix" with retry option
- Max 1 retry with corrective prompt per attempt
- Both provider and fallback fail: surface error to user
### Auth
- Requires `engineer` role or above (`require_engineer_or_admin`)
---
## New Files
| File | Purpose |
|------|---------|
| `backend/app/core/ai_provider.py` | Provider abstraction + Gemini/Anthropic implementations |
| `backend/app/core/ai_fix_service.py` | Fix generation logic + prompt building |
| `backend/app/api/endpoints/ai.py` | `POST /ai/fix-tree` endpoint |
| `backend/app/schemas/ai.py` | Request/response schemas for AI endpoints |
| `frontend/src/components/tree-editor/AIFixReviewModal.tsx` | Review modal for proposed fixes |
## Modified Files
| File | Change |
|------|--------|
| `backend/app/core/config.py` | Add Gemini config vars |
| `backend/app/core/ai_tree_generator_service.py` | Swap Anthropic calls for provider abstraction |
| `backend/app/api/router.py` | Register `/ai` routes |
| `frontend/src/api/trees.ts` | Add `fixTree()` API call |
| `frontend/src/components/tree-editor/ValidationSummary.tsx` | Add "Fix with AI" button |

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
import apiClient from './client'
import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse } from '@/types'
import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse, AIFixTreeRequest, AIFixTreeResponse } from '@/types'
export const treesApi = {
async list(params?: TreeFilters): Promise<TreeListItem[]> {
@@ -65,6 +65,12 @@ export const treesApi = {
const response = await apiClient.post<TreeValidationResponse>(`/trees/${id}/can-publish`)
return response.data
},
// AI auto-fix
async fixTree(request: AIFixTreeRequest): Promise<AIFixTreeResponse> {
const response = await apiClient.post<AIFixTreeResponse>('/ai/fix-tree', request)
return response.data
},
}
export default treesApi

View File

@@ -0,0 +1,170 @@
import { useState } from 'react'
import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react'
import { cn } from '@/lib/utils'
import type { AIFixProposal } from '@/types'
interface AIFixReviewModalProps {
fixes: AIFixProposal[]
onApply: (fix: AIFixProposal) => void
onApplyAll: () => void
onClose: () => void
}
export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) {
const [appliedIds, setAppliedIds] = useState<Set<string>>(new Set())
const [skippedIds, setSkippedIds] = useState<Set<string>>(new Set())
const [expandedIds, setExpandedIds] = useState<Set<string>>(new Set(fixes.map(f => f.target_node_id)))
const handleApply = (fix: AIFixProposal) => {
onApply(fix)
setAppliedIds(prev => new Set(prev).add(fix.target_node_id))
}
const handleSkip = (fix: AIFixProposal) => {
setSkippedIds(prev => new Set(prev).add(fix.target_node_id))
}
const toggleExpanded = (id: string) => {
setExpandedIds(prev => {
const next = new Set(prev)
if (next.has(id)) next.delete(id)
else next.add(id)
return next
})
}
const pendingFixes = fixes.filter(
f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id)
)
const allHandled = pendingFixes.length === 0
return (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/80 backdrop-blur-sm p-4">
<div className="relative flex h-[80vh] w-full max-w-2xl flex-col bg-card border border-border rounded-2xl shadow-lg">
{/* Header */}
<div className="flex items-center justify-between border-b border-border px-6 py-4">
<div className="flex items-center gap-2">
<Sparkles className="h-5 w-5 text-primary" />
<h2 className="text-lg font-semibold text-foreground">
AI Fix Proposals ({fixes.length})
</h2>
</div>
<button
onClick={onClose}
className="rounded-md p-1 text-muted-foreground hover:bg-accent hover:text-foreground"
>
<X className="h-5 w-5" />
</button>
</div>
{/* Body */}
<div className="flex-1 overflow-y-auto p-4 space-y-3">
{fixes.map((fix) => {
const isApplied = appliedIds.has(fix.target_node_id)
const isSkipped = skippedIds.has(fix.target_node_id)
const isExpanded = expandedIds.has(fix.target_node_id)
return (
<div
key={fix.target_node_id}
className={cn(
'rounded-lg border p-4',
isApplied
? 'border-emerald-400/30 bg-emerald-400/5'
: isSkipped
? 'border-border bg-accent/30 opacity-60'
: 'border-border bg-card'
)}
>
{/* Fix header */}
<div className="flex items-start justify-between gap-3">
<div className="flex-1">
<p className="text-sm text-red-400 mb-1">{fix.error_message}</p>
<p className="text-sm text-foreground">{fix.description}</p>
<p className="text-xs text-muted-foreground mt-1">
Node: {fix.target_node_id}
</p>
</div>
{isApplied && (
<span className="flex items-center gap-1 rounded-full bg-emerald-400/10 px-2 py-1 text-xs text-emerald-400">
<Check className="h-3 w-3" /> Applied
</span>
)}
{isSkipped && (
<span className="text-xs text-muted-foreground">Skipped</span>
)}
</div>
{/* Expand/collapse detail */}
{!isApplied && !isSkipped && (
<>
<button
onClick={() => toggleExpanded(fix.target_node_id)}
className="mt-2 flex items-center gap-1 text-xs text-muted-foreground hover:text-foreground"
>
{isExpanded ? <ChevronUp className="h-3 w-3" /> : <ChevronDown className="h-3 w-3" />}
{isExpanded ? 'Hide' : 'Show'} details
</button>
{isExpanded && (
<div className="mt-3 grid grid-cols-2 gap-3">
<div>
<p className="text-xs font-medium text-muted-foreground mb-1">Before</p>
<pre className="overflow-x-auto rounded bg-accent/50 p-2 text-xs text-muted-foreground max-h-48 overflow-y-auto">
{JSON.stringify(fix.original_node, null, 2)}
</pre>
</div>
<div>
<p className="text-xs font-medium text-emerald-400 mb-1">After</p>
<pre className="overflow-x-auto rounded bg-emerald-400/5 p-2 text-xs text-foreground max-h-48 overflow-y-auto">
{JSON.stringify(fix.fixed_node, null, 2)}
</pre>
</div>
</div>
)}
{/* Action buttons */}
<div className="mt-3 flex gap-2">
<button
onClick={() => handleApply(fix)}
className="flex items-center gap-1 rounded-md bg-gradient-brand px-3 py-1.5 text-xs font-medium text-white shadow-sm shadow-primary/20 hover:opacity-90"
>
<Check className="h-3 w-3" />
Apply
</button>
<button
onClick={() => handleSkip(fix)}
className="flex items-center gap-1 rounded-md border border-border px-3 py-1.5 text-xs font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
>
<SkipForward className="h-3 w-3" />
Skip
</button>
</div>
</>
)}
</div>
)
})}
</div>
{/* Footer */}
<div className="flex items-center justify-between border-t border-border px-6 py-4">
<button
onClick={onClose}
className="rounded-md border border-border px-4 py-2 text-sm font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
>
{allHandled ? 'Done' : 'Cancel'}
</button>
{!allHandled && (
<button
onClick={onApplyAll}
className="rounded-md bg-gradient-brand px-4 py-2 text-sm font-medium text-white shadow-lg shadow-primary/20 hover:opacity-90"
>
Apply All ({pendingFixes.length})
</button>
)}
</div>
</div>
</div>
)
}

View File

@@ -1,14 +1,16 @@
import { useState } from 'react'
import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp } from 'lucide-react'
import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react'
import { cn } from '@/lib/utils'
import type { ValidationError } from '@/store/treeEditorStore'
interface ValidationSummaryProps {
errors: ValidationError[]
onSelectNode: (nodeId: string) => void
onFixWithAI?: () => void
isFixing?: boolean
}
export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryProps) {
export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) {
const [isExpanded, setIsExpanded] = useState(true)
const errorItems = errors.filter(e => e.severity === 'error')
@@ -22,6 +24,8 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
}
}
const hasFixableErrors = errorItems.some(e => e.nodeId)
return (
<div
className={cn(
@@ -32,14 +36,16 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
)}
>
{/* Header */}
<button
onClick={() => setIsExpanded(!isExpanded)}
<div
className={cn(
'flex w-full items-center justify-between p-3 text-left transition-colors hover:bg-accent',
'flex w-full items-center justify-between p-3 transition-colors',
errorItems.length > 0 ? 'text-red-400' : 'text-yellow-400'
)}
>
<div className="flex items-center gap-2">
<button
onClick={() => setIsExpanded(!isExpanded)}
className="flex items-center gap-2 text-left hover:opacity-80"
>
{errorItems.length > 0 ? (
<AlertCircle className="h-5 w-5" />
) : (
@@ -58,9 +64,35 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
</>
)}
</span>
</div>
{isExpanded ? <ChevronUp className="h-4 w-4" /> : <ChevronDown className="h-4 w-4" />}
</button>
{isExpanded ? <ChevronUp className="h-4 w-4" /> : <ChevronDown className="h-4 w-4" />}
</button>
{/* Fix with AI button */}
{onFixWithAI && hasFixableErrors && (
<button
onClick={onFixWithAI}
disabled={isFixing}
className={cn(
'flex items-center gap-1.5 rounded-md px-3 py-1 text-xs font-medium transition-colors',
isFixing
? 'bg-primary/10 text-primary cursor-wait'
: 'bg-gradient-brand text-white shadow-sm shadow-primary/20 hover:opacity-90'
)}
>
{isFixing ? (
<>
<Loader2 className="h-3 w-3 animate-spin" />
Generating fixes...
</>
) : (
<>
<Sparkles className="h-3 w-3" />
Fix with AI
</>
)}
</button>
)}
</div>
{/* Error/Warning List */}
{isExpanded && (

View File

@@ -5,10 +5,11 @@ import { Undo2, Redo2, Save, CheckCircle2, Monitor, FileText, Code2, LayoutList,
import { getMonacoEditor } from '@/components/tree-editor/code-mode'
import { treesApi } from '@/api/trees'
import { treeMarkdownApi } from '@/api/treeMarkdown'
import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure } from '@/types'
import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure, AIFixProposal } from '@/types'
import { useTreeEditorStore, useTreeEditorTemporal } from '@/store/treeEditorStore'
import { TreeEditorLayout } from '@/components/tree-editor/TreeEditorLayout'
import { ValidationSummary } from '@/components/tree-editor/ValidationSummary'
import { AIFixReviewModal } from '@/components/tree-editor/AIFixReviewModal'
import { useKeyboardShortcuts } from '@/hooks/useKeyboardShortcuts'
import { usePermissions } from '@/hooks/usePermissions'
import { Spinner } from '@/components/common/Spinner'
@@ -58,6 +59,8 @@ export function TreeEditorPage() {
const [showAnalytics, setShowAnalytics] = useState(false)
const [isMetadataOpen, setIsMetadataOpen] = useState(false)
const [editingNodeId, setEditingNodeId] = useState<string | null>(null)
const [isFixing, setIsFixing] = useState(false)
const [fixProposals, setFixProposals] = useState<AIFixProposal[] | null>(null)
// Mobile detection
const [isMobile, setIsMobile] = useState(false)
@@ -217,6 +220,54 @@ export function TreeEditorPage() {
selectNode(nodeId)
}
const handleFixWithAI = async () => {
const store = useTreeEditorStore.getState()
if (!store.treeStructure) return
const fixableErrors = store.validationErrors
.filter(e => e.severity === 'error' && e.nodeId)
.map(e => ({ node_id: e.nodeId!, message: e.message }))
if (fixableErrors.length === 0) return
setIsFixing(true)
try {
const result = await treesApi.fixTree({
tree_structure: store.treeStructure as unknown as Record<string, unknown>,
tree_name: store.name,
tree_type: 'troubleshooting',
validation_errors: fixableErrors,
})
if (result.fixes.length > 0) {
setFixProposals(result.fixes)
} else {
toast.info('AI could not generate fixes for these errors')
}
} catch {
toast.error('Failed to generate AI fixes. Please try again.')
} finally {
setIsFixing(false)
}
}
const handleApplyFix = (fix: AIFixProposal) => {
updateNode(fix.target_node_id, fix.fixed_node as Partial<TreeStructure>)
}
const handleApplyAllFixes = () => {
if (!fixProposals) return
for (const fix of fixProposals) {
handleApplyFix(fix)
}
setFixProposals(null)
setTimeout(() => { validate() }, 100)
}
const handleCloseFixModal = () => {
setFixProposals(null)
validate()
}
const handleNodeSelect = useCallback((nodeId: string | null) => {
if (nodeId) {
setIsMetadataOpen(false) // close metadata when opening node editor
@@ -685,6 +736,8 @@ export function TreeEditorPage() {
<ValidationSummary
errors={validationErrors}
onSelectNode={handleSelectNode}
onFixWithAI={handleFixWithAI}
isFixing={isFixing}
/>
</div>
)}
@@ -705,6 +758,16 @@ export function TreeEditorPage() {
<FlowAnalyticsPanel treeId={id} />
</div>
)}
{/* AI Fix Review Modal */}
{fixProposals && (
<AIFixReviewModal
fixes={fixProposals}
onApply={handleApplyFix}
onApplyAll={handleApplyAllFixes}
onClose={handleCloseFixModal}
/>
)}
</div>
)
}

View File

@@ -0,0 +1,24 @@
export interface AIFixValidationError {
node_id: string
message: string
}
export interface AIFixProposal {
target_node_id: string
error_message: string
description: string
original_node: Record<string, unknown>
fixed_node: Record<string, unknown>
}
export interface AIFixTreeRequest {
tree_structure: Record<string, unknown>
tree_name: string
tree_type: 'troubleshooting' | 'procedural' | 'maintenance'
validation_errors: AIFixValidationError[]
}
export interface AIFixTreeResponse {
fixes: AIFixProposal[]
tokens_used: { input: number; output: number }
}

View File

@@ -45,3 +45,10 @@ export type {
AIAssembleResponse,
AIWizardPhase,
} from './ai'
export type {
AIFixTreeRequest,
AIFixTreeResponse,
AIFixProposal,
AIFixValidationError,
} from './ai-fix'