feat: add POST /ai/fix-tree endpoint for AI-powered validation fixes
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
78
backend/app/api/endpoints/ai_fix.py
Normal file
78
backend/app/api/endpoints/ai_fix.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""AI auto-fix endpoint for tree validation errors.
|
||||
|
||||
POST /ai/fix-tree — accepts a tree with validation errors and returns
|
||||
AI-generated fix proposals for each error.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_db, require_engineer_or_admin
|
||||
from app.core.config import settings
|
||||
from app.core.rate_limit import limiter
|
||||
from app.core.ai_fix_service import generate_fixes
|
||||
from app.models.user import User
|
||||
from app.schemas.ai_fix import (
|
||||
AIFixTreeRequest,
|
||||
AIFixTreeResponse,
|
||||
AIFixProposal,
|
||||
AIFixTokenUsage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ai", tags=["ai-fix"])
|
||||
|
||||
|
||||
def _require_ai_enabled() -> None:
|
||||
"""Raise 503 if AI is not configured."""
|
||||
if not settings.ai_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="AI fix is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/fix-tree", response_model=AIFixTreeResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def fix_tree(
|
||||
request: Request,
|
||||
body: AIFixTreeRequest,
|
||||
user: Annotated[User, Depends(require_engineer_or_admin)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Generate AI-powered fixes for tree validation errors."""
|
||||
_require_ai_enabled()
|
||||
|
||||
validation_errors = [
|
||||
{"node_id": e.node_id, "message": e.message}
|
||||
for e in body.validation_errors
|
||||
]
|
||||
|
||||
try:
|
||||
fixes, input_tokens, output_tokens = await generate_fixes(
|
||||
tree_structure=body.tree_structure,
|
||||
tree_name=body.tree_name,
|
||||
tree_type=body.tree_type,
|
||||
validation_errors=validation_errors,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
logger.error("AI provider not available: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=str(exc),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error in AI fix service")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred while generating fixes.",
|
||||
)
|
||||
|
||||
return AIFixTreeResponse(
|
||||
fixes=[AIFixProposal(**f) for f in fixes],
|
||||
tokens_used=AIFixTokenUsage(input=input_tokens, output=output_tokens),
|
||||
)
|
||||
@@ -6,6 +6,7 @@ from app.api.endpoints import target_lists
|
||||
from app.api.endpoints import maintenance_schedules
|
||||
from app.api.endpoints import feedback
|
||||
from app.api.endpoints import ai_builder
|
||||
from app.api.endpoints import ai_fix
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -36,3 +37,4 @@ api_router.include_router(target_lists.router)
|
||||
api_router.include_router(maintenance_schedules.router)
|
||||
api_router.include_router(feedback.router)
|
||||
api_router.include_router(ai_builder.router)
|
||||
api_router.include_router(ai_fix.router)
|
||||
|
||||
169
backend/tests/test_ai_fix_endpoint.py
Normal file
169
backend/tests/test_ai_fix_endpoint.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Integration tests for the POST /ai/fix-tree endpoint."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
# ── Sample tree (has a decision node with only 1 option + 1 child) ──
|
||||
|
||||
SAMPLE_TREE = {
|
||||
"id": "root",
|
||||
"type": "decision",
|
||||
"question": "Is the server up?",
|
||||
"options": [
|
||||
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
|
||||
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
|
||||
],
|
||||
"children": [
|
||||
{
|
||||
"id": "check-logs",
|
||||
"type": "action",
|
||||
"title": "Check Logs",
|
||||
"description": "Review logs.",
|
||||
"next_node_id": "logs-ok",
|
||||
},
|
||||
{
|
||||
"id": "logs-ok",
|
||||
"type": "solution",
|
||||
"title": "Logs OK",
|
||||
"description": "Issue in logs.",
|
||||
},
|
||||
{
|
||||
"id": "restart",
|
||||
"type": "decision",
|
||||
"question": "Did restart work?",
|
||||
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
|
||||
"children": [
|
||||
{
|
||||
"id": "done",
|
||||
"type": "solution",
|
||||
"title": "Done",
|
||||
"description": "Fixed.",
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Fixed version of the "restart" node — 2 options, 2 children
|
||||
FIXED_RESTART_NODE = {
|
||||
"id": "restart",
|
||||
"type": "decision",
|
||||
"question": "Did restart work?",
|
||||
"options": [
|
||||
{"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"},
|
||||
{"id": "opt-r-no", "label": "No", "next_node_id": "escalate"},
|
||||
],
|
||||
"children": [
|
||||
{
|
||||
"id": "done",
|
||||
"type": "solution",
|
||||
"title": "Done",
|
||||
"description": "Fixed.",
|
||||
},
|
||||
{
|
||||
"id": "escalate",
|
||||
"type": "solution",
|
||||
"title": "Escalate",
|
||||
"description": "Escalate to senior engineer.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
FIX_REQUEST_BODY = {
|
||||
"tree_structure": SAMPLE_TREE,
|
||||
"tree_name": "Server Troubleshooting",
|
||||
"tree_type": "troubleshooting",
|
||||
"validation_errors": [
|
||||
{
|
||||
"node_id": "restart",
|
||||
"message": "Decision node 'restart' must have at least 2 options",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _mock_ai_provider(response_text: str, input_tokens: int = 50, output_tokens: int = 100):
|
||||
"""Create a mock provider whose generate_json returns given text."""
|
||||
provider = MagicMock()
|
||||
provider.generate_json = AsyncMock(return_value=(response_text, input_tokens, output_tokens))
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_ai():
|
||||
"""Temporarily enable AI by setting a fake API key."""
|
||||
original = settings.GOOGLE_AI_API_KEY
|
||||
settings.GOOGLE_AI_API_KEY = "test-key-fake"
|
||||
yield
|
||||
settings.GOOGLE_AI_API_KEY = original
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_ai():
|
||||
"""Ensure AI is disabled."""
|
||||
orig_google = settings.GOOGLE_AI_API_KEY
|
||||
orig_anthropic = settings.ANTHROPIC_API_KEY
|
||||
settings.GOOGLE_AI_API_KEY = None
|
||||
settings.ANTHROPIC_API_KEY = None
|
||||
yield
|
||||
settings.GOOGLE_AI_API_KEY = orig_google
|
||||
settings.ANTHROPIC_API_KEY = orig_anthropic
|
||||
|
||||
|
||||
# ── Tests ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_401_without_auth(client):
|
||||
"""POST /ai/fix-tree without auth token returns 401."""
|
||||
response = await client.post("/api/v1/ai/fix-tree", json=FIX_REQUEST_BODY)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_503_when_ai_disabled(client, auth_headers, disable_ai):
|
||||
"""POST /ai/fix-tree returns 503 when no AI keys are configured."""
|
||||
response = await client.post(
|
||||
"/api/v1/ai/fix-tree",
|
||||
json=FIX_REQUEST_BODY,
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_fixes_on_success(client, auth_headers, enable_ai):
|
||||
"""POST /ai/fix-tree returns fix proposals when AI succeeds."""
|
||||
mock_provider = _mock_ai_provider(json.dumps(FIXED_RESTART_NODE))
|
||||
|
||||
with patch(
|
||||
"app.core.ai_fix_service.get_ai_provider",
|
||||
return_value=mock_provider,
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/ai/fix-tree",
|
||||
json=FIX_REQUEST_BODY,
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "fixes" in data
|
||||
assert "tokens_used" in data
|
||||
assert len(data["fixes"]) == 1
|
||||
|
||||
fix = data["fixes"][0]
|
||||
assert fix["target_node_id"] == "restart"
|
||||
assert fix["error_message"] == "Decision node 'restart' must have at least 2 options"
|
||||
assert fix["original_node"]["id"] == "restart"
|
||||
assert fix["fixed_node"]["id"] == "restart"
|
||||
assert len(fix["fixed_node"]["options"]) == 2
|
||||
assert len(fix["fixed_node"]["children"]) == 2
|
||||
|
||||
assert data["tokens_used"]["input"] == 50
|
||||
assert data["tokens_used"]["output"] == 100
|
||||
Reference in New Issue
Block a user