From b3925150d71c26a5572a4ecdb82a1e559ab57f2c Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:25:38 -0500 Subject: [PATCH] feat: add POST /ai/fix-tree endpoint for AI-powered validation fixes Co-Authored-By: Claude Opus 4.6 --- backend/app/api/endpoints/ai_fix.py | 78 ++++++++++++ backend/app/api/router.py | 2 + backend/tests/test_ai_fix_endpoint.py | 169 ++++++++++++++++++++++++++ 3 files changed, 249 insertions(+) create mode 100644 backend/app/api/endpoints/ai_fix.py create mode 100644 backend/tests/test_ai_fix_endpoint.py diff --git a/backend/app/api/endpoints/ai_fix.py b/backend/app/api/endpoints/ai_fix.py new file mode 100644 index 00000000..97ecbaf9 --- /dev/null +++ b/backend/app/api/endpoints/ai_fix.py @@ -0,0 +1,78 @@ +"""AI auto-fix endpoint for tree validation errors. + +POST /ai/fix-tree — accepts a tree with validation errors and returns +AI-generated fix proposals for each error. +""" + +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_db, require_engineer_or_admin +from app.core.config import settings +from app.core.rate_limit import limiter +from app.core.ai_fix_service import generate_fixes +from app.models.user import User +from app.schemas.ai_fix import ( + AIFixTreeRequest, + AIFixTreeResponse, + AIFixProposal, + AIFixTokenUsage, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ai", tags=["ai-fix"]) + + +def _require_ai_enabled() -> None: + """Raise 503 if AI is not configured.""" + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI fix is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) + + +@router.post("/fix-tree", response_model=AIFixTreeResponse) +@limiter.limit("10/minute") +async def fix_tree( + request: Request, + body: AIFixTreeRequest, + user: Annotated[User, Depends(require_engineer_or_admin)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Generate AI-powered fixes for tree validation errors.""" + _require_ai_enabled() + + validation_errors = [ + {"node_id": e.node_id, "message": e.message} + for e in body.validation_errors + ] + + try: + fixes, input_tokens, output_tokens = await generate_fixes( + tree_structure=body.tree_structure, + tree_name=body.tree_name, + tree_type=body.tree_type, + validation_errors=validation_errors, + ) + except RuntimeError as exc: + logger.error("AI provider not available: %s", exc) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(exc), + ) + except Exception as exc: + logger.exception("Unexpected error in AI fix service") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred while generating fixes.", + ) + + return AIFixTreeResponse( + fixes=[AIFixProposal(**f) for f in fixes], + tokens_used=AIFixTokenUsage(input=input_tokens, output=output_tokens), + ) diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 2c79e039..27963a1f 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -6,6 +6,7 @@ from app.api.endpoints import target_lists from app.api.endpoints import maintenance_schedules from app.api.endpoints import feedback from app.api.endpoints import ai_builder +from app.api.endpoints import ai_fix api_router = APIRouter() @@ -36,3 +37,4 @@ api_router.include_router(target_lists.router) api_router.include_router(maintenance_schedules.router) api_router.include_router(feedback.router) api_router.include_router(ai_builder.router) +api_router.include_router(ai_fix.router) diff --git a/backend/tests/test_ai_fix_endpoint.py b/backend/tests/test_ai_fix_endpoint.py new file mode 100644 index 00000000..a81a598e --- /dev/null +++ b/backend/tests/test_ai_fix_endpoint.py @@ -0,0 +1,169 @@ +"""Integration tests for the POST /ai/fix-tree endpoint.""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.core.config import settings + + +# ── Sample tree (has a decision node with only 1 option + 1 child) ── + +SAMPLE_TREE = { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs OK", + "description": "Issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}], + "children": [ + { + "id": "done", + "type": "solution", + "title": "Done", + "description": "Fixed.", + } + ], + }, + ], +} + +# Fixed version of the "restart" node — 2 options, 2 children +FIXED_RESTART_NODE = { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [ + {"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"}, + {"id": "opt-r-no", "label": "No", "next_node_id": "escalate"}, + ], + "children": [ + { + "id": "done", + "type": "solution", + "title": "Done", + "description": "Fixed.", + }, + { + "id": "escalate", + "type": "solution", + "title": "Escalate", + "description": "Escalate to senior engineer.", + }, + ], +} + +FIX_REQUEST_BODY = { + "tree_structure": SAMPLE_TREE, + "tree_name": "Server Troubleshooting", + "tree_type": "troubleshooting", + "validation_errors": [ + { + "node_id": "restart", + "message": "Decision node 'restart' must have at least 2 options", + } + ], +} + + +def _mock_ai_provider(response_text: str, input_tokens: int = 50, output_tokens: int = 100): + """Create a mock provider whose generate_json returns given text.""" + provider = MagicMock() + provider.generate_json = AsyncMock(return_value=(response_text, input_tokens, output_tokens)) + return provider + + +@pytest.fixture +def enable_ai(): + """Temporarily enable AI by setting a fake API key.""" + original = settings.GOOGLE_AI_API_KEY + settings.GOOGLE_AI_API_KEY = "test-key-fake" + yield + settings.GOOGLE_AI_API_KEY = original + + +@pytest.fixture +def disable_ai(): + """Ensure AI is disabled.""" + orig_google = settings.GOOGLE_AI_API_KEY + orig_anthropic = settings.ANTHROPIC_API_KEY + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + yield + settings.GOOGLE_AI_API_KEY = orig_google + settings.ANTHROPIC_API_KEY = orig_anthropic + + +# ── Tests ── + + +@pytest.mark.asyncio +async def test_returns_401_without_auth(client): + """POST /ai/fix-tree without auth token returns 401.""" + response = await client.post("/api/v1/ai/fix-tree", json=FIX_REQUEST_BODY) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_returns_503_when_ai_disabled(client, auth_headers, disable_ai): + """POST /ai/fix-tree returns 503 when no AI keys are configured.""" + response = await client.post( + "/api/v1/ai/fix-tree", + json=FIX_REQUEST_BODY, + headers=auth_headers, + ) + assert response.status_code == 503 + + +@pytest.mark.asyncio +async def test_returns_fixes_on_success(client, auth_headers, enable_ai): + """POST /ai/fix-tree returns fix proposals when AI succeeds.""" + mock_provider = _mock_ai_provider(json.dumps(FIXED_RESTART_NODE)) + + with patch( + "app.core.ai_fix_service.get_ai_provider", + return_value=mock_provider, + ): + response = await client.post( + "/api/v1/ai/fix-tree", + json=FIX_REQUEST_BODY, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert "fixes" in data + assert "tokens_used" in data + assert len(data["fixes"]) == 1 + + fix = data["fixes"][0] + assert fix["target_node_id"] == "restart" + assert fix["error_message"] == "Decision node 'restart' must have at least 2 options" + assert fix["original_node"]["id"] == "restart" + assert fix["fixed_node"]["id"] == "restart" + assert len(fix["fixed_node"]["options"]) == 2 + assert len(fix["fixed_node"]["children"]) == 2 + + assert data["tokens_used"]["input"] == 50 + assert data["tokens_used"]["output"] == 100