diff --git a/backend/app/api/endpoints/ai_builder.py b/backend/app/api/endpoints/ai_builder.py index 5ec8d55a..f3740d07 100644 --- a/backend/app/api/endpoints/ai_builder.py +++ b/backend/app/api/endpoints/ai_builder.py @@ -10,7 +10,6 @@ import logging from typing import Annotated -import anthropic from fastapi import APIRouter, Depends, HTTPException, Request, status from sqlalchemy.ext.asyncio import AsyncSession @@ -52,10 +51,9 @@ def _require_ai_enabled() -> None: if not settings.ai_enabled: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.", + detail="AI flow builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", ) - @router.get("/quota", response_model=AIQuotaStatusResponse) async def get_quota( current_user: Annotated[User, Depends(get_current_active_user)], @@ -174,27 +172,6 @@ async def scaffold( branches, input_tokens, output_tokens, cost = await scaffold_branches( conversation.wizard_state, ) - except anthropic.APIError as e: - await record_ai_usage( - user_id=current_user.id, - account_id=current_user.account_id, - conversation_id=conversation.id, - generation_type="scaffold", - tier=plan, - input_tokens=0, - output_tokens=0, - estimated_cost=0, - succeeded=False, - counts_toward_quota=False, - error_code=type(e).__name__, - extra_data={"error": str(e)}, - db=db, - ) - await db.commit() - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail="AI provider error. Please try again.", - ) except ValueError as e: await record_ai_usage( user_id=current_user.id, @@ -216,6 +193,28 @@ async def scaffold( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"AI returned invalid output: {e}", ) + except Exception as e: + logger.exception("AI scaffold failed: %s: %s", type(e).__name__, e) + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="scaffold", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"error": str(e)}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"AI provider error ({type(e).__name__}). Please try again.", + ) # Record successful usage await record_ai_usage( @@ -293,27 +292,6 @@ async def branch_detail( existing_branches, ) ) - except anthropic.APIError as e: - await record_ai_usage( - user_id=current_user.id, - account_id=current_user.account_id, - conversation_id=conversation.id, - generation_type="branch_detail", - tier=plan, - input_tokens=0, - output_tokens=0, - estimated_cost=0, - succeeded=False, - counts_toward_quota=False, - error_code=type(e).__name__, - extra_data={"error": str(e), "branch_name": data.branch_name}, - db=db, - ) - await db.commit() - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail="AI provider error. Please try again.", - ) except ValueError as e: await record_ai_usage( user_id=current_user.id, @@ -335,6 +313,28 @@ async def branch_detail( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"AI returned invalid output: {e}", ) + except Exception as e: + logger.exception("AI branch_detail failed: %s: %s", type(e).__name__, e) + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="branch_detail", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"error": str(e), "branch_name": data.branch_name}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"AI provider error ({type(e).__name__}). Please try again.", + ) # Record successful usage await record_ai_usage( diff --git a/backend/app/api/endpoints/ai_fix.py b/backend/app/api/endpoints/ai_fix.py new file mode 100644 index 00000000..97ecbaf9 --- /dev/null +++ b/backend/app/api/endpoints/ai_fix.py @@ -0,0 +1,78 @@ +"""AI auto-fix endpoint for tree validation errors. + +POST /ai/fix-tree — accepts a tree with validation errors and returns +AI-generated fix proposals for each error. +""" + +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_db, require_engineer_or_admin +from app.core.config import settings +from app.core.rate_limit import limiter +from app.core.ai_fix_service import generate_fixes +from app.models.user import User +from app.schemas.ai_fix import ( + AIFixTreeRequest, + AIFixTreeResponse, + AIFixProposal, + AIFixTokenUsage, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ai", tags=["ai-fix"]) + + +def _require_ai_enabled() -> None: + """Raise 503 if AI is not configured.""" + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI fix is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) + + +@router.post("/fix-tree", response_model=AIFixTreeResponse) +@limiter.limit("10/minute") +async def fix_tree( + request: Request, + body: AIFixTreeRequest, + user: Annotated[User, Depends(require_engineer_or_admin)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Generate AI-powered fixes for tree validation errors.""" + _require_ai_enabled() + + validation_errors = [ + {"node_id": e.node_id, "message": e.message} + for e in body.validation_errors + ] + + try: + fixes, input_tokens, output_tokens = await generate_fixes( + tree_structure=body.tree_structure, + tree_name=body.tree_name, + tree_type=body.tree_type, + validation_errors=validation_errors, + ) + except RuntimeError as exc: + logger.error("AI provider not available: %s", exc) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(exc), + ) + except Exception as exc: + logger.exception("Unexpected error in AI fix service") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred while generating fixes.", + ) + + return AIFixTreeResponse( + fixes=[AIFixProposal(**f) for f in fixes], + tokens_used=AIFixTokenUsage(input=input_tokens, output=output_tokens), + ) diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 2c79e039..27963a1f 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -6,6 +6,7 @@ from app.api.endpoints import target_lists from app.api.endpoints import maintenance_schedules from app.api.endpoints import feedback from app.api.endpoints import ai_builder +from app.api.endpoints import ai_fix api_router = APIRouter() @@ -36,3 +37,4 @@ api_router.include_router(target_lists.router) api_router.include_router(maintenance_schedules.router) api_router.include_router(feedback.router) api_router.include_router(ai_builder.router) +api_router.include_router(ai_fix.router) diff --git a/backend/app/core/ai_fix_service.py b/backend/app/core/ai_fix_service.py new file mode 100644 index 00000000..02350a15 --- /dev/null +++ b/backend/app/core/ai_fix_service.py @@ -0,0 +1,273 @@ +"""AI-powered fix service for tree validation errors. + +Given a tree structure and validation errors, generates AI-powered +proposals to fix each structural issue while preserving existing content. +""" + +import copy +import json +import logging +import re +from typing import Any + +from app.core.ai_provider import get_ai_provider +from app.core.ai_tree_validator import validate_generated_tree + +logger = logging.getLogger(__name__) + + +FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers. + +You will receive: +1. A full flow outline showing the tree structure +2. The specific failing node with its full JSON +3. The validation error message + +Your task: Return a FIXED version of the failing node as valid JSON. Rules: +- Fix ONLY the structural issue described in the error message +- Keep ALL existing content (titles, descriptions, questions, options) unchanged +- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic +- Every new node must have a unique ID (use descriptive kebab-case IDs) +- Decision nodes must have at least 2 options and at least 2 children +- Action nodes must have a next_node_id pointing to a sibling node in the parent's children +- Solution nodes are leaf nodes (no children) +- Return ONLY the fixed node JSON, no explanation""" + + +# ── Pure helper functions ── + + +def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None: + """Recursively find a node by its ID in the tree structure.""" + if not isinstance(tree, dict): + return None + if tree.get("id") == node_id: + return tree + for child in tree.get("children", []): + result = _find_node_by_id(child, node_id) + if result is not None: + return result + return None + + +def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None: + """Find the parent of a node with the given ID.""" + if not isinstance(tree, dict): + return None + for child in tree.get("children", []): + if isinstance(child, dict) and child.get("id") == target_id: + return tree + result = _find_parent_node(child, target_id) + if result is not None: + return result + return None + + +def _serialize_tree_outline( + tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None +) -> str: + """Serialize tree as a readable outline for AI prompt context. + + Format: indented "- [type] label" with "<<< ERROR HERE" marker. + """ + if not isinstance(tree, dict): + return "" + + node_type = tree.get("type", "unknown") + label = tree.get("question") or tree.get("title") or tree.get("id", "?") + prefix = " " * indent + marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else "" + line = f"{prefix}- [{node_type}] {label}{marker}" + + lines = [line] + for child in tree.get("children", []): + lines.append(_serialize_tree_outline(child, indent + 1, error_node_id)) + + return "\n".join(lines) + + +def _strip_markdown_fences(text: str) -> str: + """Strip ```json...``` fences from AI response.""" + return re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.MULTILINE).rstrip( + "`" + ).strip() + + +def _replace_node_in_tree( + tree: dict[str, Any], target_id: str, replacement: dict[str, Any] +) -> bool: + """Replace a node in-place by ID. Returns True if found and replaced.""" + if not isinstance(tree, dict): + return False + if tree.get("id") == target_id: + tree.clear() + tree.update(replacement) + return True + for child in tree.get("children", []): + if _replace_node_in_tree(child, target_id, replacement): + return True + return False + + +def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str: + """Describe what changed between original and fixed node.""" + changes: list[str] = [] + + orig_children = len(original.get("children", [])) + fixed_children = len(fixed.get("children", [])) + if fixed_children > orig_children: + changes.append(f"added {fixed_children - orig_children} child node(s)") + + orig_options = len(original.get("options", [])) + fixed_options = len(fixed.get("options", [])) + if fixed_options > orig_options: + changes.append(f"added {fixed_options - orig_options} option(s)") + + if fixed.get("next_node_id") and not original.get("next_node_id"): + changes.append("added next_node_id") + + if not changes: + changes.append("fixed structural issue") + + return "; ".join(changes).capitalize() + + +# ── Prompt building ── + + +def _build_fix_prompt( + tree: dict[str, Any], + node_id: str, + error_message: str, + tree_name: str, + tree_type: str, +) -> str: + """Build the user message for the AI fix request.""" + outline = _serialize_tree_outline(tree, error_node_id=node_id) + node = _find_node_by_id(tree, node_id) + node_json = json.dumps(node, indent=2) if node else "{}" + + return ( + f"Flow name: {tree_name}\n" + f"Flow type: {tree_type}\n\n" + f"## Full flow outline\n```\n{outline}\n```\n\n" + f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n" + f"## Validation error\n{error_message}\n\n" + f"Return the fixed version of this node as JSON." + ) + + +# ── Main entry point ── + + +async def generate_fixes( + tree_structure: dict[str, Any], + tree_name: str, + tree_type: str, + validation_errors: list[dict[str, str]], +) -> tuple[list[dict[str, Any]], int, int]: + """Generate AI-powered fixes for tree validation errors. + + Args: + tree_structure: Full tree structure dict. + tree_name: Name of the flow. + tree_type: Type of flow (troubleshooting, procedural, maintenance). + validation_errors: List of dicts with "node_id" and "message" keys. + + Returns: + Tuple of (fixes_list, total_input_tokens, total_output_tokens). + Each fix dict has: target_node_id, error_message, description, + original_node, fixed_node. + """ + provider = get_ai_provider() + fixes: list[dict[str, Any]] = [] + total_input_tokens = 0 + total_output_tokens = 0 + + for error in validation_errors: + node_id = error["node_id"] + error_message = error["message"] + + original_node = _find_node_by_id(tree_structure, node_id) + if original_node is None: + logger.warning("Node %s not found in tree, skipping fix", node_id) + continue + + original_snapshot = copy.deepcopy(original_node) + + # Build prompt and call AI + user_message = _build_fix_prompt( + tree_structure, node_id, error_message, tree_name, tree_type + ) + messages = [{"role": "user", "content": user_message}] + + try: + text, in_tok, out_tok = await provider.generate_json( + system_prompt=FIX_SYSTEM_PROMPT, + messages=messages, + max_tokens=2048, + ) + total_input_tokens += in_tok + total_output_tokens += out_tok + + cleaned = _strip_markdown_fences(text) + fixed_node = json.loads(cleaned) + except (json.JSONDecodeError, Exception) as exc: + logger.warning("AI fix failed for node %s: %s", node_id, exc) + continue + + # Validate by substituting into a tree copy + tree_copy = copy.deepcopy(tree_structure) + _replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node)) + remaining_errors = validate_generated_tree(tree_copy) + + # Check if the specific error is still present + still_has_error = any(node_id in e for e in remaining_errors) + + if still_has_error: + # Retry once with corrective prompt + retry_message = ( + f"Your previous fix still has validation errors:\n" + f"{chr(10).join(remaining_errors)}\n\n" + f"Please fix the node again. Return ONLY the corrected JSON." + ) + messages.append({"role": "assistant", "content": text}) + messages.append({"role": "user", "content": retry_message}) + + try: + text2, in_tok2, out_tok2 = await provider.generate_json( + system_prompt=FIX_SYSTEM_PROMPT, + messages=messages, + max_tokens=2048, + ) + total_input_tokens += in_tok2 + total_output_tokens += out_tok2 + + cleaned2 = _strip_markdown_fences(text2) + fixed_node = json.loads(cleaned2) + + # Re-validate + tree_copy2 = copy.deepcopy(tree_structure) + _replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node)) + remaining2 = validate_generated_tree(tree_copy2) + still_has_error = any(node_id in e for e in remaining2) + except (json.JSONDecodeError, Exception) as exc: + logger.warning("AI retry fix failed for node %s: %s", node_id, exc) + continue + + if still_has_error: + logger.warning("AI could not fix node %s after retry", node_id) + continue + + description = _describe_fix(original_snapshot, fixed_node) + fixes.append( + { + "target_node_id": node_id, + "error_message": error_message, + "description": description, + "original_node": original_snapshot, + "fixed_node": fixed_node, + } + ) + + return fixes, total_input_tokens, total_output_tokens diff --git a/backend/app/core/ai_provider.py b/backend/app/core/ai_provider.py new file mode 100644 index 00000000..cb3f7178 --- /dev/null +++ b/backend/app/core/ai_provider.py @@ -0,0 +1,175 @@ +""" +AI Provider abstraction layer. + +Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable +backends for JSON generation used by the AI Flow Builder. +""" + +import logging +from abc import ABC, abstractmethod + +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +class AIProvider(ABC): + """Abstract base class for AI providers.""" + + @abstractmethod + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + """Generate a JSON response from the AI model. + + Args: + system_prompt: System-level instruction for the model. + messages: List of message dicts with "role" and "content" keys. + max_tokens: Maximum output tokens. + + Returns: + Tuple of (response_text, input_tokens, output_tokens). + """ + ... + + +class GeminiProvider(AIProvider): + """Google Gemini provider using the google-genai SDK.""" + + def __init__(self, api_key: str, model: str) -> None: + self._api_key = api_key + self._model = model + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + from google import genai + from google.genai import types as genai_types + + client = genai.Client(api_key=self._api_key) + + # Convert messages to Gemini Content format + contents: list[genai_types.Content] = [] + for msg in messages: + role = "model" if msg["role"] == "assistant" else "user" + contents.append( + genai_types.Content( + role=role, + parts=[genai_types.Part(text=msg["content"])], + ) + ) + + config = genai_types.GenerateContentConfig( + system_instruction=system_prompt, + max_output_tokens=max_tokens, + response_mime_type="application/json", + ) + + response = await client.aio.models.generate_content( + model=self._model, + contents=contents, + config=config, + ) + + # Log finish reason to detect truncation + if response.candidates: + finish_reason = getattr(response.candidates[0], "finish_reason", None) + logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model) + if str(finish_reason) == "MAX_TOKENS": + logger.warning( + "Gemini output truncated (MAX_TOKENS). max_output_tokens=%d", + max_tokens, + ) + + text = response.text or "" + input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 + output_tokens = ( + getattr(response.usage_metadata, "candidates_token_count", 0) or 0 + ) + + return text, input_tokens, output_tokens + + +class AnthropicProvider(AIProvider): + """Anthropic Claude provider using the anthropic SDK.""" + + def __init__(self, api_key: str, model: str, timeout: int = 45) -> None: + self._api_key = api_key + self._model = model + self._timeout = timeout + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + import anthropic + + client = anthropic.AsyncAnthropic( + api_key=self._api_key, + timeout=self._timeout, + ) + + response = await client.messages.create( + model=self._model, + max_tokens=max_tokens, + system=system_prompt, + messages=messages, + ) + + text = response.content[0].text + input_tokens = response.usage.input_tokens + output_tokens = response.usage.output_tokens + + return text, input_tokens, output_tokens + + +def get_ai_provider() -> AIProvider: + """Factory that returns the configured AI provider. + + Selection logic: + 1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider + 2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider + 3. Fallback: if preferred provider key missing, try the other one + 4. If nothing configured -> raise RuntimeError + """ + provider = settings.AI_PROVIDER + + if provider == "gemini": + if settings.GOOGLE_AI_API_KEY: + return GeminiProvider( + api_key=settings.GOOGLE_AI_API_KEY, + model=settings.AI_MODEL_GEMINI, + ) + # Fallback to Anthropic + if settings.ANTHROPIC_API_KEY: + return AnthropicProvider( + api_key=settings.ANTHROPIC_API_KEY, + model=settings.AI_MODEL_ANTHROPIC, + timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, + ) + + elif provider == "anthropic": + if settings.ANTHROPIC_API_KEY: + return AnthropicProvider( + api_key=settings.ANTHROPIC_API_KEY, + model=settings.AI_MODEL_ANTHROPIC, + timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, + ) + # Fallback to Gemini + if settings.GOOGLE_AI_API_KEY: + return GeminiProvider( + api_key=settings.GOOGLE_AI_API_KEY, + model=settings.AI_MODEL_GEMINI, + ) + + raise RuntimeError( + "No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY." + ) diff --git a/backend/app/core/ai_tree_generator_service.py b/backend/app/core/ai_tree_generator_service.py index 4d40e257..bf560874 100644 --- a/backend/app/core/ai_tree_generator_service.py +++ b/backend/app/core/ai_tree_generator_service.py @@ -1,11 +1,11 @@ -"""AI-powered tree generation service using Anthropic Claude API. +"""AI-powered tree generation service. Implements the 4-stage wizard flow: Stage 2 (scaffold): AI suggests 4-7 top-level branches Stage 3 (branch_detail): AI generates detailed nodes per branch Stage 4 (assemble): Pure assembly logic — zero AI calls -System prompts are static constants to enable Anthropic prompt caching. +Uses the provider abstraction from ai_provider.py (supports Gemini + Anthropic). """ import json import logging @@ -13,8 +13,7 @@ import re import uuid from typing import Any -import anthropic - +from app.core.ai_provider import get_ai_provider from app.core.config import settings from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats @@ -121,15 +120,6 @@ def _strip_markdown_fences(text: str) -> str: return text -def _get_client() -> anthropic.AsyncAnthropic: - """Get configured async Anthropic client.""" - if not settings.ANTHROPIC_API_KEY: - raise RuntimeError("ANTHROPIC_API_KEY not configured") - return anthropic.AsyncAnthropic( - api_key=settings.ANTHROPIC_API_KEY, - timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, - ) - def _estimate_cost(input_tokens: int, output_tokens: int) -> float: """Estimate USD cost from token counts.""" @@ -146,7 +136,7 @@ async def scaffold_branches( Returns (branches, input_tokens, output_tokens, estimated_cost). Raises ValueError on invalid response. """ - client = _get_client() + provider = get_ai_provider() flow_type = wizard_state.get("flow_type", "troubleshooting") name = wizard_state.get("name", "") @@ -161,21 +151,26 @@ async def scaffold_branches( if tags: user_message += f"Environment: {', '.join(tags)}\n" - response = await client.messages.create( - model=settings.AI_MODEL, - max_tokens=1024, - system=SCAFFOLD_SYSTEM_PROMPT, + raw_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=SCAFFOLD_SYSTEM_PROMPT, messages=[{"role": "user", "content": user_message}], + max_tokens=2048, ) - raw_text = _strip_markdown_fences(response.content[0].text) - input_tokens = response.usage.input_tokens - output_tokens = response.usage.output_tokens + logger.info( + "scaffold raw response (tokens in=%d out=%d, len=%d): %s", + input_tokens, + output_tokens, + len(raw_text), + raw_text[:500], + ) + raw_text = _strip_markdown_fences(raw_text) cost = _estimate_cost(input_tokens, output_tokens) try: data = json.loads(raw_text) except json.JSONDecodeError as e: + logger.error("scaffold JSON parse failed. Full text (%d chars): %s", len(raw_text), raw_text) raise ValueError(f"AI returned invalid JSON: {e}") branches = data.get("branches", []) @@ -196,7 +191,7 @@ async def generate_branch_detail( On validation failure, retries once with corrective prompt. Raises ValueError if both attempts fail. """ - client = _get_client() + provider = get_ai_provider() flow_type = wizard_state.get("flow_type", "troubleshooting") name = wizard_state.get("name", "") @@ -217,35 +212,30 @@ async def generate_branch_detail( total_output = 0 for attempt in range(3): - response = await client.messages.create( - model=settings.AI_MODEL, - max_tokens=8192, - system=BRANCH_DETAIL_SYSTEM_PROMPT, + raw_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT, messages=messages, + max_tokens=8192, ) - total_input += response.usage.input_tokens - total_output += response.usage.output_tokens + total_input += input_tokens + total_output += output_tokens logger.debug( - "branch_detail attempt=%d stop_reason=%s content_blocks=%d output_tokens=%d", + "branch_detail attempt=%d output_tokens=%d", attempt, - response.stop_reason, - len(response.content), - response.usage.output_tokens, + output_tokens, ) - if response.stop_reason == "max_tokens": - logger.warning( - "branch_detail attempt=%d hit max_tokens limit (%d output tokens) — response may be truncated", - attempt, - response.usage.output_tokens, - ) - raw_text = _strip_markdown_fences(response.content[0].text) if response.content else "" + raw_text = _strip_markdown_fences(raw_text) if raw_text else "" if not raw_text: - logger.warning("branch_detail attempt=%d returned empty text, stop_reason=%s", attempt, response.stop_reason) + logger.warning("branch_detail attempt=%d returned empty text", attempt) try: branch_tree = json.loads(raw_text) except json.JSONDecodeError as e: + logger.error( + "branch_detail attempt=%d JSON parse failed (%d chars): %s", + attempt, len(raw_text), raw_text[:500], + ) if attempt < 2: messages.append({"role": "assistant", "content": raw_text}) messages.append({ diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 184795c0..912b2cf8 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -78,11 +78,16 @@ class Settings(BaseSettings): AI_CONVERSATION_TTL_HOURS: int = 24 AI_MAX_CALLS_PER_FLOW: int = 10 AI_REQUEST_TIMEOUT_SECONDS: int = 45 + # AI Provider selection + AI_PROVIDER: str = "gemini" # "gemini" or "anthropic" + GOOGLE_AI_API_KEY: Optional[str] = None + AI_MODEL_GEMINI: str = "gemini-2.5-flash" + AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001" @property def ai_enabled(self) -> bool: - """Check if AI Flow Builder is configured.""" - return self.ANTHROPIC_API_KEY is not None + """Check if any AI provider is configured.""" + return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None # Deployment – auto-seed test data on PR environments SEED_ON_DEPLOY: bool = False diff --git a/backend/app/schemas/ai_fix.py b/backend/app/schemas/ai_fix.py new file mode 100644 index 00000000..8c47f5a6 --- /dev/null +++ b/backend/app/schemas/ai_fix.py @@ -0,0 +1,52 @@ +"""Pydantic schemas for the AI auto-fix feature.""" +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +# ── Requests ── + + +class ValidationErrorInput(BaseModel): + """A single validation error to fix.""" + + node_id: str = Field(..., description="ID of the node with the error") + message: str = Field(..., description="The validation error message") + + +class AIFixTreeRequest(BaseModel): + """Request to generate AI fixes for validation errors.""" + + tree_structure: dict[str, Any] = Field(..., description="Full tree structure") + tree_name: str = Field("", max_length=255, description="Name of the flow") + tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field( + "troubleshooting", description="Type of flow" + ) + validation_errors: list[ValidationErrorInput] = Field( + ..., min_length=1, max_length=10, description="Errors to fix" + ) + + +# ── Responses ── + + +class AIFixProposal(BaseModel): + """A single proposed fix from the AI.""" + + target_node_id: str + error_message: str + description: str + original_node: dict[str, Any] + fixed_node: dict[str, Any] + + +class AIFixTokenUsage(BaseModel): + input: int = 0 + output: int = 0 + + +class AIFixTreeResponse(BaseModel): + """Response with proposed fixes.""" + + fixes: list[AIFixProposal] + tokens_used: AIFixTokenUsage diff --git a/backend/requirements.txt b/backend/requirements.txt index 7c8b5493..b51da249 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -33,6 +33,7 @@ httpx>=0.27.0 # AI Flow Builder anthropic>=0.40.0 +google-genai>=1.0.0 # Utilities python-dotenv==1.0.1 diff --git a/backend/tests/test_ai_endpoints.py b/backend/tests/test_ai_endpoints.py index 339448dd..1f91514e 100644 --- a/backend/tests/test_ai_endpoints.py +++ b/backend/tests/test_ai_endpoints.py @@ -1,6 +1,6 @@ """Integration tests for AI Flow Builder endpoints. -All Anthropic API calls are mocked — zero real API spend. +All AI provider calls are mocked — zero real API spend. """ import json from unittest.mock import AsyncMock, patch, MagicMock @@ -64,12 +64,11 @@ BRANCH_DETAIL_JSON = json.dumps({ }) -def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200): - """Create a mock Anthropic API response.""" - response = MagicMock() - response.content = [MagicMock(text=text)] - response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens) - return response +def _mock_ai_provider(text: str, input_tokens: int = 100, output_tokens: int = 200): + """Create a mock AI provider whose generate_json returns the given text and token counts.""" + provider = MagicMock() + provider.generate_json = AsyncMock(return_value=(text, input_tokens, output_tokens)) + return provider @pytest.fixture @@ -194,11 +193,9 @@ async def test_scaffold_success(client, auth_headers, enable_ai): ) conversation_id = start_resp.json()["conversation_id"] - # Mock Anthropic - mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=mock_response) - + # Mock AI provider + mock_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=mock_provider): response = await client.post( "/api/v1/ai/scaffold", json={"conversation_id": conversation_id}, @@ -241,9 +238,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai): ) conversation_id = start_resp.json()["conversation_id"] - scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock) + scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider): await client.post( "/api/v1/ai/scaffold", json={"conversation_id": conversation_id}, @@ -251,10 +247,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai): ) # Now generate branch detail - detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock) - + detail_provider = _mock_ai_provider(BRANCH_DETAIL_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=detail_provider): response = await client.post( "/api/v1/ai/branch-detail", json={ @@ -290,9 +284,8 @@ async def test_assemble_success(client, auth_headers, enable_ai): conversation_id = start_resp.json()["conversation_id"] # Scaffold - scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) - with patch("app.core.ai_tree_generator_service._get_client") as mock_client: - mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock) + scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider): await client.post( "/api/v1/ai/scaffold", json={"conversation_id": conversation_id}, diff --git a/backend/tests/test_ai_fix_endpoint.py b/backend/tests/test_ai_fix_endpoint.py new file mode 100644 index 00000000..a81a598e --- /dev/null +++ b/backend/tests/test_ai_fix_endpoint.py @@ -0,0 +1,169 @@ +"""Integration tests for the POST /ai/fix-tree endpoint.""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.core.config import settings + + +# ── Sample tree (has a decision node with only 1 option + 1 child) ── + +SAMPLE_TREE = { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs OK", + "description": "Issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}], + "children": [ + { + "id": "done", + "type": "solution", + "title": "Done", + "description": "Fixed.", + } + ], + }, + ], +} + +# Fixed version of the "restart" node — 2 options, 2 children +FIXED_RESTART_NODE = { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [ + {"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"}, + {"id": "opt-r-no", "label": "No", "next_node_id": "escalate"}, + ], + "children": [ + { + "id": "done", + "type": "solution", + "title": "Done", + "description": "Fixed.", + }, + { + "id": "escalate", + "type": "solution", + "title": "Escalate", + "description": "Escalate to senior engineer.", + }, + ], +} + +FIX_REQUEST_BODY = { + "tree_structure": SAMPLE_TREE, + "tree_name": "Server Troubleshooting", + "tree_type": "troubleshooting", + "validation_errors": [ + { + "node_id": "restart", + "message": "Decision node 'restart' must have at least 2 options", + } + ], +} + + +def _mock_ai_provider(response_text: str, input_tokens: int = 50, output_tokens: int = 100): + """Create a mock provider whose generate_json returns given text.""" + provider = MagicMock() + provider.generate_json = AsyncMock(return_value=(response_text, input_tokens, output_tokens)) + return provider + + +@pytest.fixture +def enable_ai(): + """Temporarily enable AI by setting a fake API key.""" + original = settings.GOOGLE_AI_API_KEY + settings.GOOGLE_AI_API_KEY = "test-key-fake" + yield + settings.GOOGLE_AI_API_KEY = original + + +@pytest.fixture +def disable_ai(): + """Ensure AI is disabled.""" + orig_google = settings.GOOGLE_AI_API_KEY + orig_anthropic = settings.ANTHROPIC_API_KEY + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + yield + settings.GOOGLE_AI_API_KEY = orig_google + settings.ANTHROPIC_API_KEY = orig_anthropic + + +# ── Tests ── + + +@pytest.mark.asyncio +async def test_returns_401_without_auth(client): + """POST /ai/fix-tree without auth token returns 401.""" + response = await client.post("/api/v1/ai/fix-tree", json=FIX_REQUEST_BODY) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_returns_503_when_ai_disabled(client, auth_headers, disable_ai): + """POST /ai/fix-tree returns 503 when no AI keys are configured.""" + response = await client.post( + "/api/v1/ai/fix-tree", + json=FIX_REQUEST_BODY, + headers=auth_headers, + ) + assert response.status_code == 503 + + +@pytest.mark.asyncio +async def test_returns_fixes_on_success(client, auth_headers, enable_ai): + """POST /ai/fix-tree returns fix proposals when AI succeeds.""" + mock_provider = _mock_ai_provider(json.dumps(FIXED_RESTART_NODE)) + + with patch( + "app.core.ai_fix_service.get_ai_provider", + return_value=mock_provider, + ): + response = await client.post( + "/api/v1/ai/fix-tree", + json=FIX_REQUEST_BODY, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert "fixes" in data + assert "tokens_used" in data + assert len(data["fixes"]) == 1 + + fix = data["fixes"][0] + assert fix["target_node_id"] == "restart" + assert fix["error_message"] == "Decision node 'restart' must have at least 2 options" + assert fix["original_node"]["id"] == "restart" + assert fix["fixed_node"]["id"] == "restart" + assert len(fix["fixed_node"]["options"]) == 2 + assert len(fix["fixed_node"]["children"]) == 2 + + assert data["tokens_used"]["input"] == 50 + assert data["tokens_used"]["output"] == 100 diff --git a/backend/tests/test_ai_fix_service.py b/backend/tests/test_ai_fix_service.py new file mode 100644 index 00000000..24410721 --- /dev/null +++ b/backend/tests/test_ai_fix_service.py @@ -0,0 +1,224 @@ +"""Unit tests for AI fix service helper functions. + +Tests pure Python helpers only — no AI mocking needed. +""" + +import pytest + +from app.core.ai_fix_service import ( + _find_node_by_id, + _find_parent_node, + _serialize_tree_outline, + _strip_markdown_fences, + _replace_node_in_tree, + _describe_fix, +) + + +# ── Sample tree ── + +SAMPLE_TREE = { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs OK", + "description": "Issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}], + "children": [ + { + "id": "done", + "type": "solution", + "title": "Done", + "description": "Fixed.", + } + ], + }, + ], +} + + +# ── _find_node_by_id ── + + +class TestFindNodeById: + def test_finds_root(self): + node = _find_node_by_id(SAMPLE_TREE, "root") + assert node is not None + assert node["id"] == "root" + assert node["type"] == "decision" + + def test_finds_nested_child(self): + node = _find_node_by_id(SAMPLE_TREE, "done") + assert node is not None + assert node["id"] == "done" + assert node["type"] == "solution" + + def test_finds_direct_child(self): + node = _find_node_by_id(SAMPLE_TREE, "check-logs") + assert node is not None + assert node["title"] == "Check Logs" + + def test_returns_none_for_missing(self): + node = _find_node_by_id(SAMPLE_TREE, "nonexistent") + assert node is None + + def test_returns_none_for_non_dict(self): + assert _find_node_by_id("not a dict", "root") is None + + +# ── _find_parent_node ── + + +class TestFindParentNode: + def test_root_has_no_parent(self): + parent = _find_parent_node(SAMPLE_TREE, "root") + assert parent is None + + def test_finds_parent_of_direct_child(self): + parent = _find_parent_node(SAMPLE_TREE, "check-logs") + assert parent is not None + assert parent["id"] == "root" + + def test_finds_parent_of_deeply_nested(self): + parent = _find_parent_node(SAMPLE_TREE, "done") + assert parent is not None + assert parent["id"] == "restart" + + def test_returns_none_for_missing(self): + parent = _find_parent_node(SAMPLE_TREE, "nonexistent") + assert parent is None + + def test_returns_none_for_non_dict(self): + assert _find_parent_node("not a dict", "root") is None + + +# ── _serialize_tree_outline ── + + +class TestSerializeTreeOutline: + def test_produces_readable_outline(self): + outline = _serialize_tree_outline(SAMPLE_TREE) + assert "- [decision] Is the server up?" in outline + assert " - [action] Check Logs" in outline + assert " - [solution] Logs OK" in outline + assert " - [solution] Done" in outline + + def test_marks_error_node(self): + outline = _serialize_tree_outline(SAMPLE_TREE, error_node_id="restart") + assert "<<< ERROR HERE" in outline + # Only the restart node should be marked + lines = outline.split("\n") + error_lines = [l for l in lines if "ERROR HERE" in l] + assert len(error_lines) == 1 + assert "Did restart work?" in error_lines[0] + + def test_no_error_marker_when_none(self): + outline = _serialize_tree_outline(SAMPLE_TREE) + assert "ERROR HERE" not in outline + + def test_handles_non_dict(self): + assert _serialize_tree_outline("not a dict") == "" + + def test_indentation_increases_with_depth(self): + outline = _serialize_tree_outline(SAMPLE_TREE) + lines = outline.split("\n") + # Root has no indentation + assert lines[0].startswith("- [decision]") + # Children have 2-space indent + child_lines = [l for l in lines if "Check Logs" in l] + assert child_lines[0].startswith(" - ") + + +# ── _strip_markdown_fences ── + + +class TestStripMarkdownFences: + def test_strips_json_fences(self): + text = '```json\n{"key": "value"}\n```' + assert _strip_markdown_fences(text) == '{"key": "value"}' + + def test_strips_plain_fences(self): + text = '```\n{"key": "value"}\n```' + assert _strip_markdown_fences(text) == '{"key": "value"}' + + def test_passes_through_plain_json(self): + text = '{"key": "value"}' + assert _strip_markdown_fences(text) == '{"key": "value"}' + + +# ── _replace_node_in_tree ── + + +class TestReplaceNodeInTree: + def test_replaces_root(self): + import copy + + tree = copy.deepcopy(SAMPLE_TREE) + replacement = {"id": "root", "type": "decision", "question": "New question"} + assert _replace_node_in_tree(tree, "root", replacement) is True + assert tree["question"] == "New question" + assert "children" not in tree # cleared and replaced + + def test_replaces_nested_node(self): + import copy + + tree = copy.deepcopy(SAMPLE_TREE) + replacement = {"id": "done", "type": "solution", "title": "All Done", "description": "Complete."} + assert _replace_node_in_tree(tree, "done", replacement) is True + found = _find_node_by_id(tree, "done") + assert found["title"] == "All Done" + + def test_returns_false_for_missing(self): + import copy + + tree = copy.deepcopy(SAMPLE_TREE) + assert _replace_node_in_tree(tree, "nonexistent", {"id": "x"}) is False + + +# ── _describe_fix ── + + +class TestDescribeFix: + def test_describes_added_children(self): + original = {"id": "n1", "children": [{"id": "c1"}]} + fixed = {"id": "n1", "children": [{"id": "c1"}, {"id": "c2"}]} + desc = _describe_fix(original, fixed) + assert "1 child node" in desc + + def test_describes_added_options(self): + original = {"id": "n1", "options": [{"id": "o1"}]} + fixed = {"id": "n1", "options": [{"id": "o1"}, {"id": "o2"}]} + desc = _describe_fix(original, fixed) + assert "1 option" in desc + + def test_describes_added_next_node_id(self): + original = {"id": "n1", "type": "action"} + fixed = {"id": "n1", "type": "action", "next_node_id": "n2"} + desc = _describe_fix(original, fixed) + assert "next_node_id" in desc + + def test_fallback_description(self): + original = {"id": "n1", "type": "solution"} + fixed = {"id": "n1", "type": "solution"} + desc = _describe_fix(original, fixed) + assert "fixed structural issue" in desc.lower() diff --git a/backend/tests/test_ai_provider.py b/backend/tests/test_ai_provider.py new file mode 100644 index 00000000..611c8e7b --- /dev/null +++ b/backend/tests/test_ai_provider.py @@ -0,0 +1,216 @@ +"""Tests for the AI provider abstraction layer.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import sys + +from app.core.ai_provider import ( + AIProvider, + AnthropicProvider, + GeminiProvider, + get_ai_provider, +) +from app.core.config import settings + + +class TestGetAIProvider: + """Tests for the get_ai_provider factory function.""" + + def test_returns_gemini_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.GOOGLE_AI_API_KEY + try: + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = "test-gemini-key" + provider = get_ai_provider() + assert isinstance(provider, GeminiProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_key + + def test_returns_anthropic_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "anthropic" + settings.ANTHROPIC_API_KEY = "test-anthropic-key" + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.ANTHROPIC_API_KEY = original_key + + def test_fallback_to_anthropic_when_gemini_key_missing(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = "test-anthropic-key" + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + def test_fallback_to_gemini_when_anthropic_key_missing(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "anthropic" + settings.ANTHROPIC_API_KEY = None + settings.GOOGLE_AI_API_KEY = "test-gemini-key" + provider = get_ai_provider() + assert isinstance(provider, GeminiProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + def test_raises_when_no_provider_configured(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + try: + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + with pytest.raises(RuntimeError, match="No AI provider configured"): + get_ai_provider() + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + +class TestAnthropicProvider: + """Tests for AnthropicProvider.generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json(self): + provider = AnthropicProvider( + api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30 + ) + + mock_response = MagicMock() + mock_response.content = [MagicMock(text='{"result": "ok"}')] + mock_response.usage = MagicMock(input_tokens=100, output_tokens=50) + + mock_client = AsyncMock() + mock_client.messages.create = AsyncMock(return_value=mock_response) + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + text, input_tokens, output_tokens = await provider.generate_json( + system_prompt="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=1024, + ) + + assert text == '{"result": "ok"}' + assert input_tokens == 100 + assert output_tokens == 50 + + mock_client.messages.create.assert_called_once_with( + model="claude-haiku-4-5-20251001", + max_tokens=1024, + system="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + ) + + +class TestGeminiProvider: + """Tests for GeminiProvider.generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json(self): + provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") + + mock_usage = MagicMock() + mock_usage.prompt_token_count = 80 + mock_usage.candidates_token_count = 40 + + mock_response = MagicMock() + mock_response.text = '{"answer": 42}' + mock_response.usage_metadata = mock_usage + + mock_client = MagicMock() + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response + ) + + mock_genai_module = MagicMock() + mock_genai_module.Client.return_value = mock_client + + mock_types = MagicMock() + mock_types.Content.side_effect = lambda **kw: kw + mock_types.Part.side_effect = lambda **kw: kw + mock_types.GenerateContentConfig.side_effect = lambda **kw: kw + + mock_google = MagicMock() + mock_google.genai = mock_genai_module + mock_genai_module.types = mock_types + + with patch.dict(sys.modules, { + "google": mock_google, + "google.genai": mock_genai_module, + "google.genai.types": mock_types, + }): + text, input_tokens, output_tokens = await provider.generate_json( + system_prompt="Generate JSON.", + messages=[ + {"role": "user", "content": "Give me data"}, + {"role": "assistant", "content": "Here it is"}, + {"role": "user", "content": "More please"}, + ], + max_tokens=2048, + ) + + assert text == '{"answer": 42}' + assert input_tokens == 80 + assert output_tokens == 40 + + mock_client.aio.models.generate_content.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_json_handles_none_usage(self): + """Token counts default to 0 when usage_metadata attributes are None.""" + provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash") + + mock_usage = MagicMock(spec=[]) # No attributes at all + mock_response = MagicMock() + mock_response.text = "{}" + mock_response.usage_metadata = mock_usage + + mock_client = MagicMock() + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response + ) + + mock_genai_module = MagicMock() + mock_genai_module.Client.return_value = mock_client + + mock_types = MagicMock() + mock_types.Content.side_effect = lambda **kw: kw + mock_types.Part.side_effect = lambda **kw: kw + mock_types.GenerateContentConfig.side_effect = lambda **kw: kw + + mock_google = MagicMock() + mock_google.genai = mock_genai_module + mock_genai_module.types = mock_types + + with patch.dict(sys.modules, { + "google": mock_google, + "google.genai": mock_genai_module, + "google.genai.types": mock_types, + }): + text, input_tokens, output_tokens = await provider.generate_json( + system_prompt="test", + messages=[{"role": "user", "content": "test"}], + ) + + assert text == "{}" + assert input_tokens == 0 + assert output_tokens == 0 diff --git a/docs/plans/2026-02-26-ai-autofix-gemini-design.md b/docs/plans/2026-02-26-ai-autofix-gemini-design.md new file mode 100644 index 00000000..7d241a00 --- /dev/null +++ b/docs/plans/2026-02-26-ai-autofix-gemini-design.md @@ -0,0 +1,209 @@ +# AI Auto-Fix & Gemini Flash Provider Design + +> **Date:** 2026-02-26 +> **Status:** Approved + +--- + +## Overview + +Two combined features: + +1. **AI Provider Abstraction** — Add Gemini 2.5 Flash as the default AI provider with Claude as fallback, behind a unified interface. +2. **AI Auto-Fix for Validation Errors** — When a flow fails validation, offer an AI-powered "Fix with AI" button that generates structural fixes for review. + +--- + +## Section 1: AI Provider Abstraction + +### Design + +New `backend/app/core/ai_provider.py` with a unified interface: + +```python +class AIProvider(ABC): + async def generate_json( + self, + system_prompt: str, + messages: list[dict], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + """Returns (text, input_tokens, output_tokens)""" +``` + +Two implementations: + +| Provider | Model | SDK | Role | +|----------|-------|-----|------| +| `GeminiProvider` | `gemini-2.5-flash` | `google-genai` | Default | +| `AnthropicProvider` | `claude-haiku-4-5-20251001` | `anthropic` | Fallback | + +### Provider Selection + +- `get_ai_provider()` factory reads `AI_PROVIDER` env var (default: `"gemini"`) +- Falls back to Anthropic if Gemini key is missing +- Existing `ai_tree_generator_service.py` swaps direct Anthropic calls for `get_ai_provider()` + +### New Environment Variables + +| Variable | Default | Purpose | +|----------|---------|---------| +| `AI_PROVIDER` | `"gemini"` | Which provider to use (`gemini` or `anthropic`) | +| `GOOGLE_AI_API_KEY` | — | Gemini API key | + +Existing `ANTHROPIC_API_KEY` remains for fallback. + +### Config Changes (`core/config.py`) + +```python +AI_PROVIDER: str = "gemini" +GOOGLE_AI_API_KEY: str | None = None +AI_MODEL_GEMINI: str = "gemini-2.5-flash" +AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001" +``` + +--- + +## Section 2: AI Auto-Fix Feature + +### Backend Endpoint + +**`POST /api/v1/ai/fix-tree`** + +Request: +```json +{ + "tree_structure": { /* full tree */ }, + "tree_name": "Router Troubleshooting", + "tree_type": "troubleshooting", + "validation_errors": [ + { + "node_id": "node_abc", + "message": "Decision node must have at least 2 children (branches)" + } + ] +} +``` + +Response: +```json +{ + "fixes": [ + { + "target_node_id": "node_abc", + "error_message": "Decision node must have at least 2 children (branches)", + "description": "Added second branch 'Check firmware version' with solution node", + "original_node": { /* snapshot before fix */ }, + "fixed_node": { /* replacement node with corrected subtree */ } + } + ], + "tokens_used": { "input": 1200, "output": 800 } +} +``` + +### How It Works + +1. For each validation error tied to a `node_id`, extract that node + its parent + siblings from the tree. +2. Build a prompt with: + - The **full tree structure** serialized as a simplified outline (node titles + types + structure) for context + - The **specific failing node** highlighted with full JSON detail + - The **validation error message** + - Instructions: "Fix ONLY this node's structural issue. Keep all existing content. Generate domain-relevant additions that fit the flow's topic." +3. AI returns a corrected version of that node (with children/options adjusted). +4. Backend re-validates the fixed node before returning it. +5. If re-validation fails, retry once with the error fed back (corrective prompt pattern). + +### Prompt Strategy + +The prompt gives the AI the full tree as a compact outline, then zooms into the failing node: + +``` +You are fixing a validation error in a troubleshooting flow called "Router Troubleshooting". + +FULL FLOW OUTLINE: +- [decision] Is the router powered on? + - [action] Check power cable → [solution] Power restored + - [decision] Are lights blinking? ← ERROR HERE + - [solution] Contact ISP + +ERROR: Decision node "Are lights blinking?" must have at least 2 children (branches). + +FAILING NODE (full detail): +{...json...} + +Fix this node by adding the minimum structure needed to resolve the error. +Return ONLY the fixed node as JSON. +``` + +### Frontend UX + +1. **Trigger**: "Fix with AI" button in `ValidationSummary` — appears when there are fixable errors (structural errors with a `node_id`). +2. **Loading state**: Button shows spinner + "Generating fixes..." — disabled during request. +3. **Review modal** (`AIFixReviewModal`): Shows each proposed fix as a card: + - Error message at top + - Before/after view of the node change + - "Apply" / "Skip" buttons per fix + - "Apply All" button in footer +4. **Apply**: Each accepted fix calls `updateNode(targetNodeId, fixedNode)` in the tree editor store. +5. **Re-validate**: After applying fixes, auto-run `validate()` to confirm resolution. + +--- + +## Section 3: Scope & Constraints + +### Fixable Errors (Auto-Fix Scope) + +Only structural validation errors with a `node_id`: +- Decision node missing children/branches +- Decision node missing options +- Action node missing `next_node_id` +- Dead-end decision nodes (no children) + +### NOT Fixable + +- Global checks (tree too small/large, not enough solutions) — require rethinking the whole tree +- Content quality issues — out of scope +- Errors without a `node_id` (root-level issues) + +Non-fixable errors still show in ValidationSummary but without the "Fix with AI" option. + +### Token Budget + +- Tree outline: ~50-100 tokens for a typical 15-node tree +- Failing node detail: ~100-200 tokens +- System prompt + instructions: ~300 tokens +- **Total input per fix: ~500-600 tokens** +- One API call per failing node (not batched) + +### Error Handling + +- Provider failure (rate limit, network): toast error, user can retry +- Fix fails re-validation: "AI couldn't generate a valid fix" with retry option +- Max 1 retry with corrective prompt per attempt +- Both provider and fallback fail: surface error to user + +### Auth + +- Requires `engineer` role or above (`require_engineer_or_admin`) + +--- + +## New Files + +| File | Purpose | +|------|---------| +| `backend/app/core/ai_provider.py` | Provider abstraction + Gemini/Anthropic implementations | +| `backend/app/core/ai_fix_service.py` | Fix generation logic + prompt building | +| `backend/app/api/endpoints/ai.py` | `POST /ai/fix-tree` endpoint | +| `backend/app/schemas/ai.py` | Request/response schemas for AI endpoints | +| `frontend/src/components/tree-editor/AIFixReviewModal.tsx` | Review modal for proposed fixes | + +## Modified Files + +| File | Change | +|------|--------| +| `backend/app/core/config.py` | Add Gemini config vars | +| `backend/app/core/ai_tree_generator_service.py` | Swap Anthropic calls for provider abstraction | +| `backend/app/api/router.py` | Register `/ai` routes | +| `frontend/src/api/trees.ts` | Add `fixTree()` API call | +| `frontend/src/components/tree-editor/ValidationSummary.tsx` | Add "Fix with AI" button | diff --git a/docs/plans/2026-02-26-ai-autofix-gemini-plan.md b/docs/plans/2026-02-26-ai-autofix-gemini-plan.md new file mode 100644 index 00000000..4406024a --- /dev/null +++ b/docs/plans/2026-02-26-ai-autofix-gemini-plan.md @@ -0,0 +1,1707 @@ +# AI Auto-Fix & Gemini Flash Provider Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add Gemini 2.5 Flash as primary AI provider (with Claude fallback), then build an AI-powered "Fix with AI" feature that generates structural fixes for validation errors in the tree editor. + +**Architecture:** A provider abstraction layer (`ai_provider.py`) wraps Gemini and Anthropic SDKs behind a unified `generate_json()` interface. The existing `ai_tree_generator_service.py` swaps its direct Anthropic calls for this abstraction. A new `ai_fix_service.py` builds prompts from validation errors + tree context and returns proposed node patches. The frontend adds a "Fix with AI" button to `ValidationSummary` and a review modal for applying fixes. + +**Tech Stack:** Python FastAPI, google-genai SDK, anthropic SDK, Pydantic v2, React 19, TypeScript, Zustand, Tailwind CSS + +**Design Doc:** `docs/plans/2026-02-26-ai-autofix-gemini-design.md` + +--- + +## Task 1: Install google-genai SDK + +**Files:** +- Modify: `backend/requirements.txt` + +**Step 1: Add the dependency** + +Add `google-genai` to `backend/requirements.txt`: + +``` +google-genai>=1.0.0 +``` + +**Step 2: Install it** + +Run: `cd backend && pip install google-genai` + +**Step 3: Commit** + +```bash +git add backend/requirements.txt +git commit -m "chore: add google-genai SDK dependency" +``` + +--- + +## Task 2: Add Gemini config vars to Settings + +**Files:** +- Modify: `backend/app/core/config.py:75-85` + +**Step 1: Add new config variables** + +In `backend/app/core/config.py`, after line 80 (`AI_REQUEST_TIMEOUT_SECONDS`), add: + +```python + # AI Provider selection + AI_PROVIDER: str = "gemini" # "gemini" or "anthropic" + GOOGLE_AI_API_KEY: Optional[str] = None + AI_MODEL_GEMINI: str = "gemini-2.5-flash" + AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001" +``` + +**Step 2: Update `ai_enabled` property** + +Replace the existing `ai_enabled` property (lines 82-85) with: + +```python + @property + def ai_enabled(self) -> bool: + """Check if any AI provider is configured.""" + return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None +``` + +**Step 3: Verify no tests break** + +Run: `cd backend && python -m pytest tests/test_ai_tree_validator.py -v` +Expected: All pass (config changes don't affect validator tests). + +**Step 4: Commit** + +```bash +git add backend/app/core/config.py +git commit -m "feat: add Gemini Flash config vars to Settings" +``` + +--- + +## Task 3: Build the AI provider abstraction + +**Files:** +- Create: `backend/app/core/ai_provider.py` +- Test: `backend/tests/test_ai_provider.py` + +**Step 1: Write tests for the provider abstraction** + +Create `backend/tests/test_ai_provider.py`: + +```python +"""Tests for AI provider abstraction layer.""" +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from app.core.config import settings + + +class TestGetAIProvider: + """Test provider factory function.""" + + def test_returns_gemini_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.GOOGLE_AI_API_KEY + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = "fake-gemini-key" + try: + from app.core.ai_provider import get_ai_provider, GeminiProvider + provider = get_ai_provider() + assert isinstance(provider, GeminiProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_key + + def test_returns_anthropic_when_configured(self): + original_provider = settings.AI_PROVIDER + original_key = settings.ANTHROPIC_API_KEY + settings.AI_PROVIDER = "anthropic" + settings.ANTHROPIC_API_KEY = "fake-anthropic-key" + try: + from app.core.ai_provider import get_ai_provider, AnthropicProvider + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.ANTHROPIC_API_KEY = original_key + + def test_falls_back_to_anthropic_when_gemini_key_missing(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = "fake-anthropic-key" + try: + from app.core.ai_provider import get_ai_provider, AnthropicProvider + provider = get_ai_provider() + assert isinstance(provider, AnthropicProvider) + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + def test_raises_when_no_provider_configured(self): + original_provider = settings.AI_PROVIDER + original_gemini_key = settings.GOOGLE_AI_API_KEY + original_anthropic_key = settings.ANTHROPIC_API_KEY + settings.AI_PROVIDER = "gemini" + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + try: + from app.core.ai_provider import get_ai_provider + with pytest.raises(RuntimeError, match="No AI provider configured"): + get_ai_provider() + finally: + settings.AI_PROVIDER = original_provider + settings.GOOGLE_AI_API_KEY = original_gemini_key + settings.ANTHROPIC_API_KEY = original_anthropic_key + + +class TestAnthropicProvider: + """Test Anthropic provider generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json_returns_text_and_tokens(self): + from app.core.ai_provider import AnthropicProvider + + mock_response = MagicMock() + mock_response.content = [MagicMock(text='{"key": "value"}')] + mock_response.usage = MagicMock(input_tokens=100, output_tokens=50) + + mock_client = AsyncMock() + mock_client.messages.create = AsyncMock(return_value=mock_response) + + with patch("app.core.ai_provider.anthropic.AsyncAnthropic", return_value=mock_client): + provider = AnthropicProvider(api_key="fake-key") + text, inp, out = await provider.generate_json( + system_prompt="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=1024, + ) + assert text == '{"key": "value"}' + assert inp == 100 + assert out == 50 + + +class TestGeminiProvider: + """Test Gemini provider generate_json.""" + + @pytest.mark.asyncio + async def test_generate_json_returns_text_and_tokens(self): + from app.core.ai_provider import GeminiProvider + + mock_response = MagicMock() + mock_response.text = '{"key": "value"}' + mock_response.usage_metadata = MagicMock( + prompt_token_count=100, + candidates_token_count=50, + ) + + mock_client = MagicMock() + mock_model = MagicMock() + mock_model.generate_content_async = AsyncMock(return_value=mock_response) + mock_client.models = mock_model + + with patch("app.core.ai_provider.genai.Client", return_value=mock_client): + provider = GeminiProvider(api_key="fake-key") + text, inp, out = await provider.generate_json( + system_prompt="You are a helper.", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=1024, + ) + assert text == '{"key": "value"}' + assert inp == 100 + assert out == 50 +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && python -m pytest tests/test_ai_provider.py -v` +Expected: ImportError — `ai_provider` module doesn't exist yet. + +**Step 3: Implement the provider abstraction** + +Create `backend/app/core/ai_provider.py`: + +```python +"""AI provider abstraction layer. + +Supports Gemini (default) and Anthropic (fallback) behind a unified interface. +""" +import logging +from abc import ABC, abstractmethod +from typing import Any + +import anthropic +from google import genai +from google.genai import types as genai_types + +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +class AIProvider(ABC): + """Base class for AI providers.""" + + @abstractmethod + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + """Generate a JSON response from the AI model. + + Args: + system_prompt: System instructions for the model. + messages: List of {"role": "user"|"assistant", "content": str} dicts. + max_tokens: Maximum output tokens. + + Returns: + (response_text, input_tokens, output_tokens) + """ + ... + + +class GeminiProvider(AIProvider): + """Google Gemini provider using google-genai SDK.""" + + def __init__(self, api_key: str, model: str | None = None): + self._api_key = api_key + self._model = model or settings.AI_MODEL_GEMINI + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + client = genai.Client(api_key=self._api_key) + + # Build contents: system instruction is separate in Gemini API + contents: list[genai_types.Content] = [] + for msg in messages: + role = "user" if msg["role"] == "user" else "model" + contents.append(genai_types.Content( + role=role, + parts=[genai_types.Part(text=msg["content"])], + )) + + config = genai_types.GenerateContentConfig( + system_instruction=system_prompt, + max_output_tokens=max_tokens, + response_mime_type="application/json", + ) + + response = await client.models.generate_content_async( + model=self._model, + contents=contents, + config=config, + ) + + text = response.text or "" + input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 + output_tokens = getattr(response.usage_metadata, "candidates_token_count", 0) or 0 + + return text, input_tokens, output_tokens + + +class AnthropicProvider(AIProvider): + """Anthropic Claude provider.""" + + def __init__(self, api_key: str, model: str | None = None): + self._api_key = api_key + self._model = model or settings.AI_MODEL_ANTHROPIC + + async def generate_json( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + client = anthropic.AsyncAnthropic( + api_key=self._api_key, + timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, + ) + + response = await client.messages.create( + model=self._model, + max_tokens=max_tokens, + system=system_prompt, + messages=messages, + ) + + text = response.content[0].text + input_tokens = response.usage.input_tokens + output_tokens = response.usage.output_tokens + + return text, input_tokens, output_tokens + + +def get_ai_provider() -> AIProvider: + """Factory: return the configured AI provider. + + Falls back to Anthropic if Gemini key is missing. + Raises RuntimeError if no provider is configured. + """ + if settings.AI_PROVIDER == "gemini" and settings.GOOGLE_AI_API_KEY: + logger.info("Using Gemini provider (%s)", settings.AI_MODEL_GEMINI) + return GeminiProvider(api_key=settings.GOOGLE_AI_API_KEY) + + if settings.AI_PROVIDER == "anthropic" and settings.ANTHROPIC_API_KEY: + logger.info("Using Anthropic provider (%s)", settings.AI_MODEL_ANTHROPIC) + return AnthropicProvider(api_key=settings.ANTHROPIC_API_KEY) + + # Fallback: if Gemini requested but key missing, try Anthropic + if settings.ANTHROPIC_API_KEY: + logger.warning("Gemini key missing, falling back to Anthropic provider") + return AnthropicProvider(api_key=settings.ANTHROPIC_API_KEY) + + raise RuntimeError( + "No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY." + ) +``` + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && python -m pytest tests/test_ai_provider.py -v` +Expected: All pass. + +**Step 5: Commit** + +```bash +git add backend/app/core/ai_provider.py backend/tests/test_ai_provider.py +git commit -m "feat: add AI provider abstraction with Gemini and Anthropic support" +``` + +--- + +## Task 4: Migrate ai_tree_generator_service to use provider abstraction + +**Files:** +- Modify: `backend/app/core/ai_tree_generator_service.py` +- Modify: `backend/app/api/endpoints/ai_builder.py` + +**Step 1: Update ai_tree_generator_service.py** + +In `backend/app/core/ai_tree_generator_service.py`: + +Replace the import at line 16: +```python +# OLD +import anthropic +# NEW +from app.core.ai_provider import get_ai_provider +``` + +Remove the `_get_client()` function (lines 124-131). + +Update `scaffold_branches()` (starting at line 141). Replace the client creation and API call: + +```python +async def scaffold_branches( + wizard_state: dict[str, Any], +) -> tuple[list[dict[str, str]], int, int, float]: + """Stage 2: AI suggests top-level branches.""" + provider = get_ai_provider() + + flow_type = wizard_state.get("flow_type", "troubleshooting") + name = wizard_state.get("name", "") + description = wizard_state.get("description", "") + tags = wizard_state.get("environment_tags", []) + + user_message = ( + f"Flow type: {flow_type}\n" + f"Name: {name}\n" + f"Description: {description}\n" + ) + if tags: + user_message += f"Environment: {', '.join(tags)}\n" + + raw_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=SCAFFOLD_SYSTEM_PROMPT, + messages=[{"role": "user", "content": user_message}], + max_tokens=1024, + ) +``` + +Then update the rest of the function to use `raw_text` instead of `response.content[0].text`, and `input_tokens`/`output_tokens` directly instead of `response.usage.*`. + +Do the same for `generate_branch_detail()` — replace `_get_client()` + `client.messages.create()` with `provider.generate_json()`. The retry loop structure stays the same; just swap the API call and response parsing. + +**Step 2: Update ai_builder.py endpoint** + +In `backend/app/api/endpoints/ai_builder.py`, line 13: + +```python +# OLD +import anthropic +# NEW (remove this import entirely — no longer needed in the endpoint file) +``` + +Update the `_require_ai_enabled()` function (line 50) to check the new config: + +```python +def _require_ai_enabled() -> None: + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI Flow Builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) +``` + +Update any `except anthropic.APIError` catch blocks in the endpoint to catch generic `Exception` or a broader error type, since the provider abstraction may raise different errors depending on backend. + +**Step 3: Run existing AI endpoint tests** + +Run: `cd backend && python -m pytest tests/test_ai_endpoints.py -v` + +These tests mock `anthropic.AsyncAnthropic` — they need to be updated to mock `app.core.ai_provider.get_ai_provider` instead. Update the mock targets in the test file: + +```python +# Replace patches like: +# @patch("app.core.ai_tree_generator_service._get_client") +# With: +# @patch("app.core.ai_tree_generator_service.get_ai_provider") +``` + +The mock provider should return an `AsyncMock` that returns `(json_text, input_tokens, output_tokens)` from `generate_json`. + +**Step 4: Verify all tests pass** + +Run: `cd backend && python -m pytest tests/test_ai_endpoints.py tests/test_ai_provider.py tests/test_ai_tree_validator.py -v` +Expected: All pass. + +**Step 5: Commit** + +```bash +git add backend/app/core/ai_tree_generator_service.py backend/app/api/endpoints/ai_builder.py backend/tests/test_ai_endpoints.py +git commit -m "refactor: migrate AI tree generator to provider abstraction" +``` + +--- + +## Task 5: Create AI fix schemas + +**Files:** +- Create: `backend/app/schemas/ai_fix.py` + +**Step 1: Create the schemas** + +Create `backend/app/schemas/ai_fix.py`: + +```python +"""Pydantic schemas for the AI auto-fix feature.""" +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class ValidationErrorInput(BaseModel): + """A single validation error to fix.""" + + node_id: str = Field(..., description="ID of the node with the error") + message: str = Field(..., description="The validation error message") + + +class AIFixTreeRequest(BaseModel): + """Request to generate AI fixes for validation errors.""" + + tree_structure: dict[str, Any] = Field(..., description="Full tree structure") + tree_name: str = Field("", max_length=255, description="Name of the flow") + tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field( + "troubleshooting", description="Type of flow" + ) + validation_errors: list[ValidationErrorInput] = Field( + ..., min_length=1, max_length=10, description="Errors to fix" + ) + + +class AIFixProposal(BaseModel): + """A single proposed fix from the AI.""" + + target_node_id: str + error_message: str + description: str + original_node: dict[str, Any] + fixed_node: dict[str, Any] + + +class AIFixTokenUsage(BaseModel): + input: int = 0 + output: int = 0 + + +class AIFixTreeResponse(BaseModel): + """Response with proposed fixes.""" + + fixes: list[AIFixProposal] + tokens_used: AIFixTokenUsage +``` + +**Step 2: Commit** + +```bash +git add backend/app/schemas/ai_fix.py +git commit -m "feat: add Pydantic schemas for AI fix-tree endpoint" +``` + +--- + +## Task 6: Build the AI fix service + +**Files:** +- Create: `backend/app/core/ai_fix_service.py` +- Test: `backend/tests/test_ai_fix_service.py` + +**Step 1: Write tests** + +Create `backend/tests/test_ai_fix_service.py`: + +```python +"""Tests for AI fix service — prompt building and node extraction.""" +import json +import pytest +from app.core.ai_fix_service import _serialize_tree_outline, _find_node_by_id, _find_parent_node + + +def _make_tree(): + return { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review event logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs Resolved", + "description": "Found issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart fix it?", + "options": [ + {"id": "opt-r-yes", "label": "Yes", "next_node_id": "restart-ok"}, + ], + "children": [ + { + "id": "restart-ok", + "type": "solution", + "title": "Restart Worked", + "description": "Server is back.", + }, + ], + }, + ], + } + + +class TestFindNodeById: + def test_finds_root(self): + tree = _make_tree() + node = _find_node_by_id(tree, "root") + assert node is not None + assert node["id"] == "root" + + def test_finds_nested_child(self): + tree = _make_tree() + node = _find_node_by_id(tree, "restart-ok") + assert node is not None + assert node["id"] == "restart-ok" + + def test_returns_none_for_missing(self): + tree = _make_tree() + assert _find_node_by_id(tree, "nonexistent") is None + + +class TestFindParentNode: + def test_root_has_no_parent(self): + tree = _make_tree() + assert _find_parent_node(tree, "root") is None + + def test_finds_parent_of_child(self): + tree = _make_tree() + parent = _find_parent_node(tree, "restart") + assert parent is not None + assert parent["id"] == "root" + + def test_finds_parent_of_deeply_nested(self): + tree = _make_tree() + parent = _find_parent_node(tree, "restart-ok") + assert parent is not None + assert parent["id"] == "restart" + + +class TestSerializeTreeOutline: + def test_produces_readable_outline(self): + tree = _make_tree() + outline = _serialize_tree_outline(tree) + assert "[decision] Is the server up?" in outline + assert "[action] Check Logs" in outline + assert "[solution] Restart Worked" in outline + + def test_marks_error_node(self): + tree = _make_tree() + outline = _serialize_tree_outline(tree, error_node_id="restart") + assert "ERROR HERE" in outline +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && python -m pytest tests/test_ai_fix_service.py -v` +Expected: ImportError — module doesn't exist. + +**Step 3: Implement the fix service** + +Create `backend/app/core/ai_fix_service.py`: + +```python +"""AI-powered fix generation for tree validation errors. + +Builds targeted prompts for each failing node, sends to the AI provider, +and validates the fix before returning it. +""" +import copy +import json +import logging +import re +from typing import Any + +from app.core.ai_provider import get_ai_provider +from app.core.ai_tree_validator import validate_generated_tree + +logger = logging.getLogger(__name__) + +# ── Helpers ── + + +def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None: + """Recursively find a node by ID in the tree.""" + if not isinstance(tree, dict): + return None + if tree.get("id") == node_id: + return tree + for child in tree.get("children", []): + result = _find_node_by_id(child, node_id) + if result is not None: + return result + return None + + +def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None: + """Find the parent of the node with target_id.""" + if not isinstance(tree, dict): + return None + for child in tree.get("children", []): + if isinstance(child, dict) and child.get("id") == target_id: + return tree + result = _find_parent_node(child, target_id) + if result is not None: + return result + return None + + +def _serialize_tree_outline( + tree: dict[str, Any], + indent: int = 0, + error_node_id: str | None = None, +) -> str: + """Serialize tree as a compact readable outline for the AI prompt.""" + if not isinstance(tree, dict): + return "" + + node_type = tree.get("type", "?") + label = tree.get("question") or tree.get("title") or tree.get("id", "?") + prefix = " " * indent + marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else "" + + line = f"{prefix}- [{node_type}] {label}{marker}" + lines = [line] + + for child in tree.get("children", []): + lines.append(_serialize_tree_outline(child, indent + 1, error_node_id)) + + return "\n".join(lines) + + +def _strip_markdown_fences(text: str) -> str: + """Strip ```json ... ``` fences from AI response.""" + match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL) + if match: + return match.group(1).strip() + return text + + +# ── Prompt ── + +FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers. + +You will receive: +1. A full flow outline showing the tree structure +2. The specific failing node with its full JSON +3. The validation error message + +Your task: Return a FIXED version of the failing node as valid JSON. Rules: +- Fix ONLY the structural issue described in the error message +- Keep ALL existing content (titles, descriptions, questions, options) unchanged +- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic +- Every new node must have a unique ID (use descriptive kebab-case IDs) +- Decision nodes must have at least 2 options and at least 2 children +- Action nodes must have a next_node_id pointing to a sibling node in the parent's children +- Solution nodes are leaf nodes (no children) +- Return ONLY the fixed node JSON, no explanation""" + + +def _build_fix_prompt( + tree: dict[str, Any], + node_id: str, + error_message: str, + tree_name: str, + tree_type: str, +) -> str: + """Build the user message for fixing a specific node.""" + outline = _serialize_tree_outline(tree, error_node_id=node_id) + failing_node = _find_node_by_id(tree, node_id) + node_json = json.dumps(failing_node, indent=2) if failing_node else "{}" + + return ( + f"Flow name: {tree_name}\n" + f"Flow type: {tree_type}\n\n" + f"FULL FLOW OUTLINE:\n{outline}\n\n" + f"ERROR: {error_message}\n\n" + f"FAILING NODE (full JSON):\n{node_json}\n\n" + f"Return the fixed version of this node as JSON." + ) + + +# ── Main Service ── + + +async def generate_fixes( + tree_structure: dict[str, Any], + tree_name: str, + tree_type: str, + validation_errors: list[dict[str, str]], +) -> tuple[list[dict[str, Any]], int, int]: + """Generate AI fixes for each validation error. + + Args: + tree_structure: Full tree JSON. + tree_name: Name of the flow. + tree_type: Type of flow (troubleshooting/procedural/maintenance). + validation_errors: List of {"node_id": str, "message": str}. + + Returns: + (fixes, total_input_tokens, total_output_tokens) + Each fix: {"target_node_id", "error_message", "description", "original_node", "fixed_node"} + """ + provider = get_ai_provider() + fixes: list[dict[str, Any]] = [] + total_input = 0 + total_output = 0 + + for error in validation_errors: + node_id = error["node_id"] + error_message = error["message"] + + original_node = _find_node_by_id(tree_structure, node_id) + if original_node is None: + logger.warning("Node %s not found in tree, skipping fix", node_id) + continue + + original_snapshot = copy.deepcopy(original_node) + user_message = _build_fix_prompt( + tree_structure, node_id, error_message, tree_name, tree_type + ) + + # Attempt fix (with 1 retry using corrective prompt) + messages: list[dict[str, str]] = [{"role": "user", "content": user_message}] + + for attempt in range(2): + try: + raw_text, inp_tokens, out_tokens = await provider.generate_json( + system_prompt=FIX_SYSTEM_PROMPT, + messages=messages, + max_tokens=4096, + ) + total_input += inp_tokens + total_output += out_tokens + + cleaned = _strip_markdown_fences(raw_text) + fixed_node = json.loads(cleaned) + + # Quick validation: check that the fix actually addresses the error + # by substituting into tree and re-validating that specific error is gone + test_tree = copy.deepcopy(tree_structure) + _replace_node_in_tree(test_tree, node_id, fixed_node) + remaining_errors = validate_generated_tree(test_tree) + still_has_error = any( + node_id in e and error_message.split(":")[0].lower() in e.lower() + for e in remaining_errors + ) + + if still_has_error and attempt == 0: + # Retry with corrective prompt + messages.append({"role": "assistant", "content": raw_text}) + messages.append({ + "role": "user", + "content": ( + f"The fix still has the same validation error. " + f"Remaining errors: {remaining_errors}. " + f"Please try again." + ), + }) + continue + + # Extract description from the fix + description = _describe_fix(original_snapshot, fixed_node) + + fixes.append({ + "target_node_id": node_id, + "error_message": error_message, + "description": description, + "original_node": original_snapshot, + "fixed_node": fixed_node, + }) + break + + except (json.JSONDecodeError, KeyError, TypeError) as exc: + logger.warning( + "Fix attempt %d for node %s failed: %s", attempt + 1, node_id, exc + ) + if attempt == 0: + messages.append({"role": "assistant", "content": raw_text if 'raw_text' in dir() else ""}) + messages.append({ + "role": "user", + "content": f"Invalid JSON response. Return ONLY valid JSON for the fixed node.", + }) + else: + logger.error("Failed to generate fix for node %s after 2 attempts", node_id) + + return fixes, total_input, total_output + + +def _replace_node_in_tree( + tree: dict[str, Any], target_id: str, replacement: dict[str, Any] +) -> bool: + """Replace a node in the tree by ID. Returns True if found and replaced.""" + if tree.get("id") == target_id: + tree.clear() + tree.update(replacement) + return True + for child in tree.get("children", []): + if isinstance(child, dict): + if child.get("id") == target_id: + child.clear() + child.update(replacement) + return True + if _replace_node_in_tree(child, target_id, replacement): + return True + return False + + +def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str: + """Generate a human-readable description of what changed.""" + orig_children = len(original.get("children", [])) + fixed_children = len(fixed.get("children", [])) + orig_options = len(original.get("options", [])) + fixed_options = len(fixed.get("options", [])) + + parts: list[str] = [] + if fixed_children > orig_children: + added = fixed_children - orig_children + parts.append(f"Added {added} child node{'s' if added > 1 else ''}") + if fixed_options > orig_options: + added = fixed_options - orig_options + parts.append(f"Added {added} option{'s' if added > 1 else ''}") + if "next_node_id" in fixed and "next_node_id" not in original: + parts.append(f"Added next_node_id '{fixed['next_node_id']}'") + + return "; ".join(parts) if parts else "Structural fix applied" +``` + +**Step 4: Run tests** + +Run: `cd backend && python -m pytest tests/test_ai_fix_service.py -v` +Expected: All pass. + +**Step 5: Commit** + +```bash +git add backend/app/core/ai_fix_service.py backend/tests/test_ai_fix_service.py +git commit -m "feat: add AI fix service with prompt building and validation" +``` + +--- + +## Task 7: Create the fix-tree endpoint + +**Files:** +- Create: `backend/app/api/endpoints/ai_fix.py` +- Modify: `backend/app/api/router.py` +- Test: `backend/tests/test_ai_fix_endpoint.py` + +**Step 1: Write the endpoint test** + +Create `backend/tests/test_ai_fix_endpoint.py`: + +```python +"""Integration tests for AI fix-tree endpoint.""" +import json +from unittest.mock import AsyncMock, patch + +import pytest + +from app.core.config import settings + + +SAMPLE_TREE = { + "id": "root", + "type": "decision", + "question": "Is the server up?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart"}, + ], + "children": [ + { + "id": "check-logs", + "type": "action", + "title": "Check Logs", + "description": "Review logs.", + "next_node_id": "logs-ok", + }, + { + "id": "logs-ok", + "type": "solution", + "title": "Logs OK", + "description": "Issue in logs.", + }, + { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}], + "children": [ + {"id": "done", "type": "solution", "title": "Done", "description": "Fixed."}, + ], + }, + ], +} + + +@pytest.fixture +def enable_ai(): + original_key = settings.GOOGLE_AI_API_KEY + original_provider = settings.AI_PROVIDER + settings.GOOGLE_AI_API_KEY = "fake-key" + settings.AI_PROVIDER = "gemini" + yield + settings.GOOGLE_AI_API_KEY = original_key + settings.AI_PROVIDER = original_provider + + +@pytest.mark.asyncio +class TestFixTreeEndpoint: + + async def test_returns_401_without_auth(self, client): + response = await client.post("/api/v1/ai/fix-tree", json={ + "tree_structure": SAMPLE_TREE, + "tree_name": "Test", + "tree_type": "troubleshooting", + "validation_errors": [{"node_id": "restart", "message": "Need 2 options"}], + }) + assert response.status_code == 401 + + async def test_returns_503_when_ai_disabled(self, client, auth_headers): + original = settings.GOOGLE_AI_API_KEY + orig_anthropic = settings.ANTHROPIC_API_KEY + settings.GOOGLE_AI_API_KEY = None + settings.ANTHROPIC_API_KEY = None + try: + response = await client.post( + "/api/v1/ai/fix-tree", + json={ + "tree_structure": SAMPLE_TREE, + "tree_name": "Test", + "tree_type": "troubleshooting", + "validation_errors": [{"node_id": "restart", "message": "test"}], + }, + headers=auth_headers, + ) + assert response.status_code == 503 + finally: + settings.GOOGLE_AI_API_KEY = original + settings.ANTHROPIC_API_KEY = orig_anthropic + + async def test_returns_fixes_on_success(self, client, auth_headers, enable_ai): + fixed_node = { + "id": "restart", + "type": "decision", + "question": "Did restart work?", + "options": [ + {"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"}, + {"id": "opt-r-no", "label": "No", "next_node_id": "escalate"}, + ], + "children": [ + {"id": "done", "type": "solution", "title": "Done", "description": "Fixed."}, + {"id": "escalate", "type": "solution", "title": "Escalate", "description": "Escalate to vendor."}, + ], + } + + mock_provider = AsyncMock() + mock_provider.generate_json = AsyncMock( + return_value=(json.dumps(fixed_node), 500, 300) + ) + + with patch("app.core.ai_fix_service.get_ai_provider", return_value=mock_provider): + response = await client.post( + "/api/v1/ai/fix-tree", + json={ + "tree_structure": SAMPLE_TREE, + "tree_name": "Server Flow", + "tree_type": "troubleshooting", + "validation_errors": [ + {"node_id": "restart", "message": "Decision node must have at least 2 options"}, + ], + }, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["fixes"]) == 1 + assert data["fixes"][0]["target_node_id"] == "restart" + assert data["tokens_used"]["input"] == 500 + assert data["tokens_used"]["output"] == 300 +``` + +**Step 2: Run to verify it fails** + +Run: `cd backend && python -m pytest tests/test_ai_fix_endpoint.py -v` +Expected: Fail — endpoint doesn't exist. + +**Step 3: Create the endpoint** + +Create `backend/app/api/endpoints/ai_fix.py`: + +```python +"""AI auto-fix endpoint for tree validation errors.""" +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.rate_limit import limiter +from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin +from app.core.config import settings +from app.core.ai_fix_service import generate_fixes +from app.models.user import User +from app.schemas.ai_fix import AIFixTreeRequest, AIFixTreeResponse, AIFixProposal, AIFixTokenUsage + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ai", tags=["ai-fix"]) + + +def _require_ai_enabled() -> None: + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) + + +@router.post("/fix-tree", response_model=AIFixTreeResponse) +@limiter.limit("10/minute") +async def fix_tree( + request: Request, + body: AIFixTreeRequest, + user: Annotated[User, Depends(require_engineer_or_admin)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Generate AI-powered fixes for tree validation errors.""" + _require_ai_enabled() + + try: + fixes, total_input, total_output = await generate_fixes( + tree_structure=body.tree_structure, + tree_name=body.tree_name, + tree_type=body.tree_type, + validation_errors=[ + {"node_id": e.node_id, "message": e.message} + for e in body.validation_errors + ], + ) + except RuntimeError as exc: + logger.error("AI fix generation failed: %s", exc) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(exc), + ) + except Exception as exc: + logger.exception("Unexpected error during AI fix generation") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to generate fixes. Please try again.", + ) + + return AIFixTreeResponse( + fixes=[ + AIFixProposal( + target_node_id=f["target_node_id"], + error_message=f["error_message"], + description=f["description"], + original_node=f["original_node"], + fixed_node=f["fixed_node"], + ) + for f in fixes + ], + tokens_used=AIFixTokenUsage(input=total_input, output=total_output), + ) +``` + +**Step 4: Register in router** + +In `backend/app/api/router.py`, add: + +After line 8 (`from app.api.endpoints import ai_builder`): +```python +from app.api.endpoints import ai_fix +``` + +After line 38 (`api_router.include_router(ai_builder.router)`): +```python +api_router.include_router(ai_fix.router) +``` + +**Step 5: Run tests** + +Run: `cd backend && python -m pytest tests/test_ai_fix_endpoint.py -v` +Expected: All pass. + +**Step 6: Run full test suite** + +Run: `cd backend && python -m pytest --override-ini="addopts=" -v` +Expected: All pass (100+ tests). + +**Step 7: Commit** + +```bash +git add backend/app/api/endpoints/ai_fix.py backend/app/api/router.py backend/app/schemas/ai_fix.py backend/tests/test_ai_fix_endpoint.py +git commit -m "feat: add POST /ai/fix-tree endpoint for AI-powered validation fixes" +``` + +--- + +## Task 8: Add frontend API client for fix-tree + +**Files:** +- Modify: `frontend/src/api/trees.ts` +- Create: `frontend/src/types/ai-fix.ts` +- Modify: `frontend/src/types/index.ts` + +**Step 1: Create the types** + +Create `frontend/src/types/ai-fix.ts`: + +```typescript +export interface AIFixValidationError { + node_id: string + message: string +} + +export interface AIFixProposal { + target_node_id: string + error_message: string + description: string + original_node: Record + fixed_node: Record +} + +export interface AIFixTreeRequest { + tree_structure: Record + tree_name: string + tree_type: 'troubleshooting' | 'procedural' | 'maintenance' + validation_errors: AIFixValidationError[] +} + +export interface AIFixTreeResponse { + fixes: AIFixProposal[] + tokens_used: { input: number; output: number } +} +``` + +**Step 2: Export from types/index.ts** + +Add to `frontend/src/types/index.ts`: + +```typescript +export type { AIFixTreeRequest, AIFixTreeResponse, AIFixProposal, AIFixValidationError } from './ai-fix' +``` + +**Step 3: Add API method to trees.ts** + +In `frontend/src/api/trees.ts`, add a new method to the `treesApi` object: + +```typescript + async fixTree(request: AIFixTreeRequest): Promise { + const response = await apiClient.post('/ai/fix-tree', request) + return response.data + }, +``` + +Import the types at the top of the file. + +**Step 4: Commit** + +```bash +git add frontend/src/types/ai-fix.ts frontend/src/types/index.ts frontend/src/api/trees.ts +git commit -m "feat: add frontend API client and types for AI fix-tree" +``` + +--- + +## Task 9: Add "Fix with AI" button to ValidationSummary + +**Files:** +- Modify: `frontend/src/components/tree-editor/ValidationSummary.tsx` + +**Step 1: Update the component** + +Add new props and the button to `ValidationSummary`: + +```typescript +import { useState } from 'react' +import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react' +import { cn } from '@/lib/utils' +import type { ValidationError } from '@/store/treeEditorStore' + +interface ValidationSummaryProps { + errors: ValidationError[] + onSelectNode: (nodeId: string) => void + onFixWithAI?: () => void + isFixing?: boolean +} + +export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) { +``` + +In the header button area (after the expand/collapse chevron at line 62), add the "Fix with AI" button. It should appear between the error count text and the chevron icon. Restructure the header to include the button: + +After the `` closing the error/warning count (around line 60), add: + +```tsx + {/* Fix with AI button — only when there are fixable errors */} + {onFixWithAI && errorItems.some(e => e.nodeId) && ( + + )} +``` + +**Step 2: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build passes. + +**Step 3: Commit** + +```bash +git add frontend/src/components/tree-editor/ValidationSummary.tsx +git commit -m "feat: add Fix with AI button to ValidationSummary" +``` + +--- + +## Task 10: Build the AIFixReviewModal + +**Files:** +- Create: `frontend/src/components/tree-editor/AIFixReviewModal.tsx` + +**Step 1: Create the review modal** + +Create `frontend/src/components/tree-editor/AIFixReviewModal.tsx`: + +```typescript +import { useState } from 'react' +import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react' +import { cn } from '@/lib/utils' +import type { AIFixProposal } from '@/types' + +interface AIFixReviewModalProps { + fixes: AIFixProposal[] + onApply: (fix: AIFixProposal) => void + onApplyAll: () => void + onClose: () => void +} + +export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) { + const [appliedIds, setAppliedIds] = useState>(new Set()) + const [skippedIds, setSkippedIds] = useState>(new Set()) + const [expandedIds, setExpandedIds] = useState>(new Set(fixes.map(f => f.target_node_id))) + + const handleApply = (fix: AIFixProposal) => { + onApply(fix) + setAppliedIds(prev => new Set(prev).add(fix.target_node_id)) + } + + const handleSkip = (fix: AIFixProposal) => { + setSkippedIds(prev => new Set(prev).add(fix.target_node_id)) + } + + const toggleExpanded = (id: string) => { + setExpandedIds(prev => { + const next = new Set(prev) + if (next.has(id)) next.delete(id) + else next.add(id) + return next + }) + } + + const pendingFixes = fixes.filter( + f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id) + ) + const allHandled = pendingFixes.length === 0 + + return ( +
+
+ {/* Header */} +
+
+ +

