Files
resolutionflow/backend/app/core/ai_fix_service.py
Michael Chihlas da93ae55c3 feat(ai): opt-in structured-system-block caching for one-shot generators (Phase 0.3)
Wraps each static system prompt in a single-block list so Phase 0.1's
AnthropicProvider applies cache_control: ephemeral automatically (policy α,
first block gets marked when no caller-authored cache_control is present).

Call sites:
- ai_tree_generator.scaffold_branches: SCAFFOLD_SYSTEM_PROMPT (~1k tokens)
- ai_tree_generator.generate_branch_detail: BRANCH_DETAIL_SYSTEM_PROMPT
  (~2.5k tokens with few-shot example); retries inside the same function
  re-read the cached block instead of paying full input cost on each attempt
- kb_conversion.convert_document: TROUBLESHOOTING or PROCEDURAL prompt
  (each caches independently by text content)
- ai_fix.generate_fixes: FIX_SYSTEM_PROMPT on first attempt + corrective retry
- script_builder.send_message: SYSTEM_PROMPT_TEMPLATE (per-session language
  substitution — same-language sessions share cache entries)

Each edit includes an inline comment explaining why the block is cacheable
(stable-constant, retry-reuse, per-language variant) so a future dev can
see the intent at the cache_control marker site.

script_builder history caching deliberately deferred — per Phase 0.1
decision (option i), AnthropicProvider does not automatically cache the
message list. If script_builder's growing 20-message history turns out
to be a visible cost driver via the anthropic.cache telemetry, route
that caller through the 0.4 chat wrapper which handles history caching.

No runtime verification from code-server; cache-hit behavior will be
confirmed against the new dev environment when it's up, per the inline
TODO(phase0-verify) in ai_provider.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-17 16:29:45 +00:00

277 lines
9.7 KiB
Python

"""AI-powered fix service for tree validation errors.
Given a tree structure and validation errors, generates AI-powered
proposals to fix each structural issue while preserving existing content.
"""
import copy
import json
import logging
import re
from typing import Any
from app.core.ai_provider import get_ai_provider
from app.core.ai_tree_validator import validate_generated_tree
logger = logging.getLogger(__name__)
FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers.
You will receive:
1. A full flow outline showing the tree structure
2. The specific failing node with its full JSON
3. The validation error message
Your task: Return a FIXED version of the failing node as valid JSON. Rules:
- Fix ONLY the structural issue described in the error message
- Keep ALL existing content (titles, descriptions, questions, options) unchanged
- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic
- Every new node must have a unique ID (use descriptive kebab-case IDs)
- Decision nodes must have at least 2 options and at least 2 children
- Action nodes must have a next_node_id pointing to a sibling node in the parent's children
- Solution nodes are leaf nodes (no children)
- Return ONLY the fixed node JSON, no explanation"""
# ── Pure helper functions ──
def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None:
"""Recursively find a node by its ID in the tree structure."""
if not isinstance(tree, dict):
return None
if tree.get("id") == node_id:
return tree
for child in tree.get("children", []):
result = _find_node_by_id(child, node_id)
if result is not None:
return result
return None
def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None:
"""Find the parent of a node with the given ID."""
if not isinstance(tree, dict):
return None
for child in tree.get("children", []):
if isinstance(child, dict) and child.get("id") == target_id:
return tree
result = _find_parent_node(child, target_id)
if result is not None:
return result
return None
def _serialize_tree_outline(
tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None
) -> str:
"""Serialize tree as a readable outline for AI prompt context.
Format: indented "- [type] label" with "<<< ERROR HERE" marker.
"""
if not isinstance(tree, dict):
return ""
node_type = tree.get("type", "unknown")
label = tree.get("question") or tree.get("title") or tree.get("id", "?")
prefix = " " * indent
marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else ""
line = f"{prefix}- [{node_type}] {label}{marker}"
lines = [line]
for child in tree.get("children", []):
lines.append(_serialize_tree_outline(child, indent + 1, error_node_id))
return "\n".join(lines)
from app.services.llm_utils import strip_markdown_fences as _strip_markdown_fences
def _replace_node_in_tree(
tree: dict[str, Any], target_id: str, replacement: dict[str, Any]
) -> bool:
"""Replace a node in-place by ID. Returns True if found and replaced."""
if not isinstance(tree, dict):
return False
if tree.get("id") == target_id:
tree.clear()
tree.update(replacement)
return True
for child in tree.get("children", []):
if _replace_node_in_tree(child, target_id, replacement):
return True
return False
def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str:
"""Describe what changed between original and fixed node."""
changes: list[str] = []
orig_children = len(original.get("children", []))
fixed_children = len(fixed.get("children", []))
if fixed_children > orig_children:
changes.append(f"added {fixed_children - orig_children} child node(s)")
orig_options = len(original.get("options", []))
fixed_options = len(fixed.get("options", []))
if fixed_options > orig_options:
changes.append(f"added {fixed_options - orig_options} option(s)")
if fixed.get("next_node_id") and not original.get("next_node_id"):
changes.append("added next_node_id")
if not changes:
changes.append("fixed structural issue")
return "; ".join(changes).capitalize()
# ── Prompt building ──
def _build_fix_prompt(
tree: dict[str, Any],
node_id: str,
error_message: str,
tree_name: str,
tree_type: str,
) -> str:
"""Build the user message for the AI fix request."""
outline = _serialize_tree_outline(tree, error_node_id=node_id)
node = _find_node_by_id(tree, node_id)
node_json = json.dumps(node, indent=2) if node else "{}"
return (
f"Flow name: {tree_name}\n"
f"Flow type: {tree_type}\n\n"
f"## Full flow outline\n```\n{outline}\n```\n\n"
f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n"
f"## Validation error\n{error_message}\n\n"
f"Return the fixed version of this node as JSON."
)
# ── Main entry point ──
async def generate_fixes(
tree_structure: dict[str, Any],
tree_name: str,
tree_type: str,
validation_errors: list[dict[str, str]],
) -> tuple[list[dict[str, Any]], int, int]:
"""Generate AI-powered fixes for tree validation errors.
Args:
tree_structure: Full tree structure dict.
tree_name: Name of the flow.
tree_type: Type of flow (troubleshooting, procedural, maintenance).
validation_errors: List of dicts with "node_id" and "message" keys.
Returns:
Tuple of (fixes_list, total_input_tokens, total_output_tokens).
Each fix dict has: target_node_id, error_message, description,
original_node, fixed_node.
"""
provider = get_ai_provider()
fixes: list[dict[str, Any]] = []
total_input_tokens = 0
total_output_tokens = 0
for error in validation_errors:
node_id = error["node_id"]
error_message = error["message"]
original_node = _find_node_by_id(tree_structure, node_id)
if original_node is None:
logger.warning("Node %s not found in tree, skipping fix", node_id)
continue
original_snapshot = copy.deepcopy(original_node)
# Build prompt and call AI
user_message = _build_fix_prompt(
tree_structure, node_id, error_message, tree_name, tree_type
)
messages = [{"role": "user", "content": user_message}]
try:
text, in_tok, out_tok = await provider.generate_json(
system_prompt=[
{"type": "text", "text": FIX_SYSTEM_PROMPT},
# cacheable: stable constant across all fix attempts
],
messages=messages,
max_tokens=2048,
)
total_input_tokens += in_tok
total_output_tokens += out_tok
cleaned = _strip_markdown_fences(text)
fixed_node = json.loads(cleaned)
except (json.JSONDecodeError, Exception) as exc:
logger.warning("AI fix failed for node %s: %s", node_id, exc)
continue
# Validate by substituting into a tree copy
tree_copy = copy.deepcopy(tree_structure)
_replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node))
remaining_errors = validate_generated_tree(tree_copy)
# Check if the specific error is still present
still_has_error = any(node_id in e for e in remaining_errors)
if still_has_error:
# Retry once with corrective prompt
retry_message = (
f"Your previous fix still has validation errors:\n"
f"{chr(10).join(remaining_errors)}\n\n"
f"Please fix the node again. Return ONLY the corrected JSON."
)
messages.append({"role": "assistant", "content": text})
messages.append({"role": "user", "content": retry_message})
try:
text2, in_tok2, out_tok2 = await provider.generate_json(
system_prompt=[
{"type": "text", "text": FIX_SYSTEM_PROMPT},
# cacheable: stable constant; retry reads the cached
# system block from the first attempt above
],
messages=messages,
max_tokens=2048,
)
total_input_tokens += in_tok2
total_output_tokens += out_tok2
cleaned2 = _strip_markdown_fences(text2)
fixed_node = json.loads(cleaned2)
# Re-validate
tree_copy2 = copy.deepcopy(tree_structure)
_replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node))
remaining2 = validate_generated_tree(tree_copy2)
still_has_error = any(node_id in e for e in remaining2)
except (json.JSONDecodeError, Exception) as exc:
logger.warning("AI retry fix failed for node %s: %s", node_id, exc)
continue
if still_has_error:
logger.warning("AI could not fix node %s after retry", node_id)
continue
description = _describe_fix(original_snapshot, fixed_node)
fixes.append(
{
"target_node_id": node_id,
"error_message": error_message,
"description": description,
"original_node": original_snapshot,
"fixed_node": fixed_node,
}
)
return fixes, total_input_tokens, total_output_tokens