diff --git a/backend/app/core/ai_fix_service.py b/backend/app/core/ai_fix_service.py new file mode 100644 index 00000000..02350a15 --- /dev/null +++ b/backend/app/core/ai_fix_service.py @@ -0,0 +1,273 @@ +"""AI-powered fix service for tree validation errors. + +Given a tree structure and validation errors, generates AI-powered +proposals to fix each structural issue while preserving existing content. +""" + +import copy +import json +import logging +import re +from typing import Any + +from app.core.ai_provider import get_ai_provider +from app.core.ai_tree_validator import validate_generated_tree + +logger = logging.getLogger(__name__) + + +FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers. + +You will receive: +1. A full flow outline showing the tree structure +2. The specific failing node with its full JSON +3. The validation error message + +Your task: Return a FIXED version of the failing node as valid JSON. Rules: +- Fix ONLY the structural issue described in the error message +- Keep ALL existing content (titles, descriptions, questions, options) unchanged +- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic +- Every new node must have a unique ID (use descriptive kebab-case IDs) +- Decision nodes must have at least 2 options and at least 2 children +- Action nodes must have a next_node_id pointing to a sibling node in the parent's children +- Solution nodes are leaf nodes (no children) +- Return ONLY the fixed node JSON, no explanation""" + + +# ── Pure helper functions ── + + +def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None: + """Recursively find a node by its ID in the tree structure.""" + if not isinstance(tree, dict): + return None + if tree.get("id") == node_id: + return tree + for child in tree.get("children", []): + result = _find_node_by_id(child, node_id) + if result is not None: + return result + return None + + +def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None: + """Find the parent of a node with the given ID.""" + if not isinstance(tree, dict): + return None + for child in tree.get("children", []): + if isinstance(child, dict) and child.get("id") == target_id: + return tree + result = _find_parent_node(child, target_id) + if result is not None: + return result + return None + + +def _serialize_tree_outline( + tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None +) -> str: + """Serialize tree as a readable outline for AI prompt context. + + Format: indented "- [type] label" with "<<< ERROR HERE" marker. + """ + if not isinstance(tree, dict): + return "" + + node_type = tree.get("type", "unknown") + label = tree.get("question") or tree.get("title") or tree.get("id", "?") + prefix = " " * indent + marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else "" + line = f"{prefix}- [{node_type}] {label}{marker}" + + lines = [line] + for child in tree.get("children", []): + lines.append(_serialize_tree_outline(child, indent + 1, error_node_id)) + + return "\n".join(lines) + + +def _strip_markdown_fences(text: str) -> str: + """Strip ```json...``` fences from AI response.""" + return re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.MULTILINE).rstrip( + "`" + ).strip() + + +def _replace_node_in_tree( + tree: dict[str, Any], target_id: str, replacement: dict[str, Any] +) -> bool: + """Replace a node in-place by ID. Returns True if found and replaced.""" + if not isinstance(tree, dict): + return False + if tree.get("id") == target_id: + tree.clear() + tree.update(replacement) + return True + for child in tree.get("children", []): + if _replace_node_in_tree(child, target_id, replacement): + return True + return False + + +def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str: + """Describe what changed between original and fixed node.""" + changes: list[str] = [] + + orig_children = len(original.get("children", [])) + fixed_children = len(fixed.get("children", [])) + if fixed_children > orig_children: + changes.append(f"added {fixed_children - orig_children} child node(s)") + + orig_options = len(original.get("options", [])) + fixed_options = len(fixed.get("options", [])) + if fixed_options > orig_options: + changes.append(f"added {fixed_options - orig_options} option(s)") + + if fixed.get("next_node_id") and not original.get("next_node_id"): + changes.append("added next_node_id") + + if not changes: + changes.append("fixed structural issue") + + return "; ".join(changes).capitalize() + + +# ── Prompt building ── + + +def _build_fix_prompt( + tree: dict[str, Any], + node_id: str, + error_message: str, + tree_name: str, + tree_type: str, +) -> str: + """Build the user message for the AI fix request.""" + outline = _serialize_tree_outline(tree, error_node_id=node_id) + node = _find_node_by_id(tree, node_id) + node_json = json.dumps(node, indent=2) if node else "{}" + + return ( + f"Flow name: {tree_name}\n" + f"Flow type: {tree_type}\n\n" + f"## Full flow outline\n```\n{outline}\n```\n\n" + f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n" + f"## Validation error\n{error_message}\n\n" + f"Return the fixed version of this node as JSON." + ) + + +# ── Main entry point ── + + +async def generate_fixes( + tree_structure: dict[str, Any], + tree_name: str, + tree_type: str, + validation_errors: list[dict[str, str]], +) -> tuple[list[dict[str, Any]], int, int]: + """Generate AI-powered fixes for tree validation errors. + + Args: + tree_structure: Full tree structure dict. + tree_name: Name of the flow. + tree_type: Type of flow (troubleshooting, procedural, maintenance). + validation_errors: List of dicts with "node_id" and "message" keys. + + Returns: + Tuple of (fixes_list, total_input_tokens, total_output_tokens). + Each fix dict has: target_node_id, error_message, description, + original_node, fixed_node. + """ + provider = get_ai_provider() + fixes: list[dict[str, Any]] = [] + total_input_tokens = 0 + total_output_tokens = 0 + + for error in validation_errors: + node_id = error["node_id"] + error_message = error["message"] + + original_node = _find_node_by_id(tree_structure, node_id) + if original_node is None: + logger.warning("Node %s not found in tree, skipping fix", node_id) + continue + + original_snapshot = copy.deepcopy(original_node) + + # Build prompt and call AI + user_message = _build_fix_prompt( + tree_structure, node_id, error_message, tree_name, tree_type + ) + messages = [{"role": "user", "content": user_message}] + + try: + text, in_tok, out_tok = await provider.generate_json( + system_prompt=FIX_SYSTEM_PROMPT, + messages=messages, + max_tokens=2048, + ) + total_input_tokens += in_tok + total_output_tokens += out_tok + + cleaned = _strip_markdown_fences(text) + fixed_node = json.loads(cleaned) + except (json.JSONDecodeError, Exception) as exc: + logger.warning("AI fix failed for node %s: %s", node_id, exc) + continue + + # Validate by substituting into a tree copy + tree_copy = copy.deepcopy(tree_structure) + _replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node)) + remaining_errors = validate_generated_tree(tree_copy) + + # Check if the specific error is still present + still_has_error = any(node_id in e for e in remaining_errors) + + if still_has_error: + # Retry once with corrective prompt + retry_message = ( + f"Your previous fix still has validation errors:\n" + f"{chr(10).join(remaining_errors)}\n\n" + f"Please fix the node again. Return ONLY the corrected JSON." + ) + messages.append({"role": "assistant", "content": text}) + messages.append({"role": "user", "content": retry_message}) + + try: + text2, in_tok2, out_tok2 = await provider.generate_json( + system_prompt=FIX_SYSTEM_PROMPT, + messages=messages, + max_tokens=2048, + ) + total_input_tokens += in_tok2 + total_output_tokens += out_tok2 + + cleaned2 = _strip_markdown_fences(text2) + fixed_node = json.loads(cleaned2) + + # Re-validate + tree_copy2 = copy.deepcopy(tree_structure) + _replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node)) + remaining2 = validate_generated_tree(tree_copy2) + still_has_error = any(node_id in e for e in remaining2) + except (json.JSONDecodeError, Exception) as exc: + logger.warning("AI retry fix failed for node %s: %s", node_id, exc) + continue + + if still_has_error: + logger.warning("AI could not fix node %s after retry", node_id) + continue + + description = _describe_fix(original_snapshot, fixed_node) + fixes.append( + { + "target_node_id": node_id, + "error_message": error_message, + "description": description, + "original_node": original_snapshot, + "fixed_node": fixed_node, + } + ) + + return fixes, total_input_tokens, total_output_tokens diff --git a/backend/tests/test_ai_fix_service.py b/backend/tests/test_ai_fix_service.py new file mode 100644 index 00000000..24410721 --- /dev/null +++ b/backend/tests/test_ai_fix_service.py @@ -0,0 +1,224 @@ +"""Unit tests for AI fix service helper functions. + +Tests pure Python helpers only — no AI mocking needed. +""" + +import pytest + +from app.core.ai_fix_service import ( + _find_node_by_id, + _find_parent_node, + _serialize_tree_outline, + _strip_markdown_fences, + _replace_node_in_tree, + _describe_fix, +) + + +# ── Sample tree ── + +SAMPLE_TREE = { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs OK", + "description": "Issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}], + "children": [ + { + "id": "done", + "type": "solution", + "title": "Done", + "description": "Fixed.", + } + ], + }, + ], +} + + +# ── _find_node_by_id ── + + +class TestFindNodeById: + def test_finds_root(self): + node = _find_node_by_id(SAMPLE_TREE, "root") + assert node is not None + assert node["id"] == "root" + assert node["type"] == "decision" + + def test_finds_nested_child(self): + node = _find_node_by_id(SAMPLE_TREE, "done") + assert node is not None + assert node["id"] == "done" + assert node["type"] == "solution" + + def test_finds_direct_child(self): + node = _find_node_by_id(SAMPLE_TREE, "check-logs") + assert node is not None + assert node["title"] == "Check Logs" + + def test_returns_none_for_missing(self): + node = _find_node_by_id(SAMPLE_TREE, "nonexistent") + assert node is None + + def test_returns_none_for_non_dict(self): + assert _find_node_by_id("not a dict", "root") is None + + +# ── _find_parent_node ── + + +class TestFindParentNode: + def test_root_has_no_parent(self): + parent = _find_parent_node(SAMPLE_TREE, "root") + assert parent is None + + def test_finds_parent_of_direct_child(self): + parent = _find_parent_node(SAMPLE_TREE, "check-logs") + assert parent is not None + assert parent["id"] == "root" + + def test_finds_parent_of_deeply_nested(self): + parent = _find_parent_node(SAMPLE_TREE, "done") + assert parent is not None + assert parent["id"] == "restart" + + def test_returns_none_for_missing(self): + parent = _find_parent_node(SAMPLE_TREE, "nonexistent") + assert parent is None + + def test_returns_none_for_non_dict(self): + assert _find_parent_node("not a dict", "root") is None + + +# ── _serialize_tree_outline ── + + +class TestSerializeTreeOutline: + def test_produces_readable_outline(self): + outline = _serialize_tree_outline(SAMPLE_TREE) + assert "- [decision] Is the server up?" in outline + assert " - [action] Check Logs" in outline + assert " - [solution] Logs OK" in outline + assert " - [solution] Done" in outline + + def test_marks_error_node(self): + outline = _serialize_tree_outline(SAMPLE_TREE, error_node_id="restart") + assert "<<< ERROR HERE" in outline + # Only the restart node should be marked + lines = outline.split("\n") + error_lines = [l for l in lines if "ERROR HERE" in l] + assert len(error_lines) == 1 + assert "Did restart work?" in error_lines[0] + + def test_no_error_marker_when_none(self): + outline = _serialize_tree_outline(SAMPLE_TREE) + assert "ERROR HERE" not in outline + + def test_handles_non_dict(self): + assert _serialize_tree_outline("not a dict") == "" + + def test_indentation_increases_with_depth(self): + outline = _serialize_tree_outline(SAMPLE_TREE) + lines = outline.split("\n") + # Root has no indentation + assert lines[0].startswith("- [decision]") + # Children have 2-space indent + child_lines = [l for l in lines if "Check Logs" in l] + assert child_lines[0].startswith(" - ") + + +# ── _strip_markdown_fences ── + + +class TestStripMarkdownFences: + def test_strips_json_fences(self): + text = '```json\n{"key": "value"}\n```' + assert _strip_markdown_fences(text) == '{"key": "value"}' + + def test_strips_plain_fences(self): + text = '```\n{"key": "value"}\n```' + assert _strip_markdown_fences(text) == '{"key": "value"}' + + def test_passes_through_plain_json(self): + text = '{"key": "value"}' + assert _strip_markdown_fences(text) == '{"key": "value"}' + + +# ── _replace_node_in_tree ── + + +class TestReplaceNodeInTree: + def test_replaces_root(self): + import copy + + tree = copy.deepcopy(SAMPLE_TREE) + replacement = {"id": "root", "type": "decision", "question": "New question"} + assert _replace_node_in_tree(tree, "root", replacement) is True + assert tree["question"] == "New question" + assert "children" not in tree # cleared and replaced + + def test_replaces_nested_node(self): + import copy + + tree = copy.deepcopy(SAMPLE_TREE) + replacement = {"id": "done", "type": "solution", "title": "All Done", "description": "Complete."} + assert _replace_node_in_tree(tree, "done", replacement) is True + found = _find_node_by_id(tree, "done") + assert found["title"] == "All Done" + + def test_returns_false_for_missing(self): + import copy + + tree = copy.deepcopy(SAMPLE_TREE) + assert _replace_node_in_tree(tree, "nonexistent", {"id": "x"}) is False + + +# ── _describe_fix ── + + +class TestDescribeFix: + def test_describes_added_children(self): + original = {"id": "n1", "children": [{"id": "c1"}]} + fixed = {"id": "n1", "children": [{"id": "c1"}, {"id": "c2"}]} + desc = _describe_fix(original, fixed) + assert "1 child node" in desc + + def test_describes_added_options(self): + original = {"id": "n1", "options": [{"id": "o1"}]} + fixed = {"id": "n1", "options": [{"id": "o1"}, {"id": "o2"}]} + desc = _describe_fix(original, fixed) + assert "1 option" in desc + + def test_describes_added_next_node_id(self): + original = {"id": "n1", "type": "action"} + fixed = {"id": "n1", "type": "action", "next_node_id": "n2"} + desc = _describe_fix(original, fixed) + assert "next_node_id" in desc + + def test_fallback_description(self): + original = {"id": "n1", "type": "solution"} + fixed = {"id": "n1", "type": "solution"} + desc = _describe_fix(original, fixed) + assert "fixed structural issue" in desc.lower()