feat: AI auto-fix + Gemini Flash provider #93
@@ -10,7 +10,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import anthropic
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -52,10 +51,9 @@ def _require_ai_enabled() -> None:
|
|||||||
if not settings.ai_enabled:
|
if not settings.ai_enabled:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.",
|
detail="AI flow builder is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/quota", response_model=AIQuotaStatusResponse)
|
@router.get("/quota", response_model=AIQuotaStatusResponse)
|
||||||
async def get_quota(
|
async def get_quota(
|
||||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
@@ -174,27 +172,6 @@ async def scaffold(
|
|||||||
branches, input_tokens, output_tokens, cost = await scaffold_branches(
|
branches, input_tokens, output_tokens, cost = await scaffold_branches(
|
||||||
conversation.wizard_state,
|
conversation.wizard_state,
|
||||||
)
|
)
|
||||||
except anthropic.APIError as e:
|
|
||||||
await record_ai_usage(
|
|
||||||
user_id=current_user.id,
|
|
||||||
account_id=current_user.account_id,
|
|
||||||
conversation_id=conversation.id,
|
|
||||||
generation_type="scaffold",
|
|
||||||
tier=plan,
|
|
||||||
input_tokens=0,
|
|
||||||
output_tokens=0,
|
|
||||||
estimated_cost=0,
|
|
||||||
succeeded=False,
|
|
||||||
counts_toward_quota=False,
|
|
||||||
error_code=type(e).__name__,
|
|
||||||
extra_data={"error": str(e)},
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
||||||
detail="AI provider error. Please try again.",
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -216,6 +193,28 @@ async def scaffold(
|
|||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
detail=f"AI returned invalid output: {e}",
|
detail=f"AI returned invalid output: {e}",
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("AI scaffold failed: %s: %s", type(e).__name__, e)
|
||||||
|
await record_ai_usage(
|
||||||
|
user_id=current_user.id,
|
||||||
|
account_id=current_user.account_id,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
generation_type="scaffold",
|
||||||
|
tier=plan,
|
||||||
|
input_tokens=0,
|
||||||
|
output_tokens=0,
|
||||||
|
estimated_cost=0,
|
||||||
|
succeeded=False,
|
||||||
|
counts_toward_quota=False,
|
||||||
|
error_code=type(e).__name__,
|
||||||
|
extra_data={"error": str(e)},
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"AI provider error ({type(e).__name__}). Please try again.",
|
||||||
|
)
|
||||||
|
|
||||||
# Record successful usage
|
# Record successful usage
|
||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
@@ -293,27 +292,6 @@ async def branch_detail(
|
|||||||
existing_branches,
|
existing_branches,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except anthropic.APIError as e:
|
|
||||||
await record_ai_usage(
|
|
||||||
user_id=current_user.id,
|
|
||||||
account_id=current_user.account_id,
|
|
||||||
conversation_id=conversation.id,
|
|
||||||
generation_type="branch_detail",
|
|
||||||
tier=plan,
|
|
||||||
input_tokens=0,
|
|
||||||
output_tokens=0,
|
|
||||||
estimated_cost=0,
|
|
||||||
succeeded=False,
|
|
||||||
counts_toward_quota=False,
|
|
||||||
error_code=type(e).__name__,
|
|
||||||
extra_data={"error": str(e), "branch_name": data.branch_name},
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
||||||
detail="AI provider error. Please try again.",
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -335,6 +313,28 @@ async def branch_detail(
|
|||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
detail=f"AI returned invalid output: {e}",
|
detail=f"AI returned invalid output: {e}",
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("AI branch_detail failed: %s: %s", type(e).__name__, e)
|
||||||
|
await record_ai_usage(
|
||||||
|
user_id=current_user.id,
|
||||||
|
account_id=current_user.account_id,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
generation_type="branch_detail",
|
||||||
|
tier=plan,
|
||||||
|
input_tokens=0,
|
||||||
|
output_tokens=0,
|
||||||
|
estimated_cost=0,
|
||||||
|
succeeded=False,
|
||||||
|
counts_toward_quota=False,
|
||||||
|
error_code=type(e).__name__,
|
||||||
|
extra_data={"error": str(e), "branch_name": data.branch_name},
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"AI provider error ({type(e).__name__}). Please try again.",
|
||||||
|
)
|
||||||
|
|
||||||
# Record successful usage
|
# Record successful usage
|
||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
|
|||||||
78
backend/app/api/endpoints/ai_fix.py
Normal file
78
backend/app/api/endpoints/ai_fix.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""AI auto-fix endpoint for tree validation errors.
|
||||||
|
|
||||||
|
POST /ai/fix-tree — accepts a tree with validation errors and returns
|
||||||
|
AI-generated fix proposals for each error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_db, require_engineer_or_admin
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.rate_limit import limiter
|
||||||
|
from app.core.ai_fix_service import generate_fixes
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.ai_fix import (
|
||||||
|
AIFixTreeRequest,
|
||||||
|
AIFixTreeResponse,
|
||||||
|
AIFixProposal,
|
||||||
|
AIFixTokenUsage,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/ai", tags=["ai-fix"])
|
||||||
|
|
||||||
|
|
||||||
|
def _require_ai_enabled() -> None:
|
||||||
|
"""Raise 503 if AI is not configured."""
|
||||||
|
if not settings.ai_enabled:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="AI fix is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/fix-tree", response_model=AIFixTreeResponse)
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
async def fix_tree(
|
||||||
|
request: Request,
|
||||||
|
body: AIFixTreeRequest,
|
||||||
|
user: Annotated[User, Depends(require_engineer_or_admin)],
|
||||||
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
):
|
||||||
|
"""Generate AI-powered fixes for tree validation errors."""
|
||||||
|
_require_ai_enabled()
|
||||||
|
|
||||||
|
validation_errors = [
|
||||||
|
{"node_id": e.node_id, "message": e.message}
|
||||||
|
for e in body.validation_errors
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
fixes, input_tokens, output_tokens = await generate_fixes(
|
||||||
|
tree_structure=body.tree_structure,
|
||||||
|
tree_name=body.tree_name,
|
||||||
|
tree_type=body.tree_type,
|
||||||
|
validation_errors=validation_errors,
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
logger.error("AI provider not available: %s", exc)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=str(exc),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Unexpected error in AI fix service")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred while generating fixes.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AIFixTreeResponse(
|
||||||
|
fixes=[AIFixProposal(**f) for f in fixes],
|
||||||
|
tokens_used=AIFixTokenUsage(input=input_tokens, output=output_tokens),
|
||||||
|
)
|
||||||
@@ -6,6 +6,7 @@ from app.api.endpoints import target_lists
|
|||||||
from app.api.endpoints import maintenance_schedules
|
from app.api.endpoints import maintenance_schedules
|
||||||
from app.api.endpoints import feedback
|
from app.api.endpoints import feedback
|
||||||
from app.api.endpoints import ai_builder
|
from app.api.endpoints import ai_builder
|
||||||
|
from app.api.endpoints import ai_fix
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
|
|
||||||
@@ -36,3 +37,4 @@ api_router.include_router(target_lists.router)
|
|||||||
api_router.include_router(maintenance_schedules.router)
|
api_router.include_router(maintenance_schedules.router)
|
||||||
api_router.include_router(feedback.router)
|
api_router.include_router(feedback.router)
|
||||||
api_router.include_router(ai_builder.router)
|
api_router.include_router(ai_builder.router)
|
||||||
|
api_router.include_router(ai_fix.router)
|
||||||
|
|||||||
273
backend/app/core/ai_fix_service.py
Normal file
273
backend/app/core/ai_fix_service.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
"""AI-powered fix service for tree validation errors.
|
||||||
|
|
||||||
|
Given a tree structure and validation errors, generates AI-powered
|
||||||
|
proposals to fix each structural issue while preserving existing content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.ai_provider import get_ai_provider
|
||||||
|
from app.core.ai_tree_validator import validate_generated_tree
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
FIX_SYSTEM_PROMPT = """You are ResolutionFlow AI, fixing structural validation errors in IT troubleshooting and maintenance flows used by MSP engineers.
|
||||||
|
|
||||||
|
You will receive:
|
||||||
|
1. A full flow outline showing the tree structure
|
||||||
|
2. The specific failing node with its full JSON
|
||||||
|
3. The validation error message
|
||||||
|
|
||||||
|
Your task: Return a FIXED version of the failing node as valid JSON. Rules:
|
||||||
|
- Fix ONLY the structural issue described in the error message
|
||||||
|
- Keep ALL existing content (titles, descriptions, questions, options) unchanged
|
||||||
|
- When adding new nodes (e.g., missing branches), generate domain-relevant content that fits the flow's topic
|
||||||
|
- Every new node must have a unique ID (use descriptive kebab-case IDs)
|
||||||
|
- Decision nodes must have at least 2 options and at least 2 children
|
||||||
|
- Action nodes must have a next_node_id pointing to a sibling node in the parent's children
|
||||||
|
- Solution nodes are leaf nodes (no children)
|
||||||
|
- Return ONLY the fixed node JSON, no explanation"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pure helper functions ──
|
||||||
|
|
||||||
|
|
||||||
|
def _find_node_by_id(tree: dict[str, Any], node_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Recursively find a node by its ID in the tree structure."""
|
||||||
|
if not isinstance(tree, dict):
|
||||||
|
return None
|
||||||
|
if tree.get("id") == node_id:
|
||||||
|
return tree
|
||||||
|
for child in tree.get("children", []):
|
||||||
|
result = _find_node_by_id(child, node_id)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_parent_node(tree: dict[str, Any], target_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Find the parent of a node with the given ID."""
|
||||||
|
if not isinstance(tree, dict):
|
||||||
|
return None
|
||||||
|
for child in tree.get("children", []):
|
||||||
|
if isinstance(child, dict) and child.get("id") == target_id:
|
||||||
|
return tree
|
||||||
|
result = _find_parent_node(child, target_id)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_tree_outline(
|
||||||
|
tree: dict[str, Any], indent: int = 0, error_node_id: str | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Serialize tree as a readable outline for AI prompt context.
|
||||||
|
|
||||||
|
Format: indented "- [type] label" with "<<< ERROR HERE" marker.
|
||||||
|
"""
|
||||||
|
if not isinstance(tree, dict):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
node_type = tree.get("type", "unknown")
|
||||||
|
label = tree.get("question") or tree.get("title") or tree.get("id", "?")
|
||||||
|
prefix = " " * indent
|
||||||
|
marker = " <<< ERROR HERE" if tree.get("id") == error_node_id else ""
|
||||||
|
line = f"{prefix}- [{node_type}] {label}{marker}"
|
||||||
|
|
||||||
|
lines = [line]
|
||||||
|
for child in tree.get("children", []):
|
||||||
|
lines.append(_serialize_tree_outline(child, indent + 1, error_node_id))
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_markdown_fences(text: str) -> str:
|
||||||
|
"""Strip ```json...``` fences from AI response."""
|
||||||
|
return re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.MULTILINE).rstrip(
|
||||||
|
"`"
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_node_in_tree(
|
||||||
|
tree: dict[str, Any], target_id: str, replacement: dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
"""Replace a node in-place by ID. Returns True if found and replaced."""
|
||||||
|
if not isinstance(tree, dict):
|
||||||
|
return False
|
||||||
|
if tree.get("id") == target_id:
|
||||||
|
tree.clear()
|
||||||
|
tree.update(replacement)
|
||||||
|
return True
|
||||||
|
for child in tree.get("children", []):
|
||||||
|
if _replace_node_in_tree(child, target_id, replacement):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _describe_fix(original: dict[str, Any], fixed: dict[str, Any]) -> str:
|
||||||
|
"""Describe what changed between original and fixed node."""
|
||||||
|
changes: list[str] = []
|
||||||
|
|
||||||
|
orig_children = len(original.get("children", []))
|
||||||
|
fixed_children = len(fixed.get("children", []))
|
||||||
|
if fixed_children > orig_children:
|
||||||
|
changes.append(f"added {fixed_children - orig_children} child node(s)")
|
||||||
|
|
||||||
|
orig_options = len(original.get("options", []))
|
||||||
|
fixed_options = len(fixed.get("options", []))
|
||||||
|
if fixed_options > orig_options:
|
||||||
|
changes.append(f"added {fixed_options - orig_options} option(s)")
|
||||||
|
|
||||||
|
if fixed.get("next_node_id") and not original.get("next_node_id"):
|
||||||
|
changes.append("added next_node_id")
|
||||||
|
|
||||||
|
if not changes:
|
||||||
|
changes.append("fixed structural issue")
|
||||||
|
|
||||||
|
return "; ".join(changes).capitalize()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prompt building ──
|
||||||
|
|
||||||
|
|
||||||
|
def _build_fix_prompt(
|
||||||
|
tree: dict[str, Any],
|
||||||
|
node_id: str,
|
||||||
|
error_message: str,
|
||||||
|
tree_name: str,
|
||||||
|
tree_type: str,
|
||||||
|
) -> str:
|
||||||
|
"""Build the user message for the AI fix request."""
|
||||||
|
outline = _serialize_tree_outline(tree, error_node_id=node_id)
|
||||||
|
node = _find_node_by_id(tree, node_id)
|
||||||
|
node_json = json.dumps(node, indent=2) if node else "{}"
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"Flow name: {tree_name}\n"
|
||||||
|
f"Flow type: {tree_type}\n\n"
|
||||||
|
f"## Full flow outline\n```\n{outline}\n```\n\n"
|
||||||
|
f"## Failing node (ID: {node_id})\n```json\n{node_json}\n```\n\n"
|
||||||
|
f"## Validation error\n{error_message}\n\n"
|
||||||
|
f"Return the fixed version of this node as JSON."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main entry point ──
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_fixes(
|
||||||
|
tree_structure: dict[str, Any],
|
||||||
|
tree_name: str,
|
||||||
|
tree_type: str,
|
||||||
|
validation_errors: list[dict[str, str]],
|
||||||
|
) -> tuple[list[dict[str, Any]], int, int]:
|
||||||
|
"""Generate AI-powered fixes for tree validation errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree_structure: Full tree structure dict.
|
||||||
|
tree_name: Name of the flow.
|
||||||
|
tree_type: Type of flow (troubleshooting, procedural, maintenance).
|
||||||
|
validation_errors: List of dicts with "node_id" and "message" keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (fixes_list, total_input_tokens, total_output_tokens).
|
||||||
|
Each fix dict has: target_node_id, error_message, description,
|
||||||
|
original_node, fixed_node.
|
||||||
|
"""
|
||||||
|
provider = get_ai_provider()
|
||||||
|
fixes: list[dict[str, Any]] = []
|
||||||
|
total_input_tokens = 0
|
||||||
|
total_output_tokens = 0
|
||||||
|
|
||||||
|
for error in validation_errors:
|
||||||
|
node_id = error["node_id"]
|
||||||
|
error_message = error["message"]
|
||||||
|
|
||||||
|
original_node = _find_node_by_id(tree_structure, node_id)
|
||||||
|
if original_node is None:
|
||||||
|
logger.warning("Node %s not found in tree, skipping fix", node_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
original_snapshot = copy.deepcopy(original_node)
|
||||||
|
|
||||||
|
# Build prompt and call AI
|
||||||
|
user_message = _build_fix_prompt(
|
||||||
|
tree_structure, node_id, error_message, tree_name, tree_type
|
||||||
|
)
|
||||||
|
messages = [{"role": "user", "content": user_message}]
|
||||||
|
|
||||||
|
try:
|
||||||
|
text, in_tok, out_tok = await provider.generate_json(
|
||||||
|
system_prompt=FIX_SYSTEM_PROMPT,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
total_input_tokens += in_tok
|
||||||
|
total_output_tokens += out_tok
|
||||||
|
|
||||||
|
cleaned = _strip_markdown_fences(text)
|
||||||
|
fixed_node = json.loads(cleaned)
|
||||||
|
except (json.JSONDecodeError, Exception) as exc:
|
||||||
|
logger.warning("AI fix failed for node %s: %s", node_id, exc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate by substituting into a tree copy
|
||||||
|
tree_copy = copy.deepcopy(tree_structure)
|
||||||
|
_replace_node_in_tree(tree_copy, node_id, copy.deepcopy(fixed_node))
|
||||||
|
remaining_errors = validate_generated_tree(tree_copy)
|
||||||
|
|
||||||
|
# Check if the specific error is still present
|
||||||
|
still_has_error = any(node_id in e for e in remaining_errors)
|
||||||
|
|
||||||
|
if still_has_error:
|
||||||
|
# Retry once with corrective prompt
|
||||||
|
retry_message = (
|
||||||
|
f"Your previous fix still has validation errors:\n"
|
||||||
|
f"{chr(10).join(remaining_errors)}\n\n"
|
||||||
|
f"Please fix the node again. Return ONLY the corrected JSON."
|
||||||
|
)
|
||||||
|
messages.append({"role": "assistant", "content": text})
|
||||||
|
messages.append({"role": "user", "content": retry_message})
|
||||||
|
|
||||||
|
try:
|
||||||
|
text2, in_tok2, out_tok2 = await provider.generate_json(
|
||||||
|
system_prompt=FIX_SYSTEM_PROMPT,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
total_input_tokens += in_tok2
|
||||||
|
total_output_tokens += out_tok2
|
||||||
|
|
||||||
|
cleaned2 = _strip_markdown_fences(text2)
|
||||||
|
fixed_node = json.loads(cleaned2)
|
||||||
|
|
||||||
|
# Re-validate
|
||||||
|
tree_copy2 = copy.deepcopy(tree_structure)
|
||||||
|
_replace_node_in_tree(tree_copy2, node_id, copy.deepcopy(fixed_node))
|
||||||
|
remaining2 = validate_generated_tree(tree_copy2)
|
||||||
|
still_has_error = any(node_id in e for e in remaining2)
|
||||||
|
except (json.JSONDecodeError, Exception) as exc:
|
||||||
|
logger.warning("AI retry fix failed for node %s: %s", node_id, exc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if still_has_error:
|
||||||
|
logger.warning("AI could not fix node %s after retry", node_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
description = _describe_fix(original_snapshot, fixed_node)
|
||||||
|
fixes.append(
|
||||||
|
{
|
||||||
|
"target_node_id": node_id,
|
||||||
|
"error_message": error_message,
|
||||||
|
"description": description,
|
||||||
|
"original_node": original_snapshot,
|
||||||
|
"fixed_node": fixed_node,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return fixes, total_input_tokens, total_output_tokens
|
||||||
175
backend/app/core/ai_provider.py
Normal file
175
backend/app/core/ai_provider.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""
|
||||||
|
AI Provider abstraction layer.
|
||||||
|
|
||||||
|
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
|
||||||
|
backends for JSON generation used by the AI Flow Builder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AIProvider(ABC):
|
||||||
|
"""Abstract base class for AI providers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate_json(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> tuple[str, int, int]:
|
||||||
|
"""Generate a JSON response from the AI model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: System-level instruction for the model.
|
||||||
|
messages: List of message dicts with "role" and "content" keys.
|
||||||
|
max_tokens: Maximum output tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response_text, input_tokens, output_tokens).
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiProvider(AIProvider):
|
||||||
|
"""Google Gemini provider using the google-genai SDK."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str) -> None:
|
||||||
|
self._api_key = api_key
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
async def generate_json(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> tuple[str, int, int]:
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types as genai_types
|
||||||
|
|
||||||
|
client = genai.Client(api_key=self._api_key)
|
||||||
|
|
||||||
|
# Convert messages to Gemini Content format
|
||||||
|
contents: list[genai_types.Content] = []
|
||||||
|
for msg in messages:
|
||||||
|
role = "model" if msg["role"] == "assistant" else "user"
|
||||||
|
contents.append(
|
||||||
|
genai_types.Content(
|
||||||
|
role=role,
|
||||||
|
parts=[genai_types.Part(text=msg["content"])],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
config = genai_types.GenerateContentConfig(
|
||||||
|
system_instruction=system_prompt,
|
||||||
|
max_output_tokens=max_tokens,
|
||||||
|
response_mime_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.aio.models.generate_content(
|
||||||
|
model=self._model,
|
||||||
|
contents=contents,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log finish reason to detect truncation
|
||||||
|
if response.candidates:
|
||||||
|
finish_reason = getattr(response.candidates[0], "finish_reason", None)
|
||||||
|
logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model)
|
||||||
|
if str(finish_reason) == "MAX_TOKENS":
|
||||||
|
logger.warning(
|
||||||
|
"Gemini output truncated (MAX_TOKENS). max_output_tokens=%d",
|
||||||
|
max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = response.text or ""
|
||||||
|
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
|
||||||
|
output_tokens = (
|
||||||
|
getattr(response.usage_metadata, "candidates_token_count", 0) or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return text, input_tokens, output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(AIProvider):
|
||||||
|
"""Anthropic Claude provider using the anthropic SDK."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str, timeout: int = 45) -> None:
|
||||||
|
self._api_key = api_key
|
||||||
|
self._model = model
|
||||||
|
self._timeout = timeout
|
||||||
|
|
||||||
|
async def generate_json(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> tuple[str, int, int]:
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.AsyncAnthropic(
|
||||||
|
api_key=self._api_key,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.messages.create(
|
||||||
|
model=self._model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = response.content[0].text
|
||||||
|
input_tokens = response.usage.input_tokens
|
||||||
|
output_tokens = response.usage.output_tokens
|
||||||
|
|
||||||
|
return text, input_tokens, output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_ai_provider() -> AIProvider:
|
||||||
|
"""Factory that returns the configured AI provider.
|
||||||
|
|
||||||
|
Selection logic:
|
||||||
|
1. If AI_PROVIDER == "gemini" and GOOGLE_AI_API_KEY is set -> GeminiProvider
|
||||||
|
2. If AI_PROVIDER == "anthropic" and ANTHROPIC_API_KEY is set -> AnthropicProvider
|
||||||
|
3. Fallback: if preferred provider key missing, try the other one
|
||||||
|
4. If nothing configured -> raise RuntimeError
|
||||||
|
"""
|
||||||
|
provider = settings.AI_PROVIDER
|
||||||
|
|
||||||
|
if provider == "gemini":
|
||||||
|
if settings.GOOGLE_AI_API_KEY:
|
||||||
|
return GeminiProvider(
|
||||||
|
api_key=settings.GOOGLE_AI_API_KEY,
|
||||||
|
model=settings.AI_MODEL_GEMINI,
|
||||||
|
)
|
||||||
|
# Fallback to Anthropic
|
||||||
|
if settings.ANTHROPIC_API_KEY:
|
||||||
|
return AnthropicProvider(
|
||||||
|
api_key=settings.ANTHROPIC_API_KEY,
|
||||||
|
model=settings.AI_MODEL_ANTHROPIC,
|
||||||
|
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif provider == "anthropic":
|
||||||
|
if settings.ANTHROPIC_API_KEY:
|
||||||
|
return AnthropicProvider(
|
||||||
|
api_key=settings.ANTHROPIC_API_KEY,
|
||||||
|
model=settings.AI_MODEL_ANTHROPIC,
|
||||||
|
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
# Fallback to Gemini
|
||||||
|
if settings.GOOGLE_AI_API_KEY:
|
||||||
|
return GeminiProvider(
|
||||||
|
api_key=settings.GOOGLE_AI_API_KEY,
|
||||||
|
model=settings.AI_MODEL_GEMINI,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
"No AI provider configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY."
|
||||||
|
)
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
"""AI-powered tree generation service using Anthropic Claude API.
|
"""AI-powered tree generation service.
|
||||||
|
|
||||||
Implements the 4-stage wizard flow:
|
Implements the 4-stage wizard flow:
|
||||||
Stage 2 (scaffold): AI suggests 4-7 top-level branches
|
Stage 2 (scaffold): AI suggests 4-7 top-level branches
|
||||||
Stage 3 (branch_detail): AI generates detailed nodes per branch
|
Stage 3 (branch_detail): AI generates detailed nodes per branch
|
||||||
Stage 4 (assemble): Pure assembly logic — zero AI calls
|
Stage 4 (assemble): Pure assembly logic — zero AI calls
|
||||||
|
|
||||||
System prompts are static constants to enable Anthropic prompt caching.
|
Uses the provider abstraction from ai_provider.py (supports Gemini + Anthropic).
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -13,8 +13,7 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import anthropic
|
from app.core.ai_provider import get_ai_provider
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
|
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
|
||||||
|
|
||||||
@@ -121,15 +120,6 @@ def _strip_markdown_fences(text: str) -> str:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def _get_client() -> anthropic.AsyncAnthropic:
|
|
||||||
"""Get configured async Anthropic client."""
|
|
||||||
if not settings.ANTHROPIC_API_KEY:
|
|
||||||
raise RuntimeError("ANTHROPIC_API_KEY not configured")
|
|
||||||
return anthropic.AsyncAnthropic(
|
|
||||||
api_key=settings.ANTHROPIC_API_KEY,
|
|
||||||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
|
def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
|
||||||
"""Estimate USD cost from token counts."""
|
"""Estimate USD cost from token counts."""
|
||||||
@@ -146,7 +136,7 @@ async def scaffold_branches(
|
|||||||
Returns (branches, input_tokens, output_tokens, estimated_cost).
|
Returns (branches, input_tokens, output_tokens, estimated_cost).
|
||||||
Raises ValueError on invalid response.
|
Raises ValueError on invalid response.
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
provider = get_ai_provider()
|
||||||
|
|
||||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||||
name = wizard_state.get("name", "")
|
name = wizard_state.get("name", "")
|
||||||
@@ -161,21 +151,26 @@ async def scaffold_branches(
|
|||||||
if tags:
|
if tags:
|
||||||
user_message += f"Environment: {', '.join(tags)}\n"
|
user_message += f"Environment: {', '.join(tags)}\n"
|
||||||
|
|
||||||
response = await client.messages.create(
|
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||||
model=settings.AI_MODEL,
|
system_prompt=SCAFFOLD_SYSTEM_PROMPT,
|
||||||
max_tokens=1024,
|
|
||||||
system=SCAFFOLD_SYSTEM_PROMPT,
|
|
||||||
messages=[{"role": "user", "content": user_message}],
|
messages=[{"role": "user", "content": user_message}],
|
||||||
|
max_tokens=2048,
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_text = _strip_markdown_fences(response.content[0].text)
|
logger.info(
|
||||||
input_tokens = response.usage.input_tokens
|
"scaffold raw response (tokens in=%d out=%d, len=%d): %s",
|
||||||
output_tokens = response.usage.output_tokens
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
len(raw_text),
|
||||||
|
raw_text[:500],
|
||||||
|
)
|
||||||
|
raw_text = _strip_markdown_fences(raw_text)
|
||||||
cost = _estimate_cost(input_tokens, output_tokens)
|
cost = _estimate_cost(input_tokens, output_tokens)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(raw_text)
|
data = json.loads(raw_text)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error("scaffold JSON parse failed. Full text (%d chars): %s", len(raw_text), raw_text)
|
||||||
raise ValueError(f"AI returned invalid JSON: {e}")
|
raise ValueError(f"AI returned invalid JSON: {e}")
|
||||||
|
|
||||||
branches = data.get("branches", [])
|
branches = data.get("branches", [])
|
||||||
@@ -196,7 +191,7 @@ async def generate_branch_detail(
|
|||||||
On validation failure, retries once with corrective prompt.
|
On validation failure, retries once with corrective prompt.
|
||||||
Raises ValueError if both attempts fail.
|
Raises ValueError if both attempts fail.
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
provider = get_ai_provider()
|
||||||
|
|
||||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||||
name = wizard_state.get("name", "")
|
name = wizard_state.get("name", "")
|
||||||
@@ -217,35 +212,30 @@ async def generate_branch_detail(
|
|||||||
total_output = 0
|
total_output = 0
|
||||||
|
|
||||||
for attempt in range(3):
|
for attempt in range(3):
|
||||||
response = await client.messages.create(
|
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||||
model=settings.AI_MODEL,
|
system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT,
|
||||||
max_tokens=8192,
|
|
||||||
system=BRANCH_DETAIL_SYSTEM_PROMPT,
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
max_tokens=8192,
|
||||||
)
|
)
|
||||||
|
|
||||||
total_input += response.usage.input_tokens
|
total_input += input_tokens
|
||||||
total_output += response.usage.output_tokens
|
total_output += output_tokens
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"branch_detail attempt=%d stop_reason=%s content_blocks=%d output_tokens=%d",
|
"branch_detail attempt=%d output_tokens=%d",
|
||||||
attempt,
|
attempt,
|
||||||
response.stop_reason,
|
output_tokens,
|
||||||
len(response.content),
|
|
||||||
response.usage.output_tokens,
|
|
||||||
)
|
)
|
||||||
if response.stop_reason == "max_tokens":
|
raw_text = _strip_markdown_fences(raw_text) if raw_text else ""
|
||||||
logger.warning(
|
|
||||||
"branch_detail attempt=%d hit max_tokens limit (%d output tokens) — response may be truncated",
|
|
||||||
attempt,
|
|
||||||
response.usage.output_tokens,
|
|
||||||
)
|
|
||||||
raw_text = _strip_markdown_fences(response.content[0].text) if response.content else ""
|
|
||||||
if not raw_text:
|
if not raw_text:
|
||||||
logger.warning("branch_detail attempt=%d returned empty text, stop_reason=%s", attempt, response.stop_reason)
|
logger.warning("branch_detail attempt=%d returned empty text", attempt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
branch_tree = json.loads(raw_text)
|
branch_tree = json.loads(raw_text)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(
|
||||||
|
"branch_detail attempt=%d JSON parse failed (%d chars): %s",
|
||||||
|
attempt, len(raw_text), raw_text[:500],
|
||||||
|
)
|
||||||
if attempt < 2:
|
if attempt < 2:
|
||||||
messages.append({"role": "assistant", "content": raw_text})
|
messages.append({"role": "assistant", "content": raw_text})
|
||||||
messages.append({
|
messages.append({
|
||||||
|
|||||||
@@ -78,11 +78,16 @@ class Settings(BaseSettings):
|
|||||||
AI_CONVERSATION_TTL_HOURS: int = 24
|
AI_CONVERSATION_TTL_HOURS: int = 24
|
||||||
AI_MAX_CALLS_PER_FLOW: int = 10
|
AI_MAX_CALLS_PER_FLOW: int = 10
|
||||||
AI_REQUEST_TIMEOUT_SECONDS: int = 45
|
AI_REQUEST_TIMEOUT_SECONDS: int = 45
|
||||||
|
# AI Provider selection
|
||||||
|
AI_PROVIDER: str = "gemini" # "gemini" or "anthropic"
|
||||||
|
GOOGLE_AI_API_KEY: Optional[str] = None
|
||||||
|
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
|
||||||
|
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ai_enabled(self) -> bool:
|
def ai_enabled(self) -> bool:
|
||||||
"""Check if AI Flow Builder is configured."""
|
"""Check if any AI provider is configured."""
|
||||||
return self.ANTHROPIC_API_KEY is not None
|
return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None
|
||||||
|
|
||||||
# Deployment – auto-seed test data on PR environments
|
# Deployment – auto-seed test data on PR environments
|
||||||
SEED_ON_DEPLOY: bool = False
|
SEED_ON_DEPLOY: bool = False
|
||||||
|
|||||||
52
backend/app/schemas/ai_fix.py
Normal file
52
backend/app/schemas/ai_fix.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Pydantic schemas for the AI auto-fix feature."""
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ── Requests ──
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationErrorInput(BaseModel):
|
||||||
|
"""A single validation error to fix."""
|
||||||
|
|
||||||
|
node_id: str = Field(..., description="ID of the node with the error")
|
||||||
|
message: str = Field(..., description="The validation error message")
|
||||||
|
|
||||||
|
|
||||||
|
class AIFixTreeRequest(BaseModel):
|
||||||
|
"""Request to generate AI fixes for validation errors."""
|
||||||
|
|
||||||
|
tree_structure: dict[str, Any] = Field(..., description="Full tree structure")
|
||||||
|
tree_name: str = Field("", max_length=255, description="Name of the flow")
|
||||||
|
tree_type: Literal["troubleshooting", "procedural", "maintenance"] = Field(
|
||||||
|
"troubleshooting", description="Type of flow"
|
||||||
|
)
|
||||||
|
validation_errors: list[ValidationErrorInput] = Field(
|
||||||
|
..., min_length=1, max_length=10, description="Errors to fix"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Responses ──
|
||||||
|
|
||||||
|
|
||||||
|
class AIFixProposal(BaseModel):
|
||||||
|
"""A single proposed fix from the AI."""
|
||||||
|
|
||||||
|
target_node_id: str
|
||||||
|
error_message: str
|
||||||
|
description: str
|
||||||
|
original_node: dict[str, Any]
|
||||||
|
fixed_node: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class AIFixTokenUsage(BaseModel):
|
||||||
|
input: int = 0
|
||||||
|
output: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AIFixTreeResponse(BaseModel):
|
||||||
|
"""Response with proposed fixes."""
|
||||||
|
|
||||||
|
fixes: list[AIFixProposal]
|
||||||
|
tokens_used: AIFixTokenUsage
|
||||||
@@ -33,6 +33,7 @@ httpx>=0.27.0
|
|||||||
|
|
||||||
# AI Flow Builder
|
# AI Flow Builder
|
||||||
anthropic>=0.40.0
|
anthropic>=0.40.0
|
||||||
|
google-genai>=1.0.0
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Integration tests for AI Flow Builder endpoints.
|
"""Integration tests for AI Flow Builder endpoints.
|
||||||
|
|
||||||
All Anthropic API calls are mocked — zero real API spend.
|
All AI provider calls are mocked — zero real API spend.
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
@@ -64,12 +64,11 @@ BRANCH_DETAIL_JSON = json.dumps({
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200):
|
def _mock_ai_provider(text: str, input_tokens: int = 100, output_tokens: int = 200):
|
||||||
"""Create a mock Anthropic API response."""
|
"""Create a mock AI provider whose generate_json returns the given text and token counts."""
|
||||||
response = MagicMock()
|
provider = MagicMock()
|
||||||
response.content = [MagicMock(text=text)]
|
provider.generate_json = AsyncMock(return_value=(text, input_tokens, output_tokens))
|
||||||
response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens)
|
return provider
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -194,11 +193,9 @@ async def test_scaffold_success(client, auth_headers, enable_ai):
|
|||||||
)
|
)
|
||||||
conversation_id = start_resp.json()["conversation_id"]
|
conversation_id = start_resp.json()["conversation_id"]
|
||||||
|
|
||||||
# Mock Anthropic
|
# Mock AI provider
|
||||||
mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
mock_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
||||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=mock_provider):
|
||||||
mock_client.return_value.messages.create = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/v1/ai/scaffold",
|
"/api/v1/ai/scaffold",
|
||||||
json={"conversation_id": conversation_id},
|
json={"conversation_id": conversation_id},
|
||||||
@@ -241,9 +238,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
|
|||||||
)
|
)
|
||||||
conversation_id = start_resp.json()["conversation_id"]
|
conversation_id = start_resp.json()["conversation_id"]
|
||||||
|
|
||||||
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
||||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
|
||||||
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
|
|
||||||
await client.post(
|
await client.post(
|
||||||
"/api/v1/ai/scaffold",
|
"/api/v1/ai/scaffold",
|
||||||
json={"conversation_id": conversation_id},
|
json={"conversation_id": conversation_id},
|
||||||
@@ -251,10 +247,8 @@ async def test_branch_detail_success(client, auth_headers, enable_ai):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Now generate branch detail
|
# Now generate branch detail
|
||||||
detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON)
|
detail_provider = _mock_ai_provider(BRANCH_DETAIL_JSON)
|
||||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=detail_provider):
|
||||||
mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock)
|
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/v1/ai/branch-detail",
|
"/api/v1/ai/branch-detail",
|
||||||
json={
|
json={
|
||||||
@@ -290,9 +284,8 @@ async def test_assemble_success(client, auth_headers, enable_ai):
|
|||||||
conversation_id = start_resp.json()["conversation_id"]
|
conversation_id = start_resp.json()["conversation_id"]
|
||||||
|
|
||||||
# Scaffold
|
# Scaffold
|
||||||
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
|
scaffold_provider = _mock_ai_provider(SCAFFOLD_RESPONSE_JSON)
|
||||||
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
|
with patch("app.core.ai_tree_generator_service.get_ai_provider", return_value=scaffold_provider):
|
||||||
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
|
|
||||||
await client.post(
|
await client.post(
|
||||||
"/api/v1/ai/scaffold",
|
"/api/v1/ai/scaffold",
|
||||||
json={"conversation_id": conversation_id},
|
json={"conversation_id": conversation_id},
|
||||||
|
|||||||
169
backend/tests/test_ai_fix_endpoint.py
Normal file
169
backend/tests/test_ai_fix_endpoint.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""Integration tests for the POST /ai/fix-tree endpoint."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
# ── Sample tree (has a decision node with only 1 option + 1 child) ──
|
||||||
|
|
||||||
|
SAMPLE_TREE = {
|
||||||
|
"id": "root",
|
||||||
|
"type": "decision",
|
||||||
|
"question": "Is the server up?",
|
||||||
|
"options": [
|
||||||
|
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
|
||||||
|
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
|
||||||
|
],
|
||||||
|
"children": [
|
||||||
|
{
|
||||||
|
"id": "check-logs",
|
||||||
|
"type": "action",
|
||||||
|
"title": "Check Logs",
|
||||||
|
"description": "Review logs.",
|
||||||
|
"next_node_id": "logs-ok",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "logs-ok",
|
||||||
|
"type": "solution",
|
||||||
|
"title": "Logs OK",
|
||||||
|
"description": "Issue in logs.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "restart",
|
||||||
|
"type": "decision",
|
||||||
|
"question": "Did restart work?",
|
||||||
|
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
|
||||||
|
"children": [
|
||||||
|
{
|
||||||
|
"id": "done",
|
||||||
|
"type": "solution",
|
||||||
|
"title": "Done",
|
||||||
|
"description": "Fixed.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Fixed version of the "restart" node — 2 options, 2 children
|
||||||
|
FIXED_RESTART_NODE = {
|
||||||
|
"id": "restart",
|
||||||
|
"type": "decision",
|
||||||
|
"question": "Did restart work?",
|
||||||
|
"options": [
|
||||||
|
{"id": "opt-r-yes", "label": "Yes", "next_node_id": "done"},
|
||||||
|
{"id": "opt-r-no", "label": "No", "next_node_id": "escalate"},
|
||||||
|
],
|
||||||
|
"children": [
|
||||||
|
{
|
||||||
|
"id": "done",
|
||||||
|
"type": "solution",
|
||||||
|
"title": "Done",
|
||||||
|
"description": "Fixed.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "escalate",
|
||||||
|
"type": "solution",
|
||||||
|
"title": "Escalate",
|
||||||
|
"description": "Escalate to senior engineer.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
FIX_REQUEST_BODY = {
|
||||||
|
"tree_structure": SAMPLE_TREE,
|
||||||
|
"tree_name": "Server Troubleshooting",
|
||||||
|
"tree_type": "troubleshooting",
|
||||||
|
"validation_errors": [
|
||||||
|
{
|
||||||
|
"node_id": "restart",
|
||||||
|
"message": "Decision node 'restart' must have at least 2 options",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_ai_provider(response_text: str, input_tokens: int = 50, output_tokens: int = 100):
|
||||||
|
"""Create a mock provider whose generate_json returns given text."""
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.generate_json = AsyncMock(return_value=(response_text, input_tokens, output_tokens))
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def enable_ai():
|
||||||
|
"""Temporarily enable AI by setting a fake API key."""
|
||||||
|
original = settings.GOOGLE_AI_API_KEY
|
||||||
|
settings.GOOGLE_AI_API_KEY = "test-key-fake"
|
||||||
|
yield
|
||||||
|
settings.GOOGLE_AI_API_KEY = original
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def disable_ai():
|
||||||
|
"""Ensure AI is disabled."""
|
||||||
|
orig_google = settings.GOOGLE_AI_API_KEY
|
||||||
|
orig_anthropic = settings.ANTHROPIC_API_KEY
|
||||||
|
settings.GOOGLE_AI_API_KEY = None
|
||||||
|
settings.ANTHROPIC_API_KEY = None
|
||||||
|
yield
|
||||||
|
settings.GOOGLE_AI_API_KEY = orig_google
|
||||||
|
settings.ANTHROPIC_API_KEY = orig_anthropic
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ──
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_401_without_auth(client):
|
||||||
|
"""POST /ai/fix-tree without auth token returns 401."""
|
||||||
|
response = await client.post("/api/v1/ai/fix-tree", json=FIX_REQUEST_BODY)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_503_when_ai_disabled(client, auth_headers, disable_ai):
|
||||||
|
"""POST /ai/fix-tree returns 503 when no AI keys are configured."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/ai/fix-tree",
|
||||||
|
json=FIX_REQUEST_BODY,
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 503
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_fixes_on_success(client, auth_headers, enable_ai):
|
||||||
|
"""POST /ai/fix-tree returns fix proposals when AI succeeds."""
|
||||||
|
mock_provider = _mock_ai_provider(json.dumps(FIXED_RESTART_NODE))
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.core.ai_fix_service.get_ai_provider",
|
||||||
|
return_value=mock_provider,
|
||||||
|
):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/ai/fix-tree",
|
||||||
|
json=FIX_REQUEST_BODY,
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "fixes" in data
|
||||||
|
assert "tokens_used" in data
|
||||||
|
assert len(data["fixes"]) == 1
|
||||||
|
|
||||||
|
fix = data["fixes"][0]
|
||||||
|
assert fix["target_node_id"] == "restart"
|
||||||
|
assert fix["error_message"] == "Decision node 'restart' must have at least 2 options"
|
||||||
|
assert fix["original_node"]["id"] == "restart"
|
||||||
|
assert fix["fixed_node"]["id"] == "restart"
|
||||||
|
assert len(fix["fixed_node"]["options"]) == 2
|
||||||
|
assert len(fix["fixed_node"]["children"]) == 2
|
||||||
|
|
||||||
|
assert data["tokens_used"]["input"] == 50
|
||||||
|
assert data["tokens_used"]["output"] == 100
|
||||||
224
backend/tests/test_ai_fix_service.py
Normal file
224
backend/tests/test_ai_fix_service.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
"""Unit tests for AI fix service helper functions.
|
||||||
|
|
||||||
|
Tests pure Python helpers only — no AI mocking needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.ai_fix_service import (
|
||||||
|
_find_node_by_id,
|
||||||
|
_find_parent_node,
|
||||||
|
_serialize_tree_outline,
|
||||||
|
_strip_markdown_fences,
|
||||||
|
_replace_node_in_tree,
|
||||||
|
_describe_fix,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Sample tree ──
|
||||||
|
|
||||||
|
SAMPLE_TREE = {
|
||||||
|
"id": "root",
|
||||||
|
"type": "decision",
|
||||||
|
"question": "Is the server up?",
|
||||||
|
"options": [
|
||||||
|
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
|
||||||
|
{"id": "opt-no", "label": "No", "next_node_id": "restart"},
|
||||||
|
],
|
||||||
|
"children": [
|
||||||
|
{
|
||||||
|
"id": "check-logs",
|
||||||
|
"type": "action",
|
||||||
|
"title": "Check Logs",
|
||||||
|
"description": "Review logs.",
|
||||||
|
"next_node_id": "logs-ok",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "logs-ok",
|
||||||
|
"type": "solution",
|
||||||
|
"title": "Logs OK",
|
||||||
|
"description": "Issue in logs.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "restart",
|
||||||
|
"type": "decision",
|
||||||
|
"question": "Did restart work?",
|
||||||
|
"options": [{"id": "opt-r", "label": "Yes", "next_node_id": "done"}],
|
||||||
|
"children": [
|
||||||
|
{
|
||||||
|
"id": "done",
|
||||||
|
"type": "solution",
|
||||||
|
"title": "Done",
|
||||||
|
"description": "Fixed.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── _find_node_by_id ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindNodeById:
|
||||||
|
def test_finds_root(self):
|
||||||
|
node = _find_node_by_id(SAMPLE_TREE, "root")
|
||||||
|
assert node is not None
|
||||||
|
assert node["id"] == "root"
|
||||||
|
assert node["type"] == "decision"
|
||||||
|
|
||||||
|
def test_finds_nested_child(self):
|
||||||
|
node = _find_node_by_id(SAMPLE_TREE, "done")
|
||||||
|
assert node is not None
|
||||||
|
assert node["id"] == "done"
|
||||||
|
assert node["type"] == "solution"
|
||||||
|
|
||||||
|
def test_finds_direct_child(self):
|
||||||
|
node = _find_node_by_id(SAMPLE_TREE, "check-logs")
|
||||||
|
assert node is not None
|
||||||
|
assert node["title"] == "Check Logs"
|
||||||
|
|
||||||
|
def test_returns_none_for_missing(self):
|
||||||
|
node = _find_node_by_id(SAMPLE_TREE, "nonexistent")
|
||||||
|
assert node is None
|
||||||
|
|
||||||
|
def test_returns_none_for_non_dict(self):
|
||||||
|
assert _find_node_by_id("not a dict", "root") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── _find_parent_node ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindParentNode:
|
||||||
|
def test_root_has_no_parent(self):
|
||||||
|
parent = _find_parent_node(SAMPLE_TREE, "root")
|
||||||
|
assert parent is None
|
||||||
|
|
||||||
|
def test_finds_parent_of_direct_child(self):
|
||||||
|
parent = _find_parent_node(SAMPLE_TREE, "check-logs")
|
||||||
|
assert parent is not None
|
||||||
|
assert parent["id"] == "root"
|
||||||
|
|
||||||
|
def test_finds_parent_of_deeply_nested(self):
|
||||||
|
parent = _find_parent_node(SAMPLE_TREE, "done")
|
||||||
|
assert parent is not None
|
||||||
|
assert parent["id"] == "restart"
|
||||||
|
|
||||||
|
def test_returns_none_for_missing(self):
|
||||||
|
parent = _find_parent_node(SAMPLE_TREE, "nonexistent")
|
||||||
|
assert parent is None
|
||||||
|
|
||||||
|
def test_returns_none_for_non_dict(self):
|
||||||
|
assert _find_parent_node("not a dict", "root") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── _serialize_tree_outline ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestSerializeTreeOutline:
|
||||||
|
def test_produces_readable_outline(self):
|
||||||
|
outline = _serialize_tree_outline(SAMPLE_TREE)
|
||||||
|
assert "- [decision] Is the server up?" in outline
|
||||||
|
assert " - [action] Check Logs" in outline
|
||||||
|
assert " - [solution] Logs OK" in outline
|
||||||
|
assert " - [solution] Done" in outline
|
||||||
|
|
||||||
|
def test_marks_error_node(self):
|
||||||
|
outline = _serialize_tree_outline(SAMPLE_TREE, error_node_id="restart")
|
||||||
|
assert "<<< ERROR HERE" in outline
|
||||||
|
# Only the restart node should be marked
|
||||||
|
lines = outline.split("\n")
|
||||||
|
error_lines = [l for l in lines if "ERROR HERE" in l]
|
||||||
|
assert len(error_lines) == 1
|
||||||
|
assert "Did restart work?" in error_lines[0]
|
||||||
|
|
||||||
|
def test_no_error_marker_when_none(self):
|
||||||
|
outline = _serialize_tree_outline(SAMPLE_TREE)
|
||||||
|
assert "ERROR HERE" not in outline
|
||||||
|
|
||||||
|
def test_handles_non_dict(self):
|
||||||
|
assert _serialize_tree_outline("not a dict") == ""
|
||||||
|
|
||||||
|
def test_indentation_increases_with_depth(self):
|
||||||
|
outline = _serialize_tree_outline(SAMPLE_TREE)
|
||||||
|
lines = outline.split("\n")
|
||||||
|
# Root has no indentation
|
||||||
|
assert lines[0].startswith("- [decision]")
|
||||||
|
# Children have 2-space indent
|
||||||
|
child_lines = [l for l in lines if "Check Logs" in l]
|
||||||
|
assert child_lines[0].startswith(" - ")
|
||||||
|
|
||||||
|
|
||||||
|
# ── _strip_markdown_fences ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestStripMarkdownFences:
|
||||||
|
def test_strips_json_fences(self):
|
||||||
|
text = '```json\n{"key": "value"}\n```'
|
||||||
|
assert _strip_markdown_fences(text) == '{"key": "value"}'
|
||||||
|
|
||||||
|
def test_strips_plain_fences(self):
|
||||||
|
text = '```\n{"key": "value"}\n```'
|
||||||
|
assert _strip_markdown_fences(text) == '{"key": "value"}'
|
||||||
|
|
||||||
|
def test_passes_through_plain_json(self):
|
||||||
|
text = '{"key": "value"}'
|
||||||
|
assert _strip_markdown_fences(text) == '{"key": "value"}'
|
||||||
|
|
||||||
|
|
||||||
|
# ── _replace_node_in_tree ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplaceNodeInTree:
|
||||||
|
def test_replaces_root(self):
|
||||||
|
import copy
|
||||||
|
|
||||||
|
tree = copy.deepcopy(SAMPLE_TREE)
|
||||||
|
replacement = {"id": "root", "type": "decision", "question": "New question"}
|
||||||
|
assert _replace_node_in_tree(tree, "root", replacement) is True
|
||||||
|
assert tree["question"] == "New question"
|
||||||
|
assert "children" not in tree # cleared and replaced
|
||||||
|
|
||||||
|
def test_replaces_nested_node(self):
|
||||||
|
import copy
|
||||||
|
|
||||||
|
tree = copy.deepcopy(SAMPLE_TREE)
|
||||||
|
replacement = {"id": "done", "type": "solution", "title": "All Done", "description": "Complete."}
|
||||||
|
assert _replace_node_in_tree(tree, "done", replacement) is True
|
||||||
|
found = _find_node_by_id(tree, "done")
|
||||||
|
assert found["title"] == "All Done"
|
||||||
|
|
||||||
|
def test_returns_false_for_missing(self):
|
||||||
|
import copy
|
||||||
|
|
||||||
|
tree = copy.deepcopy(SAMPLE_TREE)
|
||||||
|
assert _replace_node_in_tree(tree, "nonexistent", {"id": "x"}) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── _describe_fix ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestDescribeFix:
|
||||||
|
def test_describes_added_children(self):
|
||||||
|
original = {"id": "n1", "children": [{"id": "c1"}]}
|
||||||
|
fixed = {"id": "n1", "children": [{"id": "c1"}, {"id": "c2"}]}
|
||||||
|
desc = _describe_fix(original, fixed)
|
||||||
|
assert "1 child node" in desc
|
||||||
|
|
||||||
|
def test_describes_added_options(self):
|
||||||
|
original = {"id": "n1", "options": [{"id": "o1"}]}
|
||||||
|
fixed = {"id": "n1", "options": [{"id": "o1"}, {"id": "o2"}]}
|
||||||
|
desc = _describe_fix(original, fixed)
|
||||||
|
assert "1 option" in desc
|
||||||
|
|
||||||
|
def test_describes_added_next_node_id(self):
|
||||||
|
original = {"id": "n1", "type": "action"}
|
||||||
|
fixed = {"id": "n1", "type": "action", "next_node_id": "n2"}
|
||||||
|
desc = _describe_fix(original, fixed)
|
||||||
|
assert "next_node_id" in desc
|
||||||
|
|
||||||
|
def test_fallback_description(self):
|
||||||
|
original = {"id": "n1", "type": "solution"}
|
||||||
|
fixed = {"id": "n1", "type": "solution"}
|
||||||
|
desc = _describe_fix(original, fixed)
|
||||||
|
assert "fixed structural issue" in desc.lower()
|
||||||
216
backend/tests/test_ai_provider.py
Normal file
216
backend/tests/test_ai_provider.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""Tests for the AI provider abstraction layer."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from app.core.ai_provider import (
|
||||||
|
AIProvider,
|
||||||
|
AnthropicProvider,
|
||||||
|
GeminiProvider,
|
||||||
|
get_ai_provider,
|
||||||
|
)
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAIProvider:
|
||||||
|
"""Tests for the get_ai_provider factory function."""
|
||||||
|
|
||||||
|
def test_returns_gemini_when_configured(self):
|
||||||
|
original_provider = settings.AI_PROVIDER
|
||||||
|
original_key = settings.GOOGLE_AI_API_KEY
|
||||||
|
try:
|
||||||
|
settings.AI_PROVIDER = "gemini"
|
||||||
|
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
|
||||||
|
provider = get_ai_provider()
|
||||||
|
assert isinstance(provider, GeminiProvider)
|
||||||
|
finally:
|
||||||
|
settings.AI_PROVIDER = original_provider
|
||||||
|
settings.GOOGLE_AI_API_KEY = original_key
|
||||||
|
|
||||||
|
def test_returns_anthropic_when_configured(self):
|
||||||
|
original_provider = settings.AI_PROVIDER
|
||||||
|
original_key = settings.ANTHROPIC_API_KEY
|
||||||
|
try:
|
||||||
|
settings.AI_PROVIDER = "anthropic"
|
||||||
|
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
|
||||||
|
provider = get_ai_provider()
|
||||||
|
assert isinstance(provider, AnthropicProvider)
|
||||||
|
finally:
|
||||||
|
settings.AI_PROVIDER = original_provider
|
||||||
|
settings.ANTHROPIC_API_KEY = original_key
|
||||||
|
|
||||||
|
def test_fallback_to_anthropic_when_gemini_key_missing(self):
|
||||||
|
original_provider = settings.AI_PROVIDER
|
||||||
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||||
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||||
|
try:
|
||||||
|
settings.AI_PROVIDER = "gemini"
|
||||||
|
settings.GOOGLE_AI_API_KEY = None
|
||||||
|
settings.ANTHROPIC_API_KEY = "test-anthropic-key"
|
||||||
|
provider = get_ai_provider()
|
||||||
|
assert isinstance(provider, AnthropicProvider)
|
||||||
|
finally:
|
||||||
|
settings.AI_PROVIDER = original_provider
|
||||||
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||||
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||||
|
|
||||||
|
def test_fallback_to_gemini_when_anthropic_key_missing(self):
|
||||||
|
original_provider = settings.AI_PROVIDER
|
||||||
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||||
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||||
|
try:
|
||||||
|
settings.AI_PROVIDER = "anthropic"
|
||||||
|
settings.ANTHROPIC_API_KEY = None
|
||||||
|
settings.GOOGLE_AI_API_KEY = "test-gemini-key"
|
||||||
|
provider = get_ai_provider()
|
||||||
|
assert isinstance(provider, GeminiProvider)
|
||||||
|
finally:
|
||||||
|
settings.AI_PROVIDER = original_provider
|
||||||
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||||
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||||
|
|
||||||
|
def test_raises_when_no_provider_configured(self):
|
||||||
|
original_provider = settings.AI_PROVIDER
|
||||||
|
original_gemini_key = settings.GOOGLE_AI_API_KEY
|
||||||
|
original_anthropic_key = settings.ANTHROPIC_API_KEY
|
||||||
|
try:
|
||||||
|
settings.AI_PROVIDER = "gemini"
|
||||||
|
settings.GOOGLE_AI_API_KEY = None
|
||||||
|
settings.ANTHROPIC_API_KEY = None
|
||||||
|
with pytest.raises(RuntimeError, match="No AI provider configured"):
|
||||||
|
get_ai_provider()
|
||||||
|
finally:
|
||||||
|
settings.AI_PROVIDER = original_provider
|
||||||
|
settings.GOOGLE_AI_API_KEY = original_gemini_key
|
||||||
|
settings.ANTHROPIC_API_KEY = original_anthropic_key
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicProvider:
|
||||||
|
"""Tests for AnthropicProvider.generate_json."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_json(self):
|
||||||
|
provider = AnthropicProvider(
|
||||||
|
api_key="test-key", model="claude-haiku-4-5-20251001", timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = [MagicMock(text='{"result": "ok"}')]
|
||||||
|
mock_response.usage = MagicMock(input_tokens=100, output_tokens=50)
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
||||||
|
text, input_tokens, output_tokens = await provider.generate_json(
|
||||||
|
system_prompt="You are a helper.",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
max_tokens=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert text == '{"result": "ok"}'
|
||||||
|
assert input_tokens == 100
|
||||||
|
assert output_tokens == 50
|
||||||
|
|
||||||
|
mock_client.messages.create.assert_called_once_with(
|
||||||
|
model="claude-haiku-4-5-20251001",
|
||||||
|
max_tokens=1024,
|
||||||
|
system="You are a helper.",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeminiProvider:
|
||||||
|
"""Tests for GeminiProvider.generate_json."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_json(self):
|
||||||
|
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
||||||
|
|
||||||
|
mock_usage = MagicMock()
|
||||||
|
mock_usage.prompt_token_count = 80
|
||||||
|
mock_usage.candidates_token_count = 40
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = '{"answer": 42}'
|
||||||
|
mock_response.usage_metadata = mock_usage
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.aio.models.generate_content = AsyncMock(
|
||||||
|
return_value=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_genai_module = MagicMock()
|
||||||
|
mock_genai_module.Client.return_value = mock_client
|
||||||
|
|
||||||
|
mock_types = MagicMock()
|
||||||
|
mock_types.Content.side_effect = lambda **kw: kw
|
||||||
|
mock_types.Part.side_effect = lambda **kw: kw
|
||||||
|
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
||||||
|
|
||||||
|
mock_google = MagicMock()
|
||||||
|
mock_google.genai = mock_genai_module
|
||||||
|
mock_genai_module.types = mock_types
|
||||||
|
|
||||||
|
with patch.dict(sys.modules, {
|
||||||
|
"google": mock_google,
|
||||||
|
"google.genai": mock_genai_module,
|
||||||
|
"google.genai.types": mock_types,
|
||||||
|
}):
|
||||||
|
text, input_tokens, output_tokens = await provider.generate_json(
|
||||||
|
system_prompt="Generate JSON.",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Give me data"},
|
||||||
|
{"role": "assistant", "content": "Here it is"},
|
||||||
|
{"role": "user", "content": "More please"},
|
||||||
|
],
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert text == '{"answer": 42}'
|
||||||
|
assert input_tokens == 80
|
||||||
|
assert output_tokens == 40
|
||||||
|
|
||||||
|
mock_client.aio.models.generate_content.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_json_handles_none_usage(self):
|
||||||
|
"""Token counts default to 0 when usage_metadata attributes are None."""
|
||||||
|
provider = GeminiProvider(api_key="test-key", model="gemini-2.5-flash")
|
||||||
|
|
||||||
|
mock_usage = MagicMock(spec=[]) # No attributes at all
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = "{}"
|
||||||
|
mock_response.usage_metadata = mock_usage
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.aio.models.generate_content = AsyncMock(
|
||||||
|
return_value=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_genai_module = MagicMock()
|
||||||
|
mock_genai_module.Client.return_value = mock_client
|
||||||
|
|
||||||
|
mock_types = MagicMock()
|
||||||
|
mock_types.Content.side_effect = lambda **kw: kw
|
||||||
|
mock_types.Part.side_effect = lambda **kw: kw
|
||||||
|
mock_types.GenerateContentConfig.side_effect = lambda **kw: kw
|
||||||
|
|
||||||
|
mock_google = MagicMock()
|
||||||
|
mock_google.genai = mock_genai_module
|
||||||
|
mock_genai_module.types = mock_types
|
||||||
|
|
||||||
|
with patch.dict(sys.modules, {
|
||||||
|
"google": mock_google,
|
||||||
|
"google.genai": mock_genai_module,
|
||||||
|
"google.genai.types": mock_types,
|
||||||
|
}):
|
||||||
|
text, input_tokens, output_tokens = await provider.generate_json(
|
||||||
|
system_prompt="test",
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert text == "{}"
|
||||||
|
assert input_tokens == 0
|
||||||
|
assert output_tokens == 0
|
||||||
209
docs/plans/2026-02-26-ai-autofix-gemini-design.md
Normal file
209
docs/plans/2026-02-26-ai-autofix-gemini-design.md
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
# AI Auto-Fix & Gemini Flash Provider Design
|
||||||
|
|
||||||
|
> **Date:** 2026-02-26
|
||||||
|
> **Status:** Approved
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Two combined features:
|
||||||
|
|
||||||
|
1. **AI Provider Abstraction** — Add Gemini 2.5 Flash as the default AI provider with Claude as fallback, behind a unified interface.
|
||||||
|
2. **AI Auto-Fix for Validation Errors** — When a flow fails validation, offer an AI-powered "Fix with AI" button that generates structural fixes for review.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Section 1: AI Provider Abstraction
|
||||||
|
|
||||||
|
### Design
|
||||||
|
|
||||||
|
New `backend/app/core/ai_provider.py` with a unified interface:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AIProvider(ABC):
|
||||||
|
async def generate_json(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
messages: list[dict],
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> tuple[str, int, int]:
|
||||||
|
"""Returns (text, input_tokens, output_tokens)"""
|
||||||
|
```
|
||||||
|
|
||||||
|
Two implementations:
|
||||||
|
|
||||||
|
| Provider | Model | SDK | Role |
|
||||||
|
|----------|-------|-----|------|
|
||||||
|
| `GeminiProvider` | `gemini-2.5-flash` | `google-genai` | Default |
|
||||||
|
| `AnthropicProvider` | `claude-haiku-4-5-20251001` | `anthropic` | Fallback |
|
||||||
|
|
||||||
|
### Provider Selection
|
||||||
|
|
||||||
|
- `get_ai_provider()` factory reads `AI_PROVIDER` env var (default: `"gemini"`)
|
||||||
|
- Falls back to Anthropic if Gemini key is missing
|
||||||
|
- Existing `ai_tree_generator_service.py` swaps direct Anthropic calls for `get_ai_provider()`
|
||||||
|
|
||||||
|
### New Environment Variables
|
||||||
|
|
||||||
|
| Variable | Default | Purpose |
|
||||||
|
|----------|---------|---------|
|
||||||
|
| `AI_PROVIDER` | `"gemini"` | Which provider to use (`gemini` or `anthropic`) |
|
||||||
|
| `GOOGLE_AI_API_KEY` | — | Gemini API key |
|
||||||
|
|
||||||
|
Existing `ANTHROPIC_API_KEY` remains for fallback.
|
||||||
|
|
||||||
|
### Config Changes (`core/config.py`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
AI_PROVIDER: str = "gemini"
|
||||||
|
GOOGLE_AI_API_KEY: str | None = None
|
||||||
|
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
|
||||||
|
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Section 2: AI Auto-Fix Feature
|
||||||
|
|
||||||
|
### Backend Endpoint
|
||||||
|
|
||||||
|
**`POST /api/v1/ai/fix-tree`**
|
||||||
|
|
||||||
|
Request:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tree_structure": { /* full tree */ },
|
||||||
|
"tree_name": "Router Troubleshooting",
|
||||||
|
"tree_type": "troubleshooting",
|
||||||
|
"validation_errors": [
|
||||||
|
{
|
||||||
|
"node_id": "node_abc",
|
||||||
|
"message": "Decision node must have at least 2 children (branches)"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Response:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"fixes": [
|
||||||
|
{
|
||||||
|
"target_node_id": "node_abc",
|
||||||
|
"error_message": "Decision node must have at least 2 children (branches)",
|
||||||
|
"description": "Added second branch 'Check firmware version' with solution node",
|
||||||
|
"original_node": { /* snapshot before fix */ },
|
||||||
|
"fixed_node": { /* replacement node with corrected subtree */ }
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tokens_used": { "input": 1200, "output": 800 }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. For each validation error tied to a `node_id`, extract that node + its parent + siblings from the tree.
|
||||||
|
2. Build a prompt with:
|
||||||
|
- The **full tree structure** serialized as a simplified outline (node titles + types + structure) for context
|
||||||
|
- The **specific failing node** highlighted with full JSON detail
|
||||||
|
- The **validation error message**
|
||||||
|
- Instructions: "Fix ONLY this node's structural issue. Keep all existing content. Generate domain-relevant additions that fit the flow's topic."
|
||||||
|
3. AI returns a corrected version of that node (with children/options adjusted).
|
||||||
|
4. Backend re-validates the fixed node before returning it.
|
||||||
|
5. If re-validation fails, retry once with the error fed back (corrective prompt pattern).
|
||||||
|
|
||||||
|
### Prompt Strategy
|
||||||
|
|
||||||
|
The prompt gives the AI the full tree as a compact outline, then zooms into the failing node:
|
||||||
|
|
||||||
|
```
|
||||||
|
You are fixing a validation error in a troubleshooting flow called "Router Troubleshooting".
|
||||||
|
|
||||||
|
FULL FLOW OUTLINE:
|
||||||
|
- [decision] Is the router powered on?
|
||||||
|
- [action] Check power cable → [solution] Power restored
|
||||||
|
- [decision] Are lights blinking? ← ERROR HERE
|
||||||
|
- [solution] Contact ISP
|
||||||
|
|
||||||
|
ERROR: Decision node "Are lights blinking?" must have at least 2 children (branches).
|
||||||
|
|
||||||
|
FAILING NODE (full detail):
|
||||||
|
{...json...}
|
||||||
|
|
||||||
|
Fix this node by adding the minimum structure needed to resolve the error.
|
||||||
|
Return ONLY the fixed node as JSON.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend UX
|
||||||
|
|
||||||
|
1. **Trigger**: "Fix with AI" button in `ValidationSummary` — appears when there are fixable errors (structural errors with a `node_id`).
|
||||||
|
2. **Loading state**: Button shows spinner + "Generating fixes..." — disabled during request.
|
||||||
|
3. **Review modal** (`AIFixReviewModal`): Shows each proposed fix as a card:
|
||||||
|
- Error message at top
|
||||||
|
- Before/after view of the node change
|
||||||
|
- "Apply" / "Skip" buttons per fix
|
||||||
|
- "Apply All" button in footer
|
||||||
|
4. **Apply**: Each accepted fix calls `updateNode(targetNodeId, fixedNode)` in the tree editor store.
|
||||||
|
5. **Re-validate**: After applying fixes, auto-run `validate()` to confirm resolution.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Section 3: Scope & Constraints
|
||||||
|
|
||||||
|
### Fixable Errors (Auto-Fix Scope)
|
||||||
|
|
||||||
|
Only structural validation errors with a `node_id`:
|
||||||
|
- Decision node missing children/branches
|
||||||
|
- Decision node missing options
|
||||||
|
- Action node missing `next_node_id`
|
||||||
|
- Dead-end decision nodes (no children)
|
||||||
|
|
||||||
|
### NOT Fixable
|
||||||
|
|
||||||
|
- Global checks (tree too small/large, not enough solutions) — require rethinking the whole tree
|
||||||
|
- Content quality issues — out of scope
|
||||||
|
- Errors without a `node_id` (root-level issues)
|
||||||
|
|
||||||
|
Non-fixable errors still show in ValidationSummary but without the "Fix with AI" option.
|
||||||
|
|
||||||
|
### Token Budget
|
||||||
|
|
||||||
|
- Tree outline: ~50-100 tokens for a typical 15-node tree
|
||||||
|
- Failing node detail: ~100-200 tokens
|
||||||
|
- System prompt + instructions: ~300 tokens
|
||||||
|
- **Total input per fix: ~500-600 tokens**
|
||||||
|
- One API call per failing node (not batched)
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
|
||||||
|
- Provider failure (rate limit, network): toast error, user can retry
|
||||||
|
- Fix fails re-validation: "AI couldn't generate a valid fix" with retry option
|
||||||
|
- Max 1 retry with corrective prompt per attempt
|
||||||
|
- Both provider and fallback fail: surface error to user
|
||||||
|
|
||||||
|
### Auth
|
||||||
|
|
||||||
|
- Requires `engineer` role or above (`require_engineer_or_admin`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## New Files
|
||||||
|
|
||||||
|
| File | Purpose |
|
||||||
|
|------|---------|
|
||||||
|
| `backend/app/core/ai_provider.py` | Provider abstraction + Gemini/Anthropic implementations |
|
||||||
|
| `backend/app/core/ai_fix_service.py` | Fix generation logic + prompt building |
|
||||||
|
| `backend/app/api/endpoints/ai.py` | `POST /ai/fix-tree` endpoint |
|
||||||
|
| `backend/app/schemas/ai.py` | Request/response schemas for AI endpoints |
|
||||||
|
| `frontend/src/components/tree-editor/AIFixReviewModal.tsx` | Review modal for proposed fixes |
|
||||||
|
|
||||||
|
## Modified Files
|
||||||
|
|
||||||
|
| File | Change |
|
||||||
|
|------|--------|
|
||||||
|
| `backend/app/core/config.py` | Add Gemini config vars |
|
||||||
|
| `backend/app/core/ai_tree_generator_service.py` | Swap Anthropic calls for provider abstraction |
|
||||||
|
| `backend/app/api/router.py` | Register `/ai` routes |
|
||||||
|
| `frontend/src/api/trees.ts` | Add `fixTree()` API call |
|
||||||
|
| `frontend/src/components/tree-editor/ValidationSummary.tsx` | Add "Fix with AI" button |
|
||||||
1707
docs/plans/2026-02-26-ai-autofix-gemini-plan.md
Normal file
1707
docs/plans/2026-02-26-ai-autofix-gemini-plan.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
|||||||
import apiClient from './client'
|
import apiClient from './client'
|
||||||
import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse } from '@/types'
|
import type { Tree, TreeListItem, TreeCreate, TreeUpdate, TreeFilters, TreeShareCreate, TreeShare, TreeVisibilityUpdate, TreeValidationResponse, AIFixTreeRequest, AIFixTreeResponse } from '@/types'
|
||||||
|
|
||||||
export const treesApi = {
|
export const treesApi = {
|
||||||
async list(params?: TreeFilters): Promise<TreeListItem[]> {
|
async list(params?: TreeFilters): Promise<TreeListItem[]> {
|
||||||
@@ -65,6 +65,12 @@ export const treesApi = {
|
|||||||
const response = await apiClient.post<TreeValidationResponse>(`/trees/${id}/can-publish`)
|
const response = await apiClient.post<TreeValidationResponse>(`/trees/${id}/can-publish`)
|
||||||
return response.data
|
return response.data
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// AI auto-fix
|
||||||
|
async fixTree(request: AIFixTreeRequest): Promise<AIFixTreeResponse> {
|
||||||
|
const response = await apiClient.post<AIFixTreeResponse>('/ai/fix-tree', request)
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
export default treesApi
|
export default treesApi
|
||||||
|
|||||||
170
frontend/src/components/tree-editor/AIFixReviewModal.tsx
Normal file
170
frontend/src/components/tree-editor/AIFixReviewModal.tsx
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import { useState } from 'react'
|
||||||
|
import { X, Check, SkipForward, Sparkles, ChevronDown, ChevronUp } from 'lucide-react'
|
||||||
|
import { cn } from '@/lib/utils'
|
||||||
|
import type { AIFixProposal } from '@/types'
|
||||||
|
|
||||||
|
interface AIFixReviewModalProps {
|
||||||
|
fixes: AIFixProposal[]
|
||||||
|
onApply: (fix: AIFixProposal) => void
|
||||||
|
onApplyAll: () => void
|
||||||
|
onClose: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export function AIFixReviewModal({ fixes, onApply, onApplyAll, onClose }: AIFixReviewModalProps) {
|
||||||
|
const [appliedIds, setAppliedIds] = useState<Set<string>>(new Set())
|
||||||
|
const [skippedIds, setSkippedIds] = useState<Set<string>>(new Set())
|
||||||
|
const [expandedIds, setExpandedIds] = useState<Set<string>>(new Set(fixes.map(f => f.target_node_id)))
|
||||||
|
|
||||||
|
const handleApply = (fix: AIFixProposal) => {
|
||||||
|
onApply(fix)
|
||||||
|
setAppliedIds(prev => new Set(prev).add(fix.target_node_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleSkip = (fix: AIFixProposal) => {
|
||||||
|
setSkippedIds(prev => new Set(prev).add(fix.target_node_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
const toggleExpanded = (id: string) => {
|
||||||
|
setExpandedIds(prev => {
|
||||||
|
const next = new Set(prev)
|
||||||
|
if (next.has(id)) next.delete(id)
|
||||||
|
else next.add(id)
|
||||||
|
return next
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const pendingFixes = fixes.filter(
|
||||||
|
f => !appliedIds.has(f.target_node_id) && !skippedIds.has(f.target_node_id)
|
||||||
|
)
|
||||||
|
const allHandled = pendingFixes.length === 0
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/80 backdrop-blur-sm p-4">
|
||||||
|
<div className="relative flex h-[80vh] w-full max-w-2xl flex-col bg-card border border-border rounded-2xl shadow-lg">
|
||||||
|
{/* Header */}
|
||||||
|
<div className="flex items-center justify-between border-b border-border px-6 py-4">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<Sparkles className="h-5 w-5 text-primary" />
|
||||||
|
<h2 className="text-lg font-semibold text-foreground">
|
||||||
|
AI Fix Proposals ({fixes.length})
|
||||||
|
</h2>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={onClose}
|
||||||
|
className="rounded-md p-1 text-muted-foreground hover:bg-accent hover:text-foreground"
|
||||||
|
>
|
||||||
|
<X className="h-5 w-5" />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Body */}
|
||||||
|
<div className="flex-1 overflow-y-auto p-4 space-y-3">
|
||||||
|
{fixes.map((fix) => {
|
||||||
|
const isApplied = appliedIds.has(fix.target_node_id)
|
||||||
|
const isSkipped = skippedIds.has(fix.target_node_id)
|
||||||
|
const isExpanded = expandedIds.has(fix.target_node_id)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={fix.target_node_id}
|
||||||
|
className={cn(
|
||||||
|
'rounded-lg border p-4',
|
||||||
|
isApplied
|
||||||
|
? 'border-emerald-400/30 bg-emerald-400/5'
|
||||||
|
: isSkipped
|
||||||
|
? 'border-border bg-accent/30 opacity-60'
|
||||||
|
: 'border-border bg-card'
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{/* Fix header */}
|
||||||
|
<div className="flex items-start justify-between gap-3">
|
||||||
|
<div className="flex-1">
|
||||||
|
<p className="text-sm text-red-400 mb-1">{fix.error_message}</p>
|
||||||
|
<p className="text-sm text-foreground">{fix.description}</p>
|
||||||
|
<p className="text-xs text-muted-foreground mt-1">
|
||||||
|
Node: {fix.target_node_id}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
{isApplied && (
|
||||||
|
<span className="flex items-center gap-1 rounded-full bg-emerald-400/10 px-2 py-1 text-xs text-emerald-400">
|
||||||
|
<Check className="h-3 w-3" /> Applied
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
{isSkipped && (
|
||||||
|
<span className="text-xs text-muted-foreground">Skipped</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Expand/collapse detail */}
|
||||||
|
{!isApplied && !isSkipped && (
|
||||||
|
<>
|
||||||
|
<button
|
||||||
|
onClick={() => toggleExpanded(fix.target_node_id)}
|
||||||
|
className="mt-2 flex items-center gap-1 text-xs text-muted-foreground hover:text-foreground"
|
||||||
|
>
|
||||||
|
{isExpanded ? <ChevronUp className="h-3 w-3" /> : <ChevronDown className="h-3 w-3" />}
|
||||||
|
{isExpanded ? 'Hide' : 'Show'} details
|
||||||
|
</button>
|
||||||
|
|
||||||
|
{isExpanded && (
|
||||||
|
<div className="mt-3 grid grid-cols-2 gap-3">
|
||||||
|
<div>
|
||||||
|
<p className="text-xs font-medium text-muted-foreground mb-1">Before</p>
|
||||||
|
<pre className="overflow-x-auto rounded bg-accent/50 p-2 text-xs text-muted-foreground max-h-48 overflow-y-auto">
|
||||||
|
{JSON.stringify(fix.original_node, null, 2)}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<p className="text-xs font-medium text-emerald-400 mb-1">After</p>
|
||||||
|
<pre className="overflow-x-auto rounded bg-emerald-400/5 p-2 text-xs text-foreground max-h-48 overflow-y-auto">
|
||||||
|
{JSON.stringify(fix.fixed_node, null, 2)}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Action buttons */}
|
||||||
|
<div className="mt-3 flex gap-2">
|
||||||
|
<button
|
||||||
|
onClick={() => handleApply(fix)}
|
||||||
|
className="flex items-center gap-1 rounded-md bg-gradient-brand px-3 py-1.5 text-xs font-medium text-white shadow-sm shadow-primary/20 hover:opacity-90"
|
||||||
|
>
|
||||||
|
<Check className="h-3 w-3" />
|
||||||
|
Apply
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => handleSkip(fix)}
|
||||||
|
className="flex items-center gap-1 rounded-md border border-border px-3 py-1.5 text-xs font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
|
||||||
|
>
|
||||||
|
<SkipForward className="h-3 w-3" />
|
||||||
|
Skip
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Footer */}
|
||||||
|
<div className="flex items-center justify-between border-t border-border px-6 py-4">
|
||||||
|
<button
|
||||||
|
onClick={onClose}
|
||||||
|
className="rounded-md border border-border px-4 py-2 text-sm font-medium text-muted-foreground hover:bg-accent hover:text-foreground"
|
||||||
|
>
|
||||||
|
{allHandled ? 'Done' : 'Cancel'}
|
||||||
|
</button>
|
||||||
|
{!allHandled && (
|
||||||
|
<button
|
||||||
|
onClick={onApplyAll}
|
||||||
|
className="rounded-md bg-gradient-brand px-4 py-2 text-sm font-medium text-white shadow-lg shadow-primary/20 hover:opacity-90"
|
||||||
|
>
|
||||||
|
Apply All ({pendingFixes.length})
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -1,14 +1,16 @@
|
|||||||
import { useState } from 'react'
|
import { useState } from 'react'
|
||||||
import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp } from 'lucide-react'
|
import { AlertCircle, AlertTriangle, ChevronDown, ChevronUp, Sparkles, Loader2 } from 'lucide-react'
|
||||||
import { cn } from '@/lib/utils'
|
import { cn } from '@/lib/utils'
|
||||||
import type { ValidationError } from '@/store/treeEditorStore'
|
import type { ValidationError } from '@/store/treeEditorStore'
|
||||||
|
|
||||||
interface ValidationSummaryProps {
|
interface ValidationSummaryProps {
|
||||||
errors: ValidationError[]
|
errors: ValidationError[]
|
||||||
onSelectNode: (nodeId: string) => void
|
onSelectNode: (nodeId: string) => void
|
||||||
|
onFixWithAI?: () => void
|
||||||
|
isFixing?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryProps) {
|
export function ValidationSummary({ errors, onSelectNode, onFixWithAI, isFixing }: ValidationSummaryProps) {
|
||||||
const [isExpanded, setIsExpanded] = useState(true)
|
const [isExpanded, setIsExpanded] = useState(true)
|
||||||
|
|
||||||
const errorItems = errors.filter(e => e.severity === 'error')
|
const errorItems = errors.filter(e => e.severity === 'error')
|
||||||
@@ -22,6 +24,8 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const hasFixableErrors = errorItems.some(e => e.nodeId)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
@@ -32,14 +36,16 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
|
|||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{/* Header */}
|
{/* Header */}
|
||||||
<button
|
<div
|
||||||
onClick={() => setIsExpanded(!isExpanded)}
|
|
||||||
className={cn(
|
className={cn(
|
||||||
'flex w-full items-center justify-between p-3 text-left transition-colors hover:bg-accent',
|
'flex w-full items-center justify-between p-3 transition-colors',
|
||||||
errorItems.length > 0 ? 'text-red-400' : 'text-yellow-400'
|
errorItems.length > 0 ? 'text-red-400' : 'text-yellow-400'
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<div className="flex items-center gap-2">
|
<button
|
||||||
|
onClick={() => setIsExpanded(!isExpanded)}
|
||||||
|
className="flex items-center gap-2 text-left hover:opacity-80"
|
||||||
|
>
|
||||||
{errorItems.length > 0 ? (
|
{errorItems.length > 0 ? (
|
||||||
<AlertCircle className="h-5 w-5" />
|
<AlertCircle className="h-5 w-5" />
|
||||||
) : (
|
) : (
|
||||||
@@ -58,9 +64,35 @@ export function ValidationSummary({ errors, onSelectNode }: ValidationSummaryPro
|
|||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
{isExpanded ? <ChevronUp className="h-4 w-4" /> : <ChevronDown className="h-4 w-4" />}
|
||||||
{isExpanded ? <ChevronUp className="h-4 w-4" /> : <ChevronDown className="h-4 w-4" />}
|
</button>
|
||||||
</button>
|
|
||||||
|
{/* Fix with AI button */}
|
||||||
|
{onFixWithAI && hasFixableErrors && (
|
||||||
|
<button
|
||||||
|
onClick={onFixWithAI}
|
||||||
|
disabled={isFixing}
|
||||||
|
className={cn(
|
||||||
|
'flex items-center gap-1.5 rounded-md px-3 py-1 text-xs font-medium transition-colors',
|
||||||
|
isFixing
|
||||||
|
? 'bg-primary/10 text-primary cursor-wait'
|
||||||
|
: 'bg-gradient-brand text-white shadow-sm shadow-primary/20 hover:opacity-90'
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isFixing ? (
|
||||||
|
<>
|
||||||
|
<Loader2 className="h-3 w-3 animate-spin" />
|
||||||
|
Generating fixes...
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<Sparkles className="h-3 w-3" />
|
||||||
|
Fix with AI
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
{/* Error/Warning List */}
|
{/* Error/Warning List */}
|
||||||
{isExpanded && (
|
{isExpanded && (
|
||||||
|
|||||||
@@ -5,10 +5,11 @@ import { Undo2, Redo2, Save, CheckCircle2, Monitor, FileText, Code2, LayoutList,
|
|||||||
import { getMonacoEditor } from '@/components/tree-editor/code-mode'
|
import { getMonacoEditor } from '@/components/tree-editor/code-mode'
|
||||||
import { treesApi } from '@/api/trees'
|
import { treesApi } from '@/api/trees'
|
||||||
import { treeMarkdownApi } from '@/api/treeMarkdown'
|
import { treeMarkdownApi } from '@/api/treeMarkdown'
|
||||||
import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure } from '@/types'
|
import type { TreeCreate, TreeUpdate, TreeStatus, TreeStructure, AIFixProposal } from '@/types'
|
||||||
import { useTreeEditorStore, useTreeEditorTemporal } from '@/store/treeEditorStore'
|
import { useTreeEditorStore, useTreeEditorTemporal } from '@/store/treeEditorStore'
|
||||||
import { TreeEditorLayout } from '@/components/tree-editor/TreeEditorLayout'
|
import { TreeEditorLayout } from '@/components/tree-editor/TreeEditorLayout'
|
||||||
import { ValidationSummary } from '@/components/tree-editor/ValidationSummary'
|
import { ValidationSummary } from '@/components/tree-editor/ValidationSummary'
|
||||||
|
import { AIFixReviewModal } from '@/components/tree-editor/AIFixReviewModal'
|
||||||
import { useKeyboardShortcuts } from '@/hooks/useKeyboardShortcuts'
|
import { useKeyboardShortcuts } from '@/hooks/useKeyboardShortcuts'
|
||||||
import { usePermissions } from '@/hooks/usePermissions'
|
import { usePermissions } from '@/hooks/usePermissions'
|
||||||
import { Spinner } from '@/components/common/Spinner'
|
import { Spinner } from '@/components/common/Spinner'
|
||||||
@@ -58,6 +59,8 @@ export function TreeEditorPage() {
|
|||||||
const [showAnalytics, setShowAnalytics] = useState(false)
|
const [showAnalytics, setShowAnalytics] = useState(false)
|
||||||
const [isMetadataOpen, setIsMetadataOpen] = useState(false)
|
const [isMetadataOpen, setIsMetadataOpen] = useState(false)
|
||||||
const [editingNodeId, setEditingNodeId] = useState<string | null>(null)
|
const [editingNodeId, setEditingNodeId] = useState<string | null>(null)
|
||||||
|
const [isFixing, setIsFixing] = useState(false)
|
||||||
|
const [fixProposals, setFixProposals] = useState<AIFixProposal[] | null>(null)
|
||||||
|
|
||||||
// Mobile detection
|
// Mobile detection
|
||||||
const [isMobile, setIsMobile] = useState(false)
|
const [isMobile, setIsMobile] = useState(false)
|
||||||
@@ -217,6 +220,54 @@ export function TreeEditorPage() {
|
|||||||
selectNode(nodeId)
|
selectNode(nodeId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleFixWithAI = async () => {
|
||||||
|
const store = useTreeEditorStore.getState()
|
||||||
|
if (!store.treeStructure) return
|
||||||
|
|
||||||
|
const fixableErrors = store.validationErrors
|
||||||
|
.filter(e => e.severity === 'error' && e.nodeId)
|
||||||
|
.map(e => ({ node_id: e.nodeId!, message: e.message }))
|
||||||
|
|
||||||
|
if (fixableErrors.length === 0) return
|
||||||
|
|
||||||
|
setIsFixing(true)
|
||||||
|
try {
|
||||||
|
const result = await treesApi.fixTree({
|
||||||
|
tree_structure: store.treeStructure as unknown as Record<string, unknown>,
|
||||||
|
tree_name: store.name,
|
||||||
|
tree_type: 'troubleshooting',
|
||||||
|
validation_errors: fixableErrors,
|
||||||
|
})
|
||||||
|
if (result.fixes.length > 0) {
|
||||||
|
setFixProposals(result.fixes)
|
||||||
|
} else {
|
||||||
|
toast.info('AI could not generate fixes for these errors')
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
toast.error('Failed to generate AI fixes. Please try again.')
|
||||||
|
} finally {
|
||||||
|
setIsFixing(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleApplyFix = (fix: AIFixProposal) => {
|
||||||
|
updateNode(fix.target_node_id, fix.fixed_node as Partial<TreeStructure>)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleApplyAllFixes = () => {
|
||||||
|
if (!fixProposals) return
|
||||||
|
for (const fix of fixProposals) {
|
||||||
|
handleApplyFix(fix)
|
||||||
|
}
|
||||||
|
setFixProposals(null)
|
||||||
|
setTimeout(() => { validate() }, 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleCloseFixModal = () => {
|
||||||
|
setFixProposals(null)
|
||||||
|
validate()
|
||||||
|
}
|
||||||
|
|
||||||
const handleNodeSelect = useCallback((nodeId: string | null) => {
|
const handleNodeSelect = useCallback((nodeId: string | null) => {
|
||||||
if (nodeId) {
|
if (nodeId) {
|
||||||
setIsMetadataOpen(false) // close metadata when opening node editor
|
setIsMetadataOpen(false) // close metadata when opening node editor
|
||||||
@@ -685,6 +736,8 @@ export function TreeEditorPage() {
|
|||||||
<ValidationSummary
|
<ValidationSummary
|
||||||
errors={validationErrors}
|
errors={validationErrors}
|
||||||
onSelectNode={handleSelectNode}
|
onSelectNode={handleSelectNode}
|
||||||
|
onFixWithAI={handleFixWithAI}
|
||||||
|
isFixing={isFixing}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -705,6 +758,16 @@ export function TreeEditorPage() {
|
|||||||
<FlowAnalyticsPanel treeId={id} />
|
<FlowAnalyticsPanel treeId={id} />
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* AI Fix Review Modal */}
|
||||||
|
{fixProposals && (
|
||||||
|
<AIFixReviewModal
|
||||||
|
fixes={fixProposals}
|
||||||
|
onApply={handleApplyFix}
|
||||||
|
onApplyAll={handleApplyAllFixes}
|
||||||
|
onClose={handleCloseFixModal}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
24
frontend/src/types/ai-fix.ts
Normal file
24
frontend/src/types/ai-fix.ts
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
export interface AIFixValidationError {
|
||||||
|
node_id: string
|
||||||
|
message: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AIFixProposal {
|
||||||
|
target_node_id: string
|
||||||
|
error_message: string
|
||||||
|
description: string
|
||||||
|
original_node: Record<string, unknown>
|
||||||
|
fixed_node: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AIFixTreeRequest {
|
||||||
|
tree_structure: Record<string, unknown>
|
||||||
|
tree_name: string
|
||||||
|
tree_type: 'troubleshooting' | 'procedural' | 'maintenance'
|
||||||
|
validation_errors: AIFixValidationError[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AIFixTreeResponse {
|
||||||
|
fixes: AIFixProposal[]
|
||||||
|
tokens_used: { input: number; output: number }
|
||||||
|
}
|
||||||
@@ -45,3 +45,10 @@ export type {
|
|||||||
AIAssembleResponse,
|
AIAssembleResponse,
|
||||||
AIWizardPhase,
|
AIWizardPhase,
|
||||||
} from './ai'
|
} from './ai'
|
||||||
|
|
||||||
|
export type {
|
||||||
|
AIFixTreeRequest,
|
||||||
|
AIFixTreeResponse,
|
||||||
|
AIFixProposal,
|
||||||
|
AIFixValidationError,
|
||||||
|
} from './ai-fix'
|
||||||
|
|||||||
Reference in New Issue
Block a user