feat: add AI fix service with prompt building and validation
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
273
backend/app/core/ai_fix_service.py
Normal file
273
backend/app/core/ai_fix_service.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""AI-powered fix service for tree validation errors.
|
||||
|
||||
Given a tree structure and validation errors, generates AI-powered
|
||||
proposals to fix each structural issue while preserving existing content.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.ai_provider import get_ai_provider
|
||||
from app.core.ai_tree_validator import validate_generated_tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers.
|
||||
|
||||
You will receive:
|
||||
1. A full flow outline showing the tree structure
|
||||
2. The specific failing node with its full JSON
|
||||
3. The validation error message
|
||||
|
||||
Your task: Return a FIXED version of the failing node as valid JSON. Rules:
|
||||
- Fix ONLY the structural issue described in the error message
|
||||
- Keep ALL existing content (titles, descriptions, questions, options) unchanged
|
||||
- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic
|
||||
- Every new node must have a unique ID (use descriptive kebab-case IDs)
|
||||
- Decision nodes must have at least 2 options and at least 2 children
|
||||
- Action nodes must have a next_node_id pointing to a sibling node in the parent's children
|
||||
- Solution nodes are leaf nodes (no children)
|
||||
- Return ONLY the fixed node JSON, no explanation"""
|
||||
|
||||
|
||||
# ── Pure helper functions ──
|
||||
|
||||
|
||||
def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None:
|
||||
"""Recursively find a node by its ID in the tree structure."""
|
||||
if not isinstance(tree, dict):
|
||||
return None
|
||||
if tree.get("id") == node_id:
|
||||
return tree
|
||||
for child in tree.get("children", []):
|
||||
result = _find_node_by_id(child, node_id)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None:
|
||||
"""Find the parent of a node with the given ID."""
|
||||
if not isinstance(tree, dict):
|
||||
return None
|
||||
for child in tree.get("children", []):
|
||||
if isinstance(child, dict) and child.get("id") == target_id:
|
||||
return tree
|
||||
result = _find_parent_node(child, target_id)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _serialize_tree_outline(
|
||||
tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None
|
||||
) -> str:
|
||||
"""Serialize tree as a readable outline for AI prompt context.
|
||||
|
||||
Format: indented "- [type] label" with "<<< ERROR HERE" marker.
|
||||
"""
|
||||
if not isinstance(tree, dict):
|
||||
return ""
|
||||
|
||||
node_type = tree.get("type", "unknown")
|
||||
label = tree.get("question") or tree.get("title") or tree.get("id", "?")
|
||||
prefix = " " * indent
|
||||
marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else ""
|
||||
line = f"{prefix}- [{node_type}] {label}{marker}"
|
||||
|
||||
lines = [line]
|
||||
for child in tree.get("children", []):
|
||||
lines.append(_serialize_tree_outline(child, indent + 1, error_node_id))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _strip_markdown_fences(text: str) -> str:
|
||||
"""Strip ```json...``` fences from AI response."""
|
||||
return re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.MULTILINE).rstrip(
|
||||
"`"
|
||||
).strip()
|
||||
|
||||
|
||||
def _replace_node_in_tree(
|
||||
tree: dict[str, Any], target_id: str, replacement: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Replace a node in-place by ID. Returns True if found and replaced."""
|
||||
if not isinstance(tree, dict):
|
||||
return False
|
||||
if tree.get("id") == target_id:
|
||||
tree.clear()
|
||||
tree.update(replacement)
|
||||
return True
|
||||
for child in tree.get("children", []):
|
||||
if _replace_node_in_tree(child, target_id, replacement):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str:
|
||||
"""Describe what changed between original and fixed node."""
|
||||
changes: list[str] = []
|
||||
|
||||
orig_children = len(original.get("children", []))
|
||||
fixed_children = len(fixed.get("children", []))
|
||||
if fixed_children > orig_children:
|
||||
changes.append(f"added {fixed_children - orig_children} child node(s)")
|
||||
|
||||
orig_options = len(original.get("options", []))
|
||||
fixed_options = len(fixed.get("options", []))
|
||||
if fixed_options > orig_options:
|
||||
changes.append(f"added {fixed_options - orig_options} option(s)")
|
||||
|
||||
if fixed.get("next_node_id") and not original.get("next_node_id"):
|
||||
changes.append("added next_node_id")
|
||||
|
||||
if not changes:
|
||||
changes.append("fixed structural issue")
|
||||
|
||||
return "; ".join(changes).capitalize()
|
||||
|
||||
|
||||
# ── Prompt building ──
|
||||
|
||||
|
||||
def _build_fix_prompt(
|
||||
tree: dict[str, Any],
|
||||
node_id: str,
|
||||
error_message: str,
|
||||
tree_name: str,
|
||||
tree_type: str,
|
||||
) -> str:
|
||||
"""Build the user message for the AI fix request."""
|
||||
outline = _serialize_tree_outline(tree, error_node_id=node_id)
|
||||
node = _find_node_by_id(tree, node_id)
|
||||
node_json = json.dumps(node, indent=2) if node else "{}"
|
||||
|
||||
return (
|
||||
f"Flow name: {tree_name}\n"
|
||||
f"Flow type: {tree_type}\n\n"
|
||||
f"## Full flow outline\n```\n{outline}\n```\n\n"
|
||||
f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n"
|
||||
f"## Validation error\n{error_message}\n\n"
|
||||
f"Return the fixed version of this node as JSON."
|
||||
)
|
||||
|
||||
|
||||
# ── Main entry point ──
|
||||
|
||||
|
||||
async def generate_fixes(
|
||||
tree_structure: dict[str, Any],
|
||||
tree_name: str,
|
||||
tree_type: str,
|
||||
validation_errors: list[dict[str, str]],
|
||||
) -> tuple[list[dict[str, Any]], int, int]:
|
||||
"""Generate AI-powered fixes for tree validation errors.
|
||||
|
||||
Args:
|
||||
tree_structure: Full tree structure dict.
|
||||
tree_name: Name of the flow.
|
||||
tree_type: Type of flow (troubleshooting, procedural, maintenance).
|
||||
validation_errors: List of dicts with "node_id" and "message" keys.
|
||||
|
||||
Returns:
|
||||
Tuple of (fixes_list, total_input_tokens, total_output_tokens).
|
||||
Each fix dict has: target_node_id, error_message, description,
|
||||
original_node, fixed_node.
|
||||
"""
|
||||
provider = get_ai_provider()
|
||||
fixes: list[dict[str, Any]] = []
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
for error in validation_errors:
|
||||
node_id = error["node_id"]
|
||||
error_message = error["message"]
|
||||
|
||||
original_node = _find_node_by_id(tree_structure, node_id)
|
||||
if original_node is None:
|
||||
logger.warning("Node %s not found in tree, skipping fix", node_id)
|
||||
continue
|
||||
|
||||
original_snapshot = copy.deepcopy(original_node)
|
||||
|
||||
# Build prompt and call AI
|
||||
user_message = _build_fix_prompt(
|
||||
tree_structure, node_id, error_message, tree_name, tree_type
|
||||
)
|
||||
messages = [{"role": "user", "content": user_message}]
|
||||
|
||||
try:
|
||||
text, in_tok, out_tok = await provider.generate_json(
|
||||
system_prompt=FIX_SYSTEM_PROMPT,
|
||||
messages=messages,
|
||||
max_tokens=2048,
|
||||
)
|
||||
total_input_tokens += in_tok
|
||||
total_output_tokens += out_tok
|
||||
|
||||
cleaned = _strip_markdown_fences(text)
|
||||
fixed_node = json.loads(cleaned)
|
||||
except (json.JSONDecodeError, Exception) as exc:
|
||||
logger.warning("AI fix failed for node %s: %s", node_id, exc)
|
||||
continue
|
||||
|
||||
# Validate by substituting into a tree copy
|
||||
tree_copy = copy.deepcopy(tree_structure)
|
||||
_replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node))
|
||||
remaining_errors = validate_generated_tree(tree_copy)
|
||||
|
||||
# Check if the specific error is still present
|
||||
still_has_error = any(node_id in e for e in remaining_errors)
|
||||
|
||||
if still_has_error:
|
||||
# Retry once with corrective prompt
|
||||
retry_message = (
|
||||
f"Your previous fix still has validation errors:\n"
|
||||
f"{chr(10).join(remaining_errors)}\n\n"
|
||||
f"Please fix the node again. Return ONLY the corrected JSON."
|
||||
)
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
messages.append({"role": "user", "content": retry_message})
|
||||
|
||||
try:
|
||||
text2, in_tok2, out_tok2 = await provider.generate_json(
|
||||
system_prompt=FIX_SYSTEM_PROMPT,
|
||||
messages=messages,
|
||||
max_tokens=2048,
|
||||
)
|
||||
total_input_tokens += in_tok2
|
||||
total_output_tokens += out_tok2
|
||||
|
||||
cleaned2 = _strip_markdown_fences(text2)
|
||||
fixed_node = json.loads(cleaned2)
|
||||
|
||||
# Re-validate
|
||||
tree_copy2 = copy.deepcopy(tree_structure)
|
||||
_replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node))
|
||||
remaining2 = validate_generated_tree(tree_copy2)
|
||||
still_has_error = any(node_id in e for e in remaining2)
|
||||
except (json.JSONDecodeError, Exception) as exc:
|
||||
logger.warning("AI retry fix failed for node %s: %s", node_id, exc)
|
||||
continue
|
||||
|
||||
if still_has_error:
|
||||
logger.warning("AI could not fix node %s after retry", node_id)
|
||||
continue
|
||||
|
||||
description = _describe_fix(original_snapshot, fixed_node)
|
||||
fixes.append(
|
||||
{
|
||||
"target_node_id": node_id,
|
||||
"error_message": error_message,
|
||||
"description": description,
|
||||
"original_node": original_snapshot,
|
||||
"fixed_node": fixed_node,
|
||||
}
|
||||
)
|
||||
|
||||
return fixes, total_input_tokens, total_output_tokens
|
||||
224
backend/tests/test_ai_fix_service.py
Normal file
224
backend/tests/test_ai_fix_service.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Unit tests for AI fix service helper functions.
|
||||
|
||||
Tests pure Python helpers only — no AI mocking needed.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.ai_fix_service import (
|
||||
_find_node_by_id,
|
||||
_find_parent_node,
|
||||
_serialize_tree_outline,
|
||||
_strip_markdown_fences,
|
||||
_replace_node_in_tree,
|
||||
_describe_fix,
|
||||
)
|
||||
|
||||
|
||||
# ── Sample tree ──
|
||||
|
||||
SAMPLE_TREE = {
|
||||
"id": "root",
|
||||
"type": "decision",
|
||||
"question": "Is the server up?",
|
||||
"options": [
|
||||
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
|
||||
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
|
||||
],
|
||||
"children": [
|
||||
{
|
||||
"id": "check-logs",
|
||||
"type": "action",
|
||||
"title": "Check Logs",
|
||||
"description": "Review logs.",
|
||||
"next_node_id": "logs-ok",
|
||||
},
|
||||
{
|
||||
"id": "logs-ok",
|
||||
"type": "solution",
|
||||
"title": "Logs OK",
|
||||
"description": "Issue in logs.",
|
||||
},
|
||||
{
|
||||
"id": "restart",
|
||||
"type": "decision",
|
||||
"question": "Did restart work?",
|
||||
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
|
||||
"children": [
|
||||
{
|
||||
"id": "done",
|
||||
"type": "solution",
|
||||
"title": "Done",
|
||||
"description": "Fixed.",
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ── _find_node_by_id ──
|
||||
|
||||
|
||||
class TestFindNodeById:
|
||||
def test_finds_root(self):
|
||||
node = _find_node_by_id(SAMPLE_TREE, "root")
|
||||
assert node is not None
|
||||
assert node["id"] == "root"
|
||||
assert node["type"] == "decision"
|
||||
|
||||
def test_finds_nested_child(self):
|
||||
node = _find_node_by_id(SAMPLE_TREE, "done")
|
||||
assert node is not None
|
||||
assert node["id"] == "done"
|
||||
assert node["type"] == "solution"
|
||||
|
||||
def test_finds_direct_child(self):
|
||||
node = _find_node_by_id(SAMPLE_TREE, "check-logs")
|
||||
assert node is not None
|
||||
assert node["title"] == "Check Logs"
|
||||
|
||||
def test_returns_none_for_missing(self):
|
||||
node = _find_node_by_id(SAMPLE_TREE, "nonexistent")
|
||||
assert node is None
|
||||
|
||||
def test_returns_none_for_non_dict(self):
|
||||
assert _find_node_by_id("not a dict", "root") is None
|
||||
|
||||
|
||||
# ── _find_parent_node ──
|
||||
|
||||
|
||||
class TestFindParentNode:
|
||||
def test_root_has_no_parent(self):
|
||||
parent = _find_parent_node(SAMPLE_TREE, "root")
|
||||
assert parent is None
|
||||
|
||||
def test_finds_parent_of_direct_child(self):
|
||||
parent = _find_parent_node(SAMPLE_TREE, "check-logs")
|
||||
assert parent is not None
|
||||
assert parent["id"] == "root"
|
||||
|
||||
def test_finds_parent_of_deeply_nested(self):
|
||||
parent = _find_parent_node(SAMPLE_TREE, "done")
|
||||
assert parent is not None
|
||||
assert parent["id"] == "restart"
|
||||
|
||||
def test_returns_none_for_missing(self):
|
||||
parent = _find_parent_node(SAMPLE_TREE, "nonexistent")
|
||||
assert parent is None
|
||||
|
||||
def test_returns_none_for_non_dict(self):
|
||||
assert _find_parent_node("not a dict", "root") is None
|
||||
|
||||
|
||||
# ── _serialize_tree_outline ──
|
||||
|
||||
|
||||
class TestSerializeTreeOutline:
|
||||
def test_produces_readable_outline(self):
|
||||
outline = _serialize_tree_outline(SAMPLE_TREE)
|
||||
assert "- [decision] Is the server up?" in outline
|
||||
assert " - [action] Check Logs" in outline
|
||||
assert " - [solution] Logs OK" in outline
|
||||
assert " - [solution] Done" in outline
|
||||
|
||||
def test_marks_error_node(self):
|
||||
outline = _serialize_tree_outline(SAMPLE_TREE, error_node_id="restart")
|
||||
assert "<<< ERROR HERE" in outline
|
||||
# Only the restart node should be marked
|
||||
lines = outline.split("\n")
|
||||
error_lines = [l for l in lines if "ERROR HERE" in l]
|
||||
assert len(error_lines) == 1
|
||||
assert "Did restart work?" in error_lines[0]
|
||||
|
||||
def test_no_error_marker_when_none(self):
|
||||
outline = _serialize_tree_outline(SAMPLE_TREE)
|
||||
assert "ERROR HERE" not in outline
|
||||
|
||||
def test_handles_non_dict(self):
|
||||
assert _serialize_tree_outline("not a dict") == ""
|
||||
|
||||
def test_indentation_increases_with_depth(self):
|
||||
outline = _serialize_tree_outline(SAMPLE_TREE)
|
||||
lines = outline.split("\n")
|
||||
# Root has no indentation
|
||||
assert lines[0].startswith("- [decision]")
|
||||
# Children have 2-space indent
|
||||
child_lines = [l for l in lines if "Check Logs" in l]
|
||||
assert child_lines[0].startswith(" - ")
|
||||
|
||||
|
||||
# ── _strip_markdown_fences ──
|
||||
|
||||
|
||||
class TestStripMarkdownFences:
|
||||
def test_strips_json_fences(self):
|
||||
text = '```json\n{"key": "value"}\n```'
|
||||
assert _strip_markdown_fences(text) == '{"key": "value"}'
|
||||
|
||||
def test_strips_plain_fences(self):
|
||||
text = '```\n{"key": "value"}\n```'
|
||||
assert _strip_markdown_fences(text) == '{"key": "value"}'
|
||||
|
||||
def test_passes_through_plain_json(self):
|
||||
text = '{"key": "value"}'
|
||||
assert _strip_markdown_fences(text) == '{"key": "value"}'
|
||||
|
||||
|
||||
# ── _replace_node_in_tree ──
|
||||
|
||||
|
||||
class TestReplaceNodeInTree:
|
||||
def test_replaces_root(self):
|
||||
import copy
|
||||
|
||||
tree = copy.deepcopy(SAMPLE_TREE)
|
||||
replacement = {"id": "root", "type": "decision", "question": "New question"}
|
||||
assert _replace_node_in_tree(tree, "root", replacement) is True
|
||||
assert tree["question"] == "New question"
|
||||
assert "children" not in tree # cleared and replaced
|
||||
|
||||
def test_replaces_nested_node(self):
|
||||
import copy
|
||||
|
||||
tree = copy.deepcopy(SAMPLE_TREE)
|
||||
replacement = {"id": "done", "type": "solution", "title": "All Done", "description": "Complete."}
|
||||
assert _replace_node_in_tree(tree, "done", replacement) is True
|
||||
found = _find_node_by_id(tree, "done")
|
||||
assert found["title"] == "All Done"
|
||||
|
||||
def test_returns_false_for_missing(self):
|
||||
import copy
|
||||
|
||||
tree = copy.deepcopy(SAMPLE_TREE)
|
||||
assert _replace_node_in_tree(tree, "nonexistent", {"id": "x"}) is False
|
||||
|
||||
|
||||
# ── _describe_fix ──
|
||||
|
||||
|
||||
class TestDescribeFix:
|
||||
def test_describes_added_children(self):
|
||||
original = {"id": "n1", "children": [{"id": "c1"}]}
|
||||
fixed = {"id": "n1", "children": [{"id": "c1"}, {"id": "c2"}]}
|
||||
desc = _describe_fix(original, fixed)
|
||||
assert "1 child node" in desc
|
||||
|
||||
def test_describes_added_options(self):
|
||||
original = {"id": "n1", "options": [{"id": "o1"}]}
|
||||
fixed = {"id": "n1", "options": [{"id": "o1"}, {"id": "o2"}]}
|
||||
desc = _describe_fix(original, fixed)
|
||||
assert "1 option" in desc
|
||||
|
||||
def test_describes_added_next_node_id(self):
|
||||
original = {"id": "n1", "type": "action"}
|
||||
fixed = {"id": "n1", "type": "action", "next_node_id": "n2"}
|
||||
desc = _describe_fix(original, fixed)
|
||||
assert "next_node_id" in desc
|
||||
|
||||
def test_fallback_description(self):
|
||||
original = {"id": "n1", "type": "solution"}
|
||||
fixed = {"id": "n1", "type": "solution"}
|
||||
desc = _describe_fix(original, fixed)
|
||||
assert "fixed structural issue" in desc.lower()
|
||||
Reference in New Issue
Block a user