feat: AI-assisted flow builder with 4-stage wizard
Implements the complete AI flow builder feature using a guided 4-stage wizard (Foundation → Scaffold → Branch Detail → Review & Assemble). AI assists at bounded points using Claude Haiku for cost-efficient structured JSON generation (~$0.01-0.03/flow). Backend: new models (ai_conversations, ai_usage), Alembic migration, quota enforcement with billing anchor, Anthropic API integration with prompt caching, tree validation, conversation CRUD with 24h TTL, APScheduler cleanup job, 5 API endpoints, Pydantic schemas. Frontend: TypeScript types, API client, Zustand store for wizard state, 7 components (modal, step indicator, foundation form, branch selector, branch detail view, tree preview, quota display), MyTreesPage integration with "Build with AI" button (hidden when AI not configured). Tests: 14 validator unit tests + 11 endpoint integration tests with mocked Anthropic (zero real API spend). All 25 tests passing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
87
backend/app/core/ai_conversation_store.py
Normal file
87
backend/app/core/ai_conversation_store.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""DB-backed CRUD for AI wizard conversation state.
|
||||
|
||||
Conversations have a 24-hour TTL. Every access validates ownership and expiry.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.ai_conversation import AIConversation
|
||||
|
||||
|
||||
async def create_conversation(
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
wizard_state: dict[str, Any],
|
||||
db: AsyncSession,
|
||||
) -> AIConversation:
|
||||
"""Create a new AI wizard conversation."""
|
||||
conversation = AIConversation(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
status="foundation",
|
||||
wizard_state=wizard_state,
|
||||
messages=[],
|
||||
expires_at=datetime.now(timezone.utc)
|
||||
+ timedelta(hours=settings.AI_CONVERSATION_TTL_HOURS),
|
||||
)
|
||||
db.add(conversation)
|
||||
await db.flush()
|
||||
return conversation
|
||||
|
||||
|
||||
async def get_conversation(
|
||||
conversation_id: UUID,
|
||||
user_id: UUID,
|
||||
db: AsyncSession,
|
||||
) -> AIConversation:
|
||||
"""Get a conversation, validating ownership and expiry.
|
||||
|
||||
Raises HTTPException 410 if expired, 404 if not found or wrong owner.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(AIConversation).where(AIConversation.id == conversation_id)
|
||||
)
|
||||
conversation = result.scalar_one_or_none()
|
||||
|
||||
if not conversation or conversation.user_id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Conversation not found",
|
||||
)
|
||||
|
||||
if conversation.expires_at < datetime.now(timezone.utc):
|
||||
conversation.status = "expired"
|
||||
await db.flush()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_410_GONE,
|
||||
detail="Conversation expired. Please start a new AI build.",
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
async def update_conversation(
|
||||
conversation_id: UUID,
|
||||
user_id: UUID,
|
||||
updates: dict[str, Any],
|
||||
db: AsyncSession,
|
||||
) -> AIConversation:
|
||||
"""Update a conversation's fields.
|
||||
|
||||
Validates ownership and expiry before updating.
|
||||
"""
|
||||
conversation = await get_conversation(conversation_id, user_id, db)
|
||||
|
||||
for key, value in updates.items():
|
||||
if hasattr(conversation, key):
|
||||
setattr(conversation, key, value)
|
||||
|
||||
await db.flush()
|
||||
return conversation
|
||||
181
backend/app/core/ai_quota_service.py
Normal file
181
backend/app/core/ai_quota_service.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""AI generation quota management.
|
||||
|
||||
Enforces monthly and daily limits on AI flow builder usage.
|
||||
Monthly quota consumed only on successful tree assembly (counts_toward_quota=True).
|
||||
Daily limit is an anti-abuse guard consumed on conversation start.
|
||||
"""
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.ai_usage import AIUsage
|
||||
from app.models.plan_limits import PlanLimits
|
||||
from app.models.account_limit_override import AccountLimitOverride
|
||||
from app.core.subscriptions import get_account_subscription, get_plan_limits
|
||||
|
||||
|
||||
async def get_user_plan(account_id: Optional[UUID], db: AsyncSession) -> str:
|
||||
"""Get the plan tier for an account."""
|
||||
if not account_id:
|
||||
return "free"
|
||||
sub = await get_account_subscription(account_id, db)
|
||||
if sub is None:
|
||||
return "free"
|
||||
return sub.plan if sub.plan else "free"
|
||||
|
||||
|
||||
async def _get_effective_limits(
|
||||
account_id: UUID, plan: str, db: AsyncSession
|
||||
) -> tuple[Optional[int], Optional[int]]:
|
||||
"""Get effective AI limits (monthly, daily), applying account overrides.
|
||||
|
||||
Returns (monthly_limit, daily_limit). None means unlimited.
|
||||
"""
|
||||
limits = await get_plan_limits(plan, db)
|
||||
monthly = limits.max_ai_builds_per_month if limits else None
|
||||
daily = limits.max_ai_builds_per_24h if limits else None
|
||||
|
||||
# Check for account-level overrides
|
||||
result = await db.execute(
|
||||
select(AccountLimitOverride).where(
|
||||
AccountLimitOverride.account_id == account_id
|
||||
)
|
||||
)
|
||||
override = result.scalar_one_or_none()
|
||||
if override:
|
||||
if override.override_max_ai_builds_per_month is not None:
|
||||
monthly = override.override_max_ai_builds_per_month
|
||||
if override.override_max_ai_builds_per_24h is not None:
|
||||
daily = override.override_max_ai_builds_per_24h
|
||||
|
||||
return monthly, daily
|
||||
|
||||
|
||||
def _get_billing_anchor_month_start(anchor: Optional[datetime]) -> datetime:
|
||||
"""Calculate the start of the current billing month from the anchor date.
|
||||
|
||||
If the anchor is day 15, the billing month runs from the 15th of each month.
|
||||
Falls back to calendar month if anchor is None.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
if not anchor:
|
||||
return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
anchor_day = min(anchor.day, 28) # Clamp to avoid month overflow
|
||||
this_month_anchor = now.replace(
|
||||
day=anchor_day, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
if now >= this_month_anchor:
|
||||
return this_month_anchor
|
||||
else:
|
||||
# We're before the anchor day, so billing month started last month
|
||||
if now.month == 1:
|
||||
return this_month_anchor.replace(year=now.year - 1, month=12)
|
||||
else:
|
||||
return this_month_anchor.replace(month=now.month - 1)
|
||||
|
||||
|
||||
async def check_ai_quota(
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
db: AsyncSession,
|
||||
billing_anchor: Optional[datetime] = None,
|
||||
) -> tuple[bool, dict]:
|
||||
"""Check if user can make an AI generation.
|
||||
|
||||
Returns (allowed, quota_status_dict).
|
||||
Monthly counts only rows with counts_toward_quota=True.
|
||||
Daily counts only rows with generation_type in ('scaffold', 'branch_detail').
|
||||
"""
|
||||
plan = await get_user_plan(account_id, db)
|
||||
monthly_limit, daily_limit = await _get_effective_limits(account_id, plan, db)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
month_start = _get_billing_anchor_month_start(billing_anchor)
|
||||
day_start = now - timedelta(hours=24)
|
||||
|
||||
# Monthly: count successful quota-consuming records
|
||||
monthly_count = await db.scalar(
|
||||
select(func.count(AIUsage.id)).where(
|
||||
AIUsage.user_id == user_id,
|
||||
AIUsage.counts_toward_quota == True, # noqa: E712
|
||||
AIUsage.created_at >= month_start,
|
||||
)
|
||||
) or 0
|
||||
|
||||
# Daily: count all AI API calls (scaffold + branch_detail) in last 24h
|
||||
daily_count = await db.scalar(
|
||||
select(func.count(AIUsage.id)).where(
|
||||
AIUsage.user_id == user_id,
|
||||
AIUsage.succeeded == True, # noqa: E712
|
||||
AIUsage.generation_type.in_(["scaffold", "branch_detail"]),
|
||||
AIUsage.created_at >= day_start,
|
||||
)
|
||||
) or 0
|
||||
|
||||
allowed = True
|
||||
deny_reason = None
|
||||
if monthly_limit is not None and monthly_count >= monthly_limit:
|
||||
allowed = False
|
||||
deny_reason = "monthly"
|
||||
if daily_limit is not None and daily_count >= daily_limit:
|
||||
allowed = False
|
||||
deny_reason = "daily"
|
||||
|
||||
# Calculate reset timestamps
|
||||
monthly_reset_at = month_start.replace(
|
||||
month=month_start.month % 12 + 1,
|
||||
year=month_start.year + (1 if month_start.month == 12 else 0),
|
||||
)
|
||||
daily_reset_at = day_start + timedelta(hours=24)
|
||||
|
||||
return allowed, {
|
||||
"plan": plan,
|
||||
"monthly_used": monthly_count,
|
||||
"monthly_limit": monthly_limit,
|
||||
"monthly_reset_at": monthly_reset_at.isoformat(),
|
||||
"daily_used": daily_count,
|
||||
"daily_limit": daily_limit,
|
||||
"daily_reset_at": daily_reset_at.isoformat(),
|
||||
"allowed": allowed,
|
||||
"deny_reason": deny_reason,
|
||||
}
|
||||
|
||||
|
||||
async def record_ai_usage(
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
conversation_id: Optional[UUID],
|
||||
generation_type: str,
|
||||
tier: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
estimated_cost: float,
|
||||
succeeded: bool,
|
||||
counts_toward_quota: bool,
|
||||
error_code: Optional[str],
|
||||
extra_data: Optional[dict],
|
||||
db: AsyncSession,
|
||||
) -> AIUsage:
|
||||
"""Record an AI usage entry."""
|
||||
usage = AIUsage(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
conversation_id=conversation_id,
|
||||
generation_type=generation_type,
|
||||
tier_at_time=tier,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
estimated_cost_usd=estimated_cost,
|
||||
succeeded=succeeded,
|
||||
counts_toward_quota=counts_toward_quota,
|
||||
error_code=error_code,
|
||||
extra_data=extra_data or {},
|
||||
)
|
||||
db.add(usage)
|
||||
await db.flush()
|
||||
return usage
|
||||
293
backend/app/core/ai_tree_generator_service.py
Normal file
293
backend/app/core/ai_tree_generator_service.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""AI-powered tree generation service using Anthropic Claude API.
|
||||
|
||||
Implements the 4-stage wizard flow:
|
||||
Stage 2 (scaffold): AI suggests 4-7 top-level branches
|
||||
Stage 3 (branch_detail): AI generates detailed nodes per branch
|
||||
Stage 4 (assemble): Pure assembly logic — zero AI calls
|
||||
|
||||
System prompts are static constants to enable Anthropic prompt caching.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import anthropic
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Cost estimation (Haiku 4.5 pricing) ──
|
||||
COST_PER_INPUT_TOKEN = 1.0 / 1_000_000 # $1.00 per 1M input tokens
|
||||
COST_PER_OUTPUT_TOKEN = 5.0 / 1_000_000 # $5.00 per 1M output tokens
|
||||
|
||||
|
||||
# ── System Prompts ──
|
||||
|
||||
SCAFFOLD_SYSTEM_PROMPT = """You are ResolutionFlow AI, assisting MSP engineers to build troubleshooting and procedural flows for IT service management.
|
||||
|
||||
Context: Your audience is technical MSP staff experienced with Windows Server, Active Directory, networking, and common MSP tooling (ConnectWise, Datto, SonicWall, etc.).
|
||||
|
||||
Task: Given a flow type, category, name, description, and environment tags, suggest 4-7 top-level branches for the flow.
|
||||
|
||||
For TROUBLESHOOTING flows:
|
||||
- Branches should be symptom-based categories (e.g., "Authentication Failures", "Connectivity Issues", "Performance Degradation")
|
||||
- Each branch represents a common way the problem manifests
|
||||
- Order from most common to least common
|
||||
|
||||
For PROCEDURE flows:
|
||||
- Branches should be phase-based stages (e.g., "Prerequisites", "Configuration", "Verification", "Documentation")
|
||||
- Each branch represents a major step in the process
|
||||
- Order in logical execution sequence
|
||||
|
||||
Rules:
|
||||
- Suggest 4-7 branches
|
||||
- Be specific to the technology/service described — avoid generic branches
|
||||
- Branch names should be concise (2-5 words)
|
||||
- Each branch needs a brief description (1 sentence)
|
||||
- Return ONLY valid JSON, no markdown, no explanation
|
||||
|
||||
Output format:
|
||||
{"branches": [{"name": "Branch Name", "description": "Brief description of what this covers"}]}"""
|
||||
|
||||
|
||||
BRANCH_DETAIL_SYSTEM_PROMPT = """You are ResolutionFlow AI generating step-by-step detail for one branch of a troubleshooting or procedural flow for MSP engineers.
|
||||
|
||||
Context: Your audience is technical MSP staff experienced with Windows Server, Active Directory, networking, and common MSP tooling.
|
||||
|
||||
You must return ONLY valid JSON — no markdown, no code fences, no explanation.
|
||||
|
||||
Required node schema:
|
||||
|
||||
Decision nodes (branching diagnostic questions):
|
||||
{"id": "unique-slug", "type": "decision", "question": "The diagnostic question", "help_text": "Optional context or command hint", "options": [{"id": "opt-id", "label": "Answer choice", "next_node_id": "child-node-id"}], "children": []}
|
||||
|
||||
Action nodes (investigation or remediation steps):
|
||||
{"id": "unique-slug", "type": "action", "title": "Short title", "description": "Detailed instructions", "commands": ["PowerShell or CMD commands"], "expected_outcome": "What success looks like", "children": []}
|
||||
|
||||
Solution nodes (leaf nodes — the resolution):
|
||||
{"id": "unique-slug", "type": "solution", "title": "Resolution title", "description": "Full resolution description", "resolution_steps": ["Step 1", "Step 2"]}
|
||||
|
||||
Rules:
|
||||
1. Generate 3-10 nodes for this branch
|
||||
2. Start with a decision node if troubleshooting, action node if procedure
|
||||
3. Every branch path MUST end in a solution node — no dead ends
|
||||
4. Include realistic MSP commands (PowerShell preferred for Windows)
|
||||
5. Use unique node IDs prefixed with the branch context (e.g., "dns-check-service")
|
||||
6. Every option's next_node_id must match an existing child node's id
|
||||
7. All option labels must be meaningful and specific
|
||||
8. Decision nodes must have at least 2 options
|
||||
9. Return a single root node with its children nested inside
|
||||
|
||||
Few-shot example (abbreviated):
|
||||
{"id": "dns-root", "type": "decision", "question": "Can the client resolve any DNS names?", "help_text": "Run: nslookup google.com", "options": [{"id": "dns-opt-none", "label": "No DNS resolution at all", "next_node_id": "dns-check-service"}, {"id": "dns-opt-partial", "label": "Some names resolve, others don't", "next_node_id": "dns-check-specific"}], "children": [{"id": "dns-check-service", "type": "action", "title": "Check DNS Service", "description": "Verify the DNS Client service is running", "commands": ["Get-Service -Name Dnscache"], "expected_outcome": "Service should be Running", "children": [{"id": "dns-resolved", "type": "solution", "title": "DNS Service Restored", "description": "DNS client service was stopped", "resolution_steps": ["Restart DNS Client service", "Flush DNS cache: ipconfig /flushdns", "Test resolution"]}]}, {"id": "dns-check-specific", "type": "solution", "title": "Selective DNS Failure", "description": "Specific records missing or stale", "resolution_steps": ["Check DNS server configuration", "Verify zone records", "Clear DNS cache"]}]}"""
|
||||
|
||||
|
||||
CORRECTIVE_PROMPT_TEMPLATE = """Your previous JSON was invalid for ResolutionFlow's tree schema.
|
||||
|
||||
Validation errors:
|
||||
{error_list}
|
||||
|
||||
Return a corrected full JSON object only. No markdown, no prose, no code fences.
|
||||
Fix ALL listed errors while maintaining the same troubleshooting/procedural logic."""
|
||||
|
||||
|
||||
def _get_client() -> anthropic.AsyncAnthropic:
|
||||
"""Get configured async Anthropic client."""
|
||||
if not settings.ANTHROPIC_API_KEY:
|
||||
raise RuntimeError("ANTHROPIC_API_KEY not configured")
|
||||
return anthropic.AsyncAnthropic(
|
||||
api_key=settings.ANTHROPIC_API_KEY,
|
||||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
|
||||
"""Estimate USD cost from token counts."""
|
||||
return (input_tokens * COST_PER_INPUT_TOKEN) + (
|
||||
output_tokens * COST_PER_OUTPUT_TOKEN
|
||||
)
|
||||
|
||||
|
||||
async def scaffold_branches(
|
||||
wizard_state: dict[str, Any],
|
||||
) -> tuple[list[dict[str, str]], int, int, float]:
|
||||
"""Stage 2: AI suggests top-level branches.
|
||||
|
||||
Returns (branches, input_tokens, output_tokens, estimated_cost).
|
||||
Raises ValueError on invalid response.
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||
name = wizard_state.get("name", "")
|
||||
description = wizard_state.get("description", "")
|
||||
tags = wizard_state.get("environment_tags", [])
|
||||
|
||||
user_message = (
|
||||
f"Flow type: {flow_type}\n"
|
||||
f"Name: {name}\n"
|
||||
f"Description: {description}\n"
|
||||
)
|
||||
if tags:
|
||||
user_message += f"Environment: {', '.join(tags)}\n"
|
||||
|
||||
response = await client.messages.create(
|
||||
model=settings.AI_MODEL,
|
||||
max_tokens=1024,
|
||||
system=SCAFFOLD_SYSTEM_PROMPT,
|
||||
messages=[{"role": "user", "content": user_message}],
|
||||
)
|
||||
|
||||
raw_text = response.content[0].text
|
||||
input_tokens = response.usage.input_tokens
|
||||
output_tokens = response.usage.output_tokens
|
||||
cost = _estimate_cost(input_tokens, output_tokens)
|
||||
|
||||
try:
|
||||
data = json.loads(raw_text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"AI returned invalid JSON: {e}")
|
||||
|
||||
branches = data.get("branches", [])
|
||||
if not isinstance(branches, list) or len(branches) < 2:
|
||||
raise ValueError("AI returned fewer than 2 branches")
|
||||
|
||||
return branches, input_tokens, output_tokens, cost
|
||||
|
||||
|
||||
async def generate_branch_detail(
|
||||
wizard_state: dict[str, Any],
|
||||
branch_name: str,
|
||||
existing_branches: list[str],
|
||||
) -> tuple[dict[str, Any], int, int, float]:
|
||||
"""Stage 3: AI generates detailed nodes for one branch.
|
||||
|
||||
Returns (branch_tree, input_tokens, output_tokens, estimated_cost).
|
||||
On validation failure, retries once with corrective prompt.
|
||||
Raises ValueError if both attempts fail.
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||
name = wizard_state.get("name", "")
|
||||
description = wizard_state.get("description", "")
|
||||
|
||||
user_message = (
|
||||
f"Flow: {name} ({flow_type})\n"
|
||||
f"Description: {description}\n"
|
||||
f"Branch to detail: {branch_name}\n"
|
||||
)
|
||||
if existing_branches:
|
||||
other = [b for b in existing_branches if b != branch_name]
|
||||
if other:
|
||||
user_message += f"Other branches (avoid overlap): {', '.join(other)}\n"
|
||||
|
||||
messages = [{"role": "user", "content": user_message}]
|
||||
total_input = 0
|
||||
total_output = 0
|
||||
|
||||
for attempt in range(2):
|
||||
response = await client.messages.create(
|
||||
model=settings.AI_MODEL,
|
||||
max_tokens=4096,
|
||||
system=BRANCH_DETAIL_SYSTEM_PROMPT,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
raw_text = response.content[0].text
|
||||
total_input += response.usage.input_tokens
|
||||
total_output += response.usage.output_tokens
|
||||
|
||||
try:
|
||||
branch_tree = json.loads(raw_text)
|
||||
except json.JSONDecodeError as e:
|
||||
if attempt == 0:
|
||||
messages.append({"role": "assistant", "content": raw_text})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": CORRECTIVE_PROMPT_TEMPLATE.format(
|
||||
error_list=f"JSON parse error: {e}"
|
||||
),
|
||||
})
|
||||
continue
|
||||
raise ValueError(f"AI returned invalid JSON after retry: {e}")
|
||||
|
||||
# Validate the branch structure
|
||||
errors = validate_generated_tree(branch_tree)
|
||||
if not errors:
|
||||
cost = _estimate_cost(total_input, total_output)
|
||||
return branch_tree, total_input, total_output, cost
|
||||
|
||||
if attempt == 0:
|
||||
messages.append({"role": "assistant", "content": raw_text})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": CORRECTIVE_PROMPT_TEMPLATE.format(
|
||||
error_list="\n".join(f"- {e}" for e in errors)
|
||||
),
|
||||
})
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"AI tree validation failed after retry: {'; '.join(errors)}"
|
||||
)
|
||||
|
||||
# Should not reach here
|
||||
raise ValueError("Branch detail generation failed")
|
||||
|
||||
|
||||
def assemble_tree(
|
||||
wizard_state: dict[str, Any],
|
||||
branches: list[dict[str, Any]],
|
||||
) -> tuple[dict[str, Any], str, str, dict[str, int]]:
|
||||
"""Stage 4: Assemble branches into a complete tree.
|
||||
|
||||
Zero AI calls — pure assembly logic.
|
||||
Returns (tree_structure, suggested_name, suggested_description, summary_stats).
|
||||
"""
|
||||
flow_type = wizard_state.get("flow_type", "troubleshooting")
|
||||
name = wizard_state.get("name", "Untitled Flow")
|
||||
description = wizard_state.get("description", "")
|
||||
|
||||
# Build root decision node pointing to each branch
|
||||
options = []
|
||||
children = []
|
||||
for i, branch in enumerate(branches):
|
||||
branch_name = branch.get("name", f"Branch {i + 1}")
|
||||
branch_tree = branch.get("steps")
|
||||
|
||||
if not branch_tree or not isinstance(branch_tree, dict):
|
||||
# Skip branches without detail
|
||||
continue
|
||||
|
||||
branch_id = branch_tree.get("id", f"branch_{i}")
|
||||
options.append({
|
||||
"id": f"opt_{i + 1}",
|
||||
"label": branch_name,
|
||||
"next_node_id": branch_id,
|
||||
})
|
||||
children.append(branch_tree)
|
||||
|
||||
if len(options) < 2:
|
||||
raise ValueError("Need at least 2 branches with detail to assemble a tree")
|
||||
|
||||
# Determine root question based on flow type
|
||||
if flow_type == "troubleshooting":
|
||||
root_question = f"What issue is the user experiencing with {name}?"
|
||||
else:
|
||||
root_question = f"Which phase of {name} are you working on?"
|
||||
|
||||
tree_structure = {
|
||||
"id": "root",
|
||||
"type": "decision",
|
||||
"question": root_question,
|
||||
"options": options,
|
||||
"children": children,
|
||||
}
|
||||
|
||||
stats = count_tree_stats(tree_structure)
|
||||
|
||||
return tree_structure, name, description, stats
|
||||
199
backend/app/core/ai_tree_validator.py
Normal file
199
backend/app/core/ai_tree_validator.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Validation for AI-generated tree structures.
|
||||
|
||||
Ensures generated trees conform to ResolutionFlow's node schema
|
||||
before they are saved to the database.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
|
||||
VALID_NODE_TYPES = {"decision", "action", "solution"}
|
||||
|
||||
# Required fields per node type
|
||||
REQUIRED_FIELDS = {
|
||||
"decision": {"id", "type", "question", "options", "children"},
|
||||
"action": {"id", "type", "title", "description"},
|
||||
"solution": {"id", "type", "title", "description"},
|
||||
}
|
||||
|
||||
|
||||
class TreeValidationError(Exception):
|
||||
"""Raised when a generated tree fails validation."""
|
||||
|
||||
def __init__(self, errors: list[str]):
|
||||
self.errors = errors
|
||||
super().__init__(f"Tree validation failed: {'; '.join(errors)}")
|
||||
|
||||
|
||||
def validate_generated_tree(tree: dict[str, Any]) -> list[str]:
|
||||
"""Validate an AI-generated tree structure.
|
||||
|
||||
Returns a list of error strings. Empty list means valid.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
if not isinstance(tree, dict):
|
||||
return ["Tree must be a JSON object"]
|
||||
|
||||
# Root must be a decision node
|
||||
if tree.get("type") != "decision":
|
||||
errors.append("Root node must be type 'decision'")
|
||||
|
||||
# Collect all node IDs and validate structure
|
||||
all_ids: set[str] = set()
|
||||
all_referenced_ids: set[str] = set()
|
||||
node_count = 0
|
||||
solution_count = 0
|
||||
|
||||
def _validate_node(node: dict[str, Any], path: str) -> None:
|
||||
nonlocal node_count, solution_count
|
||||
|
||||
if not isinstance(node, dict):
|
||||
errors.append(f"Node at {path} is not an object")
|
||||
return
|
||||
|
||||
node_count += 1
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
|
||||
# Check node ID
|
||||
if not node_id:
|
||||
errors.append(f"Node at {path} missing 'id'")
|
||||
elif node_id in all_ids:
|
||||
errors.append(f"Duplicate node ID: '{node_id}'")
|
||||
else:
|
||||
all_ids.add(node_id)
|
||||
|
||||
# Check node type
|
||||
if node_type not in VALID_NODE_TYPES:
|
||||
errors.append(
|
||||
f"Node '{node_id or path}' has invalid type '{node_type}'. "
|
||||
f"Must be one of: {', '.join(sorted(VALID_NODE_TYPES))}"
|
||||
)
|
||||
return
|
||||
|
||||
# Check required fields
|
||||
required = REQUIRED_FIELDS[node_type]
|
||||
missing = required - set(node.keys())
|
||||
if missing:
|
||||
errors.append(
|
||||
f"Node '{node_id}' (type={node_type}) missing fields: {', '.join(sorted(missing))}"
|
||||
)
|
||||
|
||||
# Type-specific validation
|
||||
if node_type == "decision":
|
||||
options = node.get("options", [])
|
||||
if not isinstance(options, list) or len(options) < 2:
|
||||
errors.append(
|
||||
f"Decision node '{node_id}' must have at least 2 options"
|
||||
)
|
||||
else:
|
||||
children = node.get("children", [])
|
||||
child_ids = {c.get("id") for c in children if isinstance(c, dict)}
|
||||
option_ids: set[str] = set()
|
||||
|
||||
for opt in options:
|
||||
if not isinstance(opt, dict):
|
||||
errors.append(f"Option in node '{node_id}' is not an object")
|
||||
continue
|
||||
opt_id = opt.get("id")
|
||||
if opt_id and opt_id in option_ids:
|
||||
errors.append(
|
||||
f"Duplicate option ID '{opt_id}' in node '{node_id}'"
|
||||
)
|
||||
if opt_id:
|
||||
option_ids.add(opt_id)
|
||||
|
||||
next_id = opt.get("next_node_id")
|
||||
if next_id:
|
||||
all_referenced_ids.add(next_id)
|
||||
if child_ids and next_id not in child_ids:
|
||||
errors.append(
|
||||
f"Option '{opt.get('label', '?')}' in node '{node_id}' "
|
||||
f"references non-existent child '{next_id}'"
|
||||
)
|
||||
|
||||
elif node_type == "action":
|
||||
next_id = node.get("next_node_id")
|
||||
if next_id:
|
||||
all_referenced_ids.add(next_id)
|
||||
|
||||
elif node_type == "solution":
|
||||
solution_count += 1
|
||||
|
||||
# Recurse into children
|
||||
for i, child in enumerate(node.get("children", [])):
|
||||
_validate_node(child, f"{path}.children[{i}]")
|
||||
|
||||
_validate_node(tree, "root")
|
||||
|
||||
# Global checks
|
||||
if node_count < 5:
|
||||
errors.append(
|
||||
f"Tree has only {node_count} nodes. Minimum 5 required for a useful tree."
|
||||
)
|
||||
if node_count > 50:
|
||||
errors.append(
|
||||
f"Tree has {node_count} nodes. Maximum 50 allowed."
|
||||
)
|
||||
if solution_count < 2:
|
||||
errors.append(
|
||||
f"Tree has only {solution_count} solution nodes. "
|
||||
"Need at least 2 to cover different resolution paths."
|
||||
)
|
||||
|
||||
# Check that all leaf (non-solution) nodes have children or are solutions
|
||||
_check_branch_termination(tree, errors)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _check_branch_termination(node: dict[str, Any], errors: list[str]) -> None:
|
||||
"""Verify every branch eventually reaches a solution node."""
|
||||
if not isinstance(node, dict):
|
||||
return
|
||||
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id", "?")
|
||||
children = node.get("children", [])
|
||||
|
||||
if node_type == "solution":
|
||||
return # Solution is a valid terminus
|
||||
|
||||
if not children and node_type != "solution":
|
||||
errors.append(
|
||||
f"Node '{node_id}' (type={node_type}) is a dead end — "
|
||||
"it has no children and is not a solution node"
|
||||
)
|
||||
return
|
||||
|
||||
for child in children:
|
||||
_check_branch_termination(child, errors)
|
||||
|
||||
|
||||
def count_tree_stats(tree: dict[str, Any]) -> dict[str, int]:
|
||||
"""Count node types and calculate depth of a tree."""
|
||||
stats = {
|
||||
"node_count": 0,
|
||||
"decision_count": 0,
|
||||
"action_count": 0,
|
||||
"solution_count": 0,
|
||||
"depth": 0,
|
||||
}
|
||||
|
||||
def _count(node: dict[str, Any], depth: int) -> None:
|
||||
if not isinstance(node, dict):
|
||||
return
|
||||
stats["node_count"] += 1
|
||||
node_type = node.get("type", "")
|
||||
if node_type == "decision":
|
||||
stats["decision_count"] += 1
|
||||
elif node_type == "action":
|
||||
stats["action_count"] += 1
|
||||
elif node_type == "solution":
|
||||
stats["solution_count"] += 1
|
||||
stats["depth"] = max(stats["depth"], depth)
|
||||
for child in node.get("children", []):
|
||||
_count(child, depth + 1)
|
||||
|
||||
_count(tree, 1)
|
||||
return stats
|
||||
@@ -72,6 +72,18 @@ class Settings(BaseSettings):
|
||||
"""Check if Stripe is configured."""
|
||||
return self.STRIPE_SECRET_KEY is not None and self.STRIPE_WEBHOOK_SECRET is not None
|
||||
|
||||
# AI Flow Builder
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
AI_MODEL: str = "claude-haiku-4-5"
|
||||
AI_CONVERSATION_TTL_HOURS: int = 24
|
||||
AI_MAX_CALLS_PER_FLOW: int = 10
|
||||
AI_REQUEST_TIMEOUT_SECONDS: int = 45
|
||||
|
||||
@property
|
||||
def ai_enabled(self) -> bool:
|
||||
"""Check if AI Flow Builder is configured."""
|
||||
return self.ANTHROPIC_API_KEY is not None
|
||||
|
||||
# Deployment – auto-seed test data on PR environments
|
||||
SEED_ON_DEPLOY: bool = False
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""APScheduler integration for maintenance flow auto-session creation."""
|
||||
"""APScheduler integration for maintenance flow auto-session creation and AI cleanup."""
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -7,8 +7,9 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.schedulers.base import SchedulerNotRunningError
|
||||
from apscheduler.jobstores.base import JobLookupError
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
import pytz
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -114,6 +115,27 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None:
|
||||
await db.rollback()
|
||||
|
||||
|
||||
async def _cleanup_expired_ai_conversations() -> None:
|
||||
"""Delete expired AI wizard conversations."""
|
||||
import app.models # noqa: F401
|
||||
from app.core.database import async_session_maker
|
||||
from app.models.ai_conversation import AIConversation
|
||||
|
||||
async with async_session_maker() as db:
|
||||
try:
|
||||
result = await db.execute(
|
||||
delete(AIConversation).where(
|
||||
AIConversation.expires_at < datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
logger.info(f"Cleaned up {result.rowcount} expired AI conversation(s)")
|
||||
await db.commit()
|
||||
except Exception:
|
||||
logger.exception("Error cleaning up expired AI conversations")
|
||||
await db.rollback()
|
||||
|
||||
|
||||
async def load_all_schedules(db: AsyncSession) -> None:
|
||||
"""Load all active schedules into APScheduler on startup."""
|
||||
# Import all models to ensure SQLAlchemy mapper relationships resolve
|
||||
|
||||
Reference in New Issue
Block a user