From 6527b33d05225a6b02ba8e701087085001611896 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 16:28:33 -0500 Subject: [PATCH 01/14] docs: add AI auto-fix and Gemini Flash provider design Design for two combined features: Gemini 2.5 Flash as primary AI provider with Claude fallback, and AI-powered auto-fix for validation errors in the tree editor. Co-Authored-By: Claude Opus 4.6 --- .../2026-02-26-ai-autofix-gemini-design.md | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 docs/plans/2026-02-26-ai-autofix-gemini-design.md diff --git a/docs/plans/2026-02-26-ai-autofix-gemini-design.md b/docs/plans/2026-02-26-ai-autofix-gemini-design.md new file mode 100644 index 00000000..7d241a00 --- /dev/null +++ b/docs/plans/2026-02-26-ai-autofix-gemini-design.md @@ -0,0 +1,209 @@ +# AI Auto-Fix & Gemini Flash Provider Design + +> **Date:** 2026-02-26 +> **Status:** Approved + +--- + +## Overview + +Two combined features: + +1. **AI Provider Abstraction** — Add Gemini 2.5 Flash as the default AI provider with Claude as fallback, behind a unified interface. +2. **AI Auto-Fix for Validation Errors** — When a flow fails validation, offer an AI-powered "Fix with AI" button that generates structural fixes for review. + +--- + +## Section 1: AI Provider Abstraction + +### Design + +New `backend/app/core/ai_provider.py` with a unified interface: + +```python +class AIProvider(ABC): + async def generate_json( + self, + system_prompt: str, + messages: list[dict], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + """Returns (text, input_tokens, output_tokens)""" +``` + +Two implementations: + +| Provider | Model | SDK | Role | +|----------|-------|-----|------| +| `GeminiProvider` | `gemini-2.5-flash` | `google-genai` | Default | +| `AnthropicProvider` | `claude-haiku-4-5-20251001` | `anthropic` | Fallback | + +### Provider Selection + +- `get_ai_provider()` factory reads `AI_PROVIDER` env var (default: `"gemini"`) +- Falls back to Anthropic if Gemini key is missing +- Existing `ai_tree_generator_service.py` swaps direct Anthropic calls for `get_ai_provider()` + +### New Environment Variables + +| Variable | Default | Purpose | +|----------|---------|---------| +| `AI_PROVIDER` | `"gemini"` | Which provider to use (`gemini` or `anthropic`) | +| `GOOGLE_AI_API_KEY` | — | Gemini API key | + +Existing `ANTHROPIC_API_KEY` remains for fallback. + +### Config Changes (`core/config.py`) + +```python +AI_PROVIDER: str = "gemini" +GOOGLE_AI_API_KEY: str | None = None +AI_MODEL_GEMINI: str = "gemini-2.5-flash" +AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001" +``` + +--- + +## Section 2: AI Auto-Fix Feature + +### Backend Endpoint + +**`POST /api/v1/ai/fix-tree`** + +Request: +```json +{ + "tree_structure": { /* full tree */ }, + "tree_name": "Router Troubleshooting", + "tree_type": "troubleshooting", + "validation_errors": [ + { + "node_id": "node_abc", + "message": "Decision node must have at least 2 children (branches)" + } + ] +} +``` + +Response: +```json +{ + "fixes": [ + { + "target_node_id": "node_abc", + "error_message": "Decision node must have at least 2 children (branches)", + "description": "Added second branch 'Check firmware version' with solution node", + "original_node": { /* snapshot before fix */ }, + "fixed_node": { /* replacement node with corrected subtree */ } + } + ], + "tokens_used": { "input": 1200, "output": 800 } +} +``` + +### How It Works + +1. For each validation error tied to a `node_id`, extract that node + its parent + siblings from the tree. +2. Build a prompt with: + - The **full tree structure** serialized as a simplified outline (node titles + types + structure) for context + - The **specific failing node** highlighted with full JSON detail + - The **validation error message** + - Instructions: "Fix ONLY this node's structural issue. Keep all existing content. Generate domain-relevant additions that fit the flow's topic." +3. AI returns a corrected version of that node (with children/options adjusted). +4. Backend re-validates the fixed node before returning it. +5. If re-validation fails, retry once with the error fed back (corrective prompt pattern). + +### Prompt Strategy + +The prompt gives the AI the full tree as a compact outline, then zooms into the failing node: + +``` +You are fixing a validation error in a troubleshooting flow called "Router Troubleshooting". + +FULL FLOW OUTLINE: +- [decision] Is the router powered on? + - [action] Check power cable → [solution] Power restored + - [decision] Are lights blinking? ← ERROR HERE + - [solution] Contact ISP + +ERROR: Decision node "Are lights blinking?" must have at least 2 children (branches). + +FAILING NODE (full detail): +{...json...} + +Fix this node by adding the minimum structure needed to resolve the error. +Return ONLY the fixed node as JSON. +``` + +### Frontend UX + +1. **Trigger**: "Fix with AI" button in `ValidationSummary` — appears when there are fixable errors (structural errors with a `node_id`). +2. **Loading state**: Button shows spinner + "Generating fixes..." — disabled during request. +3. **Review modal** (`AIFixReviewModal`): Shows each proposed fix as a card: + - Error message at top + - Before/after view of the node change + - "Apply" / "Skip" buttons per fix + - "Apply All" button in footer +4. **Apply**: Each accepted fix calls `updateNode(targetNodeId, fixedNode)` in the tree editor store. +5. **Re-validate**: After applying fixes, auto-run `validate()` to confirm resolution. + +--- + +## Section 3: Scope & Constraints + +### Fixable Errors (Auto-Fix Scope) + +Only structural validation errors with a `node_id`: +- Decision node missing children/branches +- Decision node missing options +- Action node missing `next_node_id` +- Dead-end decision nodes (no children) + +### NOT Fixable + +- Global checks (tree too small/large, not enough solutions) — require rethinking the whole tree +- Content quality issues — out of scope +- Errors without a `node_id` (root-level issues) + +Non-fixable errors still show in ValidationSummary but without the "Fix with AI" option. + +### Token Budget + +- Tree outline: ~50-100 tokens for a typical 15-node tree +- Failing node detail: ~100-200 tokens +- System prompt + instructions: ~300 tokens +- **Total input per fix: ~500-600 tokens** +- One API call per failing node (not batched) + +### Error Handling + +- Provider failure (rate limit, network): toast error, user can retry +- Fix fails re-validation: "AI couldn't generate a valid fix" with retry option +- Max 1 retry with corrective prompt per attempt +- Both provider and fallback fail: surface error to user + +### Auth + +- Requires `engineer` role or above (`require_engineer_or_admin`) + +--- + +## New Files + +| File | Purpose | +|------|---------| +| `backend/app/core/ai_provider.py` | Provider abstraction + Gemini/Anthropic implementations | +| `backend/app/core/ai_fix_service.py` | Fix generation logic + prompt building | +| `backend/app/api/endpoints/ai.py` | `POST /ai/fix-tree` endpoint | +| `backend/app/schemas/ai.py` | Request/response schemas for AI endpoints | +| `frontend/src/components/tree-editor/AIFixReviewModal.tsx` | Review modal for proposed fixes | + +## Modified Files + +| File | Change | +|------|--------| +| `backend/app/core/config.py` | Add Gemini config vars | +| `backend/app/core/ai_tree_generator_service.py` | Swap Anthropic calls for provider abstraction | +| `backend/app/api/router.py` | Register `/ai` routes | +| `frontend/src/api/trees.ts` | Add `fixTree()` API call | +| `frontend/src/components/tree-editor/ValidationSummary.tsx` | Add "Fix with AI" button | -- 2.49.1 From 5df32aa9dad2ad052694f1feb9ee265eef317fb6 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 16:33:34 -0500 Subject: [PATCH 02/14] docs: add implementation plan for AI auto-fix and Gemini provider 12-task plan covering: SDK install, config, provider abstraction, service migration, fix service, endpoint, frontend types/API, ValidationSummary button, review modal, and TreeEditorPage wiring. Co-Authored-By: Claude Opus 4.6 --- .../2026-02-26-ai-autofix-gemini-plan.md | 1707 +++++++++++++++++ 1 file changed, 1707 insertions(+) create mode 100644 docs/plans/2026-02-26-ai-autofix-gemini-plan.md diff --git a/docs/plans/2026-02-26-ai-autofix-gemini-plan.md b/docs/plans/2026-02-26-ai-autofix-gemini-plan.md new file mode 100644 index 00000000..4406024a --- /dev/null +++ b/docs/plans/2026-02-26-ai-autofix-gemini-plan.md @@ -0,0 +1,1707 @@ +# AI Auto-Fix & Gemini Flash Provider Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add Gemini 2.5 Flash as primary AI provider (with Claude fallback), then build an AI-powered "Fix with AI" feature that generates structural fixes for validation errors in the tree editor. + +**Architecture:** A provider abstraction layer (`ai_provider.py`) wraps Gemini and Anthropic SDKs behind a unified `generate_json()` interface. The existing `ai_tree_generator_service.py` swaps its direct Anthropic calls for this abstraction. A new `ai_fix_service.py` builds prompts from validation errors + tree context and returns proposed node patches. The frontend adds a "Fix with AI" button to `ValidationSummary` and a review modal for applying fixes. + +**Tech Stack:** Python FastAPI, google-genai SDK, anthropic SDK, Pydantic v2, React 19, TypeScript, Zustand, Tailwind CSS + +**Design Doc:** `docs/plans/2026-02-26-ai-autofix-gemini-design.md` + +--- + +## Task 1: Install google-genai SDK + +**Files:** +- Modify: `backend/requirements.txt` + +**Step 1: Add the dependency** + +Add `google-genai` to `backend/requirements.txt`: + +``` +google-genai>=1.0.0 +``` + +**Step 2: Install it** + +Run: `cd backend && pip install google-genai` + +**Step 3: Commit** + +```bash +git add backend/requirements.txt +git commit -m "chore: add google-genai SDK dependency" +``` + +--- + +## Task 2: Add Gemini config vars to Settings + +**Files:** +- Modify: `backend/app/core/config.py:75-85` + +**Step 1: Add new config variables** + +In `backend/app/core/config.py`, after line 80 (`AI_REQUEST_TIMEOUT_SECONDS`), add: + +```python + # AI Provider selection + AI_PROVIDER: str = "gemini" # "gemini" or "anthropic" + GOOGLE_AI_API_KEY: Optional[str] = None + AI_MODEL_GEMINI: str = "gemini-2.5-flash" + AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001" +``` + +**Step 2: Update `ai_enabled` property** + +Replace the existing `ai_enabled` property (lines 82-85) with: + +```python + @property + def ai_enabled(self) -> bool: + """Check if any AI provider is configured.""" + return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None +``` + +**Step 3: Verify no tests break** + +Run: `cd backend && python -m pytest tests/test_ai_tree_validator.py -v` +Expected: All pass (config changes don't affect validator tests). + +**Step 4: Commit** + +```bash +git add backend/app/core/config.py +git commit -m "feat: add Gemini Flash config vars to Settings" +``` + +--- + +## Task 3: Build the AI provider abstraction + +**Files:** +- Create: `backend/app/core/ai_provider.py` +- Test: `backend/tests/test_ai_provider.py` + +**Step 1: Write tests for the provider abstraction** + +Create `backend/tests/test_ai_provider.py`: + +```python +"""Tests for AI provider abstraction layer.""" +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from app.core.config import settings + + +class TestGetAIProvider: + """Test provider factory function.""" + + def test_returns_gemini_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.GOOGLE_AI_API_KEY + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = "fake-gemini-key" + try: + from app.core.ai_provider import get_ai_provider, GeminiProvider + provider = get_ai_provider() + assert isinstance(provider, GeminiProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_key + + def test_returns_anthropic_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.ANTHROPIC_API_KEY + settings.AI_PROVIDER = "anthropic" + settings.ANTHROPIC_API_KEY = "fake-anthropic-key" + try: + from app.core.ai_provider import get_ai_provider, AnthropicProvider + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.ANTHROPIC_API_KEY = original_key + + def test_falls_back_to_anthropic_when_gemini_key_missing(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = "fake-anthropic-key" + try: + from app.core.ai_provider import get_ai_provider, AnthropicProvider + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + def test_raises_when_no_provider_configured(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + try: + from app.core.ai_provider import get_ai_provider + with pytest.raises(RuntimeError, match="No AI provider configured"): + get_ai_provider() + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + +class TestAnthropicProvider: + """Test Anthropic provider generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json_returns_text_and_tokens(self): + from app.core.ai_provider import AnthropicProvider + + mock_response = MagicMock() + mock_response.content = [MagicMock(text='{"key": "value"}')] + mock_response.usage = MagicMock(input_tokens=100, output_tokens=50) + + mock_client = AsyncMock() + mock_client.messages.create = AsyncMock(return_value=mock_response) + + with patch("app.core.ai_provider.anthropic.AsyncAnthropic", return_value=mock_client): + provider = AnthropicProvider(api_key="fake-key") + text, inp, out = await provider.generate_json( + system_prompt="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=1024, + ) + assert text == '{"key": "value"}' + assert inp == 100 + assert out == 50 + + +class TestGeminiProvider: + """Test Gemini provider generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json_returns_text_and_tokens(self): + from app.core.ai_provider import GeminiProvider + + mock_response = MagicMock() + mock_response.text = '{"key": "value"}' + mock_response.usage_metadata = MagicMock( + prompt_token_count=100, + candidates_token_count=50, + ) + + mock_client = MagicMock() + mock_model = MagicMock() + mock_model.generate_content_async = AsyncMock(return_value=mock_response) + mock_client.models = mock_model + + with patch("app.core.ai_provider.genai.Client", return_value=mock_client): + provider = GeminiProvider(api_key="fake-key") + text, inp, out = await provider.generate_json( + system_prompt="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=1024, + ) + assert text == '{"key": "value"}' + assert inp == 100 + assert out == 50 +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && python -m pytest tests/test_ai_provider.py -v` +Expected: ImportError — `ai_provider` module doesn't exist yet. + +**Step 3: Implement the provider abstraction** + +Create `backend/app/core/ai_provider.py`: + +```python +"""AI provider abstraction layer. + +Supports Gemini (default) and Anthropic (fallback) behind a unified interface. +""" +import logging +from abc import ABC, abstractmethod +from typing import Any + +import anthropic +from google import genai +from google.genai import types as genai_types + +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +class AIProvider(ABC): + """Base class for AI providers.""" + + @abstractmethod + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + """Generate a JSON response from the AI model. + + Args: + system_prompt: System instructions for the model. + messages: List of {"role": "user"|"assistant", "content": str} dicts. + max_tokens: Maximum output tokens. + + Returns: + (response_text, input_tokens, output_tokens) + """ + ... + + +class GeminiProvider(AIProvider): + """Google Gemini provider using google-genai SDK.""" + + def __init__(self, api_key: str, model: str | None = None): + self._api_key = api_key + self._model = model or settings.AI_MODEL_GEMINI + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + client = genai.Client(api_key=self._api_key) + + # Build contents: system instruction is separate in Gemini API + contents: list[genai_types.Content] = [] + for msg in messages: + role = "user" if msg["role"] == "user" else "model" + contents.append(genai_types.Content( + role=role, + parts=[genai_types.Part(text=msg["content"])], + )) + + config = genai_types.GenerateContentConfig( + system_instruction=system_prompt, + max_output_tokens=max_tokens, + response_mime_type="application/json", + ) + + response = await client.models.generate_content_async( + model=self._model, + contents=contents, + config=config, + ) + + text = response.text or "" + input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 + output_tokens = getattr(response.usage_metadata, "candidates_token_count", 0) or 0 + + return text, input_tokens, output_tokens + + +class AnthropicProvider(AIProvider): + """Anthropic Claude provider.""" + + def __init__(self, api_key: str, model: str | None = None): + self._api_key = api_key + self._model = model or settings.AI_MODEL_ANTHROPIC + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + client = anthropic.AsyncAnthropic( + api_key=self._api_key, + timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, + ) + + response = await client.messages.create( + model=self._model, + max_tokens=max_tokens, + system=system_prompt, + messages=messages, + ) + + text = response.content[0].text + input_tokens = response.usage.input_tokens + output_tokens = response.usage.output_tokens + + return text, input_tokens, output_tokens + + +def get_ai_provider() -> AIProvider: + """Factory: return the configured AI provider. + + Falls back to Anthropic if Gemini key is missing. + Raises RuntimeError if no provider is configured. + """ + if settings.AI_PROVIDER == "gemini" and settings.GOOGLE_AI_API_KEY: + logger.info("Using Gemini provider (%s)", settings.AI_MODEL_GEMINI) + return GeminiProvider(api_key=settings.GOOGLE_AI_API_KEY) + + if settings.AI_PROVIDER == "anthropic" and settings.ANTHROPIC_API_KEY: + logger.info("Using Anthropic provider (%s)", settings.AI_MODEL_ANTHROPIC) + return AnthropicProvider(api_key=settings.ANTHROPIC_API_KEY) + + # Fallback: if Gemini requested but key missing, try Anthropic + if settings.ANTHROPIC_API_KEY: + logger.warning("Gemini key missing, falling back to Anthropic provider") + return AnthropicProvider(api_key=settings.ANTHROPIC_API_KEY) + + raise RuntimeError( + "No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY." + ) +``` + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && python -m pytest tests/test_ai_provider.py -v` +Expected: All pass. + +**Step 5: Commit** + +```bash +git add backend/app/core/ai_provider.py backend/tests/test_ai_provider.py +git commit -m "feat: add AI provider abstraction with Gemini and Anthropic support" +``` + +--- + +## Task 4: Migrate ai_tree_generator_service to use provider abstraction + +**Files:** +- Modify: `backend/app/core/ai_tree_generator_service.py` +- Modify: `backend/app/api/endpoints/ai_builder.py` + +**Step 1: Update ai_tree_generator_service.py** + +In `backend/app/core/ai_tree_generator_service.py`: + +Replace the import at line 16: +```python +# OLD +import anthropic +# NEW +from app.core.ai_provider import get_ai_provider +``` + +Remove the `_get_client()` function (lines 124-131). + +Update `scaffold_branches()` (starting at line 141). Replace the client creation and API call: + +```python +async def scaffold_branches( + wizard_state: dict[str, Any], +) -> tuple[list[dict[str, str]], int, int, float]: + """Stage 2: AI suggests top-level branches.""" + provider = get_ai_provider() + + flow_type = wizard_state.get("flow_type", "troubleshooting") + name = wizard_state.get("name", "") + description = wizard_state.get("description", "") + tags = wizard_state.get("environment_tags", []) + + user_message = ( + f"Flow type: {flow_type}\n" + f"Name: {name}\n" + f"Description: {description}\n" + ) + if tags: + user_message += f"Environment: {', '.join(tags)}\n" + + raw_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=SCAFFOLD_SYSTEM_PROMPT, + messages=[{"role": "user", "content": user_message}], + max_tokens=1024, + ) +``` + +Then update the rest of the function to use `raw_text` instead of `response.content[0].text`, and `input_tokens`/`output_tokens` directly instead of `response.usage.*`. + +Do the same for `generate_branch_detail()` — replace `_get_client()` + `client.messages.create()` with `provider.generate_json()`. The retry loop structure stays the same; just swap the API call and response parsing. + +**Step 2: Update ai_builder.py endpoint** + +In `backend/app/api/endpoints/ai_builder.py`, line 13: + +```python +# OLD +import anthropic +# NEW (remove this import entirely — no longer needed in the endpoint file) +``` + +Update the `_require_ai_enabled()` function (line 50) to check the new config: + +```python +def _require_ai_enabled() -> None: + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI Flow Builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) +``` + +Update any `except anthropic.APIError` catch blocks in the endpoint to catch generic `Exception` or a broader error type, since the provider abstraction may raise different errors depending on backend. + +**Step 3: Run existing AI endpoint tests** + +Run: `cd backend && python -m pytest tests/test_ai_endpoints.py -v` + +These tests mock `anthropic.AsyncAnthropic` — they need to be updated to mock `app.core.ai_provider.get_ai_provider` instead. Update the mock targets in the test file: + +```python +# Replace patches like: +# @patch("app.core.ai_tree_generator_service._get_client") +# With: +# @patch("app.core.ai_tree_generator_service.get_ai_provider") +``` + +The mock provider should return an `AsyncMock` that returns `(json_text, input_tokens, output_tokens)` from `generate_json`. + +**Step 4: Verify all tests pass** + +Run: `cd backend && python -m pytest tests/test_ai_endpoints.py tests/test_ai_provider.py tests/test_ai_tree_validator.py -v` +Expected: All pass. + +**Step 5: Commit** + +```bash +git add backend/app/core/ai_tree_generator_service.py backend/app/api/endpoints/ai_builder.py backend/tests/test_ai_endpoints.py +git commit -m "refactor: migrate AI tree generator to provider abstraction" +``` + +--- + +## Task 5: Create AI fix schemas + +**Files:** +- Create: `backend/app/schemas/ai_fix.py` + +**Step 1: Create the schemas** + +Create `backend/app/schemas/ai_fix.py`: + +```python +"""Pydantic schemas for the AI auto-fix feature.""" +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class ValidationErrorInput(BaseModel): + """A single validation error to fix.""" + + node_id: str = Field(..., description="ID of the node with the error") + message: str = Field(..., description="The validation error message") + + +class AIFixTreeRequest(BaseModel): + """Request to generate AI fixes for validation errors.""" + + tree_structure: dict[str, Any] = Field(..., description="Full tree structure") + tree_name: str = Field("", max_length=255, description="Name of the flow") + tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field( + "troubleshooting", description="Type of flow" + ) + validation_errors: list[ValidationErrorInput] = Field( + ..., min_length=1, max_length=10, description="Errors to fix" + ) + + +class AIFixProposal(BaseModel): + """A single proposed fix from the AI.""" + + target_node_id: str + error_message: str + description: str + original_node: dict[str, Any] + fixed_node: dict[str, Any] + + +class AIFixTokenUsage(BaseModel): + input: int = 0 + output: int = 0 + + +class AIFixTreeResponse(BaseModel): + """Response with proposed fixes.""" + + fixes: list[AIFixProposal] + tokens_used: AIFixTokenUsage +``` + +**Step 2: Commit** + +```bash +git add backend/app/schemas/ai_fix.py +git commit -m "feat: add Pydantic schemas for AI fix-tree endpoint" +``` + +--- + +## Task 6: Build the AI fix service + +**Files:** +- Create: `backend/app/core/ai_fix_service.py` +- Test: `backend/tests/test_ai_fix_service.py` + +**Step 1: Write tests** + +Create `backend/tests/test_ai_fix_service.py`: + +```python +"""Tests for AI fix service — prompt building and node extraction.""" +import json +import pytest +from app.core.ai_fix_service import _serialize_tree_outline, _find_node_by_id, _find_parent_node + + +def _make_tree(): + return { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review event logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs Resolved", + "description": "Found issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart fix it?", + "options": [ + {"id": "opt-r-yes", "label": "Yes", "next_node_id": "restart-ok"}, + ], + "children": [ + { + "id": "restart-ok", + "type": "solution", + "title": "Restart Worked", + "description": "Server is back.", + }, + ], + }, + ], + } + + +class TestFindNodeById: + def test_finds_root(self): + tree = _make_tree() + node = _find_node_by_id(tree, "root") + assert node is not None + assert node["id"] == "root" + + def test_finds_nested_child(self): + tree = _make_tree() + node = _find_node_by_id(tree, "restart-ok") + assert node is not None + assert node["id"] == "restart-ok" + + def test_returns_none_for_missing(self): + tree = _make_tree() + assert _find_node_by_id(tree, "nonexistent") is None + + +class TestFindParentNode: + def test_root_has_no_parent(self): + tree = _make_tree() + assert _find_parent_node(tree, "root") is None + + def test_finds_parent_of_child(self): + tree = _make_tree() + parent = _find_parent_node(tree, "restart") + assert parent is not None + assert parent["id"] == "root" + + def test_finds_parent_of_deeply_nested(self): + tree = _make_tree() + parent = _find_parent_node(tree, "restart-ok") + assert parent is not None + assert parent["id"] == "restart" + + +class TestSerializeTreeOutline: + def test_produces_readable_outline(self): + tree = _make_tree() + outline = _serialize_tree_outline(tree) + assert "[decision] Is the server up?" in outline + assert "[action] Check Logs" in outline + assert "[solution] Restart Worked" in outline + + def test_marks_error_node(self): + tree = _make_tree() + outline = _serialize_tree_outline(tree, error_node_id="restart") + assert "ERROR HERE" in outline +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && python -m pytest tests/test_ai_fix_service.py -v` +Expected: ImportError — module doesn't exist. + +**Step 3: Implement the fix service** + +Create `backend/app/core/ai_fix_service.py`: + +```python +"""AI-powered fix generation for tree validation errors. + +Builds targeted prompts for each failing node, sends to the AI provider, +and validates the fix before returning it. +""" +import copy +import json +import logging +import re +from typing import Any + +from app.core.ai_provider import get_ai_provider +from app.core.ai_tree_validator import validate_generated_tree + +logger = logging.getLogger(__name__) + +# ── Helpers ── + + +def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None: + """Recursively find a node by ID in the tree.""" + if not isinstance(tree, dict): + return None + if tree.get("id") == node_id: + return tree + for child in tree.get("children", []): + result = _find_node_by_id(child, node_id) + if result is not None: + return result + return None + + +def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None: + """Find the parent of the node with target_id.""" + if not isinstance(tree, dict): + return None + for child in tree.get("children", []): + if isinstance(child, dict) and child.get("id") == target_id: + return tree + result = _find_parent_node(child, target_id) + if result is not None: + return result + return None + + +def _serialize_tree_outline( + tree: dict[str, Any], + indent: int = 0, + error_node_id: str | None = None, +) -> str: + """Serialize tree as a compact readable outline for the AI prompt.""" + if not isinstance(tree, dict): + return "" + + node_type = tree.get("type", "?") + label = tree.get("question") or tree.get("title") or tree.get("id", "?") + prefix = " " * indent + marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else "" + + line = f"{prefix}- [{node_type}] {label}{marker}" + lines = [line] + + for child in tree.get("children", []): + lines.append(_serialize_tree_outline(child, indent + 1, error_node_id)) + + return "\n".join(lines) + + +def _strip_markdown_fences(text: str) -> str: + """Strip ```json ... ``` fences from AI response.""" + match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL) + if match: + return match.group(1).strip() + return text + + +# ── Prompt ── + +FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers. + +You will receive: +1. A full flow outline showing the tree structure +2. The specific failing node with its full JSON +3. The validation error message + +Your task: Return a FIXED version of the failing node as valid JSON. Rules: +- Fix ONLY the structural issue described in the error message +- Keep ALL existing content (titles, descriptions, questions, options) unchanged +- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic +- Every new node must have a unique ID (use descriptive kebab-case IDs) +- Decision nodes must have at least 2 options and at least 2 children +- Action nodes must have a next_node_id pointing to a sibling node in the parent's children +- Solution nodes are leaf nodes (no children) +- Return ONLY the fixed node JSON, no explanation""" + + +def _build_fix_prompt( + tree: dict[str, Any], + node_id: str, + error_message: str, + tree_name: str, + tree_type: str, +) -> str: + """Build the user message for fixing a specific node.""" + outline = _serialize_tree_outline(tree, error_node_id=node_id) + failing_node = _find_node_by_id(tree, node_id) + node_json = json.dumps(failing_node, indent=2) if failing_node else "{}" + + return ( + f"Flow name: {tree_name}\n" + f"Flow type: {tree_type}\n\n" + f"FULL FLOW OUTLINE:\n{outline}\n\n" + f"ERROR: {error_message}\n\n" + f"FAILING NODE (full JSON):\n{node_json}\n\n" + f"Return the fixed version of this node as JSON." + ) + + +# ── Main Service ── + + +async def generate_fixes( + tree_structure: dict[str, Any], + tree_name: str, + tree_type: str, + validation_errors: list[dict[str, str]], +) -> tuple[list[dict[str, Any]], int, int]: + """Generate AI fixes for each validation error. + + Args: + tree_structure: Full tree JSON. + tree_name: Name of the flow. + tree_type: Type of flow (troubleshooting/procedural/maintenance). + validation_errors: List of {"node_id": str, "message": str}. + + Returns: + (fixes, total_input_tokens, total_output_tokens) + Each fix: {"target_node_id", "error_message", "description", "original_node", "fixed_node"} + """ + provider = get_ai_provider() + fixes: list[dict[str, Any]] = [] + total_input = 0 + total_output = 0 + + for error in validation_errors: + node_id = error["node_id"] + error_message = error["message"] + + original_node = _find_node_by_id(tree_structure, node_id) + if original_node is None: + logger.warning("Node %s not found in tree, skipping fix", node_id) + continue + + original_snapshot = copy.deepcopy(original_node) + user_message = _build_fix_prompt( + tree_structure, node_id, error_message, tree_name, tree_type + ) + + # Attempt fix (with 1 retry using corrective prompt) + messages: list[dict[str, str]] = [{"role": "user", "content": user_message}] + + for attempt in range(2): + try: + raw_text, inp_tokens, out_tokens = await provider.generate_json( + system_prompt=FIX_SYSTEM_PROMPT, + messages=messages, + max_tokens=4096, + ) + total_input += inp_tokens + total_output += out_tokens + + cleaned = _strip_markdown_fences(raw_text) + fixed_node = json.loads(cleaned) + + # Quick validation: check that the fix actually addresses the error + # by substituting into tree and re-validating that specific error is gone + test_tree = copy.deepcopy(tree_structure) + _replace_node_in_tree(test_tree, node_id, fixed_node) + remaining_errors = validate_generated_tree(test_tree) + still_has_error = any( + node_id in e and error_message.split(":")[0].lower() in e.lower() + for e in remaining_errors + ) + + if still_has_error and attempt == 0: + # Retry with corrective prompt + messages.append({"role": "assistant", "content": raw_text}) + messages.append({ + "role": "user", + "content": ( + f"The fix still has the same validation error. " + f"Remaining errors: {remaining_errors}. " + f"Please try again." + ), + }) + continue + + # Extract description from the fix + description = _describe_fix(original_snapshot, fixed_node) + + fixes.append({ + "target_node_id": node_id, + "error_message": error_message, + "description": description, + "original_node": original_snapshot, + "fixed_node": fixed_node, + }) + break + + except (json.JSONDecodeError, KeyError, TypeError) as exc: + logger.warning( + "Fix attempt %d for node %s failed: %s", attempt + 1, node_id, exc + ) + if attempt == 0: + messages.append({"role": "assistant", "content": raw_text if 'raw_text' in dir() else ""}) + messages.append({ + "role": "user", + "content": f"Invalid JSON response. Return ONLY valid JSON for the fixed node.", + }) + else: + logger.error("Failed to generate fix for node %s after 2 attempts", node_id) + + return fixes, total_input, total_output + + +def _replace_node_in_tree( + tree: dict[str, Any], target_id: str, replacement: dict[str, Any] +) -> bool: + """Replace a node in the tree by ID. Returns True if found and replaced.""" + if tree.get("id") == target_id: + tree.clear() + tree.update(replacement) + return True + for child in tree.get("children", []): + if isinstance(child, dict): + if child.get("id") == target_id: + child.clear() + child.update(replacement) + return True + if _replace_node_in_tree(child, target_id, replacement): + return True + return False + + +def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str: + """Generate a human-readable description of what changed.""" + orig_children = len(original.get("children", [])) + fixed_children = len(fixed.get("children", [])) + orig_options = len(original.get("options", [])) + fixed_options = len(fixed.get("options", [])) + + parts: list[str] = [] + if fixed_children > orig_children: + added = fixed_children - orig_children + parts.append(f"Added {added} child node{'s' if added > 1 else ''}") + if fixed_options > orig_options: + added = fixed_options - orig_options + parts.append(f"Added {added} option{'s' if added > 1 else ''}") + if "next_node_id" in fixed and "next_node_id" not in original: + parts.append(f"Added next_node_id '{fixed['next_node_id']}'") + + return "; ".join(parts) if parts else "Structural fix applied" +``` + +**Step 4: Run tests** + +Run: `cd backend && python -m pytest tests/test_ai_fix_service.py -v` +Expected: All pass. + +**Step 5: Commit** + +```bash +git add backend/app/core/ai_fix_service.py backend/tests/test_ai_fix_service.py +git commit -m "feat: add AI fix service with prompt building and validation" +``` + +--- + +## Task 7: Create the fix-tree endpoint + +**Files:** +- Create: `backend/app/api/endpoints/ai_fix.py` +- Modify: `backend/app/api/router.py` +- Test: `backend/tests/test_ai_fix_endpoint.py` + +**Step 1: Write the endpoint test** + +Create `backend/tests/test_ai_fix_endpoint.py`: + +```python +"""Integration tests for AI fix-tree endpoint.""" +import json +from unittest.mock import AsyncMock, patch + +import pytest + +from app.core.config import settings + + +SAMPLE_TREE = { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs OK", + "description": "Issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}], + "children": [ + {"id": "done", "type": "solution", "title": "Done", "description": "Fixed."}, + ], + }, + ], +} + + +@pytest.fixture +def enable_ai(): + original_key = settings.GOOGLE_AI_API_KEY + original_provider = settings.AI_PROVIDER + settings.GOOGLE_AI_API_KEY = "fake-key" + settings.AI_PROVIDER = "gemini" + yield + settings.GOOGLE_AI_API_KEY = original_key + settings.AI_PROVIDER = original_provider + + +@pytest.mark.asyncio +class TestFixTreeEndpoint: + + async def test_returns_401_without_auth(self, client): + response = await client.post("/api/v1/ai/fix-tree", json={ + "tree_structure": SAMPLE_TREE, + "tree_name": "Test", + "tree_type": "troubleshooting", + "validation_errors": [{"node_id": "restart", "message": "Need 2 options"}], + }) + assert response.status_code == 401 + + async def test_returns_503_when_ai_disabled(self, client, auth_headers): + original = settings.GOOGLE_AI_API_KEY + orig_anthropic = settings.ANTHROPIC_API_KEY + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + try: + response = await client.post( + "/api/v1/ai/fix-tree", + json={ + "tree_structure": SAMPLE_TREE, + "tree_name": "Test", + "tree_type": "troubleshooting", + "validation_errors": [{"node_id": "restart", "message": "test"}], + }, + headers=auth_headers, + ) + assert response.status_code == 503 + finally: + settings.GOOGLE_AI_API_KEY = original + settings.ANTHROPIC_API_KEY = orig_anthropic + + async def test_returns_fixes_on_success(self, client, auth_headers, enable_ai): + fixed_node = { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [ + {"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"}, + {"id": "opt-r-no", "label": "No", "next_node_id": "escalate"}, + ], + "children": [ + {"id": "done", "type": "solution", "title": "Done", "description": "Fixed."}, + {"id": "escalate", "type": "solution", "title": "Escalate", "description": "Escalate to vendor."}, + ], + } + + mock_provider = AsyncMock() + mock_provider.generate_json = AsyncMock( + return_value=(json.dumps(fixed_node), 500, 300) + ) + + with patch("app.core.ai_fix_service.get_ai_provider", return_value=mock_provider): + response = await client.post( + "/api/v1/ai/fix-tree", + json={ + "tree_structure": SAMPLE_TREE, + "tree_name": "Server Flow", + "tree_type": "troubleshooting", + "validation_errors": [ + {"node_id": "restart", "message": "Decision node must have at least 2 options"}, + ], + }, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["fixes"]) == 1 + assert data["fixes"][0]["target_node_id"] == "restart" + assert data["tokens_used"]["input"] == 500 + assert data["tokens_used"]["output"] == 300 +``` + +**Step 2: Run to verify it fails** + +Run: `cd backend && python -m pytest tests/test_ai_fix_endpoint.py -v` +Expected: Fail — endpoint doesn't exist. + +**Step 3: Create the endpoint** + +Create `backend/app/api/endpoints/ai_fix.py`: + +```python +"""AI auto-fix endpoint for tree validation errors.""" +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.rate_limit import limiter +from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin +from app.core.config import settings +from app.core.ai_fix_service import generate_fixes +from app.models.user import User +from app.schemas.ai_fix import AIFixTreeRequest, AIFixTreeResponse, AIFixProposal, AIFixTokenUsage + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ai", tags=["ai-fix"]) + + +def _require_ai_enabled() -> None: + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) + + +@router.post("/fix-tree", response_model=AIFixTreeResponse) +@limiter.limit("10/minute") +async def fix_tree( + request: Request, + body: AIFixTreeRequest, + user: Annotated[User, Depends(require_engineer_or_admin)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Generate AI-powered fixes for tree validation errors.""" + _require_ai_enabled() + + try: + fixes, total_input, total_output = await generate_fixes( + tree_structure=body.tree_structure, + tree_name=body.tree_name, + tree_type=body.tree_type, + validation_errors=[ + {"node_id": e.node_id, "message": e.message} + for e in body.validation_errors + ], + ) + except RuntimeError as exc: + logger.error("AI fix generation failed: %s", exc) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(exc), + ) + except Exception as exc: + logger.exception("Unexpected error during AI fix generation") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to generate fixes. Please try again.", + ) + + return AIFixTreeResponse( + fixes=[ + AIFixProposal( + target_node_id=f["target_node_id"], + error_message=f["error_message"], + description=f["description"], + original_node=f["original_node"], + fixed_node=f["fixed_node"], + ) + for f in fixes + ], + tokens_used=AIFixTokenUsage(input=total_input, output=total_output), + ) +``` + +**Step 4: Register in router** + +In `backend/app/api/router.py`, add: + +After line 8 (`from app.api.endpoints import ai_builder`): +```python +from app.api.endpoints import ai_fix +``` + +After line 38 (`api_router.include_router(ai_builder.router)`): +```python +api_router.include_router(ai_fix.router) +``` + +**Step 5: Run tests** + +Run: `cd backend && python -m pytest tests/test_ai_fix_endpoint.py -v` +Expected: All pass. + +**Step 6: Run full test suite** + +Run: `cd backend && python -m pytest --override-ini="addopts=" -v` +Expected: All pass (100+ tests). + +**Step 7: Commit** + +```bash +git add backend/app/api/endpoints/ai_fix.py backend/app/api/router.py backend/app/schemas/ai_fix.py backend/tests/test_ai_fix_endpoint.py +git commit -m "feat: add POST /ai/fix-tree endpoint for AI-powered validation fixes" +``` + +--- + +## Task 8: Add frontend API client for fix-tree + +**Files:** +- Modify: `frontend/src/api/trees.ts` +- Create: `frontend/src/types/ai-fix.ts` +- Modify: `frontend/src/types/index.ts` + +**Step 1: Create the types** + +Create `frontend/src/types/ai-fix.ts`: + +```typescript +export interface AIFixValidationError { + node_id: string + message: string +} + +export interface AIFixProposal { + target_node_id: string + error_message: string + description: string + original_node: Record + fixed_node: Record +} + +export interface AIFixTreeRequest { + tree_structure: Record + tree_name: string + tree_type: 'troubleshooting' | 'procedural' | 'maintenance' + validation_errors: AIFixValidationError[] +} + +export interface AIFixTreeResponse { + fixes: AIFixProposal[] + tokens_used: { input: number; output: number } +} +``` + +**Step 2: Export from types/index.ts** + +Add to `frontend/src/types/index.ts`: + +```typescript +export type { AIFixTreeRequest, AIFixTreeResponse, AIFixProposal, AIFixValidationError } from './ai-fix' +``` + +**Step 3: Add API method to trees.ts** + +In `frontend/src/api/trees.ts`, add a new method to the `treesApi` object: + +```typescript + async fixTree(request: AIFixTreeRequest): Promise { + const response = await apiClient.post('/ai/fix-tree', request) + return response.data + }, +``` + +Import the types at the top of the file. + +**Step 4: Commit** + +```bash +git add frontend/src/types/ai-fix.ts frontend/src/types/index.ts frontend/src/api/trees.ts +git commit -m "feat: add frontend API client and types for AI fix-tree" +``` + +--- + +## Task 9: Add "Fix with AI" button to ValidationSummary + +**Files:** +- Modify: `frontend/src/components/tree-editor/ValidationSummary.tsx` + +**Step 1: Update the component** + +Add new props and the button to `ValidationSummary`: + +```typescript +import { useState } from 'react' +import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react' +import { cn } from '@/lib/utils' +import type { ValidationError } from '@/store/treeEditorStore' + +interface ValidationSummaryProps { + errors: ValidationError[] + onSelectNode: (nodeId: string) => void + onFixWithAI?: () => void + isFixing?: boolean +} + +export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) { +``` + +In the header button area (after the expand/collapse chevron at line 62), add the "Fix with AI" button. It should appear between the error count text and the chevron icon. Restructure the header to include the button: + +After the `` closing the error/warning count (around line 60), add: + +```tsx + {/* Fix with AI button — only when there are fixable errors */} + {onFixWithAI && errorItems.some(e => e.nodeId) && ( + + )} +``` + +**Step 2: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build passes. + +**Step 3: Commit** + +```bash +git add frontend/src/components/tree-editor/ValidationSummary.tsx +git commit -m "feat: add Fix with AI button to ValidationSummary" +``` + +--- + +## Task 10: Build the AIFixReviewModal + +**Files:** +- Create: `frontend/src/components/tree-editor/AIFixReviewModal.tsx` + +**Step 1: Create the review modal** + +Create `frontend/src/components/tree-editor/AIFixReviewModal.tsx`: + +```typescript +import { useState } from 'react' +import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react' +import { cn } from '@/lib/utils' +import type { AIFixProposal } from '@/types' + +interface AIFixReviewModalProps { + fixes: AIFixProposal[] + onApply: (fix: AIFixProposal) => void + onApplyAll: () => void + onClose: () => void +} + +export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) { + const [appliedIds, setAppliedIds] = useState>(new Set()) + const [skippedIds, setSkippedIds] = useState>(new Set()) + const [expandedIds, setExpandedIds] = useState>(new Set(fixes.map(f => f.target_node_id))) + + const handleApply = (fix: AIFixProposal) => { + onApply(fix) + setAppliedIds(prev => new Set(prev).add(fix.target_node_id)) + } + + const handleSkip = (fix: AIFixProposal) => { + setSkippedIds(prev => new Set(prev).add(fix.target_node_id)) + } + + const toggleExpanded = (id: string) => { + setExpandedIds(prev => { + const next = new Set(prev) + if (next.has(id)) next.delete(id) + else next.add(id) + return next + }) + } + + const pendingFixes = fixes.filter( + f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id) + ) + const allHandled = pendingFixes.length === 0 + + return ( +
+
+ {/* Header */} +
+
+ +

+ AI Fix Proposals ({fixes.length}) +

+
+ +
+ + {/* Body */} +
+ {fixes.map((fix) => { + const isApplied = appliedIds.has(fix.target_node_id) + const isSkipped = skippedIds.has(fix.target_node_id) + const isExpanded = expandedIds.has(fix.target_node_id) + + return ( +
+ {/* Fix header */} +
+
+

{fix.error_message}

+

{fix.description}

+

+ Node: {fix.target_node_id} +

+
+ {isApplied && ( + + Applied + + )} + {isSkipped && ( + Skipped + )} +
+ + {/* Expand/collapse detail */} + {!isApplied && !isSkipped && ( + <> + + + {isExpanded && ( +
+
+

Before

+
+                            {JSON.stringify(fix.original_node, null, 2)}
+                          
+
+
+

After

+
+                            {JSON.stringify(fix.fixed_node, null, 2)}
+                          
+
+
+ )} + + {/* Action buttons */} +
+ + +
+ + )} +
+ ) + })} +
+ + {/* Footer */} +
+ + {!allHandled && ( + + )} +
+
+
+ ) +} +``` + +**Step 2: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build passes. + +**Step 3: Commit** + +```bash +git add frontend/src/components/tree-editor/AIFixReviewModal.tsx +git commit -m "feat: add AIFixReviewModal component for reviewing AI-proposed fixes" +``` + +--- + +## Task 11: Wire everything together in TreeEditorPage + +**Files:** +- Modify: `frontend/src/pages/TreeEditorPage.tsx` + +This is the integration task. The page needs to: + +1. Import the new components and API +2. Add state for fix flow (`isFixing`, `fixProposals`) +3. Handle "Fix with AI" button click — call `treesApi.fixTree()` +4. Show `AIFixReviewModal` when proposals are available +5. Handle apply/skip — call `updateNode()` on the tree editor store +6. Re-run `validate()` after applying fixes + +**Step 1: Find where ValidationSummary is rendered in TreeEditorPage** + +Search for `(null) +``` + +**Step 4: Add handler for "Fix with AI"** + +```typescript +const handleFixWithAI = async () => { + const store = useTreeEditorStore.getState() + if (!store.treeStructure) return + + // Get only fixable errors (structural errors with nodeId) + const fixableErrors = store.validationErrors + .filter(e => e.severity === 'error' && e.nodeId) + .map(e => ({ node_id: e.nodeId!, message: e.message })) + + if (fixableErrors.length === 0) return + + setIsFixing(true) + try { + const result = await treesApi.fixTree({ + tree_structure: store.treeStructure as Record, + tree_name: store.name, + tree_type: (store.treeType || 'troubleshooting') as 'troubleshooting' | 'procedural' | 'maintenance', + validation_errors: fixableErrors, + }) + if (result.fixes.length > 0) { + setFixProposals(result.fixes) + } else { + toast.info('AI could not generate fixes for these errors') + } + } catch { + toast.error('Failed to generate AI fixes. Please try again.') + } finally { + setIsFixing(false) + } +} +``` + +**Step 5: Add handlers for apply/close** + +```typescript +const handleApplyFix = (fix: AIFixProposal) => { + const store = useTreeEditorStore.getState() + store.updateNode(fix.target_node_id, fix.fixed_node as Partial) +} + +const handleApplyAllFixes = () => { + if (!fixProposals) return + for (const fix of fixProposals) { + handleApplyFix(fix) + } + setFixProposals(null) + // Re-validate after applying all fixes + setTimeout(() => { + useTreeEditorStore.getState().validate() + }, 100) +} + +const handleCloseFixModal = () => { + setFixProposals(null) + // Re-validate in case some fixes were applied + useTreeEditorStore.getState().validate() +} +``` + +**Step 6: Pass props to ValidationSummary** + +Update the `` JSX to include the new props: + +```tsx + +``` + +**Step 7: Add the review modal** + +After ``, add: + +```tsx +{fixProposals && ( + +)} +``` + +**Step 8: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build passes. + +**Step 9: Commit** + +```bash +git add frontend/src/pages/TreeEditorPage.tsx +git commit -m "feat: wire AI fix flow into TreeEditorPage" +``` + +--- + +## Task 12: Final verification + +**Step 1: Run all backend tests** + +Run: `cd backend && python -m pytest --override-ini="addopts=" -v` +Expected: All pass. + +**Step 2: Run frontend build** + +Run: `cd frontend && npm run build` +Expected: Build passes with no errors. + +**Step 3: Final commit if any adjustments needed** + +Fix any issues found during verification and commit. -- 2.49.1 From bbf6e2a33b81ef21333ac104c7e193d4ef2a0436 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:08:10 -0500 Subject: [PATCH 03/14] chore: add google-genai SDK dependency Co-Authored-By: Claude Opus 4.6 --- backend/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/requirements.txt b/backend/requirements.txt index 7c8b5493..b51da249 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -33,6 +33,7 @@ httpx>=0.27.0 # AI Flow Builder anthropic>=0.40.0 +google-genai>=1.0.0 # Utilities python-dotenv==1.0.1 -- 2.49.1 From be041d0d2996a80cdfd4758cd6575bdfbb4588e1 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:08:53 -0500 Subject: [PATCH 04/14] feat: add Gemini Flash config vars to Settings Adds AI_PROVIDER, GOOGLE_AI_API_KEY, AI_MODEL_GEMINI, and AI_MODEL_ANTHROPIC config vars. Updates ai_enabled to check either provider key. Co-Authored-By: Claude Opus 4.6 --- backend/app/core/config.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 184795c0..912b2cf8 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -78,11 +78,16 @@ class Settings(BaseSettings): AI_CONVERSATION_TTL_HOURS: int = 24 AI_MAX_CALLS_PER_FLOW: int = 10 AI_REQUEST_TIMEOUT_SECONDS: int = 45 + # AI Provider selection + AI_PROVIDER: str = "gemini" # "gemini" or "anthropic" + GOOGLE_AI_API_KEY: Optional[str] = None + AI_MODEL_GEMINI: str = "gemini-2.5-flash" + AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001" @property def ai_enabled(self) -> bool: - """Check if AI Flow Builder is configured.""" - return self.ANTHROPIC_API_KEY is not None + """Check if any AI provider is configured.""" + return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None # Deployment – auto-seed test data on PR environments SEED_ON_DEPLOY: bool = False -- 2.49.1 From 55be033ecb9eb6457d9ea9664335d2af5dd44164 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:16:45 -0500 Subject: [PATCH 05/14] feat: add AI provider abstraction with Gemini and Anthropic support Co-Authored-By: Claude Opus 4.6 --- backend/app/core/ai_provider.py | 162 ++++++++++++++++++++++ backend/tests/test_ai_provider.py | 216 ++++++++++++++++++++++++++++++ 2 files changed, 378 insertions(+) create mode 100644 backend/app/core/ai_provider.py create mode 100644 backend/tests/test_ai_provider.py diff --git a/backend/app/core/ai_provider.py b/backend/app/core/ai_provider.py new file mode 100644 index 00000000..b3cf16e4 --- /dev/null +++ b/backend/app/core/ai_provider.py @@ -0,0 +1,162 @@ +""" +AI Provider abstraction layer. + +Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable +backends for JSON generation used by the AI Flow Builder. +""" + +from abc import ABC, abstractmethod + +from app.core.config import settings + + +class AIProvider(ABC): + """Abstract base class for AI providers.""" + + @abstractmethod + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + """Generate a JSON response from the AI model. + + Args: + system_prompt: System-level instruction for the model. + messages: List of message dicts with "role" and "content" keys. + max_tokens: Maximum output tokens. + + Returns: + Tuple of (response_text, input_tokens, output_tokens). + """ + ... + + +class GeminiProvider(AIProvider): + """Google Gemini provider using the google-genai SDK.""" + + def __init__(self, api_key: str, model: str) -> None: + self._api_key = api_key + self._model = model + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + from google import genai + from google.genai import types as genai_types + + client = genai.Client(api_key=self._api_key) + + # Convert messages to Gemini Content format + contents: list[genai_types.Content] = [] + for msg in messages: + role = "model" if msg["role"] == "assistant" else "user" + contents.append( + genai_types.Content( + role=role, + parts=[genai_types.Part(text=msg["content"])], + ) + ) + + config = genai_types.GenerateContentConfig( + system_instruction=system_prompt, + max_output_tokens=max_tokens, + response_mime_type="application/json", + ) + + response = await client.models.generate_content_async( + model=self._model, + contents=contents, + config=config, + ) + + text = response.text or "" + input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 + output_tokens = ( + getattr(response.usage_metadata, "candidates_token_count", 0) or 0 + ) + + return text, input_tokens, output_tokens + + +class AnthropicProvider(AIProvider): + """Anthropic Claude provider using the anthropic SDK.""" + + def __init__(self, api_key: str, model: str, timeout: int = 45) -> None: + self._api_key = api_key + self._model = model + self._timeout = timeout + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + import anthropic + + client = anthropic.AsyncAnthropic( + api_key=self._api_key, + timeout=self._timeout, + ) + + response = await client.messages.create( + model=self._model, + max_tokens=max_tokens, + system=system_prompt, + messages=messages, + ) + + text = response.content[0].text + input_tokens = response.usage.input_tokens + output_tokens = response.usage.output_tokens + + return text, input_tokens, output_tokens + + +def get_ai_provider() -> AIProvider: + """Factory that returns the configured AI provider. + + Selection logic: + 1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider + 2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider + 3. Fallback: if preferred provider key missing, try the other one + 4. If nothing configured -> raise RuntimeError + """ + provider = settings.AI_PROVIDER + + if provider == "gemini": + if settings.GOOGLE_AI_API_KEY: + return GeminiProvider( + api_key=settings.GOOGLE_AI_API_KEY, + model=settings.AI_MODEL_GEMINI, + ) + # Fallback to Anthropic + if settings.ANTHROPIC_API_KEY: + return AnthropicProvider( + api_key=settings.ANTHROPIC_API_KEY, + model=settings.AI_MODEL_ANTHROPIC, + timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, + ) + + elif provider == "anthropic": + if settings.ANTHROPIC_API_KEY: + return AnthropicProvider( + api_key=settings.ANTHROPIC_API_KEY, + model=settings.AI_MODEL_ANTHROPIC, + timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, + ) + # Fallback to Gemini + if settings.GOOGLE_AI_API_KEY: + return GeminiProvider( + api_key=settings.GOOGLE_AI_API_KEY, + model=settings.AI_MODEL_GEMINI, + ) + + raise RuntimeError( + "No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY." + ) diff --git a/backend/tests/test_ai_provider.py b/backend/tests/test_ai_provider.py new file mode 100644 index 00000000..a263d5e3 --- /dev/null +++ b/backend/tests/test_ai_provider.py @@ -0,0 +1,216 @@ +"""Tests for the AI provider abstraction layer.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import sys + +from app.core.ai_provider import ( + AIProvider, + AnthropicProvider, + GeminiProvider, + get_ai_provider, +) +from app.core.config import settings + + +class TestGetAIProvider: + """Tests for the get_ai_provider factory function.""" + + def test_returns_gemini_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.GOOGLE_AI_API_KEY + try: + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = "test-gemini-key" + provider = get_ai_provider() + assert isinstance(provider, GeminiProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_key + + def test_returns_anthropic_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "anthropic" + settings.ANTHROPIC_API_KEY = "test-anthropic-key" + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.ANTHROPIC_API_KEY = original_key + + def test_fallback_to_anthropic_when_gemini_key_missing(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = "test-anthropic-key" + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + def test_fallback_to_gemini_when_anthropic_key_missing(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "anthropic" + settings.ANTHROPIC_API_KEY = None + settings.GOOGLE_AI_API_KEY = "test-gemini-key" + provider = get_ai_provider() + assert isinstance(provider, GeminiProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + def test_raises_when_no_provider_configured(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + with pytest.raises(RuntimeError, match="No AI provider configured"): + get_ai_provider() + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + +class TestAnthropicProvider: + """Tests for AnthropicProvider.generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json(self): + provider = AnthropicProvider( + api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30 + ) + + mock_response = MagicMock() + mock_response.content = [MagicMock(text='{"result": "ok"}')] + mock_response.usage = MagicMock(input_tokens=100, output_tokens=50) + + mock_client = AsyncMock() + mock_client.messages.create = AsyncMock(return_value=mock_response) + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + text, input_tokens, output_tokens = await provider.generate_json( + system_prompt="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=1024, + ) + + assert text == '{"result": "ok"}' + assert input_tokens == 100 + assert output_tokens == 50 + + mock_client.messages.create.assert_called_once_with( + model="claude-haiku-4-5-20251001", + max_tokens=1024, + system="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + ) + + +class TestGeminiProvider: + """Tests for GeminiProvider.generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json(self): + provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") + + mock_usage = MagicMock() + mock_usage.prompt_token_count = 80 + mock_usage.candidates_token_count = 40 + + mock_response = MagicMock() + mock_response.text = '{"answer": 42}' + mock_response.usage_metadata = mock_usage + + mock_client = MagicMock() + mock_client.models.generate_content_async = AsyncMock( + return_value=mock_response + ) + + mock_genai_module = MagicMock() + mock_genai_module.Client.return_value = mock_client + + mock_types = MagicMock() + mock_types.Content.side_effect = lambda **kw: kw + mock_types.Part.side_effect = lambda **kw: kw + mock_types.GenerateContentConfig.side_effect = lambda **kw: kw + + mock_google = MagicMock() + mock_google.genai = mock_genai_module + mock_genai_module.types = mock_types + + with patch.dict(sys.modules, { + "google": mock_google, + "google.genai": mock_genai_module, + "google.genai.types": mock_types, + }): + text, input_tokens, output_tokens = await provider.generate_json( + system_prompt="Generate JSON.", + messages=[ + {"role": "user", "content": "Give me data"}, + {"role": "assistant", "content": "Here it is"}, + {"role": "user", "content": "More please"}, + ], + max_tokens=2048, + ) + + assert text == '{"answer": 42}' + assert input_tokens == 80 + assert output_tokens == 40 + + mock_client.models.generate_content_async.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_json_handles_none_usage(self): + """Token counts default to 0 when usage_metadata attributes are None.""" + provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") + + mock_usage = MagicMock(spec=[]) # No attributes at all + mock_response = MagicMock() + mock_response.text = "{}" + mock_response.usage_metadata = mock_usage + + mock_client = MagicMock() + mock_client.models.generate_content_async = AsyncMock( + return_value=mock_response + ) + + mock_genai_module = MagicMock() + mock_genai_module.Client.return_value = mock_client + + mock_types = MagicMock() + mock_types.Content.side_effect = lambda **kw: kw + mock_types.Part.side_effect = lambda **kw: kw + mock_types.GenerateContentConfig.side_effect = lambda **kw: kw + + mock_google = MagicMock() + mock_google.genai = mock_genai_module + mock_genai_module.types = mock_types + + with patch.dict(sys.modules, { + "google": mock_google, + "google.genai": mock_genai_module, + "google.genai.types": mock_types, + }): + text, input_tokens, output_tokens = await provider.generate_json( + system_prompt="test", + messages=[{"role": "user", "content": "test"}], + ) + + assert text == "{}" + assert input_tokens == 0 + assert output_tokens == 0 -- 2.49.1 From eb7ea7ddd9008cfea3694fca0abe07bdc1ac22a0 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:20:48 -0500 Subject: [PATCH 06/14] refactor: migrate AI tree generator to provider abstraction Co-Authored-By: Claude Opus 4.6 --- backend/app/api/endpoints/ai_builder.py | 87 +++++++++---------- backend/app/core/ai_tree_generator_service.py | 58 ++++--------- backend/tests/test_ai_endpoints.py | 37 ++++---- 3 files changed, 76 insertions(+), 106 deletions(-) diff --git a/backend/app/api/endpoints/ai_builder.py b/backend/app/api/endpoints/ai_builder.py index 5ec8d55a..dcb0a966 100644 --- a/backend/app/api/endpoints/ai_builder.py +++ b/backend/app/api/endpoints/ai_builder.py @@ -10,7 +10,6 @@ import logging from typing import Annotated -import anthropic from fastapi import APIRouter, Depends, HTTPException, Request, status from sqlalchemy.ext.asyncio import AsyncSession @@ -52,7 +51,7 @@ def _require_ai_enabled() -> None: if not settings.ai_enabled: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.", + detail="AI flow builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", ) @@ -174,27 +173,6 @@ async def scaffold( branches, input_tokens, output_tokens, cost = await scaffold_branches( conversation.wizard_state, ) - except anthropic.APIError as e: - await record_ai_usage( - user_id=current_user.id, - account_id=current_user.account_id, - conversation_id=conversation.id, - generation_type="scaffold", - tier=plan, - input_tokens=0, - output_tokens=0, - estimated_cost=0, - succeeded=False, - counts_toward_quota=False, - error_code=type(e).__name__, - extra_data={"error": str(e)}, - db=db, - ) - await db.commit() - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail="AI provider error. Please try again.", - ) except ValueError as e: await record_ai_usage( user_id=current_user.id, @@ -216,6 +194,27 @@ async def scaffold( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"AI returned invalid output: {e}", ) + except Exception as e: + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="scaffold", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"error": str(e)}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="AI provider error. Please try again.", + ) # Record successful usage await record_ai_usage( @@ -293,27 +292,6 @@ async def branch_detail( existing_branches, ) ) - except anthropic.APIError as e: - await record_ai_usage( - user_id=current_user.id, - account_id=current_user.account_id, - conversation_id=conversation.id, - generation_type="branch_detail", - tier=plan, - input_tokens=0, - output_tokens=0, - estimated_cost=0, - succeeded=False, - counts_toward_quota=False, - error_code=type(e).__name__, - extra_data={"error": str(e), "branch_name": data.branch_name}, - db=db, - ) - await db.commit() - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail="AI provider error. Please try again.", - ) except ValueError as e: await record_ai_usage( user_id=current_user.id, @@ -335,6 +313,27 @@ async def branch_detail( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"AI returned invalid output: {e}", ) + except Exception as e: + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="branch_detail", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"error": str(e), "branch_name": data.branch_name}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="AI provider error. Please try again.", + ) # Record successful usage await record_ai_usage( diff --git a/backend/app/core/ai_tree_generator_service.py b/backend/app/core/ai_tree_generator_service.py index 4d40e257..7a562d1c 100644 --- a/backend/app/core/ai_tree_generator_service.py +++ b/backend/app/core/ai_tree_generator_service.py @@ -1,11 +1,11 @@ -"""AI-powered tree generation service using Anthropic Claude API. +"""AI-powered tree generation service. Implements the 4-stage wizard flow: Stage 2 (scaffold): AI suggests 4-7 top-level branches Stage 3 (branch_detail): AI generates detailed nodes per branch Stage 4 (assemble): Pure assembly logic — zero AI calls -System prompts are static constants to enable Anthropic prompt caching. +Uses the provider abstraction from ai_provider.py (supports Gemini + Anthropic). """ import json import logging @@ -13,8 +13,7 @@ import re import uuid from typing import Any -import anthropic - +from app.core.ai_provider import get_ai_provider from app.core.config import settings from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats @@ -121,15 +120,6 @@ def _strip_markdown_fences(text: str) -> str: return text -def _get_client() -> anthropic.AsyncAnthropic: - """Get configured async Anthropic client.""" - if not settings.ANTHROPIC_API_KEY: - raise RuntimeError("ANTHROPIC_API_KEY not configured") - return anthropic.AsyncAnthropic( - api_key=settings.ANTHROPIC_API_KEY, - timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, - ) - def _estimate_cost(input_tokens: int, output_tokens: int) -> float: """Estimate USD cost from token counts.""" @@ -146,7 +136,7 @@ async def scaffold_branches( Returns (branches, input_tokens, output_tokens, estimated_cost). Raises ValueError on invalid response. """ - client = _get_client() + provider = get_ai_provider() flow_type = wizard_state.get("flow_type", "troubleshooting") name = wizard_state.get("name", "") @@ -161,16 +151,13 @@ async def scaffold_branches( if tags: user_message += f"Environment: {', '.join(tags)}\n" - response = await client.messages.create( - model=settings.AI_MODEL, - max_tokens=1024, - system=SCAFFOLD_SYSTEM_PROMPT, + raw_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=SCAFFOLD_SYSTEM_PROMPT, messages=[{"role": "user", "content": user_message}], + max_tokens=1024, ) - raw_text = _strip_markdown_fences(response.content[0].text) - input_tokens = response.usage.input_tokens - output_tokens = response.usage.output_tokens + raw_text = _strip_markdown_fences(raw_text) cost = _estimate_cost(input_tokens, output_tokens) try: @@ -196,7 +183,7 @@ async def generate_branch_detail( On validation failure, retries once with corrective prompt. Raises ValueError if both attempts fail. """ - client = _get_client() + provider = get_ai_provider() flow_type = wizard_state.get("flow_type", "troubleshooting") name = wizard_state.get("name", "") @@ -217,31 +204,22 @@ async def generate_branch_detail( total_output = 0 for attempt in range(3): - response = await client.messages.create( - model=settings.AI_MODEL, - max_tokens=8192, - system=BRANCH_DETAIL_SYSTEM_PROMPT, + raw_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT, messages=messages, + max_tokens=8192, ) - total_input += response.usage.input_tokens - total_output += response.usage.output_tokens + total_input += input_tokens + total_output += output_tokens logger.debug( - "branch_detail attempt=%d stop_reason=%s content_blocks=%d output_tokens=%d", + "branch_detail attempt=%d output_tokens=%d", attempt, - response.stop_reason, - len(response.content), - response.usage.output_tokens, + output_tokens, ) - if response.stop_reason == "max_tokens": - logger.warning( - "branch_detail attempt=%d hit max_tokens limit (%d output tokens) — response may be truncated", - attempt, - response.usage.output_tokens, - ) - raw_text = _strip_markdown_fences(response.content[0].text) if response.content else "" + raw_text = _strip_markdown_fences(raw_text) if raw_text else "" if not raw_text: - logger.warning("branch_detail attempt=%d returned empty text, stop_reason=%s", attempt, response.stop_reason) + logger.warning("branch_detail attempt=%d returned empty text", attempt) try: branch_tree = json.loads(raw_text) diff --git a/backend/tests/test_ai_endpoints.py b/backend/tests/test_ai_endpoints.py index 339448dd..1f91514e 100644 --- a/backend/tests/test_ai_endpoints.py +++ b/backend/tests/test_ai_endpoints.py @@ -1,6 +1,6 @@ """Integration tests for AI Flow Builder endpoints. -All Anthropic API calls are mocked — zero real API spend. +All AI provider calls are mocked — zero real API spend. """ import json from unittest.mock import AsyncMock, patch, MagicMock @@ -64,12 +64,11 @@ BRANCH_DETAIL_JSON = json.dumps({ }) -def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200): - """Create a mock Anthropic API response.""" - response = MagicMock() - response.content = [MagicMock(text=text)] - response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens) - return response +def _mock_ai_provider(text: str, input_tokens: int = 100, output_tokens: int = 200): + """Create a mock AI provider whose generate_json returns the given text and token counts.""" + provider = MagicMock() + provider.generate_json = AsyncMock(return_value=(text, input_tokens, output_tokens)) + return provider @pytest.fixture @@ -194,11 +193,9 @@ async def test_scaffold_success(client, auth_headers, enable_ai): ) conversation_id = start_resp.json()["conversation_id"] - # Mock Anthropic - mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=mock_response) - + # Mock AI provider + mock_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=mock_provider): response = await client.post( "/api/v1/ai/scaffold", json={"conversation_id": conversation_id}, @@ -241,9 +238,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai): ) conversation_id = start_resp.json()["conversation_id"] - scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock) + scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider): await client.post( "/api/v1/ai/scaffold", json={"conversation_id": conversation_id}, @@ -251,10 +247,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai): ) # Now generate branch detail - detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock) - + detail_provider = _mock_ai_provider(BRANCH_DETAIL_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=detail_provider): response = await client.post( "/api/v1/ai/branch-detail", json={ @@ -290,9 +284,8 @@ async def test_assemble_success(client, auth_headers, enable_ai): conversation_id = start_resp.json()["conversation_id"] # Scaffold - scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock) + scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider): await client.post( "/api/v1/ai/scaffold", json={"conversation_id": conversation_id}, -- 2.49.1 From 5f8653e48141e3aa96095acfca3afc463772b29c Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:21:36 -0500 Subject: [PATCH 07/14] feat: add Pydantic schemas for AI fix-tree endpoint Co-Authored-By: Claude Opus 4.6 --- backend/app/schemas/ai_fix.py | 52 +++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 backend/app/schemas/ai_fix.py diff --git a/backend/app/schemas/ai_fix.py b/backend/app/schemas/ai_fix.py new file mode 100644 index 00000000..8c47f5a6 --- /dev/null +++ b/backend/app/schemas/ai_fix.py @@ -0,0 +1,52 @@ +"""Pydantic schemas for the AI auto-fix feature.""" +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +# ── Requests ── + + +class ValidationErrorInput(BaseModel): + """A single validation error to fix.""" + + node_id: str = Field(..., description="ID of the node with the error") + message: str = Field(..., description="The validation error message") + + +class AIFixTreeRequest(BaseModel): + """Request to generate AI fixes for validation errors.""" + + tree_structure: dict[str, Any] = Field(..., description="Full tree structure") + tree_name: str = Field("", max_length=255, description="Name of the flow") + tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field( + "troubleshooting", description="Type of flow" + ) + validation_errors: list[ValidationErrorInput] = Field( + ..., min_length=1, max_length=10, description="Errors to fix" + ) + + +# ── Responses ── + + +class AIFixProposal(BaseModel): + """A single proposed fix from the AI.""" + + target_node_id: str + error_message: str + description: str + original_node: dict[str, Any] + fixed_node: dict[str, Any] + + +class AIFixTokenUsage(BaseModel): + input: int = 0 + output: int = 0 + + +class AIFixTreeResponse(BaseModel): + """Response with proposed fixes.""" + + fixes: list[AIFixProposal] + tokens_used: AIFixTokenUsage -- 2.49.1 From 373736c5946c3acc8374d7f5abf9453843725149 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:25:34 -0500 Subject: [PATCH 08/14] feat: add AI fix service with prompt building and validation Co-Authored-By: Claude Opus 4.6 --- backend/app/core/ai_fix_service.py | 273 +++++++++++++++++++++++++++ backend/tests/test_ai_fix_service.py | 224 ++++++++++++++++++++++ 2 files changed, 497 insertions(+) create mode 100644 backend/app/core/ai_fix_service.py create mode 100644 backend/tests/test_ai_fix_service.py diff --git a/backend/app/core/ai_fix_service.py b/backend/app/core/ai_fix_service.py new file mode 100644 index 00000000..02350a15 --- /dev/null +++ b/backend/app/core/ai_fix_service.py @@ -0,0 +1,273 @@ +"""AI-powered fix service for tree validation errors. + +Given a tree structure and validation errors, generates AI-powered +proposals to fix each structural issue while preserving existing content. +""" + +import copy +import json +import logging +import re +from typing import Any + +from app.core.ai_provider import get_ai_provider +from app.core.ai_tree_validator import validate_generated_tree + +logger = logging.getLogger(__name__) + + +FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers. + +You will receive: +1. A full flow outline showing the tree structure +2. The specific failing node with its full JSON +3. The validation error message + +Your task: Return a FIXED version of the failing node as valid JSON. Rules: +- Fix ONLY the structural issue described in the error message +- Keep ALL existing content (titles, descriptions, questions, options) unchanged +- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic +- Every new node must have a unique ID (use descriptive kebab-case IDs) +- Decision nodes must have at least 2 options and at least 2 children +- Action nodes must have a next_node_id pointing to a sibling node in the parent's children +- Solution nodes are leaf nodes (no children) +- Return ONLY the fixed node JSON, no explanation""" + + +# ── Pure helper functions ── + + +def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None: + """Recursively find a node by its ID in the tree structure.""" + if not isinstance(tree, dict): + return None + if tree.get("id") == node_id: + return tree + for child in tree.get("children", []): + result = _find_node_by_id(child, node_id) + if result is not None: + return result + return None + + +def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None: + """Find the parent of a node with the given ID.""" + if not isinstance(tree, dict): + return None + for child in tree.get("children", []): + if isinstance(child, dict) and child.get("id") == target_id: + return tree + result = _find_parent_node(child, target_id) + if result is not None: + return result + return None + + +def _serialize_tree_outline( + tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None +) -> str: + """Serialize tree as a readable outline for AI prompt context. + + Format: indented "- [type] label" with "<<< ERROR HERE" marker. + """ + if not isinstance(tree, dict): + return "" + + node_type = tree.get("type", "unknown") + label = tree.get("question") or tree.get("title") or tree.get("id", "?") + prefix = " " * indent + marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else "" + line = f"{prefix}- [{node_type}] {label}{marker}" + + lines = [line] + for child in tree.get("children", []): + lines.append(_serialize_tree_outline(child, indent + 1, error_node_id)) + + return "\n".join(lines) + + +def _strip_markdown_fences(text: str) -> str: + """Strip ```json...``` fences from AI response.""" + return re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.MULTILINE).rstrip( + "`" + ).strip() + + +def _replace_node_in_tree( + tree: dict[str, Any], target_id: str, replacement: dict[str, Any] +) -> bool: + """Replace a node in-place by ID. Returns True if found and replaced.""" + if not isinstance(tree, dict): + return False + if tree.get("id") == target_id: + tree.clear() + tree.update(replacement) + return True + for child in tree.get("children", []): + if _replace_node_in_tree(child, target_id, replacement): + return True + return False + + +def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str: + """Describe what changed between original and fixed node.""" + changes: list[str] = [] + + orig_children = len(original.get("children", [])) + fixed_children = len(fixed.get("children", [])) + if fixed_children > orig_children: + changes.append(f"added {fixed_children - orig_children} child node(s)") + + orig_options = len(original.get("options", [])) + fixed_options = len(fixed.get("options", [])) + if fixed_options > orig_options: + changes.append(f"added {fixed_options - orig_options} option(s)") + + if fixed.get("next_node_id") and not original.get("next_node_id"): + changes.append("added next_node_id") + + if not changes: + changes.append("fixed structural issue") + + return "; ".join(changes).capitalize() + + +# ── Prompt building ── + + +def _build_fix_prompt( + tree: dict[str, Any], + node_id: str, + error_message: str, + tree_name: str, + tree_type: str, +) -> str: + """Build the user message for the AI fix request.""" + outline = _serialize_tree_outline(tree, error_node_id=node_id) + node = _find_node_by_id(tree, node_id) + node_json = json.dumps(node, indent=2) if node else "{}" + + return ( + f"Flow name: {tree_name}\n" + f"Flow type: {tree_type}\n\n" + f"## Full flow outline\n```\n{outline}\n```\n\n" + f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n" + f"## Validation error\n{error_message}\n\n" + f"Return the fixed version of this node as JSON." + ) + + +# ── Main entry point ── + + +async def generate_fixes( + tree_structure: dict[str, Any], + tree_name: str, + tree_type: str, + validation_errors: list[dict[str, str]], +) -> tuple[list[dict[str, Any]], int, int]: + """Generate AI-powered fixes for tree validation errors. + + Args: + tree_structure: Full tree structure dict. + tree_name: Name of the flow. + tree_type: Type of flow (troubleshooting, procedural, maintenance). + validation_errors: List of dicts with "node_id" and "message" keys. + + Returns: + Tuple of (fixes_list, total_input_tokens, total_output_tokens). + Each fix dict has: target_node_id, error_message, description, + original_node, fixed_node. + """ + provider = get_ai_provider() + fixes: list[dict[str, Any]] = [] + total_input_tokens = 0 + total_output_tokens = 0 + + for error in validation_errors: + node_id = error["node_id"] + error_message = error["message"] + + original_node = _find_node_by_id(tree_structure, node_id) + if original_node is None: + logger.warning("Node %s not found in tree, skipping fix", node_id) + continue + + original_snapshot = copy.deepcopy(original_node) + + # Build prompt and call AI + user_message = _build_fix_prompt( + tree_structure, node_id, error_message, tree_name, tree_type + ) + messages = [{"role": "user", "content": user_message}] + + try: + text, in_tok, out_tok = await provider.generate_json( + system_prompt=FIX_SYSTEM_PROMPT, + messages=messages, + max_tokens=2048, + ) + total_input_tokens += in_tok + total_output_tokens += out_tok + + cleaned = _strip_markdown_fences(text) + fixed_node = json.loads(cleaned) + except (json.JSONDecodeError, Exception) as exc: + logger.warning("AI fix failed for node %s: %s", node_id, exc) + continue + + # Validate by substituting into a tree copy + tree_copy = copy.deepcopy(tree_structure) + _replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node)) + remaining_errors = validate_generated_tree(tree_copy) + + # Check if the specific error is still present + still_has_error = any(node_id in e for e in remaining_errors) + + if still_has_error: + # Retry once with corrective prompt + retry_message = ( + f"Your previous fix still has validation errors:\n" + f"{chr(10).join(remaining_errors)}\n\n" + f"Please fix the node again. Return ONLY the corrected JSON." + ) + messages.append({"role": "assistant", "content": text}) + messages.append({"role": "user", "content": retry_message}) + + try: + text2, in_tok2, out_tok2 = await provider.generate_json( + system_prompt=FIX_SYSTEM_PROMPT, + messages=messages, + max_tokens=2048, + ) + total_input_tokens += in_tok2 + total_output_tokens += out_tok2 + + cleaned2 = _strip_markdown_fences(text2) + fixed_node = json.loads(cleaned2) + + # Re-validate + tree_copy2 = copy.deepcopy(tree_structure) + _replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node)) + remaining2 = validate_generated_tree(tree_copy2) + still_has_error = any(node_id in e for e in remaining2) + except (json.JSONDecodeError, Exception) as exc: + logger.warning("AI retry fix failed for node %s: %s", node_id, exc) + continue + + if still_has_error: + logger.warning("AI could not fix node %s after retry", node_id) + continue + + description = _describe_fix(original_snapshot, fixed_node) + fixes.append( + { + "target_node_id": node_id, + "error_message": error_message, + "description": description, + "original_node": original_snapshot, + "fixed_node": fixed_node, + } + ) + + return fixes, total_input_tokens, total_output_tokens diff --git a/backend/tests/test_ai_fix_service.py b/backend/tests/test_ai_fix_service.py new file mode 100644 index 00000000..24410721 --- /dev/null +++ b/backend/tests/test_ai_fix_service.py @@ -0,0 +1,224 @@ +"""Unit tests for AI fix service helper functions. + +Tests pure Python helpers only — no AI mocking needed. +""" + +import pytest + +from app.core.ai_fix_service import ( + _find_node_by_id, + _find_parent_node, + _serialize_tree_outline, + _strip_markdown_fences, + _replace_node_in_tree, + _describe_fix, +) + + +# ── Sample tree ── + +SAMPLE_TREE = { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs OK", + "description": "Issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}], + "children": [ + { + "id": "done", + "type": "solution", + "title": "Done", + "description": "Fixed.", + } + ], + }, + ], +} + + +# ── _find_node_by_id ── + + +class TestFindNodeById: + def test_finds_root(self): + node = _find_node_by_id(SAMPLE_TREE, "root") + assert node is not None + assert node["id"] == "root" + assert node["type"] == "decision" + + def test_finds_nested_child(self): + node = _find_node_by_id(SAMPLE_TREE, "done") + assert node is not None + assert node["id"] == "done" + assert node["type"] == "solution" + + def test_finds_direct_child(self): + node = _find_node_by_id(SAMPLE_TREE, "check-logs") + assert node is not None + assert node["title"] == "Check Logs" + + def test_returns_none_for_missing(self): + node = _find_node_by_id(SAMPLE_TREE, "nonexistent") + assert node is None + + def test_returns_none_for_non_dict(self): + assert _find_node_by_id("not a dict", "root") is None + + +# ── _find_parent_node ── + + +class TestFindParentNode: + def test_root_has_no_parent(self): + parent = _find_parent_node(SAMPLE_TREE, "root") + assert parent is None + + def test_finds_parent_of_direct_child(self): + parent = _find_parent_node(SAMPLE_TREE, "check-logs") + assert parent is not None + assert parent["id"] == "root" + + def test_finds_parent_of_deeply_nested(self): + parent = _find_parent_node(SAMPLE_TREE, "done") + assert parent is not None + assert parent["id"] == "restart" + + def test_returns_none_for_missing(self): + parent = _find_parent_node(SAMPLE_TREE, "nonexistent") + assert parent is None + + def test_returns_none_for_non_dict(self): + assert _find_parent_node("not a dict", "root") is None + + +# ── _serialize_tree_outline ── + + +class TestSerializeTreeOutline: + def test_produces_readable_outline(self): + outline = _serialize_tree_outline(SAMPLE_TREE) + assert "- [decision] Is the server up?" in outline + assert " - [action] Check Logs" in outline + assert " - [solution] Logs OK" in outline + assert " - [solution] Done" in outline + + def test_marks_error_node(self): + outline = _serialize_tree_outline(SAMPLE_TREE, error_node_id="restart") + assert "<<< ERROR HERE" in outline + # Only the restart node should be marked + lines = outline.split("\n") + error_lines = [l for l in lines if "ERROR HERE" in l] + assert len(error_lines) == 1 + assert "Did restart work?" in error_lines[0] + + def test_no_error_marker_when_none(self): + outline = _serialize_tree_outline(SAMPLE_TREE) + assert "ERROR HERE" not in outline + + def test_handles_non_dict(self): + assert _serialize_tree_outline("not a dict") == "" + + def test_indentation_increases_with_depth(self): + outline = _serialize_tree_outline(SAMPLE_TREE) + lines = outline.split("\n") + # Root has no indentation + assert lines[0].startswith("- [decision]") + # Children have 2-space indent + child_lines = [l for l in lines if "Check Logs" in l] + assert child_lines[0].startswith(" - ") + + +# ── _strip_markdown_fences ── + + +class TestStripMarkdownFences: + def test_strips_json_fences(self): + text = '```json\n{"key": "value"}\n```' + assert _strip_markdown_fences(text) == '{"key": "value"}' + + def test_strips_plain_fences(self): + text = '```\n{"key": "value"}\n```' + assert _strip_markdown_fences(text) == '{"key": "value"}' + + def test_passes_through_plain_json(self): + text = '{"key": "value"}' + assert _strip_markdown_fences(text) == '{"key": "value"}' + + +# ── _replace_node_in_tree ── + + +class TestReplaceNodeInTree: + def test_replaces_root(self): + import copy + + tree = copy.deepcopy(SAMPLE_TREE) + replacement = {"id": "root", "type": "decision", "question": "New question"} + assert _replace_node_in_tree(tree, "root", replacement) is True + assert tree["question"] == "New question" + assert "children" not in tree # cleared and replaced + + def test_replaces_nested_node(self): + import copy + + tree = copy.deepcopy(SAMPLE_TREE) + replacement = {"id": "done", "type": "solution", "title": "All Done", "description": "Complete."} + assert _replace_node_in_tree(tree, "done", replacement) is True + found = _find_node_by_id(tree, "done") + assert found["title"] == "All Done" + + def test_returns_false_for_missing(self): + import copy + + tree = copy.deepcopy(SAMPLE_TREE) + assert _replace_node_in_tree(tree, "nonexistent", {"id": "x"}) is False + + +# ── _describe_fix ── + + +class TestDescribeFix: + def test_describes_added_children(self): + original = {"id": "n1", "children": [{"id": "c1"}]} + fixed = {"id": "n1", "children": [{"id": "c1"}, {"id": "c2"}]} + desc = _describe_fix(original, fixed) + assert "1 child node" in desc + + def test_describes_added_options(self): + original = {"id": "n1", "options": [{"id": "o1"}]} + fixed = {"id": "n1", "options": [{"id": "o1"}, {"id": "o2"}]} + desc = _describe_fix(original, fixed) + assert "1 option" in desc + + def test_describes_added_next_node_id(self): + original = {"id": "n1", "type": "action"} + fixed = {"id": "n1", "type": "action", "next_node_id": "n2"} + desc = _describe_fix(original, fixed) + assert "next_node_id" in desc + + def test_fallback_description(self): + original = {"id": "n1", "type": "solution"} + fixed = {"id": "n1", "type": "solution"} + desc = _describe_fix(original, fixed) + assert "fixed structural issue" in desc.lower() -- 2.49.1 From b3925150d71c26a5572a4ecdb82a1e559ab57f2c Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:25:38 -0500 Subject: [PATCH 09/14] feat: add POST /ai/fix-tree endpoint for AI-powered validation fixes Co-Authored-By: Claude Opus 4.6 --- backend/app/api/endpoints/ai_fix.py | 78 ++++++++++++ backend/app/api/router.py | 2 + backend/tests/test_ai_fix_endpoint.py | 169 ++++++++++++++++++++++++++ 3 files changed, 249 insertions(+) create mode 100644 backend/app/api/endpoints/ai_fix.py create mode 100644 backend/tests/test_ai_fix_endpoint.py diff --git a/backend/app/api/endpoints/ai_fix.py b/backend/app/api/endpoints/ai_fix.py new file mode 100644 index 00000000..97ecbaf9 --- /dev/null +++ b/backend/app/api/endpoints/ai_fix.py @@ -0,0 +1,78 @@ +"""AI auto-fix endpoint for tree validation errors. + +POST /ai/fix-tree — accepts a tree with validation errors and returns +AI-generated fix proposals for each error. +""" + +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_db, require_engineer_or_admin +from app.core.config import settings +from app.core.rate_limit import limiter +from app.core.ai_fix_service import generate_fixes +from app.models.user import User +from app.schemas.ai_fix import ( + AIFixTreeRequest, + AIFixTreeResponse, + AIFixProposal, + AIFixTokenUsage, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ai", tags=["ai-fix"]) + + +def _require_ai_enabled() -> None: + """Raise 503 if AI is not configured.""" + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI fix is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) + + +@router.post("/fix-tree", response_model=AIFixTreeResponse) +@limiter.limit("10/minute") +async def fix_tree( + request: Request, + body: AIFixTreeRequest, + user: Annotated[User, Depends(require_engineer_or_admin)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Generate AI-powered fixes for tree validation errors.""" + _require_ai_enabled() + + validation_errors = [ + {"node_id": e.node_id, "message": e.message} + for e in body.validation_errors + ] + + try: + fixes, input_tokens, output_tokens = await generate_fixes( + tree_structure=body.tree_structure, + tree_name=body.tree_name, + tree_type=body.tree_type, + validation_errors=validation_errors, + ) + except RuntimeError as exc: + logger.error("AI provider not available: %s", exc) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(exc), + ) + except Exception as exc: + logger.exception("Unexpected error in AI fix service") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred while generating fixes.", + ) + + return AIFixTreeResponse( + fixes=[AIFixProposal(**f) for f in fixes], + tokens_used=AIFixTokenUsage(input=input_tokens, output=output_tokens), + ) diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 2c79e039..27963a1f 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -6,6 +6,7 @@ from app.api.endpoints import target_lists from app.api.endpoints import maintenance_schedules from app.api.endpoints import feedback from app.api.endpoints import ai_builder +from app.api.endpoints import ai_fix api_router = APIRouter() @@ -36,3 +37,4 @@ api_router.include_router(target_lists.router) api_router.include_router(maintenance_schedules.router) api_router.include_router(feedback.router) api_router.include_router(ai_builder.router) +api_router.include_router(ai_fix.router) diff --git a/backend/tests/test_ai_fix_endpoint.py b/backend/tests/test_ai_fix_endpoint.py new file mode 100644 index 00000000..a81a598e --- /dev/null +++ b/backend/tests/test_ai_fix_endpoint.py @@ -0,0 +1,169 @@ +"""Integration tests for the POST /ai/fix-tree endpoint.""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.core.config import settings + + +# ── Sample tree (has a decision node with only 1 option + 1 child) ── + +SAMPLE_TREE = { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs OK", + "description": "Issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}], + "children": [ + { + "id": "done", + "type": "solution", + "title": "Done", + "description": "Fixed.", + } + ], + }, + ], +} + +# Fixed version of the "restart" node — 2 options, 2 children +FIXED_RESTART_NODE = { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [ + {"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"}, + {"id": "opt-r-no", "label": "No", "next_node_id": "escalate"}, + ], + "children": [ + { + "id": "done", + "type": "solution", + "title": "Done", + "description": "Fixed.", + }, + { + "id": "escalate", + "type": "solution", + "title": "Escalate", + "description": "Escalate to senior engineer.", + }, + ], +} + +FIX_REQUEST_BODY = { + "tree_structure": SAMPLE_TREE, + "tree_name": "Server Troubleshooting", + "tree_type": "troubleshooting", + "validation_errors": [ + { + "node_id": "restart", + "message": "Decision node 'restart' must have at least 2 options", + } + ], +} + + +def _mock_ai_provider(response_text: str, input_tokens: int = 50, output_tokens: int = 100): + """Create a mock provider whose generate_json returns given text.""" + provider = MagicMock() + provider.generate_json = AsyncMock(return_value=(response_text, input_tokens, output_tokens)) + return provider + + +@pytest.fixture +def enable_ai(): + """Temporarily enable AI by setting a fake API key.""" + original = settings.GOOGLE_AI_API_KEY + settings.GOOGLE_AI_API_KEY = "test-key-fake" + yield + settings.GOOGLE_AI_API_KEY = original + + +@pytest.fixture +def disable_ai(): + """Ensure AI is disabled.""" + orig_google = settings.GOOGLE_AI_API_KEY + orig_anthropic = settings.ANTHROPIC_API_KEY + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + yield + settings.GOOGLE_AI_API_KEY = orig_google + settings.ANTHROPIC_API_KEY = orig_anthropic + + +# ── Tests ── + + +@pytest.mark.asyncio +async def test_returns_401_without_auth(client): + """POST /ai/fix-tree without auth token returns 401.""" + response = await client.post("/api/v1/ai/fix-tree", json=FIX_REQUEST_BODY) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_returns_503_when_ai_disabled(client, auth_headers, disable_ai): + """POST /ai/fix-tree returns 503 when no AI keys are configured.""" + response = await client.post( + "/api/v1/ai/fix-tree", + json=FIX_REQUEST_BODY, + headers=auth_headers, + ) + assert response.status_code == 503 + + +@pytest.mark.asyncio +async def test_returns_fixes_on_success(client, auth_headers, enable_ai): + """POST /ai/fix-tree returns fix proposals when AI succeeds.""" + mock_provider = _mock_ai_provider(json.dumps(FIXED_RESTART_NODE)) + + with patch( + "app.core.ai_fix_service.get_ai_provider", + return_value=mock_provider, + ): + response = await client.post( + "/api/v1/ai/fix-tree", + json=FIX_REQUEST_BODY, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert "fixes" in data + assert "tokens_used" in data + assert len(data["fixes"]) == 1 + + fix = data["fixes"][0] + assert fix["target_node_id"] == "restart" + assert fix["error_message"] == "Decision node 'restart' must have at least 2 options" + assert fix["original_node"]["id"] == "restart" + assert fix["fixed_node"]["id"] == "restart" + assert len(fix["fixed_node"]["options"]) == 2 + assert len(fix["fixed_node"]["children"]) == 2 + + assert data["tokens_used"]["input"] == 50 + assert data["tokens_used"]["output"] == 100 -- 2.49.1 From 29dc95e920fd6e2dabbf998696cc2648baaead81 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 17:29:53 -0500 Subject: [PATCH 10/14] =?UTF-8?q?feat:=20add=20AI=20auto-fix=20UI=20?= =?UTF-8?q?=E2=80=94=20types,=20API=20client,=20ValidationSummary=20button?= =?UTF-8?q?,=20review=20modal,=20and=20TreeEditorPage=20wiring?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - New ai-fix.ts types for request/response - fixTree() method on treesApi - "Fix with AI" button in ValidationSummary (shows for fixable errors) - AIFixReviewModal with per-fix apply/skip and apply-all - TreeEditorPage orchestrates the fix flow Co-Authored-By: Claude Opus 4.6 --- frontend/src/api/trees.ts | 8 +- .../tree-editor/AIFixReviewModal.tsx | 170 ++++++++++++++++++ .../tree-editor/ValidationSummary.tsx | 50 +++++- frontend/src/pages/TreeEditorPage.tsx | 65 ++++++- frontend/src/types/ai-fix.ts | 24 +++ frontend/src/types/index.ts | 7 + 6 files changed, 313 insertions(+), 11 deletions(-) create mode 100644 frontend/src/components/tree-editor/AIFixReviewModal.tsx create mode 100644 frontend/src/types/ai-fix.ts diff --git a/frontend/src/api/trees.ts b/frontend/src/api/trees.ts index 8b21840c..5f12603d 100644 --- a/frontend/src/api/trees.ts +++ b/frontend/src/api/trees.ts @@ -1,5 +1,5 @@ import apiClient from './client' -import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse } from '@/types' +import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse, AIFixTreeRequest, AIFixTreeResponse } from '@/types' export const treesApi = { async list(params?: TreeFilters): Promise { @@ -65,6 +65,12 @@ export const treesApi = { const response = await apiClient.post(`/trees/${id}/can-publish`) return response.data }, + + // AI auto-fix + async fixTree(request: AIFixTreeRequest): Promise { + const response = await apiClient.post('/ai/fix-tree', request) + return response.data + }, } export default treesApi diff --git a/frontend/src/components/tree-editor/AIFixReviewModal.tsx b/frontend/src/components/tree-editor/AIFixReviewModal.tsx new file mode 100644 index 00000000..acda5769 --- /dev/null +++ b/frontend/src/components/tree-editor/AIFixReviewModal.tsx @@ -0,0 +1,170 @@ +import { useState } from 'react' +import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react' +import { cn } from '@/lib/utils' +import type { AIFixProposal } from '@/types' + +interface AIFixReviewModalProps { + fixes: AIFixProposal[] + onApply: (fix: AIFixProposal) => void + onApplyAll: () => void + onClose: () => void +} + +export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) { + const [appliedIds, setAppliedIds] = useState>(new Set()) + const [skippedIds, setSkippedIds] = useState>(new Set()) + const [expandedIds, setExpandedIds] = useState>(new Set(fixes.map(f => f.target_node_id))) + + const handleApply = (fix: AIFixProposal) => { + onApply(fix) + setAppliedIds(prev => new Set(prev).add(fix.target_node_id)) + } + + const handleSkip = (fix: AIFixProposal) => { + setSkippedIds(prev => new Set(prev).add(fix.target_node_id)) + } + + const toggleExpanded = (id: string) => { + setExpandedIds(prev => { + const next = new Set(prev) + if (next.has(id)) next.delete(id) + else next.add(id) + return next + }) + } + + const pendingFixes = fixes.filter( + f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id) + ) + const allHandled = pendingFixes.length === 0 + + return ( +
+
+ {/* Header */} +
+
+ +

+ AI Fix Proposals ({fixes.length}) +

+
+ +
+ + {/* Body */} +
+ {fixes.map((fix) => { + const isApplied = appliedIds.has(fix.target_node_id) + const isSkipped = skippedIds.has(fix.target_node_id) + const isExpanded = expandedIds.has(fix.target_node_id) + + return ( +
+ {/* Fix header */} +
+
+

{fix.error_message}

+

{fix.description}

+

+ Node: {fix.target_node_id} +

+
+ {isApplied && ( + + Applied + + )} + {isSkipped && ( + Skipped + )} +
+ + {/* Expand/collapse detail */} + {!isApplied && !isSkipped && ( + <> + + + {isExpanded && ( +
+
+

Before

+
+                            {JSON.stringify(fix.original_node, null, 2)}
+                          
+
+
+

After

+
+                            {JSON.stringify(fix.fixed_node, null, 2)}
+                          
+
+
+ )} + + {/* Action buttons */} +
+ + +
+ + )} +
+ ) + })} +
+ + {/* Footer */} +
+ + {!allHandled && ( + + )} +
+
+
+ ) +} diff --git a/frontend/src/components/tree-editor/ValidationSummary.tsx b/frontend/src/components/tree-editor/ValidationSummary.tsx index fcf87bad..987be180 100644 --- a/frontend/src/components/tree-editor/ValidationSummary.tsx +++ b/frontend/src/components/tree-editor/ValidationSummary.tsx @@ -1,14 +1,16 @@ import { useState } from 'react' -import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp } from 'lucide-react' +import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react' import { cn } from '@/lib/utils' import type { ValidationError } from '@/store/treeEditorStore' interface ValidationSummaryProps { errors: ValidationError[] onSelectNode: (nodeId: string) => void + onFixWithAI?: () => void + isFixing?: boolean } -export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryProps) { +export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) { const [isExpanded, setIsExpanded] = useState(true) const errorItems = errors.filter(e => e.severity === 'error') @@ -22,6 +24,8 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro } } + const hasFixableErrors = errorItems.some(e => e.nodeId) + return (
{/* Header */} -
- {isExpanded ? : } - + {isExpanded ? : } + + + {/* Fix with AI button */} + {onFixWithAI && hasFixableErrors && ( + + )} + {/* Error/Warning List */} {isExpanded && ( diff --git a/frontend/src/pages/TreeEditorPage.tsx b/frontend/src/pages/TreeEditorPage.tsx index a21097b4..fa3ba215 100644 --- a/frontend/src/pages/TreeEditorPage.tsx +++ b/frontend/src/pages/TreeEditorPage.tsx @@ -5,10 +5,11 @@ import { Undo2, Redo2, Save, CheckCircle2, Monitor, FileText, Code2, LayoutList, import { getMonacoEditor } from '@/components/tree-editor/code-mode' import { treesApi } from '@/api/trees' import { treeMarkdownApi } from '@/api/treeMarkdown' -import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure } from '@/types' +import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure, AIFixProposal } from '@/types' import { useTreeEditorStore, useTreeEditorTemporal } from '@/store/treeEditorStore' import { TreeEditorLayout } from '@/components/tree-editor/TreeEditorLayout' import { ValidationSummary } from '@/components/tree-editor/ValidationSummary' +import { AIFixReviewModal } from '@/components/tree-editor/AIFixReviewModal' import { useKeyboardShortcuts } from '@/hooks/useKeyboardShortcuts' import { usePermissions } from '@/hooks/usePermissions' import { Spinner } from '@/components/common/Spinner' @@ -58,6 +59,8 @@ export function TreeEditorPage() { const [showAnalytics, setShowAnalytics] = useState(false) const [isMetadataOpen, setIsMetadataOpen] = useState(false) const [editingNodeId, setEditingNodeId] = useState(null) + const [isFixing, setIsFixing] = useState(false) + const [fixProposals, setFixProposals] = useState(null) // Mobile detection const [isMobile, setIsMobile] = useState(false) @@ -217,6 +220,54 @@ export function TreeEditorPage() { selectNode(nodeId) } + const handleFixWithAI = async () => { + const store = useTreeEditorStore.getState() + if (!store.treeStructure) return + + const fixableErrors = store.validationErrors + .filter(e => e.severity === 'error' && e.nodeId) + .map(e => ({ node_id: e.nodeId!, message: e.message })) + + if (fixableErrors.length === 0) return + + setIsFixing(true) + try { + const result = await treesApi.fixTree({ + tree_structure: store.treeStructure as unknown as Record, + tree_name: store.name, + tree_type: 'troubleshooting', + validation_errors: fixableErrors, + }) + if (result.fixes.length > 0) { + setFixProposals(result.fixes) + } else { + toast.info('AI could not generate fixes for these errors') + } + } catch { + toast.error('Failed to generate AI fixes. Please try again.') + } finally { + setIsFixing(false) + } + } + + const handleApplyFix = (fix: AIFixProposal) => { + updateNode(fix.target_node_id, fix.fixed_node as Partial) + } + + const handleApplyAllFixes = () => { + if (!fixProposals) return + for (const fix of fixProposals) { + handleApplyFix(fix) + } + setFixProposals(null) + setTimeout(() => { validate() }, 100) + } + + const handleCloseFixModal = () => { + setFixProposals(null) + validate() + } + const handleNodeSelect = useCallback((nodeId: string | null) => { if (nodeId) { setIsMetadataOpen(false) // close metadata when opening node editor @@ -685,6 +736,8 @@ export function TreeEditorPage() { )} @@ -705,6 +758,16 @@ export function TreeEditorPage() { )} + + {/* AI Fix Review Modal */} + {fixProposals && ( + + )} ) } diff --git a/frontend/src/types/ai-fix.ts b/frontend/src/types/ai-fix.ts new file mode 100644 index 00000000..a12a6fdf --- /dev/null +++ b/frontend/src/types/ai-fix.ts @@ -0,0 +1,24 @@ +export interface AIFixValidationError { + node_id: string + message: string +} + +export interface AIFixProposal { + target_node_id: string + error_message: string + description: string + original_node: Record + fixed_node: Record +} + +export interface AIFixTreeRequest { + tree_structure: Record + tree_name: string + tree_type: 'troubleshooting' | 'procedural' | 'maintenance' + validation_errors: AIFixValidationError[] +} + +export interface AIFixTreeResponse { + fixes: AIFixProposal[] + tokens_used: { input: number; output: number } +} diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 1bc91dc2..2c388169 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -45,3 +45,10 @@ export type { AIAssembleResponse, AIWizardPhase, } from './ai' + +export type { + AIFixTreeRequest, + AIFixTreeResponse, + AIFixProposal, + AIFixValidationError, +} from './ai-fix' -- 2.49.1 From 0fb3126fd22b23fdf3bc966e33ab8a8b654b42ca Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 26 Feb 2026 23:49:07 -0500 Subject: [PATCH 11/14] fix: add error logging and error type to AI builder 502 responses The generic "AI provider error" message made debugging impossible. Now logs the full exception traceback and includes the error class name in the 502 response detail. Co-Authored-By: Claude Opus 4.6 --- backend/app/api/endpoints/ai_builder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backend/app/api/endpoints/ai_builder.py b/backend/app/api/endpoints/ai_builder.py index dcb0a966..c8771d16 100644 --- a/backend/app/api/endpoints/ai_builder.py +++ b/backend/app/api/endpoints/ai_builder.py @@ -195,6 +195,7 @@ async def scaffold( detail=f"AI returned invalid output: {e}", ) except Exception as e: + logger.exception("AI scaffold failed: %s: %s", type(e).__name__, e) await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, @@ -213,7 +214,7 @@ async def scaffold( await db.commit() raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, - detail="AI provider error. Please try again.", + detail=f"AI provider error ({type(e).__name__}). Please try again.", ) # Record successful usage @@ -314,6 +315,7 @@ async def branch_detail( detail=f"AI returned invalid output: {e}", ) except Exception as e: + logger.exception("AI branch_detail failed: %s: %s", type(e).__name__, e) await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, @@ -332,7 +334,7 @@ async def branch_detail( await db.commit() raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, - detail="AI provider error. Please try again.", + detail=f"AI provider error ({type(e).__name__}). Please try again.", ) # Record successful usage -- 2.49.1 From dc68d992a48cb46564c131f2bd5cf6bfd345a806 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Fri, 27 Feb 2026 00:02:47 -0500 Subject: [PATCH 12/14] debug: add temporary /ai/provider-debug endpoint Shows which provider is selected and whether keys are loaded. Remove after debugging the 502 on PR env. Co-Authored-By: Claude Opus 4.6 --- backend/app/api/endpoints/ai_builder.py | 27 +++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/backend/app/api/endpoints/ai_builder.py b/backend/app/api/endpoints/ai_builder.py index c8771d16..b0c7aa41 100644 --- a/backend/app/api/endpoints/ai_builder.py +++ b/backend/app/api/endpoints/ai_builder.py @@ -55,6 +55,33 @@ def _require_ai_enabled() -> None: ) +@router.get("/provider-debug") +async def provider_debug( + current_user: Annotated[User, Depends(get_current_active_user)], + _: None = Depends(require_engineer_or_admin), +): + """Temporary debug endpoint — shows which AI provider would be selected.""" + from app.core.ai_provider import get_ai_provider + has_gemini_key = bool(settings.GOOGLE_AI_API_KEY) + has_anthropic_key = bool(settings.ANTHROPIC_API_KEY) + provider_setting = settings.AI_PROVIDER + try: + provider = get_ai_provider() + provider_type = type(provider).__name__ + except RuntimeError as e: + provider_type = f"ERROR: {e}" + return { + "ai_provider_setting": provider_setting, + "has_gemini_key": has_gemini_key, + "gemini_key_prefix": settings.GOOGLE_AI_API_KEY[:8] + "..." if settings.GOOGLE_AI_API_KEY else None, + "has_anthropic_key": has_anthropic_key, + "anthropic_key_prefix": settings.ANTHROPIC_API_KEY[:8] + "..." if settings.ANTHROPIC_API_KEY else None, + "selected_provider": provider_type, + "gemini_model": settings.AI_MODEL_GEMINI, + "anthropic_model": settings.AI_MODEL_ANTHROPIC, + } + + @router.get("/quota", response_model=AIQuotaStatusResponse) async def get_quota( current_user: Annotated[User, Depends(get_current_active_user)], -- 2.49.1 From 957f13b99371ef167866af9655fa1da4637da70d Mon Sep 17 00:00:00 2001 From: chihlasm Date: Fri, 27 Feb 2026 00:08:20 -0500 Subject: [PATCH 13/14] fix: use correct google-genai async API and remove debug endpoint The google-genai SDK uses `client.aio.models.generate_content()` for async calls, not `client.models.generate_content_async()` which doesn't exist. Also removes the temporary /ai/provider-debug endpoint. Co-Authored-By: Claude Opus 4.6 --- backend/app/api/endpoints/ai_builder.py | 28 ------------------------- backend/app/core/ai_provider.py | 2 +- backend/tests/test_ai_provider.py | 6 +++--- 3 files changed, 4 insertions(+), 32 deletions(-) diff --git a/backend/app/api/endpoints/ai_builder.py b/backend/app/api/endpoints/ai_builder.py index b0c7aa41..f3740d07 100644 --- a/backend/app/api/endpoints/ai_builder.py +++ b/backend/app/api/endpoints/ai_builder.py @@ -54,34 +54,6 @@ def _require_ai_enabled() -> None: detail="AI flow builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", ) - -@router.get("/provider-debug") -async def provider_debug( - current_user: Annotated[User, Depends(get_current_active_user)], - _: None = Depends(require_engineer_or_admin), -): - """Temporary debug endpoint — shows which AI provider would be selected.""" - from app.core.ai_provider import get_ai_provider - has_gemini_key = bool(settings.GOOGLE_AI_API_KEY) - has_anthropic_key = bool(settings.ANTHROPIC_API_KEY) - provider_setting = settings.AI_PROVIDER - try: - provider = get_ai_provider() - provider_type = type(provider).__name__ - except RuntimeError as e: - provider_type = f"ERROR: {e}" - return { - "ai_provider_setting": provider_setting, - "has_gemini_key": has_gemini_key, - "gemini_key_prefix": settings.GOOGLE_AI_API_KEY[:8] + "..." if settings.GOOGLE_AI_API_KEY else None, - "has_anthropic_key": has_anthropic_key, - "anthropic_key_prefix": settings.ANTHROPIC_API_KEY[:8] + "..." if settings.ANTHROPIC_API_KEY else None, - "selected_provider": provider_type, - "gemini_model": settings.AI_MODEL_GEMINI, - "anthropic_model": settings.AI_MODEL_ANTHROPIC, - } - - @router.get("/quota", response_model=AIQuotaStatusResponse) async def get_quota( current_user: Annotated[User, Depends(get_current_active_user)], diff --git a/backend/app/core/ai_provider.py b/backend/app/core/ai_provider.py index b3cf16e4..cea16269 100644 --- a/backend/app/core/ai_provider.py +++ b/backend/app/core/ai_provider.py @@ -68,7 +68,7 @@ class GeminiProvider(AIProvider): response_mime_type="application/json", ) - response = await client.models.generate_content_async( + response = await client.aio.models.generate_content( model=self._model, contents=contents, config=config, diff --git a/backend/tests/test_ai_provider.py b/backend/tests/test_ai_provider.py index a263d5e3..611c8e7b 100644 --- a/backend/tests/test_ai_provider.py +++ b/backend/tests/test_ai_provider.py @@ -137,7 +137,7 @@ class TestGeminiProvider: mock_response.usage_metadata = mock_usage mock_client = MagicMock() - mock_client.models.generate_content_async = AsyncMock( + mock_client.aio.models.generate_content = AsyncMock( return_value=mock_response ) @@ -172,7 +172,7 @@ class TestGeminiProvider: assert input_tokens == 80 assert output_tokens == 40 - mock_client.models.generate_content_async.assert_called_once() + mock_client.aio.models.generate_content.assert_called_once() @pytest.mark.asyncio async def test_generate_json_handles_none_usage(self): @@ -185,7 +185,7 @@ class TestGeminiProvider: mock_response.usage_metadata = mock_usage mock_client = MagicMock() - mock_client.models.generate_content_async = AsyncMock( + mock_client.aio.models.generate_content = AsyncMock( return_value=mock_response ) -- 2.49.1 From 6fc76187c0d76e4d4c8120238efcbd2c4eb71cb1 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Fri, 27 Feb 2026 00:15:01 -0500 Subject: [PATCH 14/14] fix: add diagnostic logging and increase scaffold max_tokens to 2048 The "Unterminated string" JSON parse error is likely caused by Gemini output truncation at 1024 tokens. Increases scaffold max_tokens to 2048 and adds logging for: raw response text, finish_reason (truncation detection), and JSON parse failures. Co-Authored-By: Claude Opus 4.6 --- backend/app/core/ai_provider.py | 13 +++++++++++++ backend/app/core/ai_tree_generator_service.py | 14 +++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/backend/app/core/ai_provider.py b/backend/app/core/ai_provider.py index cea16269..cb3f7178 100644 --- a/backend/app/core/ai_provider.py +++ b/backend/app/core/ai_provider.py @@ -5,10 +5,13 @@ Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable backends for JSON generation used by the AI Flow Builder. """ +import logging from abc import ABC, abstractmethod from app.core.config import settings +logger = logging.getLogger(__name__) + class AIProvider(ABC): """Abstract base class for AI providers.""" @@ -74,6 +77,16 @@ class GeminiProvider(AIProvider): config=config, ) + # Log finish reason to detect truncation + if response.candidates: + finish_reason = getattr(response.candidates[0], "finish_reason", None) + logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model) + if str(finish_reason) == "MAX_TOKENS": + logger.warning( + "Gemini output truncated (MAX_TOKENS). max_output_tokens=%d", + max_tokens, + ) + text = response.text or "" input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 output_tokens = ( diff --git a/backend/app/core/ai_tree_generator_service.py b/backend/app/core/ai_tree_generator_service.py index 7a562d1c..bf560874 100644 --- a/backend/app/core/ai_tree_generator_service.py +++ b/backend/app/core/ai_tree_generator_service.py @@ -154,15 +154,23 @@ async def scaffold_branches( raw_text, input_tokens, output_tokens = await provider.generate_json( system_prompt=SCAFFOLD_SYSTEM_PROMPT, messages=[{"role": "user", "content": user_message}], - max_tokens=1024, + max_tokens=2048, ) + logger.info( + "scaffold raw response (tokens in=%d out=%d, len=%d): %s", + input_tokens, + output_tokens, + len(raw_text), + raw_text[:500], + ) raw_text = _strip_markdown_fences(raw_text) cost = _estimate_cost(input_tokens, output_tokens) try: data = json.loads(raw_text) except json.JSONDecodeError as e: + logger.error("scaffold JSON parse failed. Full text (%d chars): %s", len(raw_text), raw_text) raise ValueError(f"AI returned invalid JSON: {e}") branches = data.get("branches", []) @@ -224,6 +232,10 @@ async def generate_branch_detail( try: branch_tree = json.loads(raw_text) except json.JSONDecodeError as e: + logger.error( + "branch_detail attempt=%d JSON parse failed (%d chars): %s", + attempt, len(raw_text), raw_text[:500], + ) if attempt < 2: messages.append({"role": "assistant", "content": raw_text}) messages.append({ -- 2.49.1