+ AI Fix Proposals ({fixes.length}) +

+
+ +
+ + {/* Body */} +
+ {fixes.map((fix) => { + const isApplied = appliedIds.has(fix.target_node_id) + const isSkipped = skippedIds.has(fix.target_node_id) + const isExpanded = expandedIds.has(fix.target_node_id) + + return ( +
+ {/* Fix header */} +
+
+

{fix.error_message}

+

{fix.description}

+

+ Node: {fix.target_node_id} +

+
+ {isApplied && ( + + Applied + + )} + {isSkipped && ( + Skipped + )} +
+ + {/* Expand/collapse detail */} + {!isApplied && !isSkipped && ( + <> + + + {isExpanded && ( +
+
+

Before

+
+                            {JSON.stringify(fix.original_node, null, 2)}
+                          
+
+
+

After

+
+                            {JSON.stringify(fix.fixed_node, null, 2)}
+                          
+
+
+ )} + + {/* Action buttons */} +
+ + +
+ + )} +
+ ) + })} +
+ + {/* Footer */} +
+ + {!allHandled && ( + + )} +
+
+
+ ) +} +``` + +**Step 2: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build passes. + +**Step 3: Commit** + +```bash +git add frontend/src/components/tree-editor/AIFixReviewModal.tsx +git commit -m "feat: add AIFixReviewModal component for reviewing AI-proposed fixes" +``` + +--- + +## Task 11: Wire everything together in TreeEditorPage + +**Files:** +- Modify: `frontend/src/pages/TreeEditorPage.tsx` + +This is the integration task. The page needs to: + +1. Import the new components and API +2. Add state for fix flow (`isFixing`, `fixProposals`) +3. Handle "Fix with AI" button click — call `treesApi.fixTree()` +4. Show `AIFixReviewModal` when proposals are available +5. Handle apply/skip — call `updateNode()` on the tree editor store +6. Re-run `validate()` after applying fixes + +**Step 1: Find where ValidationSummary is rendered in TreeEditorPage** + +Search for `(null) +``` + +**Step 4: Add handler for "Fix with AI"** + +```typescript +const handleFixWithAI = async () => { + const store = useTreeEditorStore.getState() + if (!store.treeStructure) return + + // Get only fixable errors (structural errors with nodeId) + const fixableErrors = store.validationErrors + .filter(e => e.severity === 'error' && e.nodeId) + .map(e => ({ node_id: e.nodeId!, message: e.message })) + + if (fixableErrors.length === 0) return + + setIsFixing(true) + try { + const result = await treesApi.fixTree({ + tree_structure: store.treeStructure as Record, + tree_name: store.name, + tree_type: (store.treeType || 'troubleshooting') as 'troubleshooting' | 'procedural' | 'maintenance', + validation_errors: fixableErrors, + }) + if (result.fixes.length > 0) { + setFixProposals(result.fixes) + } else { + toast.info('AI could not generate fixes for these errors') + } + } catch { + toast.error('Failed to generate AI fixes. Please try again.') + } finally { + setIsFixing(false) + } +} +``` + +**Step 5: Add handlers for apply/close** + +```typescript +const handleApplyFix = (fix: AIFixProposal) => { + const store = useTreeEditorStore.getState() + store.updateNode(fix.target_node_id, fix.fixed_node as Partial) +} + +const handleApplyAllFixes = () => { + if (!fixProposals) return + for (const fix of fixProposals) { + handleApplyFix(fix) + } + setFixProposals(null) + // Re-validate after applying all fixes + setTimeout(() => { + useTreeEditorStore.getState().validate() + }, 100) +} + +const handleCloseFixModal = () => { + setFixProposals(null) + // Re-validate in case some fixes were applied + useTreeEditorStore.getState().validate() +} +``` + +**Step 6: Pass props to ValidationSummary** + +Update the `` JSX to include the new props: + +```tsx + +``` + +**Step 7: Add the review modal** + +After ``, add: + +```tsx +{fixProposals && ( + +)} +``` + +**Step 8: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build passes. + +**Step 9: Commit** + +```bash +git add frontend/src/pages/TreeEditorPage.tsx +git commit -m "feat: wire AI fix flow into TreeEditorPage" +``` + +--- + +## Task 12: Final verification + +**Step 1: Run all backend tests** + +Run: `cd backend && python -m pytest --override-ini="addopts=" -v` +Expected: All pass. + +**Step 2: Run frontend build** + +Run: `cd frontend && npm run build` +Expected: Build passes with no errors. + +**Step 3: Final commit if any adjustments needed** + +Fix any issues found during verification and commit. diff --git a/frontend/src/api/trees.ts b/frontend/src/api/trees.ts index 8b21840c..5f12603d 100644 --- a/frontend/src/api/trees.ts +++ b/frontend/src/api/trees.ts @@ -1,5 +1,5 @@ import apiClient from './client' -import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse } from '@/types' +import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse, AIFixTreeRequest, AIFixTreeResponse } from '@/types' export const treesApi = { async list(params?: TreeFilters): Promise { @@ -65,6 +65,12 @@ export const treesApi = { const response = await apiClient.post(`/trees/${id}/can-publish`) return response.data }, + + // AI auto-fix + async fixTree(request: AIFixTreeRequest): Promise { + const response = await apiClient.post('/ai/fix-tree', request) + return response.data + }, } export default treesApi diff --git a/frontend/src/components/tree-editor/AIFixReviewModal.tsx b/frontend/src/components/tree-editor/AIFixReviewModal.tsx new file mode 100644 index 00000000..acda5769 --- /dev/null +++ b/frontend/src/components/tree-editor/AIFixReviewModal.tsx @@ -0,0 +1,170 @@ +import { useState } from 'react' +import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react' +import { cn } from '@/lib/utils' +import type { AIFixProposal } from '@/types' + +interface AIFixReviewModalProps { + fixes: AIFixProposal[] + onApply: (fix: AIFixProposal) => void + onApplyAll: () => void + onClose: () => void +} + +export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) { + const [appliedIds, setAppliedIds] = useState>(new Set()) + const [skippedIds, setSkippedIds] = useState>(new Set()) + const [expandedIds, setExpandedIds] = useState>(new Set(fixes.map(f => f.target_node_id))) + + const handleApply = (fix: AIFixProposal) => { + onApply(fix) + setAppliedIds(prev => new Set(prev).add(fix.target_node_id)) + } + + const handleSkip = (fix: AIFixProposal) => { + setSkippedIds(prev => new Set(prev).add(fix.target_node_id)) + } + + const toggleExpanded = (id: string) => { + setExpandedIds(prev => { + const next = new Set(prev) + if (next.has(id)) next.delete(id) + else next.add(id) + return next + }) + } + + const pendingFixes = fixes.filter( + f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id) + ) + const allHandled = pendingFixes.length === 0 + + return ( +
+
+ {/* Header */} +
+
+ +

+ AI Fix Proposals ({fixes.length}) +

+
+ +
+ + {/* Body */} +
+ {fixes.map((fix) => { + const isApplied = appliedIds.has(fix.target_node_id) + const isSkipped = skippedIds.has(fix.target_node_id) + const isExpanded = expandedIds.has(fix.target_node_id) + + return ( +
+ {/* Fix header */} +
+
+

{fix.error_message}

+

{fix.description}

+

+ Node: {fix.target_node_id} +

+
+ {isApplied && ( + + Applied + + )} + {isSkipped && ( + Skipped + )} +
+ + {/* Expand/collapse detail */} + {!isApplied && !isSkipped && ( + <> + + + {isExpanded && ( +
+
+

Before

+
+                            {JSON.stringify(fix.original_node, null, 2)}
+                          
+
+
+

After

+
+                            {JSON.stringify(fix.fixed_node, null, 2)}
+                          
+
+
+ )} + + {/* Action buttons */} +
+ + +
+ + )} +
+ ) + })} +
+ + {/* Footer */} +
+ + {!allHandled && ( + + )} +
+
+
+ ) +} diff --git a/frontend/src/components/tree-editor/ValidationSummary.tsx b/frontend/src/components/tree-editor/ValidationSummary.tsx index fcf87bad..987be180 100644 --- a/frontend/src/components/tree-editor/ValidationSummary.tsx +++ b/frontend/src/components/tree-editor/ValidationSummary.tsx @@ -1,14 +1,16 @@ import { useState } from 'react' -import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp } from 'lucide-react' +import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react' import { cn } from '@/lib/utils' import type { ValidationError } from '@/store/treeEditorStore' interface ValidationSummaryProps { errors: ValidationError[] onSelectNode: (nodeId: string) => void + onFixWithAI?: () => void + isFixing?: boolean } -export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryProps) { +export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) { const [isExpanded, setIsExpanded] = useState(true) const errorItems = errors.filter(e => e.severity === 'error') @@ -22,6 +24,8 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro } } + const hasFixableErrors = errorItems.some(e => e.nodeId) + return (
{/* Header */} -
- {isExpanded ? : } - + {isExpanded ? : } + + + {/* Fix with AI button */} + {onFixWithAI && hasFixableErrors && ( + + )} + {/* Error/Warning List */} {isExpanded && ( diff --git a/frontend/src/pages/TreeEditorPage.tsx b/frontend/src/pages/TreeEditorPage.tsx index a21097b4..fa3ba215 100644 --- a/frontend/src/pages/TreeEditorPage.tsx +++ b/frontend/src/pages/TreeEditorPage.tsx @@ -5,10 +5,11 @@ import { Undo2, Redo2, Save, CheckCircle2, Monitor, FileText, Code2, LayoutList, import { getMonacoEditor } from '@/components/tree-editor/code-mode' import { treesApi } from '@/api/trees' import { treeMarkdownApi } from '@/api/treeMarkdown' -import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure } from '@/types' +import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure, AIFixProposal } from '@/types' import { useTreeEditorStore, useTreeEditorTemporal } from '@/store/treeEditorStore' import { TreeEditorLayout } from '@/components/tree-editor/TreeEditorLayout' import { ValidationSummary } from '@/components/tree-editor/ValidationSummary' +import { AIFixReviewModal } from '@/components/tree-editor/AIFixReviewModal' import { useKeyboardShortcuts } from '@/hooks/useKeyboardShortcuts' import { usePermissions } from '@/hooks/usePermissions' import { Spinner } from '@/components/common/Spinner' @@ -58,6 +59,8 @@ export function TreeEditorPage() { const [showAnalytics, setShowAnalytics] = useState(false) const [isMetadataOpen, setIsMetadataOpen] = useState(false) const [editingNodeId, setEditingNodeId] = useState(null) + const [isFixing, setIsFixing] = useState(false) + const [fixProposals, setFixProposals] = useState(null) // Mobile detection const [isMobile, setIsMobile] = useState(false) @@ -217,6 +220,54 @@ export function TreeEditorPage() { selectNode(nodeId) } + const handleFixWithAI = async () => { + const store = useTreeEditorStore.getState() + if (!store.treeStructure) return + + const fixableErrors = store.validationErrors + .filter(e => e.severity === 'error' && e.nodeId) + .map(e => ({ node_id: e.nodeId!, message: e.message })) + + if (fixableErrors.length === 0) return + + setIsFixing(true) + try { + const result = await treesApi.fixTree({ + tree_structure: store.treeStructure as unknown as Record, + tree_name: store.name, + tree_type: 'troubleshooting', + validation_errors: fixableErrors, + }) + if (result.fixes.length > 0) { + setFixProposals(result.fixes) + } else { + toast.info('AI could not generate fixes for these errors') + } + } catch { + toast.error('Failed to generate AI fixes. Please try again.') + } finally { + setIsFixing(false) + } + } + + const handleApplyFix = (fix: AIFixProposal) => { + updateNode(fix.target_node_id, fix.fixed_node as Partial) + } + + const handleApplyAllFixes = () => { + if (!fixProposals) return + for (const fix of fixProposals) { + handleApplyFix(fix) + } + setFixProposals(null) + setTimeout(() => { validate() }, 100) + } + + const handleCloseFixModal = () => { + setFixProposals(null) + validate() + } + const handleNodeSelect = useCallback((nodeId: string | null) => { if (nodeId) { setIsMetadataOpen(false) // close metadata when opening node editor @@ -685,6 +736,8 @@ export function TreeEditorPage() { )} @@ -705,6 +758,16 @@ export function TreeEditorPage() { )} + + {/* AI Fix Review Modal */} + {fixProposals && ( + + )} ) } diff --git a/frontend/src/types/ai-fix.ts b/frontend/src/types/ai-fix.ts new file mode 100644 index 00000000..a12a6fdf --- /dev/null +++ b/frontend/src/types/ai-fix.ts @@ -0,0 +1,24 @@ +export interface AIFixValidationError { + node_id: string + message: string +} + +export interface AIFixProposal { + target_node_id: string + error_message: string + description: string + original_node: Record + fixed_node: Record +} + +export interface AIFixTreeRequest { + tree_structure: Record + tree_name: string + tree_type: 'troubleshooting' | 'procedural' | 'maintenance' + validation_errors: AIFixValidationError[] +} + +export interface AIFixTreeResponse { + fixes: AIFixProposal[] + tokens_used: { input: number; output: number } +} diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 1bc91dc2..2c388169 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -45,3 +45,10 @@ export type { AIAssembleResponse, AIWizardPhase, } from './ai' + +export type { + AIFixTreeRequest, + AIFixTreeResponse, + AIFixProposal, + AIFixValidationError, +} from './ai-fix'