feat: AI-assisted flow builder with 4-stage wizard
Implements the complete AI flow builder feature using a guided 4-stage wizard (Foundation → Scaffold → Branch Detail → Review & Assemble). AI assists at bounded points using Claude Haiku for cost-efficient structured JSON generation (~$0.01-0.03/flow). Backend: new models (ai_conversations, ai_usage), Alembic migration, quota enforcement with billing anchor, Anthropic API integration with prompt caching, tree validation, conversation CRUD with 24h TTL, APScheduler cleanup job, 5 API endpoints, Pydantic schemas. Frontend: TypeScript types, API client, Zustand store for wizard state, 7 components (modal, step indicator, foundation form, branch selector, branch detail view, tree preview, quota display), MyTreesPage integration with "Build with AI" button (hidden when AI not configured). Tests: 14 validator unit tests + 11 endpoint integration tests with mocked Anthropic (zero real API spend). All 25 tests passing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
360
backend/tests/test_ai_endpoints.py
Normal file
360
backend/tests/test_ai_endpoints.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""Integration tests for AI Flow Builder endpoints.
|
||||
|
||||
All Anthropic API calls are mocked — zero real API spend.
|
||||
"""
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
# ── Sample AI responses ──
|
||||
|
||||
SCAFFOLD_RESPONSE_JSON = json.dumps({
|
||||
"branches": [
|
||||
{"name": "Service Not Running", "description": "The target service is stopped or crashed."},
|
||||
{"name": "Authentication Failures", "description": "Users cannot authenticate against the service."},
|
||||
{"name": "Network Connectivity", "description": "Network-level issues preventing access."},
|
||||
{"name": "Configuration Errors", "description": "Misconfiguration of the service or its dependencies."},
|
||||
]
|
||||
})
|
||||
|
||||
BRANCH_DETAIL_JSON = json.dumps({
|
||||
"id": "svc-root",
|
||||
"type": "decision",
|
||||
"question": "Is the service running?",
|
||||
"options": [
|
||||
{"id": "opt-yes", "label": "Yes", "next_node_id": "svc-check-logs"},
|
||||
{"id": "opt-no", "label": "No", "next_node_id": "svc-restart"},
|
||||
],
|
||||
"children": [
|
||||
{
|
||||
"id": "svc-check-logs",
|
||||
"type": "action",
|
||||
"title": "Check Event Logs",
|
||||
"description": "Check Windows Event Viewer for errors.",
|
||||
"commands": ["Get-EventLog -LogName Application -Newest 20"],
|
||||
"children": [
|
||||
{
|
||||
"id": "svc-logs-resolved",
|
||||
"type": "solution",
|
||||
"title": "Issue Found in Logs",
|
||||
"description": "Error identified and resolved.",
|
||||
"resolution_steps": ["Fix the error", "Restart service"],
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "svc-restart",
|
||||
"type": "action",
|
||||
"title": "Restart Service",
|
||||
"description": "Attempt to restart the service.",
|
||||
"commands": ["Restart-Service -Name 'TestService'"],
|
||||
"children": [
|
||||
{
|
||||
"id": "svc-restart-ok",
|
||||
"type": "solution",
|
||||
"title": "Service Restored",
|
||||
"description": "Service is running after restart.",
|
||||
"resolution_steps": ["Verify connectivity", "Document in ticket"],
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
|
||||
def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200):
|
||||
"""Create a mock Anthropic API response."""
|
||||
response = MagicMock()
|
||||
response.content = [MagicMock(text=text)]
|
||||
response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens)
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_ai():
|
||||
"""Temporarily enable AI by setting a fake API key."""
|
||||
original = settings.ANTHROPIC_API_KEY
|
||||
settings.ANTHROPIC_API_KEY = "test-key-fake"
|
||||
yield
|
||||
settings.ANTHROPIC_API_KEY = original
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_ai():
|
||||
"""Ensure AI is disabled."""
|
||||
original = settings.ANTHROPIC_API_KEY
|
||||
settings.ANTHROPIC_API_KEY = None
|
||||
yield
|
||||
settings.ANTHROPIC_API_KEY = original
|
||||
|
||||
|
||||
# ── Quota endpoint ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_returns_disabled_when_no_key(client, auth_headers, disable_ai):
|
||||
"""GET /ai/quota returns ai_enabled=false when no API key."""
|
||||
response = await client.get("/api/v1/ai/quota", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["ai_enabled"] is False
|
||||
assert data["allowed"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_returns_enabled_with_key(client, auth_headers, enable_ai):
|
||||
"""GET /ai/quota returns ai_enabled=true with API key configured."""
|
||||
response = await client.get("/api/v1/ai/quota", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["ai_enabled"] is True
|
||||
assert data["allowed"] is True
|
||||
|
||||
|
||||
# ── Start endpoint ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_requires_auth(client, enable_ai):
|
||||
"""POST /ai/start requires authentication."""
|
||||
response = await client.post("/api/v1/ai/start", json={
|
||||
"flow_type": "troubleshooting",
|
||||
"name": "Test Flow",
|
||||
"description": "Test",
|
||||
})
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_returns_503_when_disabled(client, auth_headers, disable_ai):
|
||||
"""POST /ai/start returns 503 when AI is not configured."""
|
||||
response = await client.post(
|
||||
"/api/v1/ai/start",
|
||||
json={
|
||||
"flow_type": "troubleshooting",
|
||||
"name": "Test Flow",
|
||||
"description": "Test description",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_creates_conversation(client, auth_headers, enable_ai):
|
||||
"""POST /ai/start creates a conversation and returns conversation_id."""
|
||||
response = await client.post(
|
||||
"/api/v1/ai/start",
|
||||
json={
|
||||
"flow_type": "troubleshooting",
|
||||
"name": "DNS Issues",
|
||||
"description": "Troubleshooting DNS resolution failures",
|
||||
"environment_tags": ["Windows Server", "Active Directory"],
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert "conversation_id" in data
|
||||
assert data["status"] == "foundation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_validates_input(client, auth_headers, enable_ai):
|
||||
"""POST /ai/start rejects invalid input."""
|
||||
response = await client.post(
|
||||
"/api/v1/ai/start",
|
||||
json={
|
||||
"flow_type": "troubleshooting",
|
||||
"name": "", # Empty name
|
||||
"description": "Test",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ── Scaffold endpoint ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scaffold_success(client, auth_headers, enable_ai):
|
||||
"""POST /ai/scaffold returns AI-generated branches."""
|
||||
# Create conversation first
|
||||
start_resp = await client.post(
|
||||
"/api/v1/ai/start",
|
||||
json={
|
||||
"flow_type": "troubleshooting",
|
||||
"name": "DNS Issues",
|
||||
"description": "DNS resolution failures",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
conversation_id = start_resp.json()["conversation_id"]
|
||||
|
||||
# Mock Anthropic
|
||||
mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/ai/scaffold",
|
||||
json={"conversation_id": conversation_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "scaffolding"
|
||||
assert len(data["branches"]) == 4
|
||||
assert data["branches"][0]["name"] == "Service Not Running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scaffold_invalid_conversation(client, auth_headers, enable_ai):
|
||||
"""POST /ai/scaffold returns 404 for nonexistent conversation."""
|
||||
response = await client.post(
|
||||
"/api/v1/ai/scaffold",
|
||||
json={"conversation_id": "00000000-0000-0000-0000-000000000000"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ── Branch detail endpoint ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_branch_detail_success(client, auth_headers, enable_ai):
|
||||
"""POST /ai/branch-detail returns AI-generated branch nodes."""
|
||||
# Create and scaffold first
|
||||
start_resp = await client.post(
|
||||
"/api/v1/ai/start",
|
||||
json={
|
||||
"flow_type": "troubleshooting",
|
||||
"name": "Service Issues",
|
||||
"description": "Service troubleshooting",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
conversation_id = start_resp.json()["conversation_id"]
|
||||
|
||||
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
|
||||
await client.post(
|
||||
"/api/v1/ai/scaffold",
|
||||
json={"conversation_id": conversation_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
# Now generate branch detail
|
||||
detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON)
|
||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock)
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/ai/branch-detail",
|
||||
json={
|
||||
"conversation_id": conversation_id,
|
||||
"branch_name": "Service Not Running",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["branch_name"] == "Service Not Running"
|
||||
assert data["steps"]["id"] == "svc-root"
|
||||
assert data["steps"]["type"] == "decision"
|
||||
|
||||
|
||||
# ── Assemble endpoint ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_success(client, auth_headers, enable_ai):
|
||||
"""POST /ai/assemble returns assembled tree from branches with detail."""
|
||||
# Create conversation
|
||||
start_resp = await client.post(
|
||||
"/api/v1/ai/start",
|
||||
json={
|
||||
"flow_type": "troubleshooting",
|
||||
"name": "Service Issues",
|
||||
"description": "Service troubleshooting",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
conversation_id = start_resp.json()["conversation_id"]
|
||||
|
||||
# Scaffold
|
||||
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
|
||||
await client.post(
|
||||
"/api/v1/ai/scaffold",
|
||||
json={"conversation_id": conversation_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
# Assemble with branch detail included
|
||||
branch_tree = json.loads(BRANCH_DETAIL_JSON)
|
||||
response = await client.post(
|
||||
"/api/v1/ai/assemble",
|
||||
json={
|
||||
"conversation_id": conversation_id,
|
||||
"selected_branches": [
|
||||
{
|
||||
"name": "Service Not Running",
|
||||
"description": "The target service is stopped.",
|
||||
"steps": branch_tree,
|
||||
},
|
||||
{
|
||||
"name": "Authentication Failures",
|
||||
"description": "Users cannot authenticate.",
|
||||
"steps": branch_tree,
|
||||
},
|
||||
],
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "completed"
|
||||
assert data["suggested_name"] == "Service Issues"
|
||||
assert "tree_structure" in data
|
||||
assert data["tree_structure"]["type"] == "decision"
|
||||
assert data["summary"]["node_count"] > 0
|
||||
assert data["summary"]["solution_count"] >= 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_requires_min_2_branches(client, auth_headers, enable_ai):
|
||||
"""POST /ai/assemble rejects fewer than 2 branches."""
|
||||
start_resp = await client.post(
|
||||
"/api/v1/ai/start",
|
||||
json={
|
||||
"flow_type": "troubleshooting",
|
||||
"name": "Test",
|
||||
"description": "Test",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
conversation_id = start_resp.json()["conversation_id"]
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/ai/assemble",
|
||||
json={
|
||||
"conversation_id": conversation_id,
|
||||
"selected_branches": [
|
||||
{"name": "Only Branch", "description": "Just one"},
|
||||
],
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 422
|
||||
183
backend/tests/test_ai_tree_validator.py
Normal file
183
backend/tests/test_ai_tree_validator.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Tests for AI-generated tree structure validation."""
|
||||
import pytest
|
||||
|
||||
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
|
||||
|
||||
|
||||
def _make_valid_tree():
|
||||
"""Helper: minimal valid tree for testing."""
|
||||
return {
|
||||
"id": "root",
|
||||
"type": "decision",
|
||||
"question": "Is the service running?",
|
||||
"options": [
|
||||
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
|
||||
{"id": "opt-no", "label": "No", "next_node_id": "restart-service"},
|
||||
],
|
||||
"children": [
|
||||
{
|
||||
"id": "check-logs",
|
||||
"type": "decision",
|
||||
"question": "Are there errors in the logs?",
|
||||
"options": [
|
||||
{"id": "opt-errors", "label": "Yes", "next_node_id": "fix-errors"},
|
||||
{"id": "opt-clean", "label": "No", "next_node_id": "escalate"},
|
||||
],
|
||||
"children": [
|
||||
{
|
||||
"id": "fix-errors",
|
||||
"type": "solution",
|
||||
"title": "Fix Errors",
|
||||
"description": "Apply the fix for the errors found.",
|
||||
},
|
||||
{
|
||||
"id": "escalate",
|
||||
"type": "solution",
|
||||
"title": "Escalate",
|
||||
"description": "No errors found; escalate to Tier 2.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "restart-service",
|
||||
"type": "action",
|
||||
"title": "Restart the Service",
|
||||
"description": "Restart the service and verify.",
|
||||
"commands": ["Restart-Service -Name 'TestService'"],
|
||||
"children": [
|
||||
{
|
||||
"id": "service-resolved",
|
||||
"type": "solution",
|
||||
"title": "Service Restored",
|
||||
"description": "Service is running after restart.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class TestValidTree:
|
||||
def test_valid_tree_passes(self):
|
||||
errors = validate_generated_tree(_make_valid_tree())
|
||||
assert errors == []
|
||||
|
||||
def test_not_a_dict(self):
|
||||
errors = validate_generated_tree("not a dict")
|
||||
assert any("must be a JSON object" in e for e in errors)
|
||||
|
||||
def test_root_not_decision(self):
|
||||
tree = _make_valid_tree()
|
||||
tree["type"] = "action"
|
||||
tree["title"] = "Fake"
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("Root node must be type 'decision'" in e for e in errors)
|
||||
|
||||
|
||||
class TestNodeValidation:
|
||||
def test_missing_id(self):
|
||||
tree = _make_valid_tree()
|
||||
del tree["children"][0]["id"]
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("missing 'id'" in e for e in errors)
|
||||
|
||||
def test_duplicate_ids(self):
|
||||
tree = _make_valid_tree()
|
||||
tree["children"][1]["id"] = "check-logs" # same as sibling
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("Duplicate node ID" in e for e in errors)
|
||||
|
||||
def test_invalid_node_type(self):
|
||||
tree = _make_valid_tree()
|
||||
tree["children"][0]["type"] = "unknown"
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("invalid type" in e for e in errors)
|
||||
|
||||
def test_decision_missing_options(self):
|
||||
tree = _make_valid_tree()
|
||||
del tree["children"][0]["options"]
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("missing fields" in e for e in errors)
|
||||
|
||||
def test_decision_less_than_2_options(self):
|
||||
tree = _make_valid_tree()
|
||||
tree["children"][0]["options"] = [
|
||||
{"id": "opt-1", "label": "Only", "next_node_id": "fix-errors"}
|
||||
]
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("at least 2 options" in e for e in errors)
|
||||
|
||||
|
||||
class TestReferenceIntegrity:
|
||||
def test_option_references_nonexistent_child(self):
|
||||
tree = _make_valid_tree()
|
||||
tree["options"][0]["next_node_id"] = "nonexistent"
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("non-existent child" in e for e in errors)
|
||||
|
||||
def test_duplicate_option_ids(self):
|
||||
tree = _make_valid_tree()
|
||||
tree["options"][0]["id"] = "same"
|
||||
tree["options"][1]["id"] = "same"
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("Duplicate option ID" in e for e in errors)
|
||||
|
||||
|
||||
class TestGlobalChecks:
|
||||
def test_too_few_nodes(self):
|
||||
tree = {
|
||||
"id": "root",
|
||||
"type": "decision",
|
||||
"question": "Test?",
|
||||
"options": [
|
||||
{"id": "o1", "label": "A", "next_node_id": "s1"},
|
||||
{"id": "o2", "label": "B", "next_node_id": "s2"},
|
||||
],
|
||||
"children": [
|
||||
{"id": "s1", "type": "solution", "title": "S1", "description": "D1"},
|
||||
{"id": "s2", "type": "solution", "title": "S2", "description": "D2"},
|
||||
],
|
||||
}
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("Minimum 5 required" in e for e in errors)
|
||||
|
||||
def test_too_few_solutions(self):
|
||||
tree = _make_valid_tree()
|
||||
# Remove all solutions except one — replace children of check-logs
|
||||
tree["children"][0]["children"] = [
|
||||
{
|
||||
"id": "only-solution",
|
||||
"type": "solution",
|
||||
"title": "Only",
|
||||
"description": "Only solution",
|
||||
}
|
||||
]
|
||||
tree["children"][0]["options"] = [
|
||||
{"id": "o1", "label": "A", "next_node_id": "only-solution"},
|
||||
{"id": "o2", "label": "B", "next_node_id": "only-solution"},
|
||||
]
|
||||
# Now restart-service branch has 1 solution, check-logs has 1 = total 2
|
||||
# Remove one more to get to 1
|
||||
tree["children"][1]["children"] = []
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("solution" in e.lower() for e in errors)
|
||||
|
||||
|
||||
class TestDeadEndDetection:
|
||||
def test_dead_end_action_node(self):
|
||||
tree = _make_valid_tree()
|
||||
# Remove restart-service's children — becomes dead end
|
||||
tree["children"][1]["children"] = []
|
||||
errors = validate_generated_tree(tree)
|
||||
assert any("dead end" in e for e in errors)
|
||||
|
||||
|
||||
class TestCountTreeStats:
|
||||
def test_stats_correct(self):
|
||||
tree = _make_valid_tree()
|
||||
stats = count_tree_stats(tree)
|
||||
assert stats["node_count"] == 6
|
||||
assert stats["decision_count"] == 2
|
||||
assert stats["action_count"] == 1
|
||||
assert stats["solution_count"] == 3
|
||||
assert stats["depth"] >= 3
|
||||
Reference in New Issue
Block a user