170 lines
4.8 KiB
Python
170 lines
4.8 KiB
Python
"""Integration tests for the POST /ai/fix-tree endpoint."""
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
# ── Sample tree (has a decision node with only 1 option + 1 child) ──
|
|
|
|
SAMPLE_TREE = {
|
|
"id": "root",
|
|
"type": "decision",
|
|
"question": "Is the server up?",
|
|
"options": [
|
|
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
|
|
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
|
|
],
|
|
"children": [
|
|
{
|
|
"id": "check-logs",
|
|
"type": "action",
|
|
"title": "Check Logs",
|
|
"description": "Review logs.",
|
|
"next_node_id": "logs-ok",
|
|
},
|
|
{
|
|
"id": "logs-ok",
|
|
"type": "solution",
|
|
"title": "Logs OK",
|
|
"description": "Issue in logs.",
|
|
},
|
|
{
|
|
"id": "restart",
|
|
"type": "decision",
|
|
"question": "Did restart work?",
|
|
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
|
|
"children": [
|
|
{
|
|
"id": "done",
|
|
"type": "solution",
|
|
"title": "Done",
|
|
"description": "Fixed.",
|
|
}
|
|
],
|
|
},
|
|
],
|
|
}
|
|
|
|
# Fixed version of the "restart" node — 2 options, 2 children
|
|
FIXED_RESTART_NODE = {
|
|
"id": "restart",
|
|
"type": "decision",
|
|
"question": "Did restart work?",
|
|
"options": [
|
|
{"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"},
|
|
{"id": "opt-r-no", "label": "No", "next_node_id": "escalate"},
|
|
],
|
|
"children": [
|
|
{
|
|
"id": "done",
|
|
"type": "solution",
|
|
"title": "Done",
|
|
"description": "Fixed.",
|
|
},
|
|
{
|
|
"id": "escalate",
|
|
"type": "solution",
|
|
"title": "Escalate",
|
|
"description": "Escalate to senior engineer.",
|
|
},
|
|
],
|
|
}
|
|
|
|
FIX_REQUEST_BODY = {
|
|
"tree_structure": SAMPLE_TREE,
|
|
"tree_name": "Server Troubleshooting",
|
|
"tree_type": "troubleshooting",
|
|
"validation_errors": [
|
|
{
|
|
"node_id": "restart",
|
|
"message": "Decision node 'restart' must have at least 2 options",
|
|
}
|
|
],
|
|
}
|
|
|
|
|
|
def _mock_ai_provider(response_text: str, input_tokens: int = 50, output_tokens: int = 100):
|
|
"""Create a mock provider whose generate_json returns given text."""
|
|
provider = MagicMock()
|
|
provider.generate_json = AsyncMock(return_value=(response_text, input_tokens, output_tokens))
|
|
return provider
|
|
|
|
|
|
@pytest.fixture
|
|
def enable_ai():
|
|
"""Temporarily enable AI by setting a fake API key."""
|
|
original = settings.GOOGLE_AI_API_KEY
|
|
settings.GOOGLE_AI_API_KEY = "test-key-fake"
|
|
yield
|
|
settings.GOOGLE_AI_API_KEY = original
|
|
|
|
|
|
@pytest.fixture
|
|
def disable_ai():
|
|
"""Ensure AI is disabled."""
|
|
orig_google = settings.GOOGLE_AI_API_KEY
|
|
orig_anthropic = settings.ANTHROPIC_API_KEY
|
|
settings.GOOGLE_AI_API_KEY = None
|
|
settings.ANTHROPIC_API_KEY = None
|
|
yield
|
|
settings.GOOGLE_AI_API_KEY = orig_google
|
|
settings.ANTHROPIC_API_KEY = orig_anthropic
|
|
|
|
|
|
# ── Tests ──
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_401_without_auth(client):
|
|
"""POST /ai/fix-tree without auth token returns 401."""
|
|
response = await client.post("/api/v1/ai/fix-tree", json=FIX_REQUEST_BODY)
|
|
assert response.status_code == 401
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_503_when_ai_disabled(client, auth_headers, disable_ai):
|
|
"""POST /ai/fix-tree returns 503 when no AI keys are configured."""
|
|
response = await client.post(
|
|
"/api/v1/ai/fix-tree",
|
|
json=FIX_REQUEST_BODY,
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 503
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_fixes_on_success(client, auth_headers, enable_ai):
|
|
"""POST /ai/fix-tree returns fix proposals when AI succeeds."""
|
|
mock_provider = _mock_ai_provider(json.dumps(FIXED_RESTART_NODE))
|
|
|
|
with patch(
|
|
"app.core.ai_fix_service.get_ai_provider",
|
|
return_value=mock_provider,
|
|
):
|
|
response = await client.post(
|
|
"/api/v1/ai/fix-tree",
|
|
json=FIX_REQUEST_BODY,
|
|
headers=auth_headers,
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "fixes" in data
|
|
assert "tokens_used" in data
|
|
assert len(data["fixes"]) == 1
|
|
|
|
fix = data["fixes"][0]
|
|
assert fix["target_node_id"] == "restart"
|
|
assert fix["error_message"] == "Decision node 'restart' must have at least 2 options"
|
|
assert fix["original_node"]["id"] == "restart"
|
|
assert fix["fixed_node"]["id"] == "restart"
|
|
assert len(fix["fixed_node"]["options"]) == 2
|
|
assert len(fix["fixed_node"]["children"]) == 2
|
|
|
|
assert data["tokens_used"]["input"] == 50
|
|
assert data["tokens_used"]["output"] == 100
|