"""AI-powered fix service for tree validation errors. Given a tree structure and validation errors, generates AI-powered proposals to fix each structural issue while preserving existing content. """ import copy import json import logging import re from typing import Any from app.core.ai_provider import get_ai_provider from app.core.ai_tree_validator import validate_generated_tree logger = logging.getLogger(__name__) FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers. You will receive: 1. A full flow outline showing the tree structure 2. The specific failing node with its full JSON 3. The validation error message Your task: Return a FIXED version of the failing node as valid JSON. Rules: - Fix ONLY the structural issue described in the error message - Keep ALL existing content (titles, descriptions, questions, options) unchanged - When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic - Every new node must have a unique ID (use descriptive kebab-case IDs) - Decision nodes must have at least 2 options and at least 2 children - Action nodes must have a next_node_id pointing to a sibling node in the parent's children - Solution nodes are leaf nodes (no children) - Return ONLY the fixed node JSON, no explanation""" # ── Pure helper functions ── def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None: """Recursively find a node by its ID in the tree structure.""" if not isinstance(tree, dict): return None if tree.get("id") == node_id: return tree for child in tree.get("children", []): result = _find_node_by_id(child, node_id) if result is not None: return result return None def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None: """Find the parent of a node with the given ID.""" if not isinstance(tree, dict): return None for child in tree.get("children", []): if isinstance(child, dict) and child.get("id") == target_id: return tree result = _find_parent_node(child, target_id) if result is not None: return result return None def _serialize_tree_outline( tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None ) -> str: """Serialize tree as a readable outline for AI prompt context. Format: indented "- [type] label" with "<<< ERROR HERE" marker. """ if not isinstance(tree, dict): return "" node_type = tree.get("type", "unknown") label = tree.get("question") or tree.get("title") or tree.get("id", "?") prefix = " " * indent marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else "" line = f"{prefix}- [{node_type}] {label}{marker}" lines = [line] for child in tree.get("children", []): lines.append(_serialize_tree_outline(child, indent + 1, error_node_id)) return "\n".join(lines) from app.services.llm_utils import strip_markdown_fences as _strip_markdown_fences def _replace_node_in_tree( tree: dict[str, Any], target_id: str, replacement: dict[str, Any] ) -> bool: """Replace a node in-place by ID. Returns True if found and replaced.""" if not isinstance(tree, dict): return False if tree.get("id") == target_id: tree.clear() tree.update(replacement) return True for child in tree.get("children", []): if _replace_node_in_tree(child, target_id, replacement): return True return False def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str: """Describe what changed between original and fixed node.""" changes: list[str] = [] orig_children = len(original.get("children", [])) fixed_children = len(fixed.get("children", [])) if fixed_children > orig_children: changes.append(f"added {fixed_children - orig_children} child node(s)") orig_options = len(original.get("options", [])) fixed_options = len(fixed.get("options", [])) if fixed_options > orig_options: changes.append(f"added {fixed_options - orig_options} option(s)") if fixed.get("next_node_id") and not original.get("next_node_id"): changes.append("added next_node_id") if not changes: changes.append("fixed structural issue") return "; ".join(changes).capitalize() # ── Prompt building ── def _build_fix_prompt( tree: dict[str, Any], node_id: str, error_message: str, tree_name: str, tree_type: str, ) -> str: """Build the user message for the AI fix request.""" outline = _serialize_tree_outline(tree, error_node_id=node_id) node = _find_node_by_id(tree, node_id) node_json = json.dumps(node, indent=2) if node else "{}" return ( f"Flow name: {tree_name}\n" f"Flow type: {tree_type}\n\n" f"## Full flow outline\n```\n{outline}\n```\n\n" f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n" f"## Validation error\n{error_message}\n\n" f"Return the fixed version of this node as JSON." ) # ── Main entry point ── async def generate_fixes( tree_structure: dict[str, Any], tree_name: str, tree_type: str, validation_errors: list[dict[str, str]], ) -> tuple[list[dict[str, Any]], int, int]: """Generate AI-powered fixes for tree validation errors. Args: tree_structure: Full tree structure dict. tree_name: Name of the flow. tree_type: Type of flow (troubleshooting, procedural, maintenance). validation_errors: List of dicts with "node_id" and "message" keys. Returns: Tuple of (fixes_list, total_input_tokens, total_output_tokens). Each fix dict has: target_node_id, error_message, description, original_node, fixed_node. """ provider = get_ai_provider() fixes: list[dict[str, Any]] = [] total_input_tokens = 0 total_output_tokens = 0 for error in validation_errors: node_id = error["node_id"] error_message = error["message"] original_node = _find_node_by_id(tree_structure, node_id) if original_node is None: logger.warning("Node %s not found in tree, skipping fix", node_id) continue original_snapshot = copy.deepcopy(original_node) # Build prompt and call AI user_message = _build_fix_prompt( tree_structure, node_id, error_message, tree_name, tree_type ) messages = [{"role": "user", "content": user_message}] try: text, in_tok, out_tok = await provider.generate_json( system_prompt=[ {"type": "text", "text": FIX_SYSTEM_PROMPT}, # cacheable: stable constant across all fix attempts ], messages=messages, max_tokens=2048, ) total_input_tokens += in_tok total_output_tokens += out_tok cleaned = _strip_markdown_fences(text) fixed_node = json.loads(cleaned) except (json.JSONDecodeError, Exception) as exc: logger.warning("AI fix failed for node %s: %s", node_id, exc) continue # Validate by substituting into a tree copy tree_copy = copy.deepcopy(tree_structure) _replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node)) remaining_errors = validate_generated_tree(tree_copy) # Check if the specific error is still present still_has_error = any(node_id in e for e in remaining_errors) if still_has_error: # Retry once with corrective prompt retry_message = ( f"Your previous fix still has validation errors:\n" f"{chr(10).join(remaining_errors)}\n\n" f"Please fix the node again. Return ONLY the corrected JSON." ) messages.append({"role": "assistant", "content": text}) messages.append({"role": "user", "content": retry_message}) try: text2, in_tok2, out_tok2 = await provider.generate_json( system_prompt=[ {"type": "text", "text": FIX_SYSTEM_PROMPT}, # cacheable: stable constant; retry reads the cached # system block from the first attempt above ], messages=messages, max_tokens=2048, ) total_input_tokens += in_tok2 total_output_tokens += out_tok2 cleaned2 = _strip_markdown_fences(text2) fixed_node = json.loads(cleaned2) # Re-validate tree_copy2 = copy.deepcopy(tree_structure) _replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node)) remaining2 = validate_generated_tree(tree_copy2) still_has_error = any(node_id in e for e in remaining2) except (json.JSONDecodeError, Exception) as exc: logger.warning("AI retry fix failed for node %s: %s", node_id, exc) continue if still_has_error: logger.warning("AI could not fix node %s after retry", node_id) continue description = _describe_fix(original_snapshot, fixed_node) fixes.append( { "target_node_id": node_id, "error_message": error_message, "description": description, "original_node": original_snapshot, "fixed_node": fixed_node, } ) return fixes, total_input_tokens, total_output_tokens