diff --git a/backend/app/core/ai_tree_validator.py b/backend/app/core/ai_tree_validator.py index e57767b4..351a223f 100644 --- a/backend/app/core/ai_tree_validator.py +++ b/backend/app/core/ai_tree_validator.py @@ -40,7 +40,7 @@ def validate_generated_tree(tree: dict[str, Any]) -> list[str]: # Collect all node IDs and validate structure all_ids: set[str] = set() - all_referenced_ids: set[str] = set() # option next_node_ids (already checked locally) + all_referenced_ids: set[str] = set() # option next_node_ids (checked globally below) action_next_ids: set[str] = set() # action next_node_ids (checked globally below) node_count = 0 solution_count = 0 @@ -111,11 +111,6 @@ def validate_generated_tree(tree: dict[str, Any]) -> list[str]: next_id = opt.get("next_node_id") if next_id: all_referenced_ids.add(next_id) - if child_ids and next_id not in child_ids: - errors.append( - f"Option '{opt.get('label', '?')}' in node '{node_id}' " - f"references non-existent child '{next_id}'" - ) elif node_type == "action": next_id = node.get("next_node_id") @@ -144,6 +139,13 @@ def validate_generated_tree(tree: dict[str, Any]) -> list[str]: f"Action next_node_id '{ref_id}' references a node that does not exist in the tree" ) + # Check that all option next_node_ids exist in the tree (allows cross-references) + for ref_id in all_referenced_ids - action_next_ids: + if ref_id not in all_ids: + errors.append( + f"Option next_node_id '{ref_id}' references a node that does not exist in the tree" + ) + # Global checks if node_count < 5: errors.append( diff --git a/backend/tests/test_ai_tree_validator.py b/backend/tests/test_ai_tree_validator.py index f8f3f4d7..cfa58ff0 100644 --- a/backend/tests/test_ai_tree_validator.py +++ b/backend/tests/test_ai_tree_validator.py @@ -122,7 +122,7 @@ class TestReferenceIntegrity: tree = _make_valid_tree() tree["options"][0]["next_node_id"] = "nonexistent" errors = validate_generated_tree(tree) - assert any("non-existent child" in e for e in errors) + assert any("does not exist" in e for e in errors) def test_action_next_node_id_references_nonexistent_node(self): """Action next_node_id pointing to a node that doesn't exist anywhere in the tree.""" @@ -188,6 +188,31 @@ class TestDeadEndDetection: assert any("dead end" in e for e in errors) +class TestCrossReferenceSupport: + def test_option_referencing_non_child_node_in_tree_is_valid(self): + """A decision option can reference any node in the tree, not just direct children.""" + tree = _make_valid_tree() + # Make root option point to a grandchild (not a direct child) — cross-reference + tree["options"][0]["next_node_id"] = "fix-errors" # grandchild of root + errors = validate_generated_tree(tree) + assert not any("non-existent child" in e for e in errors) + assert not any("does not exist" in e for e in errors) + + def test_option_referencing_nonexistent_node_still_fails(self): + """Cross-references must still point to nodes that exist in the tree.""" + tree = _make_valid_tree() + tree["options"][0]["next_node_id"] = "totally-fake-id" + errors = validate_generated_tree(tree) + assert any("does not exist" in e for e in errors) + + def test_action_next_node_id_to_ancestor_is_valid(self): + """Action node can loop back to an ancestor node.""" + tree = _make_valid_tree() + tree["children"][1]["next_node_id"] = "root" + errors = validate_generated_tree(tree) + assert not any("does not exist" in e for e in errors) + + class TestCountTreeStats: def test_stats_correct(self): tree = _make_valid_tree()