feat: add AI fix service with prompt building and validation
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
273
backend/app/core/ai_fix_service.py
Normal file
273
backend/app/core/ai_fix_service.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""AI-powered fix service for tree validation errors.
|
||||
|
||||
Given a tree structure and validation errors, generates AI-powered
|
||||
proposals to fix each structural issue while preserving existing content.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.ai_provider import get_ai_provider
|
||||
from app.core.ai_tree_validator import validate_generated_tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers.
|
||||
|
||||
You will receive:
|
||||
1. A full flow outline showing the tree structure
|
||||
2. The specific failing node with its full JSON
|
||||
3. The validation error message
|
||||
|
||||
Your task: Return a FIXED version of the failing node as valid JSON. Rules:
|
||||
- Fix ONLY the structural issue described in the error message
|
||||
- Keep ALL existing content (titles, descriptions, questions, options) unchanged
|
||||
- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic
|
||||
- Every new node must have a unique ID (use descriptive kebab-case IDs)
|
||||
- Decision nodes must have at least 2 options and at least 2 children
|
||||
- Action nodes must have a next_node_id pointing to a sibling node in the parent's children
|
||||
- Solution nodes are leaf nodes (no children)
|
||||
- Return ONLY the fixed node JSON, no explanation"""
|
||||
|
||||
|
||||
# ── Pure helper functions ──
|
||||
|
||||
|
||||
def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None:
|
||||
"""Recursively find a node by its ID in the tree structure."""
|
||||
if not isinstance(tree, dict):
|
||||
return None
|
||||
if tree.get("id") == node_id:
|
||||
return tree
|
||||
for child in tree.get("children", []):
|
||||
result = _find_node_by_id(child, node_id)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None:
|
||||
"""Find the parent of a node with the given ID."""
|
||||
if not isinstance(tree, dict):
|
||||
return None
|
||||
for child in tree.get("children", []):
|
||||
if isinstance(child, dict) and child.get("id") == target_id:
|
||||
return tree
|
||||
result = _find_parent_node(child, target_id)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _serialize_tree_outline(
|
||||
tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None
|
||||
) -> str:
|
||||
"""Serialize tree as a readable outline for AI prompt context.
|
||||
|
||||
Format: indented "- [type] label" with "<<< ERROR HERE" marker.
|
||||
"""
|
||||
if not isinstance(tree, dict):
|
||||
return ""
|
||||
|
||||
node_type = tree.get("type", "unknown")
|
||||
label = tree.get("question") or tree.get("title") or tree.get("id", "?")
|
||||
prefix = " " * indent
|
||||
marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else ""
|
||||
line = f"{prefix}- [{node_type}] {label}{marker}"
|
||||
|
||||
lines = [line]
|
||||
for child in tree.get("children", []):
|
||||
lines.append(_serialize_tree_outline(child, indent + 1, error_node_id))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _strip_markdown_fences(text: str) -> str:
|
||||
"""Strip ```json...``` fences from AI response."""
|
||||
return re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.MULTILINE).rstrip(
|
||||
"`"
|
||||
).strip()
|
||||
|
||||
|
||||
def _replace_node_in_tree(
|
||||
tree: dict[str, Any], target_id: str, replacement: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Replace a node in-place by ID. Returns True if found and replaced."""
|
||||
if not isinstance(tree, dict):
|
||||
return False
|
||||
if tree.get("id") == target_id:
|
||||
tree.clear()
|
||||
tree.update(replacement)
|
||||
return True
|
||||
for child in tree.get("children", []):
|
||||
if _replace_node_in_tree(child, target_id, replacement):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str:
|
||||
"""Describe what changed between original and fixed node."""
|
||||
changes: list[str] = []
|
||||
|
||||
orig_children = len(original.get("children", []))
|
||||
fixed_children = len(fixed.get("children", []))
|
||||
if fixed_children > orig_children:
|
||||
changes.append(f"added {fixed_children - orig_children} child node(s)")
|
||||
|
||||
orig_options = len(original.get("options", []))
|
||||
fixed_options = len(fixed.get("options", []))
|
||||
if fixed_options > orig_options:
|
||||
changes.append(f"added {fixed_options - orig_options} option(s)")
|
||||
|
||||
if fixed.get("next_node_id") and not original.get("next_node_id"):
|
||||
changes.append("added next_node_id")
|
||||
|
||||
if not changes:
|
||||
changes.append("fixed structural issue")
|
||||
|
||||
return "; ".join(changes).capitalize()
|
||||
|
||||
|
||||
# ── Prompt building ──
|
||||
|
||||
|
||||
def _build_fix_prompt(
|
||||
tree: dict[str, Any],
|
||||
node_id: str,
|
||||
error_message: str,
|
||||
tree_name: str,
|
||||
tree_type: str,
|
||||
) -> str:
|
||||
"""Build the user message for the AI fix request."""
|
||||
outline = _serialize_tree_outline(tree, error_node_id=node_id)
|
||||
node = _find_node_by_id(tree, node_id)
|
||||
node_json = json.dumps(node, indent=2) if node else "{}"
|
||||
|
||||
return (
|
||||
f"Flow name: {tree_name}\n"
|
||||
f"Flow type: {tree_type}\n\n"
|
||||
f"## Full flow outline\n```\n{outline}\n```\n\n"
|
||||
f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n"
|
||||
f"## Validation error\n{error_message}\n\n"
|
||||
f"Return the fixed version of this node as JSON."
|
||||
)
|
||||
|
||||
|
||||
# ── Main entry point ──
|
||||
|
||||
|
||||
async def generate_fixes(
|
||||
tree_structure: dict[str, Any],
|
||||
tree_name: str,
|
||||
tree_type: str,
|
||||
validation_errors: list[dict[str, str]],
|
||||
) -> tuple[list[dict[str, Any]], int, int]:
|
||||
"""Generate AI-powered fixes for tree validation errors.
|
||||
|
||||
Args:
|
||||
tree_structure: Full tree structure dict.
|
||||
tree_name: Name of the flow.
|
||||
tree_type: Type of flow (troubleshooting, procedural, maintenance).
|
||||
validation_errors: List of dicts with "node_id" and "message" keys.
|
||||
|
||||
Returns:
|
||||
Tuple of (fixes_list, total_input_tokens, total_output_tokens).
|
||||
Each fix dict has: target_node_id, error_message, description,
|
||||
original_node, fixed_node.
|
||||
"""
|
||||
provider = get_ai_provider()
|
||||
fixes: list[dict[str, Any]] = []
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
for error in validation_errors:
|
||||
node_id = error["node_id"]
|
||||
error_message = error["message"]
|
||||
|
||||
original_node = _find_node_by_id(tree_structure, node_id)
|
||||
if original_node is None:
|
||||
logger.warning("Node %s not found in tree, skipping fix", node_id)
|
||||
continue
|
||||
|
||||
original_snapshot = copy.deepcopy(original_node)
|
||||
|
||||
# Build prompt and call AI
|
||||
user_message = _build_fix_prompt(
|
||||
tree_structure, node_id, error_message, tree_name, tree_type
|
||||
)
|
||||
messages = [{"role": "user", "content": user_message}]
|
||||
|
||||
try:
|
||||
text, in_tok, out_tok = await provider.generate_json(
|
||||
system_prompt=FIX_SYSTEM_PROMPT,
|
||||
messages=messages,
|
||||
max_tokens=2048,
|
||||
)
|
||||
total_input_tokens += in_tok
|
||||
total_output_tokens += out_tok
|
||||
|
||||
cleaned = _strip_markdown_fences(text)
|
||||
fixed_node = json.loads(cleaned)
|
||||
except (json.JSONDecodeError, Exception) as exc:
|
||||
logger.warning("AI fix failed for node %s: %s", node_id, exc)
|
||||
continue
|
||||
|
||||
# Validate by substituting into a tree copy
|
||||
tree_copy = copy.deepcopy(tree_structure)
|
||||
_replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node))
|
||||
remaining_errors = validate_generated_tree(tree_copy)
|
||||
|
||||
# Check if the specific error is still present
|
||||
still_has_error = any(node_id in e for e in remaining_errors)
|
||||
|
||||
if still_has_error:
|
||||
# Retry once with corrective prompt
|
||||
retry_message = (
|
||||
f"Your previous fix still has validation errors:\n"
|
||||
f"{chr(10).join(remaining_errors)}\n\n"
|
||||
f"Please fix the node again. Return ONLY the corrected JSON."
|
||||
)
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
messages.append({"role": "user", "content": retry_message})
|
||||
|
||||
try:
|
||||
text2, in_tok2, out_tok2 = await provider.generate_json(
|
||||
system_prompt=FIX_SYSTEM_PROMPT,
|
||||
messages=messages,
|
||||
max_tokens=2048,
|
||||
)
|
||||
total_input_tokens += in_tok2
|
||||
total_output_tokens += out_tok2
|
||||
|
||||
cleaned2 = _strip_markdown_fences(text2)
|
||||
fixed_node = json.loads(cleaned2)
|
||||
|
||||
# Re-validate
|
||||
tree_copy2 = copy.deepcopy(tree_structure)
|
||||
_replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node))
|
||||
remaining2 = validate_generated_tree(tree_copy2)
|
||||
still_has_error = any(node_id in e for e in remaining2)
|
||||
except (json.JSONDecodeError, Exception) as exc:
|
||||
logger.warning("AI retry fix failed for node %s: %s", node_id, exc)
|
||||
continue
|
||||
|
||||
if still_has_error:
|
||||
logger.warning("AI could not fix node %s after retry", node_id)
|
||||
continue
|
||||
|
||||
description = _describe_fix(original_snapshot, fixed_node)
|
||||
fixes.append(
|
||||
{
|
||||
"target_node_id": node_id,
|
||||
"error_message": error_message,
|
||||
"description": description,
|
||||
"original_node": original_snapshot,
|
||||
"fixed_node": fixed_node,
|
||||
}
|
||||
)
|
||||
|
||||
return fixes, total_input_tokens, total_output_tokens
|
||||
Reference in New Issue
Block a user