feat: add AI fix service with prompt building and validation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-02-26 17:25:34 -05:00
parent 5f8653e481
commit 373736c594
2 changed files with 497 additions and 0 deletions

View File

@@ -0,0 +1,273 @@
"""AI-powered fix service for tree validation errors.
Given a tree structure and validation errors, generates AI-powered
proposals to fix each structural issue while preserving existing content.
"""
import copy
import json
import logging
import re
from typing import Any
from app.core.ai_provider import get_ai_provider
from app.core.ai_tree_validator import validate_generated_tree
logger = logging.getLogger(__name__)
FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers.
You will receive:
1. A full flow outline showing the tree structure
2. The specific failing node with its full JSON
3. The validation error message
Your task: Return a FIXED version of the failing node as valid JSON. Rules:
- Fix ONLY the structural issue described in the error message
- Keep ALL existing content (titles, descriptions, questions, options) unchanged
- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic
- Every new node must have a unique ID (use descriptive kebab-case IDs)
- Decision nodes must have at least 2 options and at least 2 children
- Action nodes must have a next_node_id pointing to a sibling node in the parent's children
- Solution nodes are leaf nodes (no children)
- Return ONLY the fixed node JSON, no explanation"""
# ── Pure helper functions ──
def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None:
"""Recursively find a node by its ID in the tree structure."""
if not isinstance(tree, dict):
return None
if tree.get("id") == node_id:
return tree
for child in tree.get("children", []):
result = _find_node_by_id(child, node_id)
if result is not None:
return result
return None
def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None:
"""Find the parent of a node with the given ID."""
if not isinstance(tree, dict):
return None
for child in tree.get("children", []):
if isinstance(child, dict) and child.get("id") == target_id:
return tree
result = _find_parent_node(child, target_id)
if result is not None:
return result
return None
def _serialize_tree_outline(
tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None
) -> str:
"""Serialize tree as a readable outline for AI prompt context.
Format: indented "- [type] label" with "<<< ERROR HERE" marker.
"""
if not isinstance(tree, dict):
return ""
node_type = tree.get("type", "unknown")
label = tree.get("question") or tree.get("title") or tree.get("id", "?")
prefix = " " * indent
marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else ""
line = f"{prefix}- [{node_type}] {label}{marker}"
lines = [line]
for child in tree.get("children", []):
lines.append(_serialize_tree_outline(child, indent + 1, error_node_id))
return "\n".join(lines)
def _strip_markdown_fences(text: str) -> str:
"""Strip ```json...``` fences from AI response."""
return re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.MULTILINE).rstrip(
"`"
).strip()
def _replace_node_in_tree(
tree: dict[str, Any], target_id: str, replacement: dict[str, Any]
) -> bool:
"""Replace a node in-place by ID. Returns True if found and replaced."""
if not isinstance(tree, dict):
return False
if tree.get("id") == target_id:
tree.clear()
tree.update(replacement)
return True
for child in tree.get("children", []):
if _replace_node_in_tree(child, target_id, replacement):
return True
return False
def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str:
"""Describe what changed between original and fixed node."""
changes: list[str] = []
orig_children = len(original.get("children", []))
fixed_children = len(fixed.get("children", []))
if fixed_children > orig_children:
changes.append(f"added {fixed_children - orig_children} child node(s)")
orig_options = len(original.get("options", []))
fixed_options = len(fixed.get("options", []))
if fixed_options > orig_options:
changes.append(f"added {fixed_options - orig_options} option(s)")
if fixed.get("next_node_id") and not original.get("next_node_id"):
changes.append("added next_node_id")
if not changes:
changes.append("fixed structural issue")
return "; ".join(changes).capitalize()
# ── Prompt building ──
def _build_fix_prompt(
tree: dict[str, Any],
node_id: str,
error_message: str,
tree_name: str,
tree_type: str,
) -> str:
"""Build the user message for the AI fix request."""
outline = _serialize_tree_outline(tree, error_node_id=node_id)
node = _find_node_by_id(tree, node_id)
node_json = json.dumps(node, indent=2) if node else "{}"
return (
f"Flow name: {tree_name}\n"
f"Flow type: {tree_type}\n\n"
f"## Full flow outline\n```\n{outline}\n```\n\n"
f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n"
f"## Validation error\n{error_message}\n\n"
f"Return the fixed version of this node as JSON."
)
# ── Main entry point ──
async def generate_fixes(
tree_structure: dict[str, Any],
tree_name: str,
tree_type: str,
validation_errors: list[dict[str, str]],
) -> tuple[list[dict[str, Any]], int, int]:
"""Generate AI-powered fixes for tree validation errors.
Args:
tree_structure: Full tree structure dict.
tree_name: Name of the flow.
tree_type: Type of flow (troubleshooting, procedural, maintenance).
validation_errors: List of dicts with "node_id" and "message" keys.
Returns:
Tuple of (fixes_list, total_input_tokens, total_output_tokens).
Each fix dict has: target_node_id, error_message, description,
original_node, fixed_node.
"""
provider = get_ai_provider()
fixes: list[dict[str, Any]] = []
total_input_tokens = 0
total_output_tokens = 0
for error in validation_errors:
node_id = error["node_id"]
error_message = error["message"]
original_node = _find_node_by_id(tree_structure, node_id)
if original_node is None:
logger.warning("Node %s not found in tree, skipping fix", node_id)
continue
original_snapshot = copy.deepcopy(original_node)
# Build prompt and call AI
user_message = _build_fix_prompt(
tree_structure, node_id, error_message, tree_name, tree_type
)
messages = [{"role": "user", "content": user_message}]
try:
text, in_tok, out_tok = await provider.generate_json(
system_prompt=FIX_SYSTEM_PROMPT,
messages=messages,
max_tokens=2048,
)
total_input_tokens += in_tok
total_output_tokens += out_tok
cleaned = _strip_markdown_fences(text)
fixed_node = json.loads(cleaned)
except (json.JSONDecodeError, Exception) as exc:
logger.warning("AI fix failed for node %s: %s", node_id, exc)
continue
# Validate by substituting into a tree copy
tree_copy = copy.deepcopy(tree_structure)
_replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node))
remaining_errors = validate_generated_tree(tree_copy)
# Check if the specific error is still present
still_has_error = any(node_id in e for e in remaining_errors)
if still_has_error:
# Retry once with corrective prompt
retry_message = (
f"Your previous fix still has validation errors:\n"
f"{chr(10).join(remaining_errors)}\n\n"
f"Please fix the node again. Return ONLY the corrected JSON."
)
messages.append({"role": "assistant", "content": text})
messages.append({"role": "user", "content": retry_message})
try:
text2, in_tok2, out_tok2 = await provider.generate_json(
system_prompt=FIX_SYSTEM_PROMPT,
messages=messages,
max_tokens=2048,
)
total_input_tokens += in_tok2
total_output_tokens += out_tok2
cleaned2 = _strip_markdown_fences(text2)
fixed_node = json.loads(cleaned2)
# Re-validate
tree_copy2 = copy.deepcopy(tree_structure)
_replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node))
remaining2 = validate_generated_tree(tree_copy2)
still_has_error = any(node_id in e for e in remaining2)
except (json.JSONDecodeError, Exception) as exc:
logger.warning("AI retry fix failed for node %s: %s", node_id, exc)
continue
if still_has_error:
logger.warning("AI could not fix node %s after retry", node_id)
continue
description = _describe_fix(original_snapshot, fixed_node)
fixes.append(
{
"target_node_id": node_id,
"error_message": error_message,
"description": description,
"original_node": original_snapshot,
"fixed_node": fixed_node,
}
)
return fixes, total_input_tokens, total_output_tokens

View File

@@ -0,0 +1,224 @@
"""Unit tests for AI fix service helper functions.
Tests pure Python helpers only — no AI mocking needed.
"""
import pytest
from app.core.ai_fix_service import (
_find_node_by_id,
_find_parent_node,
_serialize_tree_outline,
_strip_markdown_fences,
_replace_node_in_tree,
_describe_fix,
)
# ── Sample tree ──
SAMPLE_TREE = {
"id": "root",
"type": "decision",
"question": "Is the server up?",
"options": [
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
],
"children": [
{
"id": "check-logs",
"type": "action",
"title": "Check Logs",
"description": "Review logs.",
"next_node_id": "logs-ok",
},
{
"id": "logs-ok",
"type": "solution",
"title": "Logs OK",
"description": "Issue in logs.",
},
{
"id": "restart",
"type": "decision",
"question": "Did restart work?",
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
"children": [
{
"id": "done",
"type": "solution",
"title": "Done",
"description": "Fixed.",
}
],
},
],
}
# ── _find_node_by_id ──
class TestFindNodeById:
def test_finds_root(self):
node = _find_node_by_id(SAMPLE_TREE, "root")
assert node is not None
assert node["id"] == "root"
assert node["type"] == "decision"
def test_finds_nested_child(self):
node = _find_node_by_id(SAMPLE_TREE, "done")
assert node is not None
assert node["id"] == "done"
assert node["type"] == "solution"
def test_finds_direct_child(self):
node = _find_node_by_id(SAMPLE_TREE, "check-logs")
assert node is not None
assert node["title"] == "Check Logs"
def test_returns_none_for_missing(self):
node = _find_node_by_id(SAMPLE_TREE, "nonexistent")
assert node is None
def test_returns_none_for_non_dict(self):
assert _find_node_by_id("not a dict", "root") is None
# ── _find_parent_node ──
class TestFindParentNode:
def test_root_has_no_parent(self):
parent = _find_parent_node(SAMPLE_TREE, "root")
assert parent is None
def test_finds_parent_of_direct_child(self):
parent = _find_parent_node(SAMPLE_TREE, "check-logs")
assert parent is not None
assert parent["id"] == "root"
def test_finds_parent_of_deeply_nested(self):
parent = _find_parent_node(SAMPLE_TREE, "done")
assert parent is not None
assert parent["id"] == "restart"
def test_returns_none_for_missing(self):
parent = _find_parent_node(SAMPLE_TREE, "nonexistent")
assert parent is None
def test_returns_none_for_non_dict(self):
assert _find_parent_node("not a dict", "root") is None
# ── _serialize_tree_outline ──
class TestSerializeTreeOutline:
def test_produces_readable_outline(self):
outline = _serialize_tree_outline(SAMPLE_TREE)
assert "- [decision] Is the server up?" in outline
assert " - [action] Check Logs" in outline
assert " - [solution] Logs OK" in outline
assert " - [solution] Done" in outline
def test_marks_error_node(self):
outline = _serialize_tree_outline(SAMPLE_TREE, error_node_id="restart")
assert "<<< ERROR HERE" in outline
# Only the restart node should be marked
lines = outline.split("\n")
error_lines = [l for l in lines if "ERROR HERE" in l]
assert len(error_lines) == 1
assert "Did restart work?" in error_lines[0]
def test_no_error_marker_when_none(self):
outline = _serialize_tree_outline(SAMPLE_TREE)
assert "ERROR HERE" not in outline
def test_handles_non_dict(self):
assert _serialize_tree_outline("not a dict") == ""
def test_indentation_increases_with_depth(self):
outline = _serialize_tree_outline(SAMPLE_TREE)
lines = outline.split("\n")
# Root has no indentation
assert lines[0].startswith("- [decision]")
# Children have 2-space indent
child_lines = [l for l in lines if "Check Logs" in l]
assert child_lines[0].startswith(" - ")
# ── _strip_markdown_fences ──
class TestStripMarkdownFences:
def test_strips_json_fences(self):
text = '```json\n{"key": "value"}\n```'
assert _strip_markdown_fences(text) == '{"key": "value"}'
def test_strips_plain_fences(self):
text = '```\n{"key": "value"}\n```'
assert _strip_markdown_fences(text) == '{"key": "value"}'
def test_passes_through_plain_json(self):
text = '{"key": "value"}'
assert _strip_markdown_fences(text) == '{"key": "value"}'
# ── _replace_node_in_tree ──
class TestReplaceNodeInTree:
def test_replaces_root(self):
import copy
tree = copy.deepcopy(SAMPLE_TREE)
replacement = {"id": "root", "type": "decision", "question": "New question"}
assert _replace_node_in_tree(tree, "root", replacement) is True
assert tree["question"] == "New question"
assert "children" not in tree # cleared and replaced
def test_replaces_nested_node(self):
import copy
tree = copy.deepcopy(SAMPLE_TREE)
replacement = {"id": "done", "type": "solution", "title": "All Done", "description": "Complete."}
assert _replace_node_in_tree(tree, "done", replacement) is True
found = _find_node_by_id(tree, "done")
assert found["title"] == "All Done"
def test_returns_false_for_missing(self):
import copy
tree = copy.deepcopy(SAMPLE_TREE)
assert _replace_node_in_tree(tree, "nonexistent", {"id": "x"}) is False
# ── _describe_fix ──
class TestDescribeFix:
def test_describes_added_children(self):
original = {"id": "n1", "children": [{"id": "c1"}]}
fixed = {"id": "n1", "children": [{"id": "c1"}, {"id": "c2"}]}
desc = _describe_fix(original, fixed)
assert "1 child node" in desc
def test_describes_added_options(self):
original = {"id": "n1", "options": [{"id": "o1"}]}
fixed = {"id": "n1", "options": [{"id": "o1"}, {"id": "o2"}]}
desc = _describe_fix(original, fixed)
assert "1 option" in desc
def test_describes_added_next_node_id(self):
original = {"id": "n1", "type": "action"}
fixed = {"id": "n1", "type": "action", "next_node_id": "n2"}
desc = _describe_fix(original, fixed)
assert "next_node_id" in desc
def test_fallback_description(self):
original = {"id": "n1", "type": "solution"}
fixed = {"id": "n1", "type": "solution"}
desc = _describe_fix(original, fixed)
assert "fixed structural issue" in desc.lower()