feat: AI auto-fix + Gemini Flash provider #93
@@ -10,7 +10,6 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
import anthropic
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -52,10 +51,9 @@ def _require_ai_enabled() -> None:
|
||||
if not settings.ai_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.",
|
||||
detail="AI flow builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/quota", response_model=AIQuotaStatusResponse)
|
||||
async def get_quota(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -174,27 +172,6 @@ async def scaffold(
|
||||
branches, input_tokens, output_tokens, cost = await scaffold_branches(
|
||||
conversation.wizard_state,
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=conversation.id,
|
||||
generation_type="scaffold",
|
||||
tier=plan,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
estimated_cost=0,
|
||||
succeeded=False,
|
||||
counts_toward_quota=False,
|
||||
error_code=type(e).__name__,
|
||||
extra_data={"error": str(e)},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="AI provider error. Please try again.",
|
||||
)
|
||||
except ValueError as e:
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
@@ -216,6 +193,28 @@ async def scaffold(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"AI returned invalid output: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("AI scaffold failed: %s: %s", type(e).__name__, e)
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=conversation.id,
|
||||
generation_type="scaffold",
|
||||
tier=plan,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
estimated_cost=0,
|
||||
succeeded=False,
|
||||
counts_toward_quota=False,
|
||||
error_code=type(e).__name__,
|
||||
extra_data={"error": str(e)},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"AI provider error ({type(e).__name__}). Please try again.",
|
||||
)
|
||||
|
||||
# Record successful usage
|
||||
await record_ai_usage(
|
||||
@@ -293,27 +292,6 @@ async def branch_detail(
|
||||
existing_branches,
|
||||
)
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=conversation.id,
|
||||
generation_type="branch_detail",
|
||||
tier=plan,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
estimated_cost=0,
|
||||
succeeded=False,
|
||||
counts_toward_quota=False,
|
||||
error_code=type(e).__name__,
|
||||
extra_data={"error": str(e), "branch_name": data.branch_name},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="AI provider error. Please try again.",
|
||||
)
|
||||
except ValueError as e:
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
@@ -335,6 +313,28 @@ async def branch_detail(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"AI returned invalid output: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("AI branch_detail failed: %s: %s", type(e).__name__, e)
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=conversation.id,
|
||||
generation_type="branch_detail",
|
||||
tier=plan,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
estimated_cost=0,
|
||||
succeeded=False,
|
||||
counts_toward_quota=False,
|
||||
error_code=type(e).__name__,
|
||||
extra_data={"error": str(e), "branch_name": data.branch_name},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"AI provider error ({type(e).__name__}). Please try again.",
|
||||
)
|
||||
|
||||
# Record successful usage
|
||||
await record_ai_usage(
|
||||
|
||||
78
backend/app/api/endpoints/ai_fix.py
Normal file
78
backend/app/api/endpoints/ai_fix.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""AI auto-fix endpoint for tree validation errors.
|
||||
|
||||
POST /ai/fix-tree — accepts a tree with validation errors and returns
|
||||
AI-generated fix proposals for each error.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_db, require_engineer_or_admin
|
||||
from app.core.config import settings
|
||||
from app.core.rate_limit import limiter
|
||||
from app.core.ai_fix_service import generate_fixes
|
||||
from app.models.user import User
|
||||
from app.schemas.ai_fix import (
|
||||
AIFixTreeRequest,
|
||||
AIFixTreeResponse,
|
||||
AIFixProposal,
|
||||
AIFixTokenUsage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ai", tags=["ai-fix"])
|
||||
|
||||
|
||||
def _require_ai_enabled() -> None:
|
||||
"""Raise 503 if AI is not configured."""
|
||||
if not settings.ai_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="AI fix is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/fix-tree", response_model=AIFixTreeResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def fix_tree(
|
||||
request: Request,
|
||||
body: AIFixTreeRequest,
|
||||
user: Annotated[User, Depends(require_engineer_or_admin)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Generate AI-powered fixes for tree validation errors."""
|
||||
_require_ai_enabled()
|
||||
|
||||
validation_errors = [
|
||||
{"node_id": e.node_id, "message": e.message}
|
||||
for e in body.validation_errors
|
||||
]
|
||||
|
||||
try:
|
||||
fixes, input_tokens, output_tokens = await generate_fixes(
|
||||
tree_structure=body.tree_structure,
|
||||
tree_name=body.tree_name,
|
||||
tree_type=body.tree_type,
|
||||
validation_errors=validation_errors,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
logger.error("AI provider not available: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=str(exc),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error in AI fix service")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred while generating fixes.",
|
||||
)
|
||||
|
||||
return AIFixTreeResponse(
|
||||
fixes=[AIFixProposal(**f) for f in fixes],
|
||||
tokens_used=AIFixTokenUsage(input=input_tokens, output=output_tokens),
|
||||
)
|
||||
@@ -6,6 +6,7 @@ from app.api.endpoints import target_lists
|
||||
from app.api.endpoints import maintenance_schedules
|
||||
from app.api.endpoints import feedback
|
||||
from app.api.endpoints import ai_builder
|
||||
from app.api.endpoints import ai_fix
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -36,3 +37,4 @@ api_router.include_router(target_lists.router)
|
||||
api_router.include_router(maintenance_schedules.router)
|
||||
api_router.include_router(feedback.router)
|
||||
api_router.include_router(ai_builder.router)
|
||||
api_router.include_router(ai_fix.router)
|
||||
|
||||
273
backend/app/core/ai_fix_service.py
Normal file
273
backend/app/core/ai_fix_service.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""AI-powered fix service for tree validation errors.
|
||||
|
||||
Given a tree structure and validation errors, generates AI-powered
|
||||
proposals to fix each structural issue while preserving existing content.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.ai_provider import get_ai_provider
|
||||
from app.core.ai_tree_validator import validate_generated_tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers.
|
||||
|
||||
You will receive:
|
||||
1. A full flow outline showing the tree structure
|
||||
2. The specific failing node with its full JSON
|
||||
3. The validation error message
|
||||
|
||||
Your task: Return a FIXED version of the failing node as valid JSON. Rules:
|
||||
- Fix ONLY the structural issue described in the error message
|
||||
- Keep ALL existing content (titles, descriptions, questions, options) unchanged
|
||||
- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic
|
||||
- Every new node must have a unique ID (use descriptive kebab-case IDs)
|
||||
- Decision nodes must have at least 2 options and at least 2 children
|
||||
- Action nodes must have a next_node_id pointing to a sibling node in the parent's children
|
||||
- Solution nodes are leaf nodes (no children)
|
||||
- Return ONLY the fixed node JSON, no explanation"""
|
||||
|
||||
|
||||
# ── Pure helper functions ──
|
||||
|
||||
|
||||
def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None:
|
||||
"""Recursively find a node by its ID in the tree structure."""
|
||||
if not isinstance(tree, dict):
|
||||
return None
|
||||
if tree.get("id") == node_id:
|
||||
return tree
|
||||
for child in tree.get("children", []):
|
||||
result = _find_node_by_id(child, node_id)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None:
|
||||
"""Find the parent of a node with the given ID."""
|
||||
if not isinstance(tree, dict):
|
||||
return None
|
||||
for child in tree.get("children", []):
|
||||
if isinstance(child, dict) and child.get("id") == target_id:
|
||||
return tree
|
||||
result = _find_parent_node(child, target_id)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _serialize_tree_outline(
|
||||
tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None
|
||||
) -> str:
|
||||
"""Serialize tree as a readable outline for AI prompt context.
|
||||
|
||||
Format: indented "- [type] label" with "<<< ERROR HERE" marker.
|
||||
"""
|
||||
if not isinstance(tree, dict):
|
||||
return ""
|
||||
|
||||
node_type = tree.get("type", "unknown")
|
||||
label = tree.get("question") or tree.get("title") or tree.get("id", "?")
|
||||
prefix = " " * indent
|
||||
marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else ""
|
||||
line = f"{prefix}- [{node_type}] {label}{marker}"
|
||||
|
||||
lines = [line]
|
||||
for child in tree.get("children", []):
|
||||
lines.append(_serialize_tree_outline(child, indent + 1, error_node_id))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _strip_markdown_fences(text: str) -> str:
|
||||
"""Strip ```json...``` fences from AI response."""
|
||||
return re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.MULTILINE).rstrip(
|
||||
"`"
|
||||
).strip()
|
||||
|
||||
|
||||
def _replace_node_in_tree(
|
||||
tree: dict[str, Any], target_id: str, replacement: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Replace a node in-place by ID. Returns True if found and replaced."""
|
||||
if not isinstance(tree, dict):
|
||||
return False
|
||||
if tree.get("id") == target_id:
|
||||
tree.clear()
|
||||
tree.update(replacement)
|
||||
return True
|
||||
for child in tree.get("children", []):
|
||||
if _replace_node_in_tree(child, target_id, replacement):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str:
|
||||
"""Describe what changed between original and fixed node."""
|
||||
changes: list[str] = []
|
||||
|
||||
orig_children = len(original.get("children", []))
|
||||
fixed_children = len(fixed.get("children", []))
|
||||
if fixed_children > orig_children:
|
||||
changes.append(f"added {fixed_children - orig_children} child node(s)")
|
||||
|
||||
orig_options = len(original.get("options", []))
|
||||
fixed_options = len(fixed.get("options", []))
|
||||
if fixed_options > orig_options:
|
||||
changes.append(f"added {fixed_options - orig_options} option(s)")
|
||||
|
||||
if fixed.get("next_node_id") and not original.get("next_node_id"):
|
||||
changes.append("added next_node_id")
|
||||
|
||||
if not changes:
|
||||
changes.append("fixed structural issue")
|
||||
|
||||
return "; ".join(changes).capitalize()
|
||||
|
||||
|
||||
# ── Prompt building ──
|
||||
|
||||
|
||||
def _build_fix_prompt(
|
||||
tree: dict[str, Any],
|
||||
node_id: str,
|
||||
error_message: str,
|
||||
tree_name: str,
|
||||
tree_type: str,
|
||||
) -> str:
|
||||
"""Build the user message for the AI fix request."""
|
||||
outline = _serialize_tree_outline(tree, error_node_id=node_id)
|
||||
node = _find_node_by_id(tree, node_id)
|
||||
node_json = json.dumps(node, indent=2) if node else "{}"
|
||||
|
||||
return (
|
||||
f"Flow name: {tree_name}\n"
|
||||
f"Flow type: {tree_type}\n\n"
|
||||
f"## Full flow outline\n```\n{outline}\n```\n\n"
|
||||
f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n"
|
||||
f"## Validation error\n{error_message}\n\n"
|
||||
f"Return the fixed version of this node as JSON."
|
||||
)
|
||||
|
||||
|
||||
# ── Main entry point ──
|
||||
|
||||
|
||||
async def generate_fixes(
|
||||
tree_structure: dict[str, Any],
|
||||
tree_name: str,
|
||||
tree_type: str,
|
||||
validation_errors: list[dict[str, str]],
|
||||
) -> tuple[list[dict[str, Any]], int, int]:
|
||||
"""Generate AI-powered fixes for tree validation errors.
|
||||
|
||||
Args:
|
||||
tree_structure: Full tree structure dict.
|
||||
tree_name: Name of the flow.
|
||||
tree_type: Type of flow (troubleshooting, procedural, maintenance).
|
||||
validation_errors: List of dicts with "node_id" and "message" keys.
|
||||
|
||||
Returns:
|
||||
Tuple of (fixes_list, total_input_tokens, total_output_tokens).
|
||||
Each fix dict has: target_node_id, error_message, description,
|
||||
original_node, fixed_node.
|
||||
"""
|
||||
provider = get_ai_provider()
|
||||
fixes: list[dict[str, Any]] = []
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
for error in validation_errors:
|
||||
node_id = error["node_id"]
|
||||
error_message = error["message"]
|
||||
|
||||
original_node = _find_node_by_id(tree_structure, node_id)
|
||||
if original_node is None:
|
||||
logger.warning("Node %s not found in tree, skipping fix", node_id)
|
||||
continue
|
||||
|
||||
original_snapshot = copy.deepcopy(original_node)
|
||||
|
||||
# Build prompt and call AI
|
||||
user_message = _build_fix_prompt(
|
||||
tree_structure, node_id, error_message, tree_name, tree_type
|
||||
)
|
||||
messages = [{"role": "user", "content": user_message}]
|
||||
|
||||
try:
|
||||
text, in_tok, out_tok = await provider.generate_json(
|
||||
system_prompt=FIX_SYSTEM_PROMPT,
|
||||
messages=messages,
|
||||
max_tokens=2048,
|
||||
)
|
||||
total_input_tokens += in_tok
|
||||
total_output_tokens += out_tok
|
||||
|
||||
cleaned = _strip_markdown_fences(text)
|
||||
fixed_node = json.loads(cleaned)
|
||||
except (json.JSONDecodeError, Exception) as exc:
|
||||
logger.warning("AI fix failed for node %s: %s", node_id, exc)
|
||||
continue
|
||||
|
||||
# Validate by substituting into a tree copy
|
||||
tree_copy = copy.deepcopy(tree_structure)
|
||||
_replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node))
|
||||
remaining_errors = validate_generated_tree(tree_copy)
|
||||
|
||||
# Check if the specific error is still present
|
||||
still_has_error = any(node_id in e for e in remaining_errors)
|
||||
|
||||
if still_has_error:
|
||||
# Retry once with corrective prompt
|
||||
retry_message = (
|
||||
f"Your previous fix still has validation errors:\n"
|
||||
f"{chr(10).join(remaining_errors)}\n\n"
|
||||
f"Please fix the node again. Return ONLY the corrected JSON."
|
||||
)
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
messages.append({"role": "user", "content": retry_message})
|
||||
|
||||
try:
|
||||
text2, in_tok2, out_tok2 = await provider.generate_json(
|
||||
system_prompt=FIX_SYSTEM_PROMPT,
|
||||
messages=messages,
|
||||
max_tokens=2048,
|
||||
)
|
||||
total_input_tokens += in_tok2
|
||||
total_output_tokens += out_tok2
|
||||
|
||||
cleaned2 = _strip_markdown_fences(text2)
|
||||
fixed_node = json.loads(cleaned2)
|
||||
|
||||
# Re-validate
|
||||
tree_copy2 = copy.deepcopy(tree_structure)
|
||||
_replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node))
|
||||
remaining2 = validate_generated_tree(tree_copy2)
|
||||
still_has_error = any(node_id in e for e in remaining2)
|
||||
except (json.JSONDecodeError, Exception) as exc:
|
||||
logger.warning("AI retry fix failed for node %s: %s", node_id, exc)
|
||||
continue
|
||||
|
||||
if still_has_error:
|
||||
logger.warning("AI could not fix node %s after retry", node_id)
|
||||
continue
|
||||
|
||||
description = _describe_fix(original_snapshot, fixed_node)
|
||||
fixes.append(
|
||||
{
|
||||
"target_node_id": node_id,
|
||||
"error_message": error_message,
|
||||
"description": description,
|
||||
"original_node": original_snapshot,
|
||||
"fixed_node": fixed_node,
|
||||
}
|
||||
)
|
||||
|
||||
return fixes, total_input_tokens, total_output_tokens
|
||||
175
backend/app/core/ai_provider.py
Normal file
175
backend/app/core/ai_provider.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
AI Provider abstraction layer.
|
||||
|
||||
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
|
||||
backends for JSON generation used by the AI Flow Builder.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_json(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
"""Generate a JSON response from the AI model.
|
||||
|
||||
Args:
|
||||
system_prompt: System-level instruction for the model.
|
||||
messages: List of message dicts with "role" and "content" keys.
|
||||
max_tokens: Maximum output tokens.
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, input_tokens, output_tokens).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class GeminiProvider(AIProvider):
|
||||
"""Google Gemini provider using the google-genai SDK."""
|
||||
|
||||
def __init__(self, api_key: str, model: str) -> None:
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
|
||||
async def generate_json(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
from google import genai
|
||||
from google.genai import types as genai_types
|
||||
|
||||
client = genai.Client(api_key=self._api_key)
|
||||
|
||||
# Convert messages to Gemini Content format
|
||||
contents: list[genai_types.Content] = []
|
||||
for msg in messages:
|
||||
role = "model" if msg["role"] == "assistant" else "user"
|
||||
contents.append(
|
||||
genai_types.Content(
|
||||
role=role,
|
||||
parts=[genai_types.Part(text=msg["content"])],
|
||||
)
|
||||
)
|
||||
|
||||
config = genai_types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
max_output_tokens=max_tokens,
|
||||
response_mime_type="application/json",
|
||||
)
|
||||
|
||||
response = await client.aio.models.generate_content(
|
||||
model=self._model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Log finish reason to detect truncation
|
||||
if response.candidates:
|
||||
finish_reason = getattr(response.candidates[0], "finish_reason", None)
|
||||
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
|
||||
if str(finish_reason) == "MAX_TOKENS":
|
||||
logger.warning(
|
||||
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
text = response.text or ""
|
||||
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
||||
output_tokens = (
|
||||
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
||||
)
|
||||
|
||||
return text, input_tokens, output_tokens
|
||||
|
||||
|
||||
class AnthropicProvider(AIProvider):
|
||||
"""Anthropic Claude provider using the anthropic SDK."""
|
||||
|
||||
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
self._timeout = timeout
|
||||
|
||||
async def generate_json(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
import anthropic
|
||||
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=self._api_key,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
|
||||
response = await client.messages.create(
|
||||
model=self._model,
|
||||
max_tokens=max_tokens,
|
||||
system=system_prompt,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
text = response.content[0].text
|
||||
input_tokens = response.usage.input_tokens
|
||||
output_tokens = response.usage.output_tokens
|
||||
|
||||
return text, input_tokens, output_tokens
|
||||
|
||||
|
||||
def get_ai_provider() -> AIProvider:
|
||||
"""Factory that returns the configured AI provider.
|
||||
|
||||
Selection logic:
|
||||
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
|
||||
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
|
||||
3. Fallback: if preferred provider key missing, try the other one
|
||||
4. If nothing configured -> raise RuntimeError
|
||||
"""
|
||||
provider = settings.AI_PROVIDER
|
||||
|
||||
if provider == "gemini":
|
||||
if settings.GOOGLE_AI_API_KEY:
|
||||
return GeminiProvider(
|
||||
api_key=settings.GOOGLE_AI_API_KEY,
|
||||
model=settings.AI_MODEL_GEMINI,
|
||||
)
|
||||
# Fallback to Anthropic
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
return AnthropicProvider(
|
||||
api_key=settings.ANTHROPIC_API_KEY,
|
||||
model=settings.AI_MODEL_ANTHROPIC,
|
||||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
elif provider == "anthropic":
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
return AnthropicProvider(
|
||||
api_key=settings.ANTHROPIC_API_KEY,
|
||||
model=settings.AI_MODEL_ANTHROPIC,
|
||||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
# Fallback to Gemini
|
||||
if settings.GOOGLE_AI_API_KEY:
|
||||
return GeminiProvider(
|
||||
api_key=settings.GOOGLE_AI_API_KEY,
|
||||
model=settings.AI_MODEL_GEMINI,
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
|
||||
)
|
||||
@@ -1,11 +1,11 @@
|
||||
"""AI-powered tree generation service using Anthropic Claude API.
|
||||
"""AI-powered tree generation service.
|
||||
|
||||
Implements the 4-stage wizard flow:
|
||||
Stage 2 (scaffold): AI suggests 4-7 top-level branches
|
||||
Stage 3 (branch_detail): AI generates detailed nodes per branch
|
||||
Stage 4 (assemble): Pure assembly logic — zero AI calls
|
||||
|
||||
System prompts are static constants to enable Anthropic prompt caching.
|
||||
Uses the provider abstraction from ai_provider.py (supports Gemini + Anthropic).
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
@@ -13,8 +13,7 @@ import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import anthropic
|
||||
|
||||
from app.core.ai_provider import get_ai_provider
|
||||
from app.core.config import settings
|
||||
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
|
||||
|
||||
@@ -121,15 +120,6 @@ def _strip_markdown_fences(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def _get_client() -> anthropic.AsyncAnthropic:
|
||||
"""Get configured async Anthropic client."""
|
||||
if not settings.ANTHROPIC_API_KEY:
|
||||
raise RuntimeError("ANTHROPIC_API_KEY not configured")
|
||||
return anthropic.AsyncAnthropic(
|
||||
api_key=settings.ANTHROPIC_API_KEY,
|
||||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
|
||||
"""Estimate USD cost from token counts."""
|
||||
@@ -146,7 +136,7 @@ async def scaffold_branches(
|
||||
Returns (branches, input_tokens, output_tokens, estimated_cost).
|
||||
Raises ValueError on invalid response.
|
||||
"""
|
||||
client = _get_client()
|
||||
provider = get_ai_provider()
|
||||
|
||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||
name = wizard_state.get("name", "")
|
||||
@@ -161,21 +151,26 @@ async def scaffold_branches(
|
||||
if tags:
|
||||
user_message += f"Environment: {', '.join(tags)}\n"
|
||||
|
||||
response = await client.messages.create(
|
||||
model=settings.AI_MODEL,
|
||||
max_tokens=1024,
|
||||
system=SCAFFOLD_SYSTEM_PROMPT,
|
||||
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt=SCAFFOLD_SYSTEM_PROMPT,
|
||||
messages=[{"role": "user", "content": user_message}],
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
raw_text = _strip_markdown_fences(response.content[0].text)
|
||||
input_tokens = response.usage.input_tokens
|
||||
output_tokens = response.usage.output_tokens
|
||||
logger.info(
|
||||
"scaffold raw response (tokens in=%d out=%d, len=%d): %s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
len(raw_text),
|
||||
raw_text[:500],
|
||||
)
|
||||
raw_text = _strip_markdown_fences(raw_text)
|
||||
cost = _estimate_cost(input_tokens, output_tokens)
|
||||
|
||||
try:
|
||||
data = json.loads(raw_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("scaffold JSON parse failed. Full text (%d chars): %s", len(raw_text), raw_text)
|
||||
raise ValueError(f"AI returned invalid JSON: {e}")
|
||||
|
||||
branches = data.get("branches", [])
|
||||
@@ -196,7 +191,7 @@ async def generate_branch_detail(
|
||||
On validation failure, retries once with corrective prompt.
|
||||
Raises ValueError if both attempts fail.
|
||||
"""
|
||||
client = _get_client()
|
||||
provider = get_ai_provider()
|
||||
|
||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||
name = wizard_state.get("name", "")
|
||||
@@ -217,35 +212,30 @@ async def generate_branch_detail(
|
||||
total_output = 0
|
||||
|
||||
for attempt in range(3):
|
||||
response = await client.messages.create(
|
||||
model=settings.AI_MODEL,
|
||||
max_tokens=8192,
|
||||
system=BRANCH_DETAIL_SYSTEM_PROMPT,
|
||||
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT,
|
||||
messages=messages,
|
||||
max_tokens=8192,
|
||||
)
|
||||
|
||||
total_input += response.usage.input_tokens
|
||||
total_output += response.usage.output_tokens
|
||||
total_input += input_tokens
|
||||
total_output += output_tokens
|
||||
logger.debug(
|
||||
"branch_detail attempt=%d stop_reason=%s content_blocks=%d output_tokens=%d",
|
||||
"branch_detail attempt=%d output_tokens=%d",
|
||||
attempt,
|
||||
response.stop_reason,
|
||||
len(response.content),
|
||||
response.usage.output_tokens,
|
||||
output_tokens,
|
||||
)
|
||||
if response.stop_reason == "max_tokens":
|
||||
logger.warning(
|
||||
"branch_detail attempt=%d hit max_tokens limit (%d output tokens) — response may be truncated",
|
||||
attempt,
|
||||
response.usage.output_tokens,
|
||||
)
|
||||
raw_text = _strip_markdown_fences(response.content[0].text) if response.content else ""
|
||||
raw_text = _strip_markdown_fences(raw_text) if raw_text else ""
|
||||
if not raw_text:
|
||||
logger.warning("branch_detail attempt=%d returned empty text, stop_reason=%s", attempt, response.stop_reason)
|
||||
logger.warning("branch_detail attempt=%d returned empty text", attempt)
|
||||
|
||||
try:
|
||||
branch_tree = json.loads(raw_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"branch_detail attempt=%d JSON parse failed (%d chars): %s",
|
||||
attempt, len(raw_text), raw_text[:500],
|
||||
)
|
||||
if attempt < 2:
|
||||
messages.append({"role": "assistant", "content": raw_text})
|
||||
messages.append({
|
||||
|
||||
@@ -78,11 +78,16 @@ class Settings(BaseSettings):
|
||||
AI_CONVERSATION_TTL_HOURS: int = 24
|
||||
AI_MAX_CALLS_PER_FLOW: int = 10
|
||||
AI_REQUEST_TIMEOUT_SECONDS: int = 45
|
||||
# AI Provider selection
|
||||
AI_PROVIDER: str = "gemini" # "gemini" or "anthropic"
|
||||
GOOGLE_AI_API_KEY: Optional[str] = None
|
||||
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
|
||||
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
|
||||
|
||||
@property
|
||||
def ai_enabled(self) -> bool:
|
||||
"""Check if AI Flow Builder is configured."""
|
||||
return self.ANTHROPIC_API_KEY is not None
|
||||
"""Check if any AI provider is configured."""
|
||||
return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None
|
||||
|
||||
# Deployment – auto-seed test data on PR environments
|
||||
SEED_ON_DEPLOY: bool = False
|
||||
|
||||
52
backend/app/schemas/ai_fix.py
Normal file
52
backend/app/schemas/ai_fix.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Pydantic schemas for the AI auto-fix feature."""
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ── Requests ──
|
||||
|
||||
|
||||
class ValidationErrorInput(BaseModel):
|
||||
"""A single validation error to fix."""
|
||||
|
||||
node_id: str = Field(..., description="ID of the node with the error")
|
||||
message: str = Field(..., description="The validation error message")
|
||||
|
||||
|
||||
class AIFixTreeRequest(BaseModel):
|
||||
"""Request to generate AI fixes for validation errors."""
|
||||
|
||||
tree_structure: dict[str, Any] = Field(..., description="Full tree structure")
|
||||
tree_name: str = Field("", max_length=255, description="Name of the flow")
|
||||
tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field(
|
||||
"troubleshooting", description="Type of flow"
|
||||
)
|
||||
validation_errors: list[ValidationErrorInput] = Field(
|
||||
..., min_length=1, max_length=10, description="Errors to fix"
|
||||
)
|
||||
|
||||
|
||||
# ── Responses ──
|
||||
|
||||
|
||||
class AIFixProposal(BaseModel):
|
||||
"""A single proposed fix from the AI."""
|
||||
|
||||
target_node_id: str
|
||||
error_message: str
|
||||
description: str
|
||||
original_node: dict[str, Any]
|
||||
fixed_node: dict[str, Any]
|
||||
|
||||
|
||||
class AIFixTokenUsage(BaseModel):
|
||||
input: int = 0
|
||||
output: int = 0
|
||||
|
||||
|
||||
class AIFixTreeResponse(BaseModel):
|
||||
"""Response with proposed fixes."""
|
||||
|
||||
fixes: list[AIFixProposal]
|
||||
tokens_used: AIFixTokenUsage
|
||||
@@ -33,6 +33,7 @@ httpx>=0.27.0
|
||||
|
||||
# AI Flow Builder
|
||||
anthropic>=0.40.0
|
||||
google-genai>=1.0.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv==1.0.1
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Integration tests for AI Flow Builder endpoints.
|
||||
|
||||
All Anthropic API calls are mocked — zero real API spend.
|
||||
All AI provider calls are mocked — zero real API spend.
|
||||
"""
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
@@ -64,12 +64,11 @@ BRANCH_DETAIL_JSON = json.dumps({
|
||||
})
|
||||
|
||||
|
||||
def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200):
|
||||
"""Create a mock Anthropic API response."""
|
||||
response = MagicMock()
|
||||
response.content = [MagicMock(text=text)]
|
||||
response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens)
|
||||
return response
|
||||
def _mock_ai_provider(text: str, input_tokens: int = 100, output_tokens: int = 200):
|
||||
"""Create a mock AI provider whose generate_json returns the given text and token counts."""
|
||||
provider = MagicMock()
|
||||
provider.generate_json = AsyncMock(return_value=(text, input_tokens, output_tokens))
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -194,11 +193,9 @@ async def test_scaffold_success(client, auth_headers, enable_ai):
|
||||
)
|
||||
conversation_id = start_resp.json()["conversation_id"]
|
||||
|
||||
# Mock Anthropic
|
||||
mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Mock AI provider
|
||||
mock_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
||||
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=mock_provider):
|
||||
response = await client.post(
|
||||
"/api/v1/ai/scaffold",
|
||||
json={"conversation_id": conversation_id},
|
||||
@@ -241,9 +238,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
|
||||
)
|
||||
conversation_id = start_resp.json()["conversation_id"]
|
||||
|
||||
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
|
||||
scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
||||
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
|
||||
await client.post(
|
||||
"/api/v1/ai/scaffold",
|
||||
json={"conversation_id": conversation_id},
|
||||
@@ -251,10 +247,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
|
||||
)
|
||||
|
||||
# Now generate branch detail
|
||||
detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON)
|
||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock)
|
||||
|
||||
detail_provider = _mock_ai_provider(BRANCH_DETAIL_JSON)
|
||||
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=detail_provider):
|
||||
response = await client.post(
|
||||
"/api/v1/ai/branch-detail",
|
||||
json={
|
||||
@@ -290,9 +284,8 @@ async def test_assemble_success(client, auth_headers, enable_ai):
|
||||
conversation_id = start_resp.json()["conversation_id"]
|
||||
|
||||
# Scaffold
|
||||
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
|
||||
scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
||||
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
|
||||
await client.post(
|
||||
"/api/v1/ai/scaffold",
|
||||
json={"conversation_id": conversation_id},
|
||||
|
||||
169
backend/tests/test_ai_fix_endpoint.py
Normal file
169
backend/tests/test_ai_fix_endpoint.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Integration tests for the POST /ai/fix-tree endpoint."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
# ── Sample tree (has a decision node with only 1 option + 1 child) ──
|
||||
|
||||
SAMPLE_TREE = {
|
||||
"id": "root",
|
||||
"type": "decision",
|
||||
"question": "Is the server up?",
|
||||
"options": [
|
||||
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
|
||||
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
|
||||
],
|
||||
"children": [
|
||||
{
|
||||
"id": "check-logs",
|
||||
"type": "action",
|
||||
"title": "Check Logs",
|
||||
"description": "Review logs.",
|
||||
"next_node_id": "logs-ok",
|
||||
},
|
||||
{
|
||||
"id": "logs-ok",
|
||||
"type": "solution",
|
||||
"title": "Logs OK",
|
||||
"description": "Issue in logs.",
|
||||
},
|
||||
{
|
||||
"id": "restart",
|
||||
"type": "decision",
|
||||
"question": "Did restart work?",
|
||||
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
|
||||
"children": [
|
||||
{
|
||||
"id": "done",
|
||||
"type": "solution",
|
||||
"title": "Done",
|
||||
"description": "Fixed.",
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Fixed version of the "restart" node — 2 options, 2 children
|
||||
FIXED_RESTART_NODE = {
|
||||
"id": "restart",
|
||||
"type": "decision",
|
||||
"question": "Did restart work?",
|
||||
"options": [
|
||||
{"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"},
|
||||
{"id": "opt-r-no", "label": "No", "next_node_id": "escalate"},
|
||||
],
|
||||
"children": [
|
||||
{
|
||||
"id": "done",
|
||||
"type": "solution",
|
||||
"title": "Done",
|
||||
"description": "Fixed.",
|
||||
},
|
||||
{
|
||||
"id": "escalate",
|
||||
"type": "solution",
|
||||
"title": "Escalate",
|
||||
"description": "Escalate to senior engineer.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
FIX_REQUEST_BODY = {
|
||||
"tree_structure": SAMPLE_TREE,
|
||||
"tree_name": "Server Troubleshooting",
|
||||
"tree_type": "troubleshooting",
|
||||
"validation_errors": [
|
||||
{
|
||||
"node_id": "restart",
|
||||
"message": "Decision node 'restart' must have at least 2 options",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _mock_ai_provider(response_text: str, input_tokens: int = 50, output_tokens: int = 100):
|
||||
"""Create a mock provider whose generate_json returns given text."""
|
||||
provider = MagicMock()
|
||||
provider.generate_json = AsyncMock(return_value=(response_text, input_tokens, output_tokens))
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_ai():
|
||||
"""Temporarily enable AI by setting a fake API key."""
|
||||
original = settings.GOOGLE_AI_API_KEY
|
||||
settings.GOOGLE_AI_API_KEY = "test-key-fake"
|
||||
yield
|
||||
settings.GOOGLE_AI_API_KEY = original
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_ai():
|
||||
"""Ensure AI is disabled."""
|
||||
orig_google = settings.GOOGLE_AI_API_KEY
|
||||
orig_anthropic = settings.ANTHROPIC_API_KEY
|
||||
settings.GOOGLE_AI_API_KEY = None
|
||||
settings.ANTHROPIC_API_KEY = None
|
||||
yield
|
||||
settings.GOOGLE_AI_API_KEY = orig_google
|
||||
settings.ANTHROPIC_API_KEY = orig_anthropic
|
||||
|
||||
|
||||
# ── Tests ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_401_without_auth(client):
|
||||
"""POST /ai/fix-tree without auth token returns 401."""
|
||||
response = await client.post("/api/v1/ai/fix-tree", json=FIX_REQUEST_BODY)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_503_when_ai_disabled(client, auth_headers, disable_ai):
|
||||
"""POST /ai/fix-tree returns 503 when no AI keys are configured."""
|
||||
response = await client.post(
|
||||
"/api/v1/ai/fix-tree",
|
||||
json=FIX_REQUEST_BODY,
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_fixes_on_success(client, auth_headers, enable_ai):
|
||||
"""POST /ai/fix-tree returns fix proposals when AI succeeds."""
|
||||
mock_provider = _mock_ai_provider(json.dumps(FIXED_RESTART_NODE))
|
||||
|
||||
with patch(
|
||||
"app.core.ai_fix_service.get_ai_provider",
|
||||
return_value=mock_provider,
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/ai/fix-tree",
|
||||
json=FIX_REQUEST_BODY,
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "fixes" in data
|
||||
assert "tokens_used" in data
|
||||
assert len(data["fixes"]) == 1
|
||||
|
||||
fix = data["fixes"][0]
|
||||
assert fix["target_node_id"] == "restart"
|
||||
assert fix["error_message"] == "Decision node 'restart' must have at least 2 options"
|
||||
assert fix["original_node"]["id"] == "restart"
|
||||
assert fix["fixed_node"]["id"] == "restart"
|
||||
assert len(fix["fixed_node"]["options"]) == 2
|
||||
assert len(fix["fixed_node"]["children"]) == 2
|
||||
|
||||
assert data["tokens_used"]["input"] == 50
|
||||
assert data["tokens_used"]["output"] == 100
|
||||
224
backend/tests/test_ai_fix_service.py
Normal file
224
backend/tests/test_ai_fix_service.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Unit tests for AI fix service helper functions.
|
||||
|
||||
Tests pure Python helpers only — no AI mocking needed.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.ai_fix_service import (
|
||||
_find_node_by_id,
|
||||
_find_parent_node,
|
||||
_serialize_tree_outline,
|
||||
_strip_markdown_fences,
|
||||
_replace_node_in_tree,
|
||||
_describe_fix,
|
||||
)
|
||||
|
||||
|
||||
# ── Sample tree ──
|
||||
|
||||
SAMPLE_TREE = {
|
||||
"id": "root",
|
||||
"type": "decision",
|
||||
"question": "Is the server up?",
|
||||
"options": [
|
||||
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
|
||||
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
|
||||
],
|
||||
"children": [
|
||||
{
|
||||
"id": "check-logs",
|
||||
"type": "action",
|
||||
"title": "Check Logs",
|
||||
"description": "Review logs.",
|
||||
"next_node_id": "logs-ok",
|
||||
},
|
||||
{
|
||||
"id": "logs-ok",
|
||||
"type": "solution",
|
||||
"title": "Logs OK",
|
||||
"description": "Issue in logs.",
|
||||
},
|
||||
{
|
||||
"id": "restart",
|
||||
"type": "decision",
|
||||
"question": "Did restart work?",
|
||||
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
|
||||
"children": [
|
||||
{
|
||||
"id": "done",
|
||||
"type": "solution",
|
||||
"title": "Done",
|
||||
"description": "Fixed.",
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ── _find_node_by_id ──
|
||||
|
||||
|
||||
class TestFindNodeById:
|
||||
def test_finds_root(self):
|
||||
node = _find_node_by_id(SAMPLE_TREE, "root")
|
||||
assert node is not None
|
||||
assert node["id"] == "root"
|
||||
assert node["type"] == "decision"
|
||||
|
||||
def test_finds_nested_child(self):
|
||||
node = _find_node_by_id(SAMPLE_TREE, "done")
|
||||
assert node is not None
|
||||
assert node["id"] == "done"
|
||||
assert node["type"] == "solution"
|
||||
|
||||
def test_finds_direct_child(self):
|
||||
node = _find_node_by_id(SAMPLE_TREE, "check-logs")
|
||||
assert node is not None
|
||||
assert node["title"] == "Check Logs"
|
||||
|
||||
def test_returns_none_for_missing(self):
|
||||
node = _find_node_by_id(SAMPLE_TREE, "nonexistent")
|
||||
assert node is None
|
||||
|
||||
def test_returns_none_for_non_dict(self):
|
||||
assert _find_node_by_id("not a dict", "root") is None
|
||||
|
||||
|
||||
# ── _find_parent_node ──
|
||||
|
||||
|
||||
class TestFindParentNode:
|
||||
def test_root_has_no_parent(self):
|
||||
parent = _find_parent_node(SAMPLE_TREE, "root")
|
||||
assert parent is None
|
||||
|
||||
def test_finds_parent_of_direct_child(self):
|
||||
parent = _find_parent_node(SAMPLE_TREE, "check-logs")
|
||||
assert parent is not None
|
||||
assert parent["id"] == "root"
|
||||
|
||||
def test_finds_parent_of_deeply_nested(self):
|
||||
parent = _find_parent_node(SAMPLE_TREE, "done")
|
||||
assert parent is not None
|
||||
assert parent["id"] == "restart"
|
||||
|
||||
def test_returns_none_for_missing(self):
|
||||
parent = _find_parent_node(SAMPLE_TREE, "nonexistent")
|
||||
assert parent is None
|
||||
|
||||
def test_returns_none_for_non_dict(self):
|
||||
assert _find_parent_node("not a dict", "root") is None
|
||||
|
||||
|
||||
# ── _serialize_tree_outline ──
|
||||
|
||||
|
||||
class TestSerializeTreeOutline:
|
||||
def test_produces_readable_outline(self):
|
||||
outline = _serialize_tree_outline(SAMPLE_TREE)
|
||||
assert "- [decision] Is the server up?" in outline
|
||||
assert " - [action] Check Logs" in outline
|
||||
assert " - [solution] Logs OK" in outline
|
||||
assert " - [solution] Done" in outline
|
||||
|
||||
def test_marks_error_node(self):
|
||||
outline = _serialize_tree_outline(SAMPLE_TREE, error_node_id="restart")
|
||||
assert "<<< ERROR HERE" in outline
|
||||
# Only the restart node should be marked
|
||||
lines = outline.split("\n")
|
||||
error_lines = [l for l in lines if "ERROR HERE" in l]
|
||||
assert len(error_lines) == 1
|
||||
assert "Did restart work?" in error_lines[0]
|
||||
|
||||
def test_no_error_marker_when_none(self):
|
||||
outline = _serialize_tree_outline(SAMPLE_TREE)
|
||||
assert "ERROR HERE" not in outline
|
||||
|
||||
def test_handles_non_dict(self):
|
||||
assert _serialize_tree_outline("not a dict") == ""
|
||||
|
||||
def test_indentation_increases_with_depth(self):
|
||||
outline = _serialize_tree_outline(SAMPLE_TREE)
|
||||
lines = outline.split("\n")
|
||||
# Root has no indentation
|
||||
assert lines[0].startswith("- [decision]")
|
||||
# Children have 2-space indent
|
||||
child_lines = [l for l in lines if "Check Logs" in l]
|
||||
assert child_lines[0].startswith(" - ")
|
||||
|
||||
|
||||
# ── _strip_markdown_fences ──
|
||||
|
||||
|
||||
class TestStripMarkdownFences:
|
||||
def test_strips_json_fences(self):
|
||||
text = '```json\n{"key": "value"}\n```'
|
||||
assert _strip_markdown_fences(text) == '{"key": "value"}'
|
||||
|
||||
def test_strips_plain_fences(self):
|
||||
text = '```\n{"key": "value"}\n```'
|
||||
assert _strip_markdown_fences(text) == '{"key": "value"}'
|
||||
|
||||
def test_passes_through_plain_json(self):
|
||||
text = '{"key": "value"}'
|
||||
assert _strip_markdown_fences(text) == '{"key": "value"}'
|
||||
|
||||
|
||||
# ── _replace_node_in_tree ──
|
||||
|
||||
|
||||
class TestReplaceNodeInTree:
|
||||
def test_replaces_root(self):
|
||||
import copy
|
||||
|
||||
tree = copy.deepcopy(SAMPLE_TREE)
|
||||
replacement = {"id": "root", "type": "decision", "question": "New question"}
|
||||
assert _replace_node_in_tree(tree, "root", replacement) is True
|
||||
assert tree["question"] == "New question"
|
||||
assert "children" not in tree # cleared and replaced
|
||||
|
||||
def test_replaces_nested_node(self):
|
||||
import copy
|
||||
|
||||
tree = copy.deepcopy(SAMPLE_TREE)
|
||||
replacement = {"id": "done", "type": "solution", "title": "All Done", "description": "Complete."}
|
||||
assert _replace_node_in_tree(tree, "done", replacement) is True
|
||||
found = _find_node_by_id(tree, "done")
|
||||
assert found["title"] == "All Done"
|
||||
|
||||
def test_returns_false_for_missing(self):
|
||||
import copy
|
||||
|
||||
tree = copy.deepcopy(SAMPLE_TREE)
|
||||
assert _replace_node_in_tree(tree, "nonexistent", {"id": "x"}) is False
|
||||
|
||||
|
||||
# ── _describe_fix ──
|
||||
|
||||
|
||||
class TestDescribeFix:
|
||||
def test_describes_added_children(self):
|
||||
original = {"id": "n1", "children": [{"id": "c1"}]}
|
||||
fixed = {"id": "n1", "children": [{"id": "c1"}, {"id": "c2"}]}
|
||||
desc = _describe_fix(original, fixed)
|
||||
assert "1 child node" in desc
|
||||
|
||||
def test_describes_added_options(self):
|
||||
original = {"id": "n1", "options": [{"id": "o1"}]}
|
||||
fixed = {"id": "n1", "options": [{"id": "o1"}, {"id": "o2"}]}
|
||||
desc = _describe_fix(original, fixed)
|
||||
assert "1 option" in desc
|
||||
|
||||
def test_describes_added_next_node_id(self):
|
||||
original = {"id": "n1", "type": "action"}
|
||||
fixed = {"id": "n1", "type": "action", "next_node_id": "n2"}
|
||||
desc = _describe_fix(original, fixed)
|
||||
assert "next_node_id" in desc
|
||||
|
||||
def test_fallback_description(self):
|
||||
original = {"id": "n1", "type": "solution"}
|
||||
fixed = {"id": "n1", "type": "solution"}
|
||||
desc = _describe_fix(original, fixed)
|
||||
assert "fixed structural issue" in desc.lower()
|
||||
216
backend/tests/test_ai_provider.py
Normal file
216
backend/tests/test_ai_provider.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Tests for the AI provider abstraction layer."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import sys
|
||||
|
||||
from app.core.ai_provider import (
|
||||
AIProvider,
|
||||
AnthropicProvider,
|
||||
GeminiProvider,
|
||||
get_ai_provider,
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestGetAIProvider:
|
||||
"""Tests for the get_ai_provider factory function."""
|
||||
|
||||
def test_returns_gemini_when_configured(self):
|
||||
original_provider = settings.AI_PROVIDER
|
||||
original_key = settings.GOOGLE_AI_API_KEY
|
||||
try:
|
||||
settings.AI_PROVIDER = "gemini"
|
||||
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
|
||||
provider = get_ai_provider()
|
||||
assert isinstance(provider, GeminiProvider)
|
||||
finally:
|
||||
settings.AI_PROVIDER = original_provider
|
||||
settings.GOOGLE_AI_API_KEY = original_key
|
||||
|
||||
def test_returns_anthropic_when_configured(self):
|
||||
original_provider = settings.AI_PROVIDER
|
||||
original_key = settings.ANTHROPIC_API_KEY
|
||||
try:
|
||||
settings.AI_PROVIDER = "anthropic"
|
||||
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
|
||||
provider = get_ai_provider()
|
||||
assert isinstance(provider, AnthropicProvider)
|
||||
finally:
|
||||
settings.AI_PROVIDER = original_provider
|
||||
settings.ANTHROPIC_API_KEY = original_key
|
||||
|
||||
def test_fallback_to_anthropic_when_gemini_key_missing(self):
|
||||
original_provider = settings.AI_PROVIDER
|
||||
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||
try:
|
||||
settings.AI_PROVIDER = "gemini"
|
||||
settings.GOOGLE_AI_API_KEY = None
|
||||
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
|
||||
provider = get_ai_provider()
|
||||
assert isinstance(provider, AnthropicProvider)
|
||||
finally:
|
||||
settings.AI_PROVIDER = original_provider
|
||||
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||
|
||||
def test_fallback_to_gemini_when_anthropic_key_missing(self):
|
||||
original_provider = settings.AI_PROVIDER
|
||||
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||
try:
|
||||
settings.AI_PROVIDER = "anthropic"
|
||||
settings.ANTHROPIC_API_KEY = None
|
||||
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
|
||||
provider = get_ai_provider()
|
||||
assert isinstance(provider, GeminiProvider)
|
||||
finally:
|
||||
settings.AI_PROVIDER = original_provider
|
||||
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||
|
||||
def test_raises_when_no_provider_configured(self):
|
||||
original_provider = settings.AI_PROVIDER
|
||||
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||
try:
|
||||
settings.AI_PROVIDER = "gemini"
|
||||
settings.GOOGLE_AI_API_KEY = None
|
||||
settings.ANTHROPIC_API_KEY = None
|
||||
with pytest.raises(RuntimeError, match="No AI provider configured"):
|
||||
get_ai_provider()
|
||||
finally:
|
||||
settings.AI_PROVIDER = original_provider
|
||||
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||
|
||||
|
||||
class TestAnthropicProvider:
|
||||
"""Tests for AnthropicProvider.generate_json."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_json(self):
|
||||
provider = AnthropicProvider(
|
||||
api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30
|
||||
)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text='{"result": "ok"}')]
|
||||
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
||||
text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt="You are a helper.",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
assert text == '{"result": "ok"}'
|
||||
assert input_tokens == 100
|
||||
assert output_tokens == 50
|
||||
|
||||
mock_client.messages.create.assert_called_once_with(
|
||||
model="claude-haiku-4-5-20251001",
|
||||
max_tokens=1024,
|
||||
system="You are a helper.",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
|
||||
class TestGeminiProvider:
|
||||
"""Tests for GeminiProvider.generate_json."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_json(self):
|
||||
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
||||
|
||||
mock_usage = MagicMock()
|
||||
mock_usage.prompt_token_count = 80
|
||||
mock_usage.candidates_token_count = 40
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"answer": 42}'
|
||||
mock_response.usage_metadata = mock_usage
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.aio.models.generate_content = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
mock_genai_module = MagicMock()
|
||||
mock_genai_module.Client.return_value = mock_client
|
||||
|
||||
mock_types = MagicMock()
|
||||
mock_types.Content.side_effect = lambda **kw: kw
|
||||
mock_types.Part.side_effect = lambda **kw: kw
|
||||
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
||||
|
||||
mock_google = MagicMock()
|
||||
mock_google.genai = mock_genai_module
|
||||
mock_genai_module.types = mock_types
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"google": mock_google,
|
||||
"google.genai": mock_genai_module,
|
||||
"google.genai.types": mock_types,
|
||||
}):
|
||||
text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt="Generate JSON.",
|
||||
messages=[
|
||||
{"role": "user", "content": "Give me data"},
|
||||
{"role": "assistant", "content": "Here it is"},
|
||||
{"role": "user", "content": "More please"},
|
||||
],
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
assert text == '{"answer": 42}'
|
||||
assert input_tokens == 80
|
||||
assert output_tokens == 40
|
||||
|
||||
mock_client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_json_handles_none_usage(self):
|
||||
"""Token counts default to 0 when usage_metadata attributes are None."""
|
||||
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
||||
|
||||
mock_usage = MagicMock(spec=[]) # No attributes at all
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "{}"
|
||||
mock_response.usage_metadata = mock_usage
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.aio.models.generate_content = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
mock_genai_module = MagicMock()
|
||||
mock_genai_module.Client.return_value = mock_client
|
||||
|
||||
mock_types = MagicMock()
|
||||
mock_types.Content.side_effect = lambda **kw: kw
|
||||
mock_types.Part.side_effect = lambda **kw: kw
|
||||
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
||||
|
||||
mock_google = MagicMock()
|
||||
mock_google.genai = mock_genai_module
|
||||
mock_genai_module.types = mock_types
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"google": mock_google,
|
||||
"google.genai": mock_genai_module,
|
||||
"google.genai.types": mock_types,
|
||||
}):
|
||||
text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt="test",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
)
|
||||
|
||||
assert text == "{}"
|
||||
assert input_tokens == 0
|
||||
assert output_tokens == 0
|
||||
209
docs/plans/2026-02-26-ai-autofix-gemini-design.md
Normal file
209
docs/plans/2026-02-26-ai-autofix-gemini-design.md
Normal file
@@ -0,0 +1,209 @@
|
||||
# AI Auto-Fix & Gemini Flash Provider Design
|
||||
|
||||
> **Date:** 2026-02-26
|
||||
> **Status:** Approved
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Two combined features:
|
||||
|
||||
1. **AI Provider Abstraction** — Add Gemini 2.5 Flash as the default AI provider with Claude as fallback, behind a unified interface.
|
||||
2. **AI Auto-Fix for Validation Errors** — When a flow fails validation, offer an AI-powered "Fix with AI" button that generates structural fixes for review.
|
||||
|
||||
---
|
||||
|
||||
## Section 1: AI Provider Abstraction
|
||||
|
||||
### Design
|
||||
|
||||
New `backend/app/core/ai_provider.py` with a unified interface:
|
||||
|
||||
```python
|
||||
class AIProvider(ABC):
|
||||
async def generate_json(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
"""Returns (text, input_tokens, output_tokens)"""
|
||||
```
|
||||
|
||||
Two implementations:
|
||||
|
||||
| Provider | Model | SDK | Role |
|
||||
|----------|-------|-----|------|
|
||||
| `GeminiProvider` | `gemini-2.5-flash` | `google-genai` | Default |
|
||||
| `AnthropicProvider` | `claude-haiku-4-5-20251001` | `anthropic` | Fallback |
|
||||
|
||||
### Provider Selection
|
||||
|
||||
- `get_ai_provider()` factory reads `AI_PROVIDER` env var (default: `"gemini"`)
|
||||
- Falls back to Anthropic if Gemini key is missing
|
||||
- Existing `ai_tree_generator_service.py` swaps direct Anthropic calls for `get_ai_provider()`
|
||||
|
||||
### New Environment Variables
|
||||
|
||||
| Variable | Default | Purpose |
|
||||
|----------|---------|---------|
|
||||
| `AI_PROVIDER` | `"gemini"` | Which provider to use (`gemini` or `anthropic`) |
|
||||
| `GOOGLE_AI_API_KEY` | — | Gemini API key |
|
||||
|
||||
Existing `ANTHROPIC_API_KEY` remains for fallback.
|
||||
|
||||
### Config Changes (`core/config.py`)
|
||||
|
||||
```python
|
||||
AI_PROVIDER: str = "gemini"
|
||||
GOOGLE_AI_API_KEY: str | None = None
|
||||
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
|
||||
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Section 2: AI Auto-Fix Feature
|
||||
|
||||
### Backend Endpoint
|
||||
|
||||
**`POST /api/v1/ai/fix-tree`**
|
||||
|
||||
Request:
|
||||
```json
|
||||
{
|
||||
"tree_structure": { /* full tree */ },
|
||||
"tree_name": "Router Troubleshooting",
|
||||
"tree_type": "troubleshooting",
|
||||
"validation_errors": [
|
||||
{
|
||||
"node_id": "node_abc",
|
||||
"message": "Decision node must have at least 2 children (branches)"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"fixes": [
|
||||
{
|
||||
"target_node_id": "node_abc",
|
||||
"error_message": "Decision node must have at least 2 children (branches)",
|
||||
"description": "Added second branch 'Check firmware version' with solution node",
|
||||
"original_node": { /* snapshot before fix */ },
|
||||
"fixed_node": { /* replacement node with corrected subtree */ }
|
||||
}
|
||||
],
|
||||
"tokens_used": { "input": 1200, "output": 800 }
|
||||
}
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
1. For each validation error tied to a `node_id`, extract that node + its parent + siblings from the tree.
|
||||
2. Build a prompt with:
|
||||
- The **full tree structure** serialized as a simplified outline (node titles + types + structure) for context
|
||||
- The **specific failing node** highlighted with full JSON detail
|
||||
- The **validation error message**
|
||||
- Instructions: "Fix ONLY this node's structural issue. Keep all existing content. Generate domain-relevant additions that fit the flow's topic."
|
||||
3. AI returns a corrected version of that node (with children/options adjusted).
|
||||
4. Backend re-validates the fixed node before returning it.
|
||||
5. If re-validation fails, retry once with the error fed back (corrective prompt pattern).
|
||||
|
||||
### Prompt Strategy
|
||||
|
||||
The prompt gives the AI the full tree as a compact outline, then zooms into the failing node:
|
||||
|
||||
```
|
||||
You are fixing a validation error in a troubleshooting flow called "Router Troubleshooting".
|
||||
|
||||
FULL FLOW OUTLINE:
|
||||
- [decision] Is the router powered on?
|
||||
- [action] Check power cable → [solution] Power restored
|
||||
- [decision] Are lights blinking? ← ERROR HERE
|
||||
- [solution] Contact ISP
|
||||
|
||||
ERROR: Decision node "Are lights blinking?" must have at least 2 children (branches).
|
||||
|
||||
FAILING NODE (full detail):
|
||||
{...json...}
|
||||
|
||||
Fix this node by adding the minimum structure needed to resolve the error.
|
||||
Return ONLY the fixed node as JSON.
|
||||
```
|
||||
|
||||
### Frontend UX
|
||||
|
||||
1. **Trigger**: "Fix with AI" button in `ValidationSummary` — appears when there are fixable errors (structural errors with a `node_id`).
|
||||
2. **Loading state**: Button shows spinner + "Generating fixes..." — disabled during request.
|
||||
3. **Review modal** (`AIFixReviewModal`): Shows each proposed fix as a card:
|
||||
- Error message at top
|
||||
- Before/after view of the node change
|
||||
- "Apply" / "Skip" buttons per fix
|
||||
- "Apply All" button in footer
|
||||
4. **Apply**: Each accepted fix calls `updateNode(targetNodeId, fixedNode)` in the tree editor store.
|
||||
5. **Re-validate**: After applying fixes, auto-run `validate()` to confirm resolution.
|
||||
|
||||
---
|
||||
|
||||
## Section 3: Scope & Constraints
|
||||
|
||||
### Fixable Errors (Auto-Fix Scope)
|
||||
|
||||
Only structural validation errors with a `node_id`:
|
||||
- Decision node missing children/branches
|
||||
- Decision node missing options
|
||||
- Action node missing `next_node_id`
|
||||
- Dead-end decision nodes (no children)
|
||||
|
||||
### NOT Fixable
|
||||
|
||||
- Global checks (tree too small/large, not enough solutions) — require rethinking the whole tree
|
||||
- Content quality issues — out of scope
|
||||
- Errors without a `node_id` (root-level issues)
|
||||
|
||||
Non-fixable errors still show in ValidationSummary but without the "Fix with AI" option.
|
||||
|
||||
### Token Budget
|
||||
|
||||
- Tree outline: ~50-100 tokens for a typical 15-node tree
|
||||
- Failing node detail: ~100-200 tokens
|
||||
- System prompt + instructions: ~300 tokens
|
||||
- **Total input per fix: ~500-600 tokens**
|
||||
- One API call per failing node (not batched)
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Provider failure (rate limit, network): toast error, user can retry
|
||||
- Fix fails re-validation: "AI couldn't generate a valid fix" with retry option
|
||||
- Max 1 retry with corrective prompt per attempt
|
||||
- Both provider and fallback fail: surface error to user
|
||||
|
||||
### Auth
|
||||
|
||||
- Requires `engineer` role or above (`require_engineer_or_admin`)
|
||||
|
||||
---
|
||||
|
||||
## New Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `backend/app/core/ai_provider.py` | Provider abstraction + Gemini/Anthropic implementations |
|
||||
| `backend/app/core/ai_fix_service.py` | Fix generation logic + prompt building |
|
||||
| `backend/app/api/endpoints/ai.py` | `POST /ai/fix-tree` endpoint |
|
||||
| `backend/app/schemas/ai.py` | Request/response schemas for AI endpoints |
|
||||
| `frontend/src/components/tree-editor/AIFixReviewModal.tsx` | Review modal for proposed fixes |
|
||||
|
||||
## Modified Files
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `backend/app/core/config.py` | Add Gemini config vars |
|
||||
| `backend/app/core/ai_tree_generator_service.py` | Swap Anthropic calls for provider abstraction |
|
||||
| `backend/app/api/router.py` | Register `/ai` routes |
|
||||
| `frontend/src/api/trees.ts` | Add `fixTree()` API call |
|
||||
| `frontend/src/components/tree-editor/ValidationSummary.tsx` | Add "Fix with AI" button |
|
||||
1707
docs/plans/2026-02-26-ai-autofix-gemini-plan.md
Normal file
1707
docs/plans/2026-02-26-ai-autofix-gemini-plan.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
import apiClient from './client'
|
||||
import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse } from '@/types'
|
||||
import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse, AIFixTreeRequest, AIFixTreeResponse } from '@/types'
|
||||
|
||||
export const treesApi = {
|
||||
async list(params?: TreeFilters): Promise<TreeListItem[]> {
|
||||
@@ -65,6 +65,12 @@ export const treesApi = {
|
||||
const response = await apiClient.post<TreeValidationResponse>(`/trees/${id}/can-publish`)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// AI auto-fix
|
||||
async fixTree(request: AIFixTreeRequest): Promise<AIFixTreeResponse> {
|
||||
const response = await apiClient.post<AIFixTreeResponse>('/ai/fix-tree', request)
|
||||
return response.data
|
||||
},
|
||||
}
|
||||
|
||||
export default treesApi
|
||||
|
||||
170
frontend/src/components/tree-editor/AIFixReviewModal.tsx
Normal file
170
frontend/src/components/tree-editor/AIFixReviewModal.tsx
Normal file
@@ -0,0 +1,170 @@
|
||||
import { useState } from 'react'
|
||||
import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react'
|
||||
import { cn } from '@/lib/utils'
|
||||
import type { AIFixProposal } from '@/types'
|
||||
|
||||
interface AIFixReviewModalProps {
|
||||
fixes: AIFixProposal[]
|
||||
onApply: (fix: AIFixProposal) => void
|
||||
onApplyAll: () => void
|
||||
onClose: () => void
|
||||
}
|
||||
|
||||
export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) {
|
||||
const [appliedIds, setAppliedIds] = useState<Set<string>>(new Set())
|
||||
const [skippedIds, setSkippedIds] = useState<Set<string>>(new Set())
|
||||
const [expandedIds, setExpandedIds] = useState<Set<string>>(new Set(fixes.map(f => f.target_node_id)))
|
||||
|
||||
const handleApply = (fix: AIFixProposal) => {
|
||||
onApply(fix)
|
||||
setAppliedIds(prev => new Set(prev).add(fix.target_node_id))
|
||||
}
|
||||
|
||||
const handleSkip = (fix: AIFixProposal) => {
|
||||
setSkippedIds(prev => new Set(prev).add(fix.target_node_id))
|
||||
}
|
||||
|
||||
const toggleExpanded = (id: string) => {
|
||||
setExpandedIds(prev => {
|
||||
const next = new Set(prev)
|
||||
if (next.has(id)) next.delete(id)
|
||||
else next.add(id)
|
||||
return next
|
||||
})
|
||||
}
|
||||
|
||||
const pendingFixes = fixes.filter(
|
||||
f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id)
|
||||
)
|
||||
const allHandled = pendingFixes.length === 0
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/80 backdrop-blur-sm p-4">
|
||||
<div className="relative flex h-[80vh] w-full max-w-2xl flex-col bg-card border border-border rounded-2xl shadow-lg">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between border-b border-border px-6 py-4">
|
||||
<div className="flex items-center gap-2">
|
||||
<Sparkles className="h-5 w-5 text-primary" />
|
||||
<h2 className="text-lg font-semibold text-foreground">
|
||||
AI Fix Proposals ({fixes.length})
|
||||
</h2>
|
||||
</div>
|
||||
<button
|
||||
onClick={onClose}
|
||||
className="rounded-md p-1 text-muted-foreground hover:bg-accent hover:text-foreground"
|
||||
>
|
||||
<X className="h-5 w-5" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Body */}
|
||||
<div className="flex-1 overflow-y-auto p-4 space-y-3">
|
||||
{fixes.map((fix) => {
|
||||
const isApplied = appliedIds.has(fix.target_node_id)
|
||||
const isSkipped = skippedIds.has(fix.target_node_id)
|
||||
const isExpanded = expandedIds.has(fix.target_node_id)
|
||||
|
||||
return (
|
||||
<div
|
||||
key={fix.target_node_id}
|
||||
className={cn(
|
||||
'rounded-lg border p-4',
|
||||
isApplied
|
||||
? 'border-emerald-400/30 bg-emerald-400/5'
|
||||
: isSkipped
|
||||
? 'border-border bg-accent/30 opacity-60'
|
||||
: 'border-border bg-card'
|
||||
)}
|
||||
>
|
||||
{/* Fix header */}
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div className="flex-1">
|
||||
<p className="text-sm text-red-400 mb-1">{fix.error_message}</p>
|
||||
<p className="text-sm text-foreground">{fix.description}</p>
|
||||
<p className="text-xs text-muted-foreground mt-1">
|
||||
Node: {fix.target_node_id}
|
||||
</p>
|
||||
</div>
|
||||
{isApplied && (
|
||||
<span className="flex items-center gap-1 rounded-full bg-emerald-400/10 px-2 py-1 text-xs text-emerald-400">
|
||||
<Check className="h-3 w-3" /> Applied
|
||||
</span>
|
||||
)}
|
||||
{isSkipped && (
|
||||
<span className="text-xs text-muted-foreground">Skipped</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Expand/collapse detail */}
|
||||
{!isApplied && !isSkipped && (
|
||||
<>
|
||||
<button
|
||||
onClick={() => toggleExpanded(fix.target_node_id)}
|
||||
className="mt-2 flex items-center gap-1 text-xs text-muted-foreground hover:text-foreground"
|
||||
>
|
||||
{isExpanded ? <ChevronUp className="h-3 w-3" /> : <ChevronDown className="h-3 w-3" />}
|
||||
{isExpanded ? 'Hide' : 'Show'} details
|
||||
</button>
|
||||
|
||||
{isExpanded && (
|
||||
<div className="mt-3 grid grid-cols-2 gap-3">
|
||||
<div>
|
||||
<p className="text-xs font-medium text-muted-foreground mb-1">Before</p>
|
||||
<pre className="overflow-x-auto rounded bg-accent/50 p-2 text-xs text-muted-foreground max-h-48 overflow-y-auto">
|
||||
{JSON.stringify(fix.original_node, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
<div>
|
||||
<p className="text-xs font-medium text-emerald-400 mb-1">After</p>
|
||||
<pre className="overflow-x-auto rounded bg-emerald-400/5 p-2 text-xs text-foreground max-h-48 overflow-y-auto">
|
||||
{JSON.stringify(fix.fixed_node, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Action buttons */}
|
||||
<div className="mt-3 flex gap-2">
|
||||
<button
|
||||
onClick={() => handleApply(fix)}
|
||||
className="flex items-center gap-1 rounded-md bg-gradient-brand px-3 py-1.5 text-xs font-medium text-white shadow-sm shadow-primary/20 hover:opacity-90"
|
||||
>
|
||||
<Check className="h-3 w-3" />
|
||||
Apply
|
||||
</button>
|
||||
<button
|
||||
onClick={() => handleSkip(fix)}
|
||||
className="flex items-center gap-1 rounded-md border border-border px-3 py-1.5 text-xs font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
|
||||
>
|
||||
<SkipForward className="h-3 w-3" />
|
||||
Skip
|
||||
</button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
<div className="flex items-center justify-between border-t border-border px-6 py-4">
|
||||
<button
|
||||
onClick={onClose}
|
||||
className="rounded-md border border-border px-4 py-2 text-sm font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
|
||||
>
|
||||
{allHandled ? 'Done' : 'Cancel'}
|
||||
</button>
|
||||
{!allHandled && (
|
||||
<button
|
||||
onClick={onApplyAll}
|
||||
className="rounded-md bg-gradient-brand px-4 py-2 text-sm font-medium text-white shadow-lg shadow-primary/20 hover:opacity-90"
|
||||
>
|
||||
Apply All ({pendingFixes.length})
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -1,14 +1,16 @@
|
||||
import { useState } from 'react'
|
||||
import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp } from 'lucide-react'
|
||||
import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react'
|
||||
import { cn } from '@/lib/utils'
|
||||
import type { ValidationError } from '@/store/treeEditorStore'
|
||||
|
||||
interface ValidationSummaryProps {
|
||||
errors: ValidationError[]
|
||||
onSelectNode: (nodeId: string) => void
|
||||
onFixWithAI?: () => void
|
||||
isFixing?: boolean
|
||||
}
|
||||
|
||||
export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryProps) {
|
||||
export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) {
|
||||
const [isExpanded, setIsExpanded] = useState(true)
|
||||
|
||||
const errorItems = errors.filter(e => e.severity === 'error')
|
||||
@@ -22,6 +24,8 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
|
||||
}
|
||||
}
|
||||
|
||||
const hasFixableErrors = errorItems.some(e => e.nodeId)
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -32,14 +36,16 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
|
||||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<button
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
<div
|
||||
className={cn(
|
||||
'flex w-full items-center justify-between p-3 text-left transition-colors hover:bg-accent',
|
||||
'flex w-full items-center justify-between p-3 transition-colors',
|
||||
errorItems.length > 0 ? 'text-red-400' : 'text-yellow-400'
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
className="flex items-center gap-2 text-left hover:opacity-80"
|
||||
>
|
||||
{errorItems.length > 0 ? (
|
||||
<AlertCircle className="h-5 w-5" />
|
||||
) : (
|
||||
@@ -58,9 +64,35 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
|
||||
</>
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
{isExpanded ? <ChevronUp className="h-4 w-4" /> : <ChevronDown className="h-4 w-4" />}
|
||||
</button>
|
||||
{isExpanded ? <ChevronUp className="h-4 w-4" /> : <ChevronDown className="h-4 w-4" />}
|
||||
</button>
|
||||
|
||||
{/* Fix with AI button */}
|
||||
{onFixWithAI && hasFixableErrors && (
|
||||
<button
|
||||
onClick={onFixWithAI}
|
||||
disabled={isFixing}
|
||||
className={cn(
|
||||
'flex items-center gap-1.5 rounded-md px-3 py-1 text-xs font-medium transition-colors',
|
||||
isFixing
|
||||
? 'bg-primary/10 text-primary cursor-wait'
|
||||
: 'bg-gradient-brand text-white shadow-sm shadow-primary/20 hover:opacity-90'
|
||||
)}
|
||||
>
|
||||
{isFixing ? (
|
||||
<>
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
Generating fixes...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Sparkles className="h-3 w-3" />
|
||||
Fix with AI
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Error/Warning List */}
|
||||
{isExpanded && (
|
||||
|
||||
@@ -5,10 +5,11 @@ import { Undo2, Redo2, Save, CheckCircle2, Monitor, FileText, Code2, LayoutList,
|
||||
import { getMonacoEditor } from '@/components/tree-editor/code-mode'
|
||||
import { treesApi } from '@/api/trees'
|
||||
import { treeMarkdownApi } from '@/api/treeMarkdown'
|
||||
import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure } from '@/types'
|
||||
import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure, AIFixProposal } from '@/types'
|
||||
import { useTreeEditorStore, useTreeEditorTemporal } from '@/store/treeEditorStore'
|
||||
import { TreeEditorLayout } from '@/components/tree-editor/TreeEditorLayout'
|
||||
import { ValidationSummary } from '@/components/tree-editor/ValidationSummary'
|
||||
import { AIFixReviewModal } from '@/components/tree-editor/AIFixReviewModal'
|
||||
import { useKeyboardShortcuts } from '@/hooks/useKeyboardShortcuts'
|
||||
import { usePermissions } from '@/hooks/usePermissions'
|
||||
import { Spinner } from '@/components/common/Spinner'
|
||||
@@ -58,6 +59,8 @@ export function TreeEditorPage() {
|
||||
const [showAnalytics, setShowAnalytics] = useState(false)
|
||||
const [isMetadataOpen, setIsMetadataOpen] = useState(false)
|
||||
const [editingNodeId, setEditingNodeId] = useState<string | null>(null)
|
||||
const [isFixing, setIsFixing] = useState(false)
|
||||
const [fixProposals, setFixProposals] = useState<AIFixProposal[] | null>(null)
|
||||
|
||||
// Mobile detection
|
||||
const [isMobile, setIsMobile] = useState(false)
|
||||
@@ -217,6 +220,54 @@ export function TreeEditorPage() {
|
||||
selectNode(nodeId)
|
||||
}
|
||||
|
||||
const handleFixWithAI = async () => {
|
||||
const store = useTreeEditorStore.getState()
|
||||
if (!store.treeStructure) return
|
||||
|
||||
const fixableErrors = store.validationErrors
|
||||
.filter(e => e.severity === 'error' && e.nodeId)
|
||||
.map(e => ({ node_id: e.nodeId!, message: e.message }))
|
||||
|
||||
if (fixableErrors.length === 0) return
|
||||
|
||||
setIsFixing(true)
|
||||
try {
|
||||
const result = await treesApi.fixTree({
|
||||
tree_structure: store.treeStructure as unknown as Record<string, unknown>,
|
||||
tree_name: store.name,
|
||||
tree_type: 'troubleshooting',
|
||||
validation_errors: fixableErrors,
|
||||
})
|
||||
if (result.fixes.length > 0) {
|
||||
setFixProposals(result.fixes)
|
||||
} else {
|
||||
toast.info('AI could not generate fixes for these errors')
|
||||
}
|
||||
} catch {
|
||||
toast.error('Failed to generate AI fixes. Please try again.')
|
||||
} finally {
|
||||
setIsFixing(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleApplyFix = (fix: AIFixProposal) => {
|
||||
updateNode(fix.target_node_id, fix.fixed_node as Partial<TreeStructure>)
|
||||
}
|
||||
|
||||
const handleApplyAllFixes = () => {
|
||||
if (!fixProposals) return
|
||||
for (const fix of fixProposals) {
|
||||
handleApplyFix(fix)
|
||||
}
|
||||
setFixProposals(null)
|
||||
setTimeout(() => { validate() }, 100)
|
||||
}
|
||||
|
||||
const handleCloseFixModal = () => {
|
||||
setFixProposals(null)
|
||||
validate()
|
||||
}
|
||||
|
||||
const handleNodeSelect = useCallback((nodeId: string | null) => {
|
||||
if (nodeId) {
|
||||
setIsMetadataOpen(false) // close metadata when opening node editor
|
||||
@@ -685,6 +736,8 @@ export function TreeEditorPage() {
|
||||
<ValidationSummary
|
||||
errors={validationErrors}
|
||||
onSelectNode={handleSelectNode}
|
||||
onFixWithAI={handleFixWithAI}
|
||||
isFixing={isFixing}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
@@ -705,6 +758,16 @@ export function TreeEditorPage() {
|
||||
<FlowAnalyticsPanel treeId={id} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* AI Fix Review Modal */}
|
||||
{fixProposals && (
|
||||
<AIFixReviewModal
|
||||
fixes={fixProposals}
|
||||
onApply={handleApplyFix}
|
||||
onApplyAll={handleApplyAllFixes}
|
||||
onClose={handleCloseFixModal}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
24
frontend/src/types/ai-fix.ts
Normal file
24
frontend/src/types/ai-fix.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
export interface AIFixValidationError {
|
||||
node_id: string
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface AIFixProposal {
|
||||
target_node_id: string
|
||||
error_message: string
|
||||
description: string
|
||||
original_node: Record<string, unknown>
|
||||
fixed_node: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AIFixTreeRequest {
|
||||
tree_structure: Record<string, unknown>
|
||||
tree_name: string
|
||||
tree_type: 'troubleshooting' | 'procedural' | 'maintenance'
|
||||
validation_errors: AIFixValidationError[]
|
||||
}
|
||||
|
||||
export interface AIFixTreeResponse {
|
||||
fixes: AIFixProposal[]
|
||||
tokens_used: { input: number; output: number }
|
||||
}
|
||||
@@ -45,3 +45,10 @@ export type {
|
||||
AIAssembleResponse,
|
||||
AIWizardPhase,
|
||||
} from './ai'
|
||||
|
||||
export type {
|
||||
AIFixTreeRequest,
|
||||
AIFixTreeResponse,
|
||||
AIFixProposal,
|
||||
AIFixValidationError,
|
||||
} from './ai-fix'
|
||||
|
||||
Reference in New Issue
Block a user