Files
resolutionflow/backend/tests/test_ai_tree_validator.py
2026-02-28 19:18:02 -05:00

225 lines
8.5 KiB
Python

"""Tests for AI-generated tree structure validation."""
import pytest
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
def _make_valid_tree():
"""Helper: minimal valid tree for testing.
Action nodes use next_node_id to point to a sibling (not children).
The solution following an action is a sibling under the parent decision.
"""
return {
"id": "root",
"type": "decision",
"question": "Is the service running?",
"options": [
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
{"id": "opt-no", "label": "No", "next_node_id": "restart-service"},
],
"children": [
{
"id": "check-logs",
"type": "decision",
"question": "Are there errors in the logs?",
"options": [
{"id": "opt-errors", "label": "Yes", "next_node_id": "fix-errors"},
{"id": "opt-clean", "label": "No", "next_node_id": "escalate"},
],
"children": [
{
"id": "fix-errors",
"type": "solution",
"title": "Fix Errors",
"description": "Apply the fix for the errors found.",
},
{
"id": "escalate",
"type": "solution",
"title": "Escalate",
"description": "No errors found; escalate to Tier 2.",
},
],
},
{
"id": "restart-service",
"type": "action",
"title": "Restart the Service",
"description": "Restart the service and verify.",
"commands": ["Restart-Service -Name 'TestService'"],
"next_node_id": "service-resolved",
},
{
"id": "service-resolved",
"type": "solution",
"title": "Service Restored",
"description": "Service is running after restart.",
},
],
}
class TestValidTree:
def test_valid_tree_passes(self):
errors = validate_generated_tree(_make_valid_tree())
assert errors == []
def test_not_a_dict(self):
errors = validate_generated_tree("not a dict")
assert any("must be a JSON object" in e for e in errors)
def test_root_not_decision(self):
tree = _make_valid_tree()
tree["type"] = "action"
tree["title"] = "Fake"
errors = validate_generated_tree(tree)
assert any("Root node must be type 'decision'" in e for e in errors)
class TestNodeValidation:
def test_missing_id(self):
tree = _make_valid_tree()
del tree["children"][0]["id"]
errors = validate_generated_tree(tree)
assert any("missing 'id'" in e for e in errors)
def test_duplicate_ids(self):
tree = _make_valid_tree()
tree["children"][1]["id"] = "check-logs" # same as sibling
errors = validate_generated_tree(tree)
assert any("Duplicate node ID" in e for e in errors)
def test_invalid_node_type(self):
tree = _make_valid_tree()
tree["children"][0]["type"] = "unknown"
errors = validate_generated_tree(tree)
assert any("invalid type" in e for e in errors)
def test_decision_missing_options(self):
tree = _make_valid_tree()
del tree["children"][0]["options"]
errors = validate_generated_tree(tree)
assert any("missing fields" in e for e in errors)
def test_decision_less_than_2_options(self):
tree = _make_valid_tree()
tree["children"][0]["options"] = [
{"id": "opt-1", "label": "Only", "next_node_id": "fix-errors"}
]
errors = validate_generated_tree(tree)
assert any("at least 2 options" in e for e in errors)
def test_action_missing_next_node_id(self):
tree = _make_valid_tree()
del tree["children"][1]["next_node_id"]
errors = validate_generated_tree(tree)
assert any("missing 'next_node_id'" in e for e in errors)
class TestReferenceIntegrity:
def test_option_references_nonexistent_child(self):
tree = _make_valid_tree()
tree["options"][0]["next_node_id"] = "nonexistent"
errors = validate_generated_tree(tree)
assert any("does not exist" in e for e in errors)
def test_action_next_node_id_references_nonexistent_node(self):
"""Action next_node_id pointing to a node that doesn't exist anywhere in the tree."""
tree = _make_valid_tree()
tree["children"][1]["next_node_id"] = "ghost-node"
errors = validate_generated_tree(tree)
assert any("ghost-node" in e for e in errors)
def test_duplicate_option_ids(self):
tree = _make_valid_tree()
tree["options"][0]["id"] = "same"
tree["options"][1]["id"] = "same"
errors = validate_generated_tree(tree)
assert any("Duplicate option ID" in e for e in errors)
class TestGlobalChecks:
def test_too_few_nodes(self):
tree = {
"id": "root",
"type": "decision",
"question": "Test?",
"options": [
{"id": "o1", "label": "A", "next_node_id": "s1"},
{"id": "o2", "label": "B", "next_node_id": "s2"},
],
"children": [
{"id": "s1", "type": "solution", "title": "S1", "description": "D1"},
{"id": "s2", "type": "solution", "title": "S2", "description": "D2"},
],
}
errors = validate_generated_tree(tree)
assert any("Minimum 5 required" in e for e in errors)
def test_too_few_solutions(self):
tree = _make_valid_tree()
# Remove all solutions except one — replace children of check-logs
tree["children"][0]["children"] = [
{
"id": "only-solution",
"type": "solution",
"title": "Only",
"description": "Only solution",
}
]
tree["children"][0]["options"] = [
{"id": "o1", "label": "A", "next_node_id": "only-solution"},
{"id": "o2", "label": "B", "next_node_id": "only-solution"},
]
# Remove the solution that restart-service points to
tree["children"].pop(2) # remove service-resolved
errors = validate_generated_tree(tree)
assert any("solution" in e.lower() for e in errors)
class TestDeadEndDetection:
def test_dead_end_decision_node(self):
"""A decision node with no children is a dead end."""
tree = _make_valid_tree()
# Remove children from check-logs decision node — becomes dead end
tree["children"][0]["children"] = []
errors = validate_generated_tree(tree)
assert any("dead end" in e for e in errors)
class TestCrossReferenceSupport:
def test_option_referencing_non_child_node_in_tree_is_valid(self):
"""A decision option can reference any node in the tree, not just direct children."""
tree = _make_valid_tree()
# Make root option point to a grandchild (not a direct child) — cross-reference
tree["options"][0]["next_node_id"] = "fix-errors" # grandchild of root
errors = validate_generated_tree(tree)
assert not any("non-existent child" in e for e in errors)
assert not any("does not exist" in e for e in errors)
def test_option_referencing_nonexistent_node_still_fails(self):
"""Cross-references must still point to nodes that exist in the tree."""
tree = _make_valid_tree()
tree["options"][0]["next_node_id"] = "totally-fake-id"
errors = validate_generated_tree(tree)
assert any("does not exist" in e for e in errors)
def test_action_next_node_id_to_ancestor_is_valid(self):
"""Action node can loop back to an ancestor node."""
tree = _make_valid_tree()
tree["children"][1]["next_node_id"] = "root"
errors = validate_generated_tree(tree)
assert not any("does not exist" in e for e in errors)
class TestCountTreeStats:
def test_stats_correct(self):
tree = _make_valid_tree()
stats = count_tree_stats(tree)
assert stats["node_count"] == 6
assert stats["decision_count"] == 2
assert stats["action_count"] == 1
assert stats["solution_count"] == 3
assert stats["depth"] >= 3