diff --git a/backend/app/services/ai_tree_builder.py b/backend/app/services/ai_tree_builder.py new file mode 100644 index 00000000..c9f44c4b --- /dev/null +++ b/backend/app/services/ai_tree_builder.py @@ -0,0 +1,154 @@ +"""Constrained, node-by-node L1 decision-tree generation (spec §4/§5/§6.1). + +Each call produces ONE node given the problem, category, and full walked path. +Generation is constrained to safe/reversible L1 steps and biased to escalate +early. normalize_walked_path() turns a resolved walk into a valid tree object +for flywheel capture. +""" +import logging +from typing import Any, Optional + +from app.core.ai_provider import get_ai_provider +from app.core.config import settings +from app.services.l1_category_service import HARD_FLOOR_TEXT_PATTERNS +from app.services.llm_utils import parse_llm_json + +logger = logging.getLogger(__name__) + +MAX_DEPTH = 12 +VALID_NODE_TYPES = {"question", "instruction", "resolved", "escalate"} + + +class UnsafeNodeError(ValueError): + """Raised when a generated node violates the hard floor or is malformed.""" + + +SYSTEM_PROMPT = """\ +You are an L1 helpdesk troubleshooting guide builder. Given a problem and the +steps already tried, produce the SINGLE next node of a yes/no decision tree. + +HARD RULES: +- Only safe, reversible, observe-or-restart-class steps: checking status, toggling, + restarting, reconnecting, re-entering credentials the USER already knows. +- NEVER produce steps that: edit the registry/system files/boot config; delete or + format data/disks; change credentials/MFA/security/firewall/AV; run elevated or + admin scripts; touch domain controllers/DNS/DHCP or production servers; or have + billing/license impact. These are out of L1 scope. +- When you run out of safe in-scope steps, DO NOT GUESS. Emit an "escalate" node. + +Return ONLY a JSON object for ONE node, one of: +{"node_type":"question","text":""} +{"node_type":"instruction","text":""} +{"node_type":"resolved","text":""} +{"node_type":"escalate","reason_category":"exhausted_safe_steps","text":""} +No prose, no markdown fences. +""" + + +def _build_context(problem_text: str, category: str, walked_path: list[dict]) -> str: + lines = [f"PROBLEM: {problem_text}", f"CATEGORY: {category}", "STEPS SO FAR:"] + if not walked_path: + lines.append("(none yet — produce the first diagnostic question)") + for i, step in enumerate(walked_path, 1): + ans = step.get("answer") + suffix = f" -> {ans}" if ans else "" + lines.append(f"{i}. [{step.get('node_type','?')}] {step.get('text','')}{suffix}") + return "\n".join(lines) + + +def validate_node(node: dict[str, Any]) -> dict[str, Any]: + """Shape + hard-floor validation. Raises UnsafeNodeError on violation.""" + if not isinstance(node, dict) or node.get("node_type") not in VALID_NODE_TYPES: + raise UnsafeNodeError(f"invalid node_type: {node!r}") + text = (node.get("text") or "").lower() + for pat in HARD_FLOOR_TEXT_PATTERNS: + if pat in text: + raise UnsafeNodeError(f"hard-floor pattern '{pat}' in node text") + return node + + +def escalate_if_depth_exceeded(walked_path: list[dict]) -> Optional[dict[str, Any]]: + if len(walked_path) >= MAX_DEPTH: + return { + "node_type": "escalate", + "reason_category": "depth_cap", + "text": "Reached the L1 troubleshooting depth limit — escalating to engineering.", + } + return None + + +async def generate_next_node( + problem_text: str, category: str, walked_path: list[dict] +) -> dict[str, Any]: + """Generate + validate the next node. Regenerate once on failure, then escalate.""" + capped = escalate_if_depth_exceeded(walked_path) + if capped: + return capped + + provider = get_ai_provider(settings.get_model_for_action("l1_realtime_build")) + context = _build_context(problem_text, category, walked_path) + + for attempt in range(2): + try: + raw, _, _ = await provider.generate_json( + system_prompt=SYSTEM_PROMPT, + messages=[{"role": "user", "content": context}], + max_tokens=1024, + ) + node = parse_llm_json(raw) + return validate_node(node) + except Exception as e: + logger.warning("ai_tree_builder node attempt %d failed: %s", attempt + 1, e) + continue + + return { + "node_type": "escalate", + "reason_category": "generation_failed", + "text": "Could not generate a safe next step — escalating to engineering.", + } + + +def normalize_walked_path(walked_path: list[dict]) -> dict[str, Any]: + """Turn a resolved walk into a valid troubleshooting tree (spec §6.1). + + Root = first node's id; question nodes' traversed branch points to the next + node, the untraversed branch to a needs_review stub; terminal node ends it. + Returns {id, nodes: {id: node}} — a dict with an id (passes the proposal + approval guard). + """ + nodes: dict[str, Any] = {} + if not walked_path: + root_id = "root" + nodes[root_id] = {"id": root_id, "node_type": "needs_review", + "text": "Empty walk — needs authoring."} + return {"id": root_id, "nodes": nodes} + + stub_seq = 0 + for i, step in enumerate(walked_path): + nid = step.get("id") or f"n{i+1}" + ntype = step.get("node_type", "question") + nxt = walked_path[i + 1].get("id", f"n{i+2}") if i + 1 < len(walked_path) else None + node: dict[str, Any] = {"id": nid, "node_type": ntype, "text": step.get("text", "")} + if step.get("reason_category"): + node["reason_category"] = step["reason_category"] + if ntype == "question": + answer = (step.get("answer") or "").lower() + stub_seq += 1 + stub_id = f"review-{stub_seq}" + nodes[stub_id] = {"id": stub_id, "node_type": "needs_review", + "text": "Branch not explored during the originating call."} + traversed_next = nxt + if traversed_next is None: + # Walk ended on this question (no terminal recorded) — stub the + # branch the tech actually took so the tree has no dangling edge. + stub_seq += 1 + traversed_next = f"review-{stub_seq}" + nodes[traversed_next] = {"id": traversed_next, "node_type": "needs_review", + "text": "Walk ended here before a terminal step was reached."} + node["yes_next"] = traversed_next if answer == "yes" else stub_id + node["no_next"] = traversed_next if answer == "no" else stub_id + elif ntype == "instruction": + node["next"] = nxt + nodes[nid] = node + + return {"id": walked_path[0].get("id", "n1"), "nodes": nodes} diff --git a/backend/tests/test_ai_tree_builder.py b/backend/tests/test_ai_tree_builder.py new file mode 100644 index 00000000..49d4854a --- /dev/null +++ b/backend/tests/test_ai_tree_builder.py @@ -0,0 +1,58 @@ +import pytest +from app.services import ai_tree_builder as atb + + +def test_validate_node_rejects_hard_floor_text(): + node = {"node_type": "instruction", "id": "n1", "text": "Open regedit and change the key", "next": "generate"} + with pytest.raises(atb.UnsafeNodeError): + atb.validate_node(node) + + +def test_validate_node_accepts_safe_instruction(): + node = {"node_type": "instruction", "id": "n1", "text": "Restart the printer.", "next": "generate"} + assert atb.validate_node(node)["node_type"] == "instruction" + + +def test_depth_cap_forces_escalate(): + walked = [{"node_type": "question", "id": f"n{i}", "text": "?", "answer": "no"} for i in range(atb.MAX_DEPTH)] + node = atb.escalate_if_depth_exceeded(walked) + assert node is not None and node["node_type"] == "escalate" + + +def test_normalize_walked_path_builds_valid_tree(): + walked = [ + {"node_type": "question", "id": "n1", "text": "Powered on?", "answer": "no"}, + {"node_type": "instruction", "id": "n2", "text": "Power it on.", "answer": "ack"}, + {"node_type": "resolved", "id": "n3", "text": "Fixed."}, + ] + tree = atb.normalize_walked_path(walked) + assert isinstance(tree, dict) and tree.get("id") == "n1" + # untraversed 'yes' branch of n1 became a needs_review stub + assert any(n["node_type"] == "needs_review" for n in tree["nodes"].values()) + + +def test_normalize_walk_ending_on_question_has_no_none_branches(): + walked = [ + {"node_type": "question", "id": "n1", "text": "Powered on?", "answer": "no"}, + ] + tree = atb.normalize_walked_path(walked) + n1 = tree["nodes"]["n1"] + assert n1["yes_next"] is not None and n1["no_next"] is not None + # both branches must reference real nodes present in the tree + assert n1["yes_next"] in tree["nodes"] and n1["no_next"] in tree["nodes"] + + +def test_normalize_preserves_escalate_reason_category(): + walked = [ + {"node_type": "question", "id": "n1", "text": "On?", "answer": "no"}, + {"node_type": "escalate", "id": "n2", "text": "Beyond L1.", + "reason_category": "exhausted_safe_steps"}, + ] + tree = atb.normalize_walked_path(walked) + assert tree["nodes"]["n2"]["reason_category"] == "exhausted_safe_steps" + + +def test_normalize_empty_walk_returns_needs_review_root(): + tree = atb.normalize_walked_path([]) + assert tree["id"] in tree["nodes"] + assert tree["nodes"][tree["id"]]["node_type"] == "needs_review"