358 lines
12 KiB
Python
358 lines
12 KiB
Python
"""Integration tests for AI Flow Builder endpoints.
|
|
|
|
All AI provider calls are mocked — zero real API spend.
|
|
"""
|
|
import json
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
# ── Sample AI responses ──
|
|
|
|
SCAFFOLD_RESPONSE_JSON = json.dumps({
|
|
"branches": [
|
|
{"name": "Service Not Running", "description": "The target service is stopped or crashed."},
|
|
{"name": "Authentication Failures", "description": "Users cannot authenticate against the service."},
|
|
{"name": "Network Connectivity", "description": "Network-level issues preventing access."},
|
|
{"name": "Configuration Errors", "description": "Misconfiguration of the service or its dependencies."},
|
|
]
|
|
})
|
|
|
|
BRANCH_DETAIL_JSON = json.dumps({
|
|
"id": "svc-root",
|
|
"type": "decision",
|
|
"question": "Is the service running?",
|
|
"options": [
|
|
{"id": "opt-yes", "label": "Yes", "next_node_id": "svc-check-logs"},
|
|
{"id": "opt-no", "label": "No", "next_node_id": "svc-restart"},
|
|
],
|
|
"children": [
|
|
{
|
|
"id": "svc-check-logs",
|
|
"type": "action",
|
|
"title": "Check Event Logs",
|
|
"description": "Check Windows Event Viewer for errors.",
|
|
"commands": ["Get-EventLog -LogName Application -Newest 20"],
|
|
"next_node_id": "svc-logs-resolved",
|
|
},
|
|
{
|
|
"id": "svc-logs-resolved",
|
|
"type": "solution",
|
|
"title": "Issue Found in Logs",
|
|
"description": "Error identified and resolved.",
|
|
"resolution_steps": ["Fix the error", "Restart service"],
|
|
},
|
|
{
|
|
"id": "svc-restart",
|
|
"type": "action",
|
|
"title": "Restart Service",
|
|
"description": "Attempt to restart the service.",
|
|
"commands": ["Restart-Service -Name 'TestService'"],
|
|
"next_node_id": "svc-restart-ok",
|
|
},
|
|
{
|
|
"id": "svc-restart-ok",
|
|
"type": "solution",
|
|
"title": "Service Restored",
|
|
"description": "Service is running after restart.",
|
|
"resolution_steps": ["Verify connectivity", "Document in ticket"],
|
|
},
|
|
],
|
|
})
|
|
|
|
|
|
def _mock_ai_provider(text: str, input_tokens: int = 100, output_tokens: int = 200):
|
|
"""Create a mock AI provider whose generate_json returns the given text and token counts."""
|
|
provider = MagicMock()
|
|
provider.generate_json = AsyncMock(return_value=(text, input_tokens, output_tokens))
|
|
return provider
|
|
|
|
|
|
@pytest.fixture
|
|
def enable_ai():
|
|
"""Temporarily enable AI by setting a fake API key."""
|
|
original_anthropic = settings.ANTHROPIC_API_KEY
|
|
original_google = settings.GOOGLE_AI_API_KEY
|
|
settings.ANTHROPIC_API_KEY = "test-key-fake"
|
|
settings.GOOGLE_AI_API_KEY = None
|
|
yield
|
|
settings.ANTHROPIC_API_KEY = original_anthropic
|
|
settings.GOOGLE_AI_API_KEY = original_google
|
|
|
|
|
|
@pytest.fixture
|
|
def disable_ai():
|
|
"""Ensure AI is disabled."""
|
|
original_anthropic = settings.ANTHROPIC_API_KEY
|
|
original_google = settings.GOOGLE_AI_API_KEY
|
|
settings.ANTHROPIC_API_KEY = None
|
|
settings.GOOGLE_AI_API_KEY = None
|
|
yield
|
|
settings.ANTHROPIC_API_KEY = original_anthropic
|
|
settings.GOOGLE_AI_API_KEY = original_google
|
|
|
|
|
|
# ── Quota endpoint ──
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_quota_returns_disabled_when_no_key(client, auth_headers, disable_ai):
|
|
"""GET /ai/quota returns ai_enabled=false when no API key."""
|
|
response = await client.get("/api/v1/ai/quota", headers=auth_headers)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["ai_enabled"] is False
|
|
assert data["allowed"] is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_quota_returns_enabled_with_key(client, auth_headers, enable_ai):
|
|
"""GET /ai/quota returns ai_enabled=true with API key configured."""
|
|
response = await client.get("/api/v1/ai/quota", headers=auth_headers)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["ai_enabled"] is True
|
|
assert data["allowed"] is True
|
|
|
|
|
|
# ── Start endpoint ──
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_requires_auth(client, enable_ai):
|
|
"""POST /ai/start requires authentication."""
|
|
response = await client.post("/api/v1/ai/start", json={
|
|
"flow_type": "troubleshooting",
|
|
"name": "Test Flow",
|
|
"description": "Test",
|
|
})
|
|
assert response.status_code == 401
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_returns_503_when_disabled(client, auth_headers, disable_ai):
|
|
"""POST /ai/start returns 503 when AI is not configured."""
|
|
response = await client.post(
|
|
"/api/v1/ai/start",
|
|
json={
|
|
"flow_type": "troubleshooting",
|
|
"name": "Test Flow",
|
|
"description": "Test description",
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 503
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_creates_conversation(client, auth_headers, enable_ai):
|
|
"""POST /ai/start creates a conversation and returns conversation_id."""
|
|
response = await client.post(
|
|
"/api/v1/ai/start",
|
|
json={
|
|
"flow_type": "troubleshooting",
|
|
"name": "DNS Issues",
|
|
"description": "Troubleshooting DNS resolution failures",
|
|
"environment_tags": ["Windows Server", "Active Directory"],
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert "conversation_id" in data
|
|
assert data["status"] == "foundation"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_validates_input(client, auth_headers, enable_ai):
|
|
"""POST /ai/start rejects invalid input."""
|
|
response = await client.post(
|
|
"/api/v1/ai/start",
|
|
json={
|
|
"flow_type": "troubleshooting",
|
|
"name": "", # Empty name
|
|
"description": "Test",
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
|
|
# ── Scaffold endpoint ──
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scaffold_success(client, auth_headers, enable_ai):
|
|
"""POST /ai/scaffold returns AI-generated branches."""
|
|
# Create conversation first
|
|
start_resp = await client.post(
|
|
"/api/v1/ai/start",
|
|
json={
|
|
"flow_type": "troubleshooting",
|
|
"name": "DNS Issues",
|
|
"description": "DNS resolution failures",
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
conversation_id = start_resp.json()["conversation_id"]
|
|
|
|
# Mock AI provider
|
|
mock_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=mock_provider):
|
|
response = await client.post(
|
|
"/api/v1/ai/scaffold",
|
|
json={"conversation_id": conversation_id},
|
|
headers=auth_headers,
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "scaffolding"
|
|
assert len(data["branches"]) == 4
|
|
assert data["branches"][0]["name"] == "Service Not Running"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scaffold_invalid_conversation(client, auth_headers, enable_ai):
|
|
"""POST /ai/scaffold returns 404 for nonexistent conversation."""
|
|
response = await client.post(
|
|
"/api/v1/ai/scaffold",
|
|
json={"conversation_id": "00000000-0000-0000-0000-000000000000"},
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
|
|
# ── Branch detail endpoint ──
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_branch_detail_success(client, auth_headers, enable_ai):
|
|
"""POST /ai/branch-detail returns AI-generated branch nodes."""
|
|
# Create and scaffold first
|
|
start_resp = await client.post(
|
|
"/api/v1/ai/start",
|
|
json={
|
|
"flow_type": "troubleshooting",
|
|
"name": "Service Issues",
|
|
"description": "Service troubleshooting",
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
conversation_id = start_resp.json()["conversation_id"]
|
|
|
|
scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
|
|
await client.post(
|
|
"/api/v1/ai/scaffold",
|
|
json={"conversation_id": conversation_id},
|
|
headers=auth_headers,
|
|
)
|
|
|
|
# Now generate branch detail
|
|
detail_provider = _mock_ai_provider(BRANCH_DETAIL_JSON)
|
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=detail_provider):
|
|
response = await client.post(
|
|
"/api/v1/ai/branch-detail",
|
|
json={
|
|
"conversation_id": conversation_id,
|
|
"branch_name": "Service Not Running",
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["branch_name"] == "Service Not Running"
|
|
assert data["steps"]["id"] == "svc-root"
|
|
assert data["steps"]["type"] == "decision"
|
|
|
|
|
|
# ── Assemble endpoint ──
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_success(client, auth_headers, enable_ai):
|
|
"""POST /ai/assemble returns assembled tree from branches with detail."""
|
|
# Create conversation
|
|
start_resp = await client.post(
|
|
"/api/v1/ai/start",
|
|
json={
|
|
"flow_type": "troubleshooting",
|
|
"name": "Service Issues",
|
|
"description": "Service troubleshooting",
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
conversation_id = start_resp.json()["conversation_id"]
|
|
|
|
# Scaffold
|
|
scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
|
|
await client.post(
|
|
"/api/v1/ai/scaffold",
|
|
json={"conversation_id": conversation_id},
|
|
headers=auth_headers,
|
|
)
|
|
|
|
# Assemble with branch detail included
|
|
branch_tree = json.loads(BRANCH_DETAIL_JSON)
|
|
response = await client.post(
|
|
"/api/v1/ai/assemble",
|
|
json={
|
|
"conversation_id": conversation_id,
|
|
"selected_branches": [
|
|
{
|
|
"name": "Service Not Running",
|
|
"description": "The target service is stopped.",
|
|
"steps": branch_tree,
|
|
},
|
|
{
|
|
"name": "Authentication Failures",
|
|
"description": "Users cannot authenticate.",
|
|
"steps": branch_tree,
|
|
},
|
|
],
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "completed"
|
|
assert data["suggested_name"] == "Service Issues"
|
|
assert "tree_structure" in data
|
|
assert data["tree_structure"]["type"] == "decision"
|
|
assert data["summary"]["node_count"] > 0
|
|
assert data["summary"]["solution_count"] >= 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_requires_min_2_branches(client, auth_headers, enable_ai):
|
|
"""POST /ai/assemble rejects fewer than 2 branches."""
|
|
start_resp = await client.post(
|
|
"/api/v1/ai/start",
|
|
json={
|
|
"flow_type": "troubleshooting",
|
|
"name": "Test",
|
|
"description": "Test",
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
conversation_id = start_resp.json()["conversation_id"]
|
|
|
|
response = await client.post(
|
|
"/api/v1/ai/assemble",
|
|
json={
|
|
"conversation_id": conversation_id,
|
|
"selected_branches": [
|
|
{"name": "Only Branch", "description": "Just one"},
|
|
],
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 422
|