feat: AI-assisted flow builder with 4-stage wizard

Implements the complete AI flow builder feature using a guided 4-stage
wizard (Foundation → Scaffold → Branch Detail → Review & Assemble).
AI assists at bounded points using Claude Haiku for cost-efficient
structured JSON generation (~$0.01-0.03/flow).

Backend: new models (ai_conversations, ai_usage), Alembic migration,
quota enforcement with billing anchor, Anthropic API integration with
prompt caching, tree validation, conversation CRUD with 24h TTL,
APScheduler cleanup job, 5 API endpoints, Pydantic schemas.

Frontend: TypeScript types, API client, Zustand store for wizard state,
7 components (modal, step indicator, foundation form, branch selector,
branch detail view, tree preview, quota display), MyTreesPage integration
with "Build with AI" button (hidden when AI not configured).

Tests: 14 validator unit tests + 11 endpoint integration tests with
mocked Anthropic (zero real API spend). All 25 tests passing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-02-20 08:07:08 -05:00
parent aef40078d0
commit 44432413c2
35 changed files with 3662 additions and 5 deletions

View File

@@ -0,0 +1,216 @@
"""add ai flow builder tables and columns
Revision ID: a1b2c3d4e5f6
Revises: e65b9f8fd458
Create Date: 2026-02-20 12:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision: str = "a1b2c3d4e5f6"
down_revision: Union[str, None] = "e65b9f8fd458"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── ai_conversations table ──
op.create_table(
"ai_conversations",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"account_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("status", sa.String(20), nullable=False, server_default="foundation"),
sa.Column("messages", postgresql.JSONB(), nullable=False, server_default="[]"),
sa.Column(
"wizard_state", postgresql.JSONB(), nullable=False, server_default="{}"
),
sa.Column("generated_tree", postgresql.JSONB(), nullable=True),
sa.Column("question_rounds", sa.Integer(), nullable=False, server_default="0"),
sa.Column(
"expires_at", sa.DateTime(timezone=True), nullable=False
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.create_index(
"ix_ai_conversations_user_id", "ai_conversations", ["user_id"]
)
op.create_index(
"ix_ai_conversations_account_id", "ai_conversations", ["account_id"]
)
op.create_index(
"ix_ai_conversations_user_created",
"ai_conversations",
["user_id", sa.text("created_at DESC")],
)
op.create_index(
"ix_ai_conversations_expires_at", "ai_conversations", ["expires_at"]
)
# ── ai_usage table ──
op.create_table(
"ai_usage",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"account_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"conversation_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("ai_conversations.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("generation_type", sa.String(20), nullable=False),
sa.Column("tier_at_time", sa.String(20), nullable=False),
sa.Column("input_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column("output_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column(
"estimated_cost_usd",
sa.Numeric(10, 6),
nullable=False,
server_default="0",
),
sa.Column("succeeded", sa.Boolean(), nullable=False, server_default="true"),
sa.Column(
"counts_toward_quota",
sa.Boolean(),
nullable=False,
server_default="false",
),
sa.Column("error_code", sa.String(100), nullable=True),
sa.Column("metadata", postgresql.JSONB(), nullable=False, server_default="{}"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.create_index("ix_ai_usage_user_id", "ai_usage", ["user_id"])
op.create_index("ix_ai_usage_account_id", "ai_usage", ["account_id"])
op.create_index("ix_ai_usage_created_at", "ai_usage", ["created_at"])
op.create_index(
"ix_ai_usage_user_created",
"ai_usage",
["user_id", sa.text("created_at DESC")],
)
op.create_index(
"ix_ai_usage_user_type_created",
"ai_usage",
["user_id", "generation_type", sa.text("created_at DESC")],
)
# Prevents double quota decrement from race conditions
op.execute(
"""
CREATE UNIQUE INDEX ix_ai_usage_unique_quota
ON ai_usage (conversation_id)
WHERE counts_toward_quota = true;
"""
)
# ── Schema modifications to existing tables ──
# users: add ai_billing_cycle_anchor_at
op.add_column(
"users",
sa.Column("ai_billing_cycle_anchor_at", sa.DateTime(timezone=True), nullable=True),
)
# Backfill: use created_at as the billing anchor
op.execute(
"UPDATE users SET ai_billing_cycle_anchor_at = created_at WHERE ai_billing_cycle_anchor_at IS NULL"
)
# plan_limits: add AI limit columns
op.add_column(
"plan_limits",
sa.Column("max_ai_builds_per_month", sa.Integer(), nullable=True),
)
op.add_column(
"plan_limits",
sa.Column("max_ai_builds_per_24h", sa.Integer(), nullable=True),
)
# account_limit_overrides: add AI override columns
op.add_column(
"account_limit_overrides",
sa.Column("override_max_ai_builds_per_month", sa.Integer(), nullable=True),
)
op.add_column(
"account_limit_overrides",
sa.Column("override_max_ai_builds_per_24h", sa.Integer(), nullable=True),
)
# Seed plan_limits with AI quota values
op.execute(
"""
UPDATE plan_limits SET max_ai_builds_per_month = 2, max_ai_builds_per_24h = 1
WHERE plan = 'free';
"""
)
op.execute(
"""
UPDATE plan_limits SET max_ai_builds_per_month = 50, max_ai_builds_per_24h = 10
WHERE plan = 'pro';
"""
)
op.execute(
"""
UPDATE plan_limits SET max_ai_builds_per_month = 200, max_ai_builds_per_24h = 20
WHERE plan = 'team';
"""
)
# Enterprise: NULL means unlimited (no update needed as default is NULL)
def downgrade() -> None:
# Drop AI override columns from account_limit_overrides
op.drop_column("account_limit_overrides", "override_max_ai_builds_per_24h")
op.drop_column("account_limit_overrides", "override_max_ai_builds_per_month")
# Drop AI limit columns from plan_limits
op.drop_column("plan_limits", "max_ai_builds_per_24h")
op.drop_column("plan_limits", "max_ai_builds_per_month")
# Drop ai_billing_cycle_anchor_at from users
op.drop_column("users", "ai_billing_cycle_anchor_at")
# Drop ai_usage table (indexes drop automatically)
op.drop_table("ai_usage")
# Drop ai_conversations table (indexes drop automatically)
op.drop_table("ai_conversations")

View File

@@ -0,0 +1,427 @@
"""AI Flow Builder wizard endpoints.
4-stage wizard:
POST /ai/start — Stage 1: create conversation with metadata
POST /ai/scaffold — Stage 2: AI suggests branches
POST /ai/branch-detail — Stage 3: AI generates detail for one branch
POST /ai/assemble — Stage 4: assemble branches into tree (no AI)
GET /ai/quota — quota status
"""
import logging
from typing import Annotated
import anthropic
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin
from app.core.config import settings
from app.core.ai_conversation_store import (
create_conversation,
get_conversation,
update_conversation,
)
from app.core.ai_quota_service import check_ai_quota, record_ai_usage, get_user_plan
from app.core.ai_tree_generator_service import (
scaffold_branches,
generate_branch_detail,
assemble_tree,
)
from app.models.user import User
from app.schemas.ai_builder import (
AIStartRequest,
AIStartResponse,
AIScaffoldRequest,
AIScaffoldResponse,
AIBranchDetailRequest,
AIBranchDetailResponse,
AIAssembleRequest,
AIAssembleResponse,
AIQuotaStatusResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ai", tags=["ai-builder"])
def _require_ai_enabled() -> None:
"""Raise 503 if AI is not configured."""
if not settings.ai_enabled:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.",
)
@router.get("/quota", response_model=AIQuotaStatusResponse)
async def get_quota(
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Get current user's AI quota status."""
if not settings.ai_enabled:
return AIQuotaStatusResponse(
plan="free",
monthly_used=0,
monthly_limit=None,
monthly_reset_at="",
daily_used=0,
daily_limit=None,
daily_reset_at="",
allowed=False,
ai_enabled=False,
)
_, quota_status = await check_ai_quota(
user_id=current_user.id,
account_id=current_user.account_id,
db=db,
billing_anchor=current_user.ai_billing_cycle_anchor_at,
)
return AIQuotaStatusResponse(
**quota_status,
ai_enabled=True,
)
@router.post("/start", response_model=AIStartResponse, status_code=201)
async def start_conversation(
data: AIStartRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Stage 1: Create a new AI wizard conversation with foundation metadata."""
_require_ai_enabled()
# Check daily quota (anti-abuse)
allowed, quota_status = await check_ai_quota(
user_id=current_user.id,
account_id=current_user.account_id,
db=db,
billing_anchor=current_user.ai_billing_cycle_anchor_at,
)
if not allowed:
reset_key = (
"daily_reset_at"
if quota_status.get("deny_reason") == "daily"
else "monthly_reset_at"
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"message": f"AI build limit exceeded ({quota_status['deny_reason']})",
"reset_at": quota_status.get(reset_key),
"quota": quota_status,
},
)
wizard_state = {
"flow_type": data.flow_type,
"name": data.name,
"description": data.description,
"environment_tags": data.environment_tags,
"category_id": str(data.category_id) if data.category_id else None,
}
conversation = await create_conversation(
user_id=current_user.id,
account_id=current_user.account_id,
wizard_state=wizard_state,
db=db,
)
await db.commit()
return AIStartResponse(
conversation_id=conversation.id,
status=conversation.status,
)
@router.post("/scaffold", response_model=AIScaffoldResponse)
async def scaffold(
data: AIScaffoldRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Stage 2: AI suggests top-level branches."""
_require_ai_enabled()
conversation = await get_conversation(
data.conversation_id, current_user.id, db
)
# Check per-flow call limit
if conversation.question_rounds >= settings.AI_MAX_CALLS_PER_FLOW:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Maximum AI calls per flow exceeded",
)
plan = await get_user_plan(current_user.account_id, db)
try:
branches, input_tokens, output_tokens, cost = await scaffold_branches(
conversation.wizard_state,
)
except anthropic.APIError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="scaffold",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e)},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="AI provider error. Please try again.",
)
except ValueError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="scaffold",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code="invalid_output",
extra_data={"error": str(e)},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"AI returned invalid output: {e}",
)
# Record successful usage
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="scaffold",
tier=plan,
input_tokens=input_tokens,
output_tokens=output_tokens,
estimated_cost=cost,
succeeded=True,
counts_toward_quota=False,
error_code=None,
extra_data=None,
db=db,
)
# Update conversation state
wizard_state = dict(conversation.wizard_state)
wizard_state["branches"] = branches
await update_conversation(
conversation.id,
current_user.id,
{
"status": "scaffolding",
"wizard_state": wizard_state,
"question_rounds": conversation.question_rounds + 1,
},
db,
)
await db.commit()
return AIScaffoldResponse(
conversation_id=conversation.id,
branches=branches,
status="scaffolding",
)
@router.post("/branch-detail", response_model=AIBranchDetailResponse)
async def branch_detail(
data: AIBranchDetailRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Stage 3: AI generates detailed nodes for one branch."""
_require_ai_enabled()
conversation = await get_conversation(
data.conversation_id, current_user.id, db
)
if conversation.question_rounds >= settings.AI_MAX_CALLS_PER_FLOW:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Maximum AI calls per flow exceeded",
)
wizard_state = conversation.wizard_state
existing_branches = [
b.get("name", "") for b in wizard_state.get("branches", [])
]
plan = await get_user_plan(current_user.account_id, db)
try:
branch_tree, input_tokens, output_tokens, cost = (
await generate_branch_detail(
wizard_state,
data.branch_name,
existing_branches,
)
)
except anthropic.APIError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="branch_detail",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e), "branch_name": data.branch_name},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="AI provider error. Please try again.",
)
except ValueError as e:
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="branch_detail",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=False,
counts_toward_quota=False,
error_code="invalid_output",
extra_data={"error": str(e), "branch_name": data.branch_name},
db=db,
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"AI returned invalid output: {e}",
)
# Record successful usage
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="branch_detail",
tier=plan,
input_tokens=input_tokens,
output_tokens=output_tokens,
estimated_cost=cost,
succeeded=True,
counts_toward_quota=False,
error_code=None,
extra_data={"branch_name": data.branch_name},
db=db,
)
# Update conversation
await update_conversation(
conversation.id,
current_user.id,
{
"status": "detailing",
"question_rounds": conversation.question_rounds + 1,
},
db,
)
await db.commit()
return AIBranchDetailResponse(
conversation_id=conversation.id,
branch_name=data.branch_name,
steps=branch_tree,
status="detailing",
)
@router.post("/assemble", response_model=AIAssembleResponse)
async def assemble(
data: AIAssembleRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Stage 4: Assemble selected branches into a complete tree (no AI calls)."""
conversation = await get_conversation(
data.conversation_id, current_user.id, db
)
wizard_state = conversation.wizard_state
branches_for_assembly = [b.model_dump() for b in data.selected_branches]
try:
tree_structure, name, description, stats = assemble_tree(
wizard_state, branches_for_assembly
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=str(e),
)
# Record quota-consuming usage on successful assembly
plan = await get_user_plan(current_user.account_id, db)
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=conversation.id,
generation_type="tree",
tier=plan,
input_tokens=0,
output_tokens=0,
estimated_cost=0,
succeeded=True,
counts_toward_quota=True,
error_code=None,
extra_data={"stats": stats},
db=db,
)
# Update conversation with assembled tree
await update_conversation(
conversation.id,
current_user.id,
{
"status": "completed",
"generated_tree": tree_structure,
},
db,
)
await db.commit()
return AIAssembleResponse(
tree_structure=tree_structure,
suggested_name=name,
suggested_description=description,
summary=stats,
status="completed",
)

View File

@@ -5,6 +5,7 @@ from app.api.endpoints import ratings, analytics
from app.api.endpoints import target_lists
from app.api.endpoints import maintenance_schedules
from app.api.endpoints import feedback
from app.api.endpoints import ai_builder
api_router = APIRouter()
@@ -34,3 +35,4 @@ api_router.include_router(analytics.router)
api_router.include_router(target_lists.router)
api_router.include_router(maintenance_schedules.router)
api_router.include_router(feedback.router)
api_router.include_router(ai_builder.router)

View File

@@ -0,0 +1,87 @@
"""DB-backed CRUD for AI wizard conversation state.
Conversations have a 24-hour TTL. Every access validates ownership and expiry.
"""
import uuid
from datetime import datetime, timezone, timedelta
from typing import Any, Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.models.ai_conversation import AIConversation
async def create_conversation(
user_id: UUID,
account_id: UUID,
wizard_state: dict[str, Any],
db: AsyncSession,
) -> AIConversation:
"""Create a new AI wizard conversation."""
conversation = AIConversation(
user_id=user_id,
account_id=account_id,
status="foundation",
wizard_state=wizard_state,
messages=[],
expires_at=datetime.now(timezone.utc)
+ timedelta(hours=settings.AI_CONVERSATION_TTL_HOURS),
)
db.add(conversation)
await db.flush()
return conversation
async def get_conversation(
conversation_id: UUID,
user_id: UUID,
db: AsyncSession,
) -> AIConversation:
"""Get a conversation, validating ownership and expiry.
Raises HTTPException 410 if expired, 404 if not found or wrong owner.
"""
result = await db.execute(
select(AIConversation).where(AIConversation.id == conversation_id)
)
conversation = result.scalar_one_or_none()
if not conversation or conversation.user_id != user_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation not found",
)
if conversation.expires_at < datetime.now(timezone.utc):
conversation.status = "expired"
await db.flush()
raise HTTPException(
status_code=status.HTTP_410_GONE,
detail="Conversation expired. Please start a new AI build.",
)
return conversation
async def update_conversation(
conversation_id: UUID,
user_id: UUID,
updates: dict[str, Any],
db: AsyncSession,
) -> AIConversation:
"""Update a conversation's fields.
Validates ownership and expiry before updating.
"""
conversation = await get_conversation(conversation_id, user_id, db)
for key, value in updates.items():
if hasattr(conversation, key):
setattr(conversation, key, value)
await db.flush()
return conversation

View File

@@ -0,0 +1,181 @@
"""AI generation quota management.
Enforces monthly and daily limits on AI flow builder usage.
Monthly quota consumed only on successful tree assembly (counts_toward_quota=True).
Daily limit is an anti-abuse guard consumed on conversation start.
"""
from datetime import datetime, timezone, timedelta
from typing import Optional
from uuid import UUID
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ai_usage import AIUsage
from app.models.plan_limits import PlanLimits
from app.models.account_limit_override import AccountLimitOverride
from app.core.subscriptions import get_account_subscription, get_plan_limits
async def get_user_plan(account_id: Optional[UUID], db: AsyncSession) -> str:
"""Get the plan tier for an account."""
if not account_id:
return "free"
sub = await get_account_subscription(account_id, db)
if sub is None:
return "free"
return sub.plan if sub.plan else "free"
async def _get_effective_limits(
account_id: UUID, plan: str, db: AsyncSession
) -> tuple[Optional[int], Optional[int]]:
"""Get effective AI limits (monthly, daily), applying account overrides.
Returns (monthly_limit, daily_limit). None means unlimited.
"""
limits = await get_plan_limits(plan, db)
monthly = limits.max_ai_builds_per_month if limits else None
daily = limits.max_ai_builds_per_24h if limits else None
# Check for account-level overrides
result = await db.execute(
select(AccountLimitOverride).where(
AccountLimitOverride.account_id == account_id
)
)
override = result.scalar_one_or_none()
if override:
if override.override_max_ai_builds_per_month is not None:
monthly = override.override_max_ai_builds_per_month
if override.override_max_ai_builds_per_24h is not None:
daily = override.override_max_ai_builds_per_24h
return monthly, daily
def _get_billing_anchor_month_start(anchor: Optional[datetime]) -> datetime:
"""Calculate the start of the current billing month from the anchor date.
If the anchor is day 15, the billing month runs from the 15th of each month.
Falls back to calendar month if anchor is None.
"""
now = datetime.now(timezone.utc)
if not anchor:
return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
anchor_day = min(anchor.day, 28) # Clamp to avoid month overflow
this_month_anchor = now.replace(
day=anchor_day, hour=0, minute=0, second=0, microsecond=0
)
if now >= this_month_anchor:
return this_month_anchor
else:
# We're before the anchor day, so billing month started last month
if now.month == 1:
return this_month_anchor.replace(year=now.year - 1, month=12)
else:
return this_month_anchor.replace(month=now.month - 1)
async def check_ai_quota(
user_id: UUID,
account_id: UUID,
db: AsyncSession,
billing_anchor: Optional[datetime] = None,
) -> tuple[bool, dict]:
"""Check if user can make an AI generation.
Returns (allowed, quota_status_dict).
Monthly counts only rows with counts_toward_quota=True.
Daily counts only rows with generation_type in ('scaffold', 'branch_detail').
"""
plan = await get_user_plan(account_id, db)
monthly_limit, daily_limit = await _get_effective_limits(account_id, plan, db)
now = datetime.now(timezone.utc)
month_start = _get_billing_anchor_month_start(billing_anchor)
day_start = now - timedelta(hours=24)
# Monthly: count successful quota-consuming records
monthly_count = await db.scalar(
select(func.count(AIUsage.id)).where(
AIUsage.user_id == user_id,
AIUsage.counts_toward_quota == True, # noqa: E712
AIUsage.created_at >= month_start,
)
) or 0
# Daily: count all AI API calls (scaffold + branch_detail) in last 24h
daily_count = await db.scalar(
select(func.count(AIUsage.id)).where(
AIUsage.user_id == user_id,
AIUsage.succeeded == True, # noqa: E712
AIUsage.generation_type.in_(["scaffold", "branch_detail"]),
AIUsage.created_at >= day_start,
)
) or 0
allowed = True
deny_reason = None
if monthly_limit is not None and monthly_count >= monthly_limit:
allowed = False
deny_reason = "monthly"
if daily_limit is not None and daily_count >= daily_limit:
allowed = False
deny_reason = "daily"
# Calculate reset timestamps
monthly_reset_at = month_start.replace(
month=month_start.month % 12 + 1,
year=month_start.year + (1 if month_start.month == 12 else 0),
)
daily_reset_at = day_start + timedelta(hours=24)
return allowed, {
"plan": plan,
"monthly_used": monthly_count,
"monthly_limit": monthly_limit,
"monthly_reset_at": monthly_reset_at.isoformat(),
"daily_used": daily_count,
"daily_limit": daily_limit,
"daily_reset_at": daily_reset_at.isoformat(),
"allowed": allowed,
"deny_reason": deny_reason,
}
async def record_ai_usage(
user_id: UUID,
account_id: UUID,
conversation_id: Optional[UUID],
generation_type: str,
tier: str,
input_tokens: int,
output_tokens: int,
estimated_cost: float,
succeeded: bool,
counts_toward_quota: bool,
error_code: Optional[str],
extra_data: Optional[dict],
db: AsyncSession,
) -> AIUsage:
"""Record an AI usage entry."""
usage = AIUsage(
user_id=user_id,
account_id=account_id,
conversation_id=conversation_id,
generation_type=generation_type,
tier_at_time=tier,
input_tokens=input_tokens,
output_tokens=output_tokens,
estimated_cost_usd=estimated_cost,
succeeded=succeeded,
counts_toward_quota=counts_toward_quota,
error_code=error_code,
extra_data=extra_data or {},
)
db.add(usage)
await db.flush()
return usage

View File

@@ -0,0 +1,293 @@
"""AI-powered tree generation service using Anthropic Claude API.
Implements the 4-stage wizard flow:
Stage 2 (scaffold): AI suggests 4-7 top-level branches
Stage 3 (branch_detail): AI generates detailed nodes per branch
Stage 4 (assemble): Pure assembly logic — zero AI calls
System prompts are static constants to enable Anthropic prompt caching.
"""
import json
import logging
import uuid
from typing import Any
import anthropic
from app.core.config import settings
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
logger = logging.getLogger(__name__)
# ── Cost estimation (Haiku 4.5 pricing) ──
COST_PER_INPUT_TOKEN = 1.0 / 1_000_000 # $1.00 per 1M input tokens
COST_PER_OUTPUT_TOKEN = 5.0 / 1_000_000 # $5.00 per 1M output tokens
# ── System Prompts ──
SCAFFOLD_SYSTEM_PROMPT = """You are ResolutionFlow AI, assisting MSP engineers to build troubleshooting and procedural flows for IT service management.
Context: Your audience is technical MSP staff experienced with Windows Server, Active Directory, networking, and common MSP tooling (ConnectWise, Datto, SonicWall, etc.).
Task: Given a flow type, category, name, description, and environment tags, suggest 4-7 top-level branches for the flow.
For TROUBLESHOOTING flows:
- Branches should be symptom-based categories (e.g., "Authentication Failures", "Connectivity Issues", "Performance Degradation")
- Each branch represents a common way the problem manifests
- Order from most common to least common
For PROCEDURE flows:
- Branches should be phase-based stages (e.g., "Prerequisites", "Configuration", "Verification", "Documentation")
- Each branch represents a major step in the process
- Order in logical execution sequence
Rules:
- Suggest 4-7 branches
- Be specific to the technology/service described — avoid generic branches
- Branch names should be concise (2-5 words)
- Each branch needs a brief description (1 sentence)
- Return ONLY valid JSON, no markdown, no explanation
Output format:
{"branches": [{"name": "Branch Name", "description": "Brief description of what this covers"}]}"""
BRANCH_DETAIL_SYSTEM_PROMPT = """You are ResolutionFlow AI generating step-by-step detail for one branch of a troubleshooting or procedural flow for MSP engineers.
Context: Your audience is technical MSP staff experienced with Windows Server, Active Directory, networking, and common MSP tooling.
You must return ONLY valid JSON — no markdown, no code fences, no explanation.
Required node schema:
Decision nodes (branching diagnostic questions):
{"id": "unique-slug", "type": "decision", "question": "The diagnostic question", "help_text": "Optional context or command hint", "options": [{"id": "opt-id", "label": "Answer choice", "next_node_id": "child-node-id"}], "children": []}
Action nodes (investigation or remediation steps):
{"id": "unique-slug", "type": "action", "title": "Short title", "description": "Detailed instructions", "commands": ["PowerShell or CMD commands"], "expected_outcome": "What success looks like", "children": []}
Solution nodes (leaf nodes — the resolution):
{"id": "unique-slug", "type": "solution", "title": "Resolution title", "description": "Full resolution description", "resolution_steps": ["Step 1", "Step 2"]}
Rules:
1. Generate 3-10 nodes for this branch
2. Start with a decision node if troubleshooting, action node if procedure
3. Every branch path MUST end in a solution node — no dead ends
4. Include realistic MSP commands (PowerShell preferred for Windows)
5. Use unique node IDs prefixed with the branch context (e.g., "dns-check-service")
6. Every option's next_node_id must match an existing child node's id
7. All option labels must be meaningful and specific
8. Decision nodes must have at least 2 options
9. Return a single root node with its children nested inside
Few-shot example (abbreviated):
{"id": "dns-root", "type": "decision", "question": "Can the client resolve any DNS names?", "help_text": "Run: nslookup google.com", "options": [{"id": "dns-opt-none", "label": "No DNS resolution at all", "next_node_id": "dns-check-service"}, {"id": "dns-opt-partial", "label": "Some names resolve, others don't", "next_node_id": "dns-check-specific"}], "children": [{"id": "dns-check-service", "type": "action", "title": "Check DNS Service", "description": "Verify the DNS Client service is running", "commands": ["Get-Service -Name Dnscache"], "expected_outcome": "Service should be Running", "children": [{"id": "dns-resolved", "type": "solution", "title": "DNS Service Restored", "description": "DNS client service was stopped", "resolution_steps": ["Restart DNS Client service", "Flush DNS cache: ipconfig /flushdns", "Test resolution"]}]}, {"id": "dns-check-specific", "type": "solution", "title": "Selective DNS Failure", "description": "Specific records missing or stale", "resolution_steps": ["Check DNS server configuration", "Verify zone records", "Clear DNS cache"]}]}"""
CORRECTIVE_PROMPT_TEMPLATE = """Your previous JSON was invalid for ResolutionFlow's tree schema.
Validation errors:
{error_list}
Return a corrected full JSON object only. No markdown, no prose, no code fences.
Fix ALL listed errors while maintaining the same troubleshooting/procedural logic."""
def _get_client() -> anthropic.AsyncAnthropic:
"""Get configured async Anthropic client."""
if not settings.ANTHROPIC_API_KEY:
raise RuntimeError("ANTHROPIC_API_KEY not configured")
return anthropic.AsyncAnthropic(
api_key=settings.ANTHROPIC_API_KEY,
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
)
def _estimate_cost(input_tokens: int, output_tokens: int) -> float:
"""Estimate USD cost from token counts."""
return (input_tokens * COST_PER_INPUT_TOKEN) + (
output_tokens * COST_PER_OUTPUT_TOKEN
)
async def scaffold_branches(
wizard_state: dict[str, Any],
) -> tuple[list[dict[str, str]], int, int, float]:
"""Stage 2: AI suggests top-level branches.
Returns (branches, input_tokens, output_tokens, estimated_cost).
Raises ValueError on invalid response.
"""
client = _get_client()
flow_type = wizard_state.get("flow_type", "troubleshooting")
name = wizard_state.get("name", "")
description = wizard_state.get("description", "")
tags = wizard_state.get("environment_tags", [])
user_message = (
f"Flow type: {flow_type}\n"
f"Name: {name}\n"
f"Description: {description}\n"
)
if tags:
user_message += f"Environment: {', '.join(tags)}\n"
response = await client.messages.create(
model=settings.AI_MODEL,
max_tokens=1024,
system=SCAFFOLD_SYSTEM_PROMPT,
messages=[{"role": "user", "content": user_message}],
)
raw_text = response.content[0].text
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
cost = _estimate_cost(input_tokens, output_tokens)
try:
data = json.loads(raw_text)
except json.JSONDecodeError as e:
raise ValueError(f"AI returned invalid JSON: {e}")
branches = data.get("branches", [])
if not isinstance(branches, list) or len(branches) < 2:
raise ValueError("AI returned fewer than 2 branches")
return branches, input_tokens, output_tokens, cost
async def generate_branch_detail(
wizard_state: dict[str, Any],
branch_name: str,
existing_branches: list[str],
) -> tuple[dict[str, Any], int, int, float]:
"""Stage 3: AI generates detailed nodes for one branch.
Returns (branch_tree, input_tokens, output_tokens, estimated_cost).
On validation failure, retries once with corrective prompt.
Raises ValueError if both attempts fail.
"""
client = _get_client()
flow_type = wizard_state.get("flow_type", "troubleshooting")
name = wizard_state.get("name", "")
description = wizard_state.get("description", "")
user_message = (
f"Flow: {name} ({flow_type})\n"
f"Description: {description}\n"
f"Branch to detail: {branch_name}\n"
)
if existing_branches:
other = [b for b in existing_branches if b != branch_name]
if other:
user_message += f"Other branches (avoid overlap): {', '.join(other)}\n"
messages = [{"role": "user", "content": user_message}]
total_input = 0
total_output = 0
for attempt in range(2):
response = await client.messages.create(
model=settings.AI_MODEL,
max_tokens=4096,
system=BRANCH_DETAIL_SYSTEM_PROMPT,
messages=messages,
)
raw_text = response.content[0].text
total_input += response.usage.input_tokens
total_output += response.usage.output_tokens
try:
branch_tree = json.loads(raw_text)
except json.JSONDecodeError as e:
if attempt == 0:
messages.append({"role": "assistant", "content": raw_text})
messages.append({
"role": "user",
"content": CORRECTIVE_PROMPT_TEMPLATE.format(
error_list=f"JSON parse error: {e}"
),
})
continue
raise ValueError(f"AI returned invalid JSON after retry: {e}")
# Validate the branch structure
errors = validate_generated_tree(branch_tree)
if not errors:
cost = _estimate_cost(total_input, total_output)
return branch_tree, total_input, total_output, cost
if attempt == 0:
messages.append({"role": "assistant", "content": raw_text})
messages.append({
"role": "user",
"content": CORRECTIVE_PROMPT_TEMPLATE.format(
error_list="\n".join(f"- {e}" for e in errors)
),
})
continue
raise ValueError(
f"AI tree validation failed after retry: {'; '.join(errors)}"
)
# Should not reach here
raise ValueError("Branch detail generation failed")
def assemble_tree(
wizard_state: dict[str, Any],
branches: list[dict[str, Any]],
) -> tuple[dict[str, Any], str, str, dict[str, int]]:
"""Stage 4: Assemble branches into a complete tree.
Zero AI calls — pure assembly logic.
Returns (tree_structure, suggested_name, suggested_description, summary_stats).
"""
flow_type = wizard_state.get("flow_type", "troubleshooting")
name = wizard_state.get("name", "Untitled Flow")
description = wizard_state.get("description", "")
# Build root decision node pointing to each branch
options = []
children = []
for i, branch in enumerate(branches):
branch_name = branch.get("name", f"Branch {i + 1}")
branch_tree = branch.get("steps")
if not branch_tree or not isinstance(branch_tree, dict):
# Skip branches without detail
continue
branch_id = branch_tree.get("id", f"branch_{i}")
options.append({
"id": f"opt_{i + 1}",
"label": branch_name,
"next_node_id": branch_id,
})
children.append(branch_tree)
if len(options) < 2:
raise ValueError("Need at least 2 branches with detail to assemble a tree")
# Determine root question based on flow type
if flow_type == "troubleshooting":
root_question = f"What issue is the user experiencing with {name}?"
else:
root_question = f"Which phase of {name} are you working on?"
tree_structure = {
"id": "root",
"type": "decision",
"question": root_question,
"options": options,
"children": children,
}
stats = count_tree_stats(tree_structure)
return tree_structure, name, description, stats

View File

@@ -0,0 +1,199 @@
"""Validation for AI-generated tree structures.
Ensures generated trees conform to ResolutionFlow's node schema
before they are saved to the database.
"""
from typing import Any
VALID_NODE_TYPES = {"decision", "action", "solution"}
# Required fields per node type
REQUIRED_FIELDS = {
"decision": {"id", "type", "question", "options", "children"},
"action": {"id", "type", "title", "description"},
"solution": {"id", "type", "title", "description"},
}
class TreeValidationError(Exception):
"""Raised when a generated tree fails validation."""
def __init__(self, errors: list[str]):
self.errors = errors
super().__init__(f"Tree validation failed: {'; '.join(errors)}")
def validate_generated_tree(tree: dict[str, Any]) -> list[str]:
"""Validate an AI-generated tree structure.
Returns a list of error strings. Empty list means valid.
"""
errors: list[str] = []
if not isinstance(tree, dict):
return ["Tree must be a JSON object"]
# Root must be a decision node
if tree.get("type") != "decision":
errors.append("Root node must be type 'decision'")
# Collect all node IDs and validate structure
all_ids: set[str] = set()
all_referenced_ids: set[str] = set()
node_count = 0
solution_count = 0
def _validate_node(node: dict[str, Any], path: str) -> None:
nonlocal node_count, solution_count
if not isinstance(node, dict):
errors.append(f"Node at {path} is not an object")
return
node_count += 1
node_type = node.get("type")
node_id = node.get("id")
# Check node ID
if not node_id:
errors.append(f"Node at {path} missing 'id'")
elif node_id in all_ids:
errors.append(f"Duplicate node ID: '{node_id}'")
else:
all_ids.add(node_id)
# Check node type
if node_type not in VALID_NODE_TYPES:
errors.append(
f"Node '{node_id or path}' has invalid type '{node_type}'. "
f"Must be one of: {', '.join(sorted(VALID_NODE_TYPES))}"
)
return
# Check required fields
required = REQUIRED_FIELDS[node_type]
missing = required - set(node.keys())
if missing:
errors.append(
f"Node '{node_id}' (type={node_type}) missing fields: {', '.join(sorted(missing))}"
)
# Type-specific validation
if node_type == "decision":
options = node.get("options", [])
if not isinstance(options, list) or len(options) < 2:
errors.append(
f"Decision node '{node_id}' must have at least 2 options"
)
else:
children = node.get("children", [])
child_ids = {c.get("id") for c in children if isinstance(c, dict)}
option_ids: set[str] = set()
for opt in options:
if not isinstance(opt, dict):
errors.append(f"Option in node '{node_id}' is not an object")
continue
opt_id = opt.get("id")
if opt_id and opt_id in option_ids:
errors.append(
f"Duplicate option ID '{opt_id}' in node '{node_id}'"
)
if opt_id:
option_ids.add(opt_id)
next_id = opt.get("next_node_id")
if next_id:
all_referenced_ids.add(next_id)
if child_ids and next_id not in child_ids:
errors.append(
f"Option '{opt.get('label', '?')}' in node '{node_id}' "
f"references non-existent child '{next_id}'"
)
elif node_type == "action":
next_id = node.get("next_node_id")
if next_id:
all_referenced_ids.add(next_id)
elif node_type == "solution":
solution_count += 1
# Recurse into children
for i, child in enumerate(node.get("children", [])):
_validate_node(child, f"{path}.children[{i}]")
_validate_node(tree, "root")
# Global checks
if node_count < 5:
errors.append(
f"Tree has only {node_count} nodes. Minimum 5 required for a useful tree."
)
if node_count > 50:
errors.append(
f"Tree has {node_count} nodes. Maximum 50 allowed."
)
if solution_count < 2:
errors.append(
f"Tree has only {solution_count} solution nodes. "
"Need at least 2 to cover different resolution paths."
)
# Check that all leaf (non-solution) nodes have children or are solutions
_check_branch_termination(tree, errors)
return errors
def _check_branch_termination(node: dict[str, Any], errors: list[str]) -> None:
"""Verify every branch eventually reaches a solution node."""
if not isinstance(node, dict):
return
node_type = node.get("type")
node_id = node.get("id", "?")
children = node.get("children", [])
if node_type == "solution":
return # Solution is a valid terminus
if not children and node_type != "solution":
errors.append(
f"Node '{node_id}' (type={node_type}) is a dead end — "
"it has no children and is not a solution node"
)
return
for child in children:
_check_branch_termination(child, errors)
def count_tree_stats(tree: dict[str, Any]) -> dict[str, int]:
"""Count node types and calculate depth of a tree."""
stats = {
"node_count": 0,
"decision_count": 0,
"action_count": 0,
"solution_count": 0,
"depth": 0,
}
def _count(node: dict[str, Any], depth: int) -> None:
if not isinstance(node, dict):
return
stats["node_count"] += 1
node_type = node.get("type", "")
if node_type == "decision":
stats["decision_count"] += 1
elif node_type == "action":
stats["action_count"] += 1
elif node_type == "solution":
stats["solution_count"] += 1
stats["depth"] = max(stats["depth"], depth)
for child in node.get("children", []):
_count(child, depth + 1)
_count(tree, 1)
return stats

View File

@@ -72,6 +72,18 @@ class Settings(BaseSettings):
"""Check if Stripe is configured."""
return self.STRIPE_SECRET_KEY is not None and self.STRIPE_WEBHOOK_SECRET is not None
# AI Flow Builder
ANTHROPIC_API_KEY: Optional[str] = None
AI_MODEL: str = "claude-haiku-4-5"
AI_CONVERSATION_TTL_HOURS: int = 24
AI_MAX_CALLS_PER_FLOW: int = 10
AI_REQUEST_TIMEOUT_SECONDS: int = 45
@property
def ai_enabled(self) -> bool:
"""Check if AI Flow Builder is configured."""
return self.ANTHROPIC_API_KEY is not None
# Deployment auto-seed test data on PR environments
SEED_ON_DEPLOY: bool = False

View File

@@ -1,4 +1,4 @@
"""APScheduler integration for maintenance flow auto-session creation."""
"""APScheduler integration for maintenance flow auto-session creation and AI cleanup."""
import logging
import uuid
from datetime import datetime, timezone
@@ -7,8 +7,9 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.schedulers.base import SchedulerNotRunningError
from apscheduler.jobstores.base import JobLookupError
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
import pytz
from sqlalchemy import select
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
@@ -114,6 +115,27 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None:
await db.rollback()
async def _cleanup_expired_ai_conversations() -> None:
"""Delete expired AI wizard conversations."""
import app.models # noqa: F401
from app.core.database import async_session_maker
from app.models.ai_conversation import AIConversation
async with async_session_maker() as db:
try:
result = await db.execute(
delete(AIConversation).where(
AIConversation.expires_at < datetime.now(timezone.utc)
)
)
if result.rowcount > 0:
logger.info(f"Cleaned up {result.rowcount} expired AI conversation(s)")
await db.commit()
except Exception:
logger.exception("Error cleaning up expired AI conversations")
await db.rollback()
async def load_all_schedules(db: AsyncSession) -> None:
"""Load all active schedules into APScheduler on startup."""
# Import all models to ensure SQLAlchemy mapper relationships resolve

View File

@@ -13,7 +13,7 @@ from app.core.logging_config import setup_logging
from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware
from app.core.rate_limit import limiter
from app.api.router import api_router
from app.core.scheduler import scheduler, load_all_schedules
from app.core.scheduler import scheduler, load_all_schedules, _cleanup_expired_ai_conversations
# Initialize logging configuration
setup_logging()
@@ -103,10 +103,17 @@ async def lifespan(app: FastAPI):
# Note: In production, use Alembic migrations instead of init_db
# await init_db()
# Start maintenance schedule runner
# Start maintenance schedule runner + AI conversation cleanup
scheduler.start()
async with async_session_maker() as db:
await load_all_schedules(db)
scheduler.add_job(
_cleanup_expired_ai_conversations,
trigger="interval",
hours=1,
id="cleanup_ai_conversations",
replace_existing=True,
)
# Auto-seed trees in background on PR environments
seed_task = None

View File

@@ -26,6 +26,8 @@ from .user_pinned_tree import UserPinnedTree
from .target_list import TargetList
from .maintenance_schedule import MaintenanceSchedule
from .feedback import Feedback
from .ai_conversation import AIConversation
from .ai_usage import AIUsage
__all__ = [
"User",
@@ -63,4 +65,6 @@ __all__ = [
"TargetList",
"MaintenanceSchedule",
"Feedback",
"AIConversation",
"AIUsage",
]

View File

@@ -24,6 +24,8 @@ class AccountLimitOverride(Base):
override_max_trees: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
override_max_sessions_per_month: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
override_max_users: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
override_max_ai_builds_per_month: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
override_max_ai_builds_per_24h: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
note: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_by_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),

View File

@@ -0,0 +1,67 @@
"""AI Flow Builder conversation tracking.
Stores wizard session state across the 4-stage flow builder process.
Conversations expire after 24 hours and are cleaned up by the scheduler.
"""
import uuid
from datetime import datetime, timezone
from typing import Optional, Any
from sqlalchemy import String, DateTime, ForeignKey, Integer, Text
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.core.database import Base
class AIConversation(Base):
__tablename__ = "ai_conversations"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
status: Mapped[str] = mapped_column(
String(20),
nullable=False,
default="foundation",
comment="foundation | scaffolding | detailing | reviewing | completed | expired",
)
# Conversation history across all wizard stages
messages: Mapped[list[dict[str, Any]]] = mapped_column(
JSONB, nullable=False, default=list
)
# Wizard state: Stage 1 metadata, Stage 2 branches, Stage 3 detail
wizard_state: Mapped[dict[str, Any]] = mapped_column(
JSONB, nullable=False, default=dict
)
# Assembled tree from Stage 4 (null until assembly)
generated_tree: Mapped[Optional[dict[str, Any]]] = mapped_column(
JSONB, nullable=True
)
# Tracks AI call count for per-flow limits
question_rounds: Mapped[int] = mapped_column(
Integer, nullable=False, default=0
)
expires_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,69 @@
"""AI usage tracking for quota enforcement and cost visibility.
Every AI API call is recorded here. Only rows with counts_toward_quota=True
and succeeded=True are counted against the user's monthly quota.
"""
import uuid
from datetime import datetime, timezone
from typing import Optional, Any
from sqlalchemy import String, DateTime, ForeignKey, Integer, Boolean, Numeric
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.core.database import Base
class AIUsage(Base):
__tablename__ = "ai_usage"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
conversation_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("ai_conversations.id", ondelete="SET NULL"),
nullable=True,
)
generation_type: Mapped[str] = mapped_column(
String(20),
nullable=False,
comment="scaffold | branch_detail | branch_suggest",
)
tier_at_time: Mapped[str] = mapped_column(
String(20),
nullable=False,
comment="free | pro | team | enterprise",
)
input_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
output_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
estimated_cost_usd: Mapped[float] = mapped_column(
Numeric(10, 6), nullable=False, default=0
)
succeeded: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
counts_toward_quota: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
error_code: Mapped[Optional[str]] = mapped_column(
String(100), nullable=True
)
extra_data: Mapped[dict[str, Any]] = mapped_column(
"metadata", JSONB, nullable=False, default=dict
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
index=True,
)

View File

@@ -14,3 +14,7 @@ class PlanLimits(Base):
custom_branding: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
priority_support: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
export_formats: Mapped[list] = mapped_column(JSONB, nullable=False, default=lambda: ["markdown", "text"])
# AI Flow Builder limits
max_ai_builds_per_month: Mapped[int | None] = mapped_column(Integer, nullable=True)
max_ai_builds_per_24h: Mapped[int | None] = mapped_column(Integer, nullable=True)

View File

@@ -67,6 +67,11 @@ class User(Base):
)
last_login: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
# AI billing cycle anchor (for quota reset calculation)
ai_billing_cycle_anchor_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Soft delete
deleted_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),

View File

@@ -5,6 +5,11 @@ from .session import SessionCreate, SessionUpdate, SessionResponse, SessionExpor
from .category import CategoryCreate, CategoryUpdate, CategoryResponse, CategoryListResponse
from .tag import TagCreate, TagResponse, TagListResponse, TagAssignment
from .folder import FolderCreate, FolderUpdate, FolderResponse, FolderListResponse, FolderReorderRequest, FolderTreeRequest
from .ai_builder import (
AIStartRequest, AIScaffoldRequest, AIBranchDetailRequest, AIAssembleRequest,
AIStartResponse, AIScaffoldResponse, AIBranchDetailResponse, AIAssembleResponse,
AIQuotaStatusResponse,
)
__all__ = [
# User
@@ -21,4 +26,8 @@ __all__ = [
"TagCreate", "TagResponse", "TagListResponse", "TagAssignment",
# Folder
"FolderCreate", "FolderUpdate", "FolderResponse", "FolderListResponse", "FolderReorderRequest", "FolderTreeRequest",
# AI Builder
"AIStartRequest", "AIScaffoldRequest", "AIBranchDetailRequest", "AIAssembleRequest",
"AIStartResponse", "AIScaffoldResponse", "AIBranchDetailResponse", "AIAssembleResponse",
"AIQuotaStatusResponse",
]

View File

@@ -0,0 +1,116 @@
"""Pydantic schemas for the AI Flow Builder wizard."""
from typing import Any, Literal, Optional
from uuid import UUID
from pydantic import BaseModel, Field
# ── Requests ──
class AIStartRequest(BaseModel):
"""Stage 1: Foundation — engineer provides flow metadata."""
flow_type: Literal["troubleshooting", "procedural"] = Field(
..., description="Type of flow to generate"
)
category_id: Optional[UUID] = None
name: str = Field(..., min_length=1, max_length=255)
description: str = Field("", max_length=2000)
environment_tags: list[str] = Field(default_factory=list)
class AIScaffoldRequest(BaseModel):
"""Stage 2: Request AI-generated branch suggestions."""
conversation_id: UUID
class AIBranchDetailRequest(BaseModel):
"""Stage 3: Request AI-generated detail for one branch."""
conversation_id: UUID
branch_name: str = Field(..., min_length=1, max_length=255)
class AIBranchUpdate(BaseModel):
"""A branch with optional user edits for assembly."""
name: str
description: str = ""
steps: Optional[dict[str, Any]] = None
class AIAssembleRequest(BaseModel):
"""Stage 4: Assemble selected branches into a complete tree."""
conversation_id: UUID
selected_branches: list[AIBranchUpdate] = Field(..., min_length=2)
# ── Responses ──
class AIStartResponse(BaseModel):
"""Response after creating a conversation."""
conversation_id: UUID
status: str
class AIBranchSuggestion(BaseModel):
"""A single branch suggestion from the AI."""
name: str
description: str
class AIScaffoldResponse(BaseModel):
"""Response with AI-suggested branches."""
conversation_id: UUID
branches: list[AIBranchSuggestion]
status: str
class AIBranchDetailResponse(BaseModel):
"""Response with AI-generated detail for one branch."""
conversation_id: UUID
branch_name: str
steps: dict[str, Any]
status: str
class AITreeSummary(BaseModel):
"""Summary statistics for an assembled tree."""
node_count: int
decision_count: int
action_count: int
solution_count: int
depth: int
class AIAssembleResponse(BaseModel):
"""Response with the fully assembled tree."""
tree_structure: dict[str, Any]
suggested_name: str
suggested_description: str
summary: AITreeSummary
status: str
class AIQuotaStatusResponse(BaseModel):
"""Current user's AI quota status."""
plan: str
monthly_used: int
monthly_limit: Optional[int]
monthly_reset_at: str
daily_used: int
daily_limit: Optional[int]
daily_reset_at: str
allowed: bool
ai_enabled: bool

View File

@@ -31,6 +31,9 @@ resend==2.21.0
# HTTP client (seed scripts, internal API calls)
httpx>=0.27.0
# AI Flow Builder
anthropic>=0.40.0
# Utilities
python-dotenv==1.0.1
croniter>=2.0.0

View File

@@ -0,0 +1,360 @@
"""Integration tests for AI Flow Builder endpoints.
All Anthropic API calls are mocked — zero real API spend.
"""
import json
from unittest.mock import AsyncMock, patch, MagicMock
import pytest
from app.core.config import settings
# ── Sample AI responses ──
SCAFFOLD_RESPONSE_JSON = json.dumps({
"branches": [
{"name": "Service Not Running", "description": "The target service is stopped or crashed."},
{"name": "Authentication Failures", "description": "Users cannot authenticate against the service."},
{"name": "Network Connectivity", "description": "Network-level issues preventing access."},
{"name": "Configuration Errors", "description": "Misconfiguration of the service or its dependencies."},
]
})
BRANCH_DETAIL_JSON = json.dumps({
"id": "svc-root",
"type": "decision",
"question": "Is the service running?",
"options": [
{"id": "opt-yes", "label": "Yes", "next_node_id": "svc-check-logs"},
{"id": "opt-no", "label": "No", "next_node_id": "svc-restart"},
],
"children": [
{
"id": "svc-check-logs",
"type": "action",
"title": "Check Event Logs",
"description": "Check Windows Event Viewer for errors.",
"commands": ["Get-EventLog -LogName Application -Newest 20"],
"children": [
{
"id": "svc-logs-resolved",
"type": "solution",
"title": "Issue Found in Logs",
"description": "Error identified and resolved.",
"resolution_steps": ["Fix the error", "Restart service"],
}
],
},
{
"id": "svc-restart",
"type": "action",
"title": "Restart Service",
"description": "Attempt to restart the service.",
"commands": ["Restart-Service -Name 'TestService'"],
"children": [
{
"id": "svc-restart-ok",
"type": "solution",
"title": "Service Restored",
"description": "Service is running after restart.",
"resolution_steps": ["Verify connectivity", "Document in ticket"],
}
],
},
],
})
def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200):
"""Create a mock Anthropic API response."""
response = MagicMock()
response.content = [MagicMock(text=text)]
response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens)
return response
@pytest.fixture
def enable_ai():
"""Temporarily enable AI by setting a fake API key."""
original = settings.ANTHROPIC_API_KEY
settings.ANTHROPIC_API_KEY = "test-key-fake"
yield
settings.ANTHROPIC_API_KEY = original
@pytest.fixture
def disable_ai():
"""Ensure AI is disabled."""
original = settings.ANTHROPIC_API_KEY
settings.ANTHROPIC_API_KEY = None
yield
settings.ANTHROPIC_API_KEY = original
# ── Quota endpoint ──
@pytest.mark.asyncio
async def test_quota_returns_disabled_when_no_key(client, auth_headers, disable_ai):
"""GET /ai/quota returns ai_enabled=false when no API key."""
response = await client.get("/api/v1/ai/quota", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert data["ai_enabled"] is False
assert data["allowed"] is False
@pytest.mark.asyncio
async def test_quota_returns_enabled_with_key(client, auth_headers, enable_ai):
"""GET /ai/quota returns ai_enabled=true with API key configured."""
response = await client.get("/api/v1/ai/quota", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert data["ai_enabled"] is True
assert data["allowed"] is True
# ── Start endpoint ──
@pytest.mark.asyncio
async def test_start_requires_auth(client, enable_ai):
"""POST /ai/start requires authentication."""
response = await client.post("/api/v1/ai/start", json={
"flow_type": "troubleshooting",
"name": "Test Flow",
"description": "Test",
})
assert response.status_code == 401
@pytest.mark.asyncio
async def test_start_returns_503_when_disabled(client, auth_headers, disable_ai):
"""POST /ai/start returns 503 when AI is not configured."""
response = await client.post(
"/api/v1/ai/start",
json={
"flow_type": "troubleshooting",
"name": "Test Flow",
"description": "Test description",
},
headers=auth_headers,
)
assert response.status_code == 503
@pytest.mark.asyncio
async def test_start_creates_conversation(client, auth_headers, enable_ai):
"""POST /ai/start creates a conversation and returns conversation_id."""
response = await client.post(
"/api/v1/ai/start",
json={
"flow_type": "troubleshooting",
"name": "DNS Issues",
"description": "Troubleshooting DNS resolution failures",
"environment_tags": ["Windows Server", "Active Directory"],
},
headers=auth_headers,
)
assert response.status_code == 201
data = response.json()
assert "conversation_id" in data
assert data["status"] == "foundation"
@pytest.mark.asyncio
async def test_start_validates_input(client, auth_headers, enable_ai):
"""POST /ai/start rejects invalid input."""
response = await client.post(
"/api/v1/ai/start",
json={
"flow_type": "troubleshooting",
"name": "", # Empty name
"description": "Test",
},
headers=auth_headers,
)
assert response.status_code == 422
# ── Scaffold endpoint ──
@pytest.mark.asyncio
async def test_scaffold_success(client, auth_headers, enable_ai):
"""POST /ai/scaffold returns AI-generated branches."""
# Create conversation first
start_resp = await client.post(
"/api/v1/ai/start",
json={
"flow_type": "troubleshooting",
"name": "DNS Issues",
"description": "DNS resolution failures",
},
headers=auth_headers,
)
conversation_id = start_resp.json()["conversation_id"]
# Mock Anthropic
mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(return_value=mock_response)
response = await client.post(
"/api/v1/ai/scaffold",
json={"conversation_id": conversation_id},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "scaffolding"
assert len(data["branches"]) == 4
assert data["branches"][0]["name"] == "Service Not Running"
@pytest.mark.asyncio
async def test_scaffold_invalid_conversation(client, auth_headers, enable_ai):
"""POST /ai/scaffold returns 404 for nonexistent conversation."""
response = await client.post(
"/api/v1/ai/scaffold",
json={"conversation_id": "00000000-0000-0000-0000-000000000000"},
headers=auth_headers,
)
assert response.status_code == 404
# ── Branch detail endpoint ──
@pytest.mark.asyncio
async def test_branch_detail_success(client, auth_headers, enable_ai):
"""POST /ai/branch-detail returns AI-generated branch nodes."""
# Create and scaffold first
start_resp = await client.post(
"/api/v1/ai/start",
json={
"flow_type": "troubleshooting",
"name": "Service Issues",
"description": "Service troubleshooting",
},
headers=auth_headers,
)
conversation_id = start_resp.json()["conversation_id"]
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
await client.post(
"/api/v1/ai/scaffold",
json={"conversation_id": conversation_id},
headers=auth_headers,
)
# Now generate branch detail
detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock)
response = await client.post(
"/api/v1/ai/branch-detail",
json={
"conversation_id": conversation_id,
"branch_name": "Service Not Running",
},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["branch_name"] == "Service Not Running"
assert data["steps"]["id"] == "svc-root"
assert data["steps"]["type"] == "decision"
# ── Assemble endpoint ──
@pytest.mark.asyncio
async def test_assemble_success(client, auth_headers, enable_ai):
"""POST /ai/assemble returns assembled tree from branches with detail."""
# Create conversation
start_resp = await client.post(
"/api/v1/ai/start",
json={
"flow_type": "troubleshooting",
"name": "Service Issues",
"description": "Service troubleshooting",
},
headers=auth_headers,
)
conversation_id = start_resp.json()["conversation_id"]
# Scaffold
scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON)
with patch("app.core.ai_tree_generator_service._get_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock)
await client.post(
"/api/v1/ai/scaffold",
json={"conversation_id": conversation_id},
headers=auth_headers,
)
# Assemble with branch detail included
branch_tree = json.loads(BRANCH_DETAIL_JSON)
response = await client.post(
"/api/v1/ai/assemble",
json={
"conversation_id": conversation_id,
"selected_branches": [
{
"name": "Service Not Running",
"description": "The target service is stopped.",
"steps": branch_tree,
},
{
"name": "Authentication Failures",
"description": "Users cannot authenticate.",
"steps": branch_tree,
},
],
},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "completed"
assert data["suggested_name"] == "Service Issues"
assert "tree_structure" in data
assert data["tree_structure"]["type"] == "decision"
assert data["summary"]["node_count"] > 0
assert data["summary"]["solution_count"] >= 2
@pytest.mark.asyncio
async def test_assemble_requires_min_2_branches(client, auth_headers, enable_ai):
"""POST /ai/assemble rejects fewer than 2 branches."""
start_resp = await client.post(
"/api/v1/ai/start",
json={
"flow_type": "troubleshooting",
"name": "Test",
"description": "Test",
},
headers=auth_headers,
)
conversation_id = start_resp.json()["conversation_id"]
response = await client.post(
"/api/v1/ai/assemble",
json={
"conversation_id": conversation_id,
"selected_branches": [
{"name": "Only Branch", "description": "Just one"},
],
},
headers=auth_headers,
)
assert response.status_code == 422

View File

@@ -0,0 +1,183 @@
"""Tests for AI-generated tree structure validation."""
import pytest
from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats
def _make_valid_tree():
"""Helper: minimal valid tree for testing."""
return {
"id": "root",
"type": "decision",
"question": "Is the service running?",
"options": [
{"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"},
{"id": "opt-no", "label": "No", "next_node_id": "restart-service"},
],
"children": [
{
"id": "check-logs",
"type": "decision",
"question": "Are there errors in the logs?",
"options": [
{"id": "opt-errors", "label": "Yes", "next_node_id": "fix-errors"},
{"id": "opt-clean", "label": "No", "next_node_id": "escalate"},
],
"children": [
{
"id": "fix-errors",
"type": "solution",
"title": "Fix Errors",
"description": "Apply the fix for the errors found.",
},
{
"id": "escalate",
"type": "solution",
"title": "Escalate",
"description": "No errors found; escalate to Tier 2.",
},
],
},
{
"id": "restart-service",
"type": "action",
"title": "Restart the Service",
"description": "Restart the service and verify.",
"commands": ["Restart-Service -Name 'TestService'"],
"children": [
{
"id": "service-resolved",
"type": "solution",
"title": "Service Restored",
"description": "Service is running after restart.",
},
],
},
],
}
class TestValidTree:
def test_valid_tree_passes(self):
errors = validate_generated_tree(_make_valid_tree())
assert errors == []
def test_not_a_dict(self):
errors = validate_generated_tree("not a dict")
assert any("must be a JSON object" in e for e in errors)
def test_root_not_decision(self):
tree = _make_valid_tree()
tree["type"] = "action"
tree["title"] = "Fake"
errors = validate_generated_tree(tree)
assert any("Root node must be type 'decision'" in e for e in errors)
class TestNodeValidation:
def test_missing_id(self):
tree = _make_valid_tree()
del tree["children"][0]["id"]
errors = validate_generated_tree(tree)
assert any("missing 'id'" in e for e in errors)
def test_duplicate_ids(self):
tree = _make_valid_tree()
tree["children"][1]["id"] = "check-logs" # same as sibling
errors = validate_generated_tree(tree)
assert any("Duplicate node ID" in e for e in errors)
def test_invalid_node_type(self):
tree = _make_valid_tree()
tree["children"][0]["type"] = "unknown"
errors = validate_generated_tree(tree)
assert any("invalid type" in e for e in errors)
def test_decision_missing_options(self):
tree = _make_valid_tree()
del tree["children"][0]["options"]
errors = validate_generated_tree(tree)
assert any("missing fields" in e for e in errors)
def test_decision_less_than_2_options(self):
tree = _make_valid_tree()
tree["children"][0]["options"] = [
{"id": "opt-1", "label": "Only", "next_node_id": "fix-errors"}
]
errors = validate_generated_tree(tree)
assert any("at least 2 options" in e for e in errors)
class TestReferenceIntegrity:
def test_option_references_nonexistent_child(self):
tree = _make_valid_tree()
tree["options"][0]["next_node_id"] = "nonexistent"
errors = validate_generated_tree(tree)
assert any("non-existent child" in e for e in errors)
def test_duplicate_option_ids(self):
tree = _make_valid_tree()
tree["options"][0]["id"] = "same"
tree["options"][1]["id"] = "same"
errors = validate_generated_tree(tree)
assert any("Duplicate option ID" in e for e in errors)
class TestGlobalChecks:
def test_too_few_nodes(self):
tree = {
"id": "root",
"type": "decision",
"question": "Test?",
"options": [
{"id": "o1", "label": "A", "next_node_id": "s1"},
{"id": "o2", "label": "B", "next_node_id": "s2"},
],
"children": [
{"id": "s1", "type": "solution", "title": "S1", "description": "D1"},
{"id": "s2", "type": "solution", "title": "S2", "description": "D2"},
],
}
errors = validate_generated_tree(tree)
assert any("Minimum 5 required" in e for e in errors)
def test_too_few_solutions(self):
tree = _make_valid_tree()
# Remove all solutions except one — replace children of check-logs
tree["children"][0]["children"] = [
{
"id": "only-solution",
"type": "solution",
"title": "Only",
"description": "Only solution",
}
]
tree["children"][0]["options"] = [
{"id": "o1", "label": "A", "next_node_id": "only-solution"},
{"id": "o2", "label": "B", "next_node_id": "only-solution"},
]
# Now restart-service branch has 1 solution, check-logs has 1 = total 2
# Remove one more to get to 1
tree["children"][1]["children"] = []
errors = validate_generated_tree(tree)
assert any("solution" in e.lower() for e in errors)
class TestDeadEndDetection:
def test_dead_end_action_node(self):
tree = _make_valid_tree()
# Remove restart-service's children — becomes dead end
tree["children"][1]["children"] = []
errors = validate_generated_tree(tree)
assert any("dead end" in e for e in errors)
class TestCountTreeStats:
def test_stats_correct(self):
tree = _make_valid_tree()
stats = count_tree_stats(tree)
assert stats["node_count"] == 6
assert stats["decision_count"] == 2
assert stats["action_count"] == 1
assert stats["solution_count"] == 3
assert stats["depth"] >= 3

View File

@@ -0,0 +1,61 @@
import { apiClient } from './client'
import type {
AIQuotaStatus,
AIStartResponse,
AIScaffoldResponse,
AIBranchDetailResponse,
AIAssembleResponse,
} from '@/types'
export const aiBuilderApi = {
getQuota: async (): Promise<AIQuotaStatus> => {
const { data } = await apiClient.get('/ai/quota')
return data
},
start: async (params: {
flow_type: 'troubleshooting' | 'procedural'
name: string
description: string
environment_tags?: string[]
category_id?: string
}): Promise<AIStartResponse> => {
const { data } = await apiClient.post('/ai/start', params)
return data
},
scaffold: async (conversationId: string): Promise<AIScaffoldResponse> => {
const { data } = await apiClient.post('/ai/scaffold', {
conversation_id: conversationId,
})
return data
},
branchDetail: async (
conversationId: string,
branchName: string
): Promise<AIBranchDetailResponse> => {
const { data } = await apiClient.post('/ai/branch-detail', {
conversation_id: conversationId,
branch_name: branchName,
})
return data
},
assemble: async (
conversationId: string,
selectedBranches: Array<{
name: string
description: string
steps?: Record<string, unknown>
}>
): Promise<AIAssembleResponse> => {
const { data } = await apiClient.post('/ai/assemble', {
conversation_id: conversationId,
selected_branches: selectedBranches,
})
return data
},
}
export default aiBuilderApi

View File

@@ -16,3 +16,4 @@ export { default as analyticsApi } from './analytics'
export { targetListsApi } from './targetLists'
export { maintenanceSchedulesApi, batchLaunchApi } from './maintenanceSchedules'
export { default as feedbackApi } from './feedback'
export { default as aiBuilderApi } from './aiBuilder'

View File

@@ -0,0 +1,135 @@
import { useEffect } from 'react'
import { useNavigate } from 'react-router-dom'
import { Modal } from '@/components/common/Modal'
import { useAIFlowBuilderStore } from '@/store/aiFlowBuilderStore'
import { treesApi } from '@/api/trees'
import { toast } from '@/lib/toast'
import { WizardStepIndicator } from './WizardStepIndicator'
import { FoundationForm } from './FoundationForm'
import { BranchSelector } from './BranchSelector'
import { BranchDetailView } from './BranchDetailView'
import { TreePreviewCard } from './TreePreviewCard'
import { GeneratingAnimation } from './GeneratingAnimation'
interface AIFlowBuilderModalProps {
isOpen: boolean
onClose: () => void
}
export function AIFlowBuilderModal({ isOpen, onClose }: AIFlowBuilderModalProps) {
const navigate = useNavigate()
const {
phase,
metadata,
assembledTree,
loadQuota,
scaffold,
reset,
} = useAIFlowBuilderStore()
// Load quota when modal opens
useEffect(() => {
if (isOpen) {
loadQuota()
}
}, [isOpen, loadQuota])
// Auto-trigger scaffold after conversation starts
useEffect(() => {
if (phase === 'scaffolding' && !useAIFlowBuilderStore.getState().suggestedBranches.length) {
scaffold()
}
}, [phase, scaffold])
const handleClose = () => {
reset()
onClose()
}
const handleOpenInEditor = async () => {
if (!assembledTree) return
try {
const tree = await treesApi.create({
name: assembledTree.suggested_name,
description: assembledTree.suggested_description,
tree_structure: assembledTree.tree_structure,
tree_type: metadata.flow_type,
})
handleClose()
const editorPath =
metadata.flow_type === 'procedural'
? `/flows/${tree.id}/edit`
: `/trees/${tree.id}/edit`
navigate(editorPath)
} catch {
toast.error('Failed to create flow. Please try again.')
}
}
const getTitle = () => {
switch (phase) {
case 'foundation':
return 'Build with AI'
case 'scaffolding':
case 'generating':
return 'AI Scaffold'
case 'detailing':
return 'Branch Detail'
case 'reviewing':
return 'Review & Assemble'
case 'error':
return 'AI Flow Builder'
default:
return 'Build with AI'
}
}
return (
<Modal
isOpen={isOpen}
onClose={handleClose}
title={getTitle()}
size="lg"
footer={
<WizardStepIndicator phase={phase} />
}
>
{phase === 'foundation' && <FoundationForm />}
{phase === 'scaffolding' && <BranchSelector />}
{phase === 'generating' && <GeneratingAnimation />}
{phase === 'detailing' && <BranchDetailView />}
{phase === 'reviewing' && (
<TreePreviewCard onOpenInEditor={handleOpenInEditor} />
)}
{phase === 'error' && <ErrorView />}
</Modal>
)
}
function ErrorView() {
const { error, reset, setPhase } = useAIFlowBuilderStore()
return (
<div className="flex flex-col items-center gap-4 py-8">
<div className="rounded-lg border border-red-400/20 bg-red-400/5 px-4 py-3 text-sm text-red-400">
{error || 'An unexpected error occurred.'}
</div>
<div className="flex gap-2">
<button
type="button"
onClick={() => setPhase('foundation')}
className="rounded-lg border border-border px-4 py-2 text-sm text-muted-foreground hover:bg-accent hover:text-foreground"
>
Go Back
</button>
<button
type="button"
onClick={reset}
className="rounded-lg bg-gradient-brand px-4 py-2 text-sm font-medium text-white shadow-lg shadow-primary/20 hover:opacity-90"
>
Start Over
</button>
</div>
</div>
)
}

View File

@@ -0,0 +1,208 @@
import { useState } from 'react'
import { Check, RefreshCw, SkipForward, ChevronRight, ChevronLeft } from 'lucide-react'
import { useAIFlowBuilderStore } from '@/store/aiFlowBuilderStore'
import { GeneratingAnimation } from './GeneratingAnimation'
import { cn } from '@/lib/utils'
export function BranchDetailView() {
const {
selectedBranches,
generateBranchDetail,
assemble,
isLoading,
error,
phase,
setError,
} = useAIFlowBuilderStore()
const [viewingIndex, setViewingIndex] = useState(0)
const currentBranch = selectedBranches[viewingIndex]
const allBranchesHaveDetail = selectedBranches.every((b) => b.steps)
const branchesWithDetail = selectedBranches.filter((b) => b.steps).length
const handleGenerate = async (branchName: string) => {
setError(null)
await generateBranchDetail(branchName)
}
const handleAssemble = async () => {
await assemble()
}
if (phase === 'generating' && isLoading) {
return <GeneratingAnimation />
}
return (
<div className="space-y-4">
{/* Branch tabs */}
<div className="flex items-center gap-2 overflow-x-auto pb-1">
{selectedBranches.map((branch, i) => (
<button
key={branch.name}
type="button"
onClick={() => setViewingIndex(i)}
className={cn(
'flex shrink-0 items-center gap-1.5 rounded-full border px-3 py-1.5 text-xs font-medium transition-colors',
viewingIndex === i
? 'border-primary/30 bg-primary/10 text-foreground'
: 'border-border text-muted-foreground hover:bg-accent',
branch.steps && 'pr-2'
)}
>
{branch.name}
{branch.steps && (
<Check className="h-3 w-3 text-green-400" />
)}
</button>
))}
</div>
{/* Current branch detail */}
{currentBranch && (
<div className="space-y-3">
<div className="flex items-center justify-between">
<div>
<h3 className="text-sm font-medium text-foreground">{currentBranch.name}</h3>
<p className="text-xs text-muted-foreground">{currentBranch.description}</p>
</div>
</div>
{currentBranch.steps ? (
<div className="space-y-3">
{/* Mini tree preview */}
<div className="rounded-lg border border-border bg-accent/30 p-3">
<NodePreview node={currentBranch.steps} depth={0} />
</div>
<div className="flex items-center gap-2">
<button
type="button"
onClick={() => handleGenerate(currentBranch.name)}
disabled={isLoading}
className="flex items-center gap-1.5 rounded-lg border border-border px-3 py-1.5 text-xs text-muted-foreground hover:bg-accent hover:text-foreground disabled:opacity-50"
>
<RefreshCw className="h-3 w-3" />
Regenerate
</button>
</div>
</div>
) : (
<div className="flex flex-col items-center gap-3 rounded-lg border border-dashed border-border bg-accent/20 py-8">
<p className="text-sm text-muted-foreground">
Generate AI detail for this branch
</p>
<div className="flex gap-2">
<button
type="button"
onClick={() => handleGenerate(currentBranch.name)}
disabled={isLoading}
className={cn(
'rounded-lg bg-gradient-brand px-4 py-2 text-sm font-medium text-white shadow-lg shadow-primary/20',
isLoading ? 'cursor-not-allowed opacity-50' : 'hover:opacity-90'
)}
>
Generate Detail
</button>
<button
type="button"
onClick={() => {
if (viewingIndex < selectedBranches.length - 1) {
setViewingIndex(viewingIndex + 1)
}
}}
className="flex items-center gap-1 rounded-lg border border-border px-3 py-2 text-sm text-muted-foreground hover:bg-accent"
>
<SkipForward className="h-3.5 w-3.5" />
Skip
</button>
</div>
</div>
)}
</div>
)}
{/* Error */}
{error && (
<div className="rounded-lg border border-red-400/20 bg-red-400/5 px-3 py-2 text-sm text-red-400">
{error}
</div>
)}
{/* Navigation */}
<div className="flex items-center justify-between border-t border-border pt-3">
<div className="flex items-center gap-2">
<button
type="button"
onClick={() => setViewingIndex(Math.max(0, viewingIndex - 1))}
disabled={viewingIndex === 0}
className="flex items-center gap-1 rounded-lg border border-border px-3 py-1.5 text-xs text-muted-foreground hover:bg-accent disabled:opacity-30"
>
<ChevronLeft className="h-3.5 w-3.5" />
Previous
</button>
<button
type="button"
onClick={() =>
setViewingIndex(Math.min(selectedBranches.length - 1, viewingIndex + 1))
}
disabled={viewingIndex === selectedBranches.length - 1}
className="flex items-center gap-1 rounded-lg border border-border px-3 py-1.5 text-xs text-muted-foreground hover:bg-accent disabled:opacity-30"
>
Next
<ChevronRight className="h-3.5 w-3.5" />
</button>
</div>
<div className="flex items-center gap-3">
<span className="text-xs text-muted-foreground">
{branchesWithDetail}/{selectedBranches.length} detailed
</span>
<button
type="button"
onClick={handleAssemble}
disabled={!allBranchesHaveDetail || isLoading}
className={cn(
'rounded-lg bg-gradient-brand px-4 py-2 text-sm font-medium text-white shadow-lg shadow-primary/20',
allBranchesHaveDetail && !isLoading
? 'hover:opacity-90'
: 'cursor-not-allowed opacity-50'
)}
>
Assemble Tree
</button>
</div>
</div>
</div>
)
}
/** Recursive mini-preview of a node tree */
function NodePreview({ node, depth }: { node: Record<string, unknown>; depth: number }) {
const type = node.type as string
const label =
type === 'decision'
? (node.question as string)
: (node.title as string) || 'Untitled'
const children = (node.children as Record<string, unknown>[]) || []
const typeColors: Record<string, string> = {
decision: 'bg-blue-400',
action: 'bg-amber-400',
solution: 'bg-green-400',
}
return (
<div style={{ marginLeft: depth * 16 }}>
<div className="flex items-center gap-2 py-0.5">
<div className={cn('h-2 w-2 rounded-full', typeColors[type] || 'bg-muted-foreground')} />
<span className="text-xs text-foreground truncate">{label}</span>
<span className="text-[10px] font-label text-muted-foreground">{type}</span>
</div>
{children.map((child, i) => (
<NodePreview key={i} node={child} depth={depth + 1} />
))}
</div>
)
}

View File

@@ -0,0 +1,280 @@
import { useState } from 'react'
import { GripVertical, Plus, X, Pencil, Check } from 'lucide-react'
import { useAIFlowBuilderStore } from '@/store/aiFlowBuilderStore'
import { cn } from '@/lib/utils'
import type { AIBranch } from '@/types'
export function BranchSelector() {
const {
suggestedBranches,
selectedBranches,
selectBranches,
setPhase,
error,
} = useAIFlowBuilderStore()
const [editingIndex, setEditingIndex] = useState<number | null>(null)
const [editName, setEditName] = useState('')
const [editDesc, setEditDesc] = useState('')
const [showAddForm, setShowAddForm] = useState(false)
const [newName, setNewName] = useState('')
const [newDesc, setNewDesc] = useState('')
const toggleBranch = (branch: AIBranch) => {
const isSelected = selectedBranches.some((b) => b.name === branch.name)
if (isSelected) {
selectBranches(selectedBranches.filter((b) => b.name !== branch.name))
} else {
selectBranches([...selectedBranches, branch])
}
}
const startEditing = (index: number) => {
const branch = selectedBranches[index]
setEditingIndex(index)
setEditName(branch.name)
setEditDesc(branch.description)
}
const saveEdit = () => {
if (editingIndex === null || !editName.trim()) return
const updated = [...selectedBranches]
updated[editingIndex] = {
...updated[editingIndex],
name: editName.trim(),
description: editDesc.trim(),
}
selectBranches(updated)
setEditingIndex(null)
}
const addCustomBranch = () => {
if (!newName.trim()) return
const branch: AIBranch = {
name: newName.trim(),
description: newDesc.trim(),
isCustom: true,
}
selectBranches([...selectedBranches, branch])
setNewName('')
setNewDesc('')
setShowAddForm(false)
}
const moveBranch = (fromIndex: number, direction: 'up' | 'down') => {
const toIndex = direction === 'up' ? fromIndex - 1 : fromIndex + 1
if (toIndex < 0 || toIndex >= selectedBranches.length) return
const updated = [...selectedBranches]
;[updated[fromIndex], updated[toIndex]] = [updated[toIndex], updated[fromIndex]]
selectBranches(updated)
}
const canProceed = selectedBranches.length >= 2
return (
<div className="space-y-4">
<div>
<p className="text-sm text-muted-foreground">
AI suggested {suggestedBranches.length} branches. Select, reorder, rename, or add your own.
</p>
</div>
{/* Branch list */}
<div className="space-y-2">
{suggestedBranches.map((branch) => {
const isSelected = selectedBranches.some((b) => b.name === branch.name)
const selectedIndex = selectedBranches.findIndex((b) => b.name === branch.name)
return (
<div
key={branch.name}
className={cn(
'flex items-start gap-3 rounded-lg border p-3 transition-colors cursor-pointer',
isSelected
? 'border-primary/30 bg-primary/5'
: 'border-border bg-card hover:bg-accent/50'
)}
onClick={() => toggleBranch(branch)}
>
<div
className={cn(
'mt-0.5 flex h-5 w-5 shrink-0 items-center justify-center rounded border',
isSelected
? 'border-primary bg-primary text-white'
: 'border-border'
)}
>
{isSelected && <Check className="h-3 w-3" />}
</div>
<div className="flex-1 min-w-0">
{editingIndex !== null && selectedIndex === editingIndex ? (
<div className="space-y-2" onClick={(e) => e.stopPropagation()}>
<input
type="text"
value={editName}
onChange={(e) => setEditName(e.target.value)}
className="w-full rounded border border-border bg-card px-2 py-1 text-sm text-foreground focus:border-primary focus:outline-none"
autoFocus
/>
<input
type="text"
value={editDesc}
onChange={(e) => setEditDesc(e.target.value)}
className="w-full rounded border border-border bg-card px-2 py-1 text-xs text-muted-foreground focus:border-primary focus:outline-none"
/>
<div className="flex gap-1">
<button
type="button"
onClick={saveEdit}
className="rounded bg-primary px-2 py-0.5 text-xs text-white"
>
Save
</button>
<button
type="button"
onClick={() => setEditingIndex(null)}
className="rounded border border-border px-2 py-0.5 text-xs text-muted-foreground"
>
Cancel
</button>
</div>
</div>
) : (
<>
<div className="text-sm font-medium text-foreground">{branch.name}</div>
<div className="text-xs text-muted-foreground">{branch.description}</div>
</>
)}
</div>
{isSelected && editingIndex !== selectedIndex && (
<div className="flex items-center gap-0.5" onClick={(e) => e.stopPropagation()}>
<button
type="button"
onClick={() => moveBranch(selectedIndex, 'up')}
disabled={selectedIndex === 0}
className="rounded p-1 text-muted-foreground hover:text-foreground disabled:opacity-30"
title="Move up"
>
<GripVertical className="h-3.5 w-3.5" />
</button>
<button
type="button"
onClick={() => startEditing(selectedIndex)}
className="rounded p-1 text-muted-foreground hover:text-foreground"
title="Edit"
>
<Pencil className="h-3.5 w-3.5" />
</button>
</div>
)}
</div>
)
})}
{/* Custom branches (not in suggested) */}
{selectedBranches
.filter((b) => b.isCustom)
.map((branch, i) => {
return (
<div
key={`custom-${i}`}
className="flex items-start gap-3 rounded-lg border border-primary/30 bg-primary/5 p-3"
>
<div className="mt-0.5 flex h-5 w-5 shrink-0 items-center justify-center rounded border border-primary bg-primary text-white">
<Check className="h-3 w-3" />
</div>
<div className="flex-1 min-w-0">
<div className="text-sm font-medium text-foreground">{branch.name}</div>
<div className="text-xs text-muted-foreground">{branch.description}</div>
<span className="mt-1 inline-block rounded bg-primary/10 px-1.5 py-0.5 text-[10px] font-label text-primary">
Custom
</span>
</div>
<button
type="button"
onClick={() =>
selectBranches(selectedBranches.filter((b) => b.name !== branch.name))
}
className="rounded p-1 text-muted-foreground hover:text-red-400"
>
<X className="h-3.5 w-3.5" />
</button>
</div>
)
})}
</div>
{/* Add custom branch */}
{showAddForm ? (
<div className="space-y-2 rounded-lg border border-dashed border-border p-3">
<input
type="text"
value={newName}
onChange={(e) => setNewName(e.target.value)}
placeholder="Branch name"
className="w-full rounded border border-border bg-card px-2 py-1.5 text-sm text-foreground placeholder:text-muted-foreground focus:border-primary focus:outline-none"
autoFocus
/>
<input
type="text"
value={newDesc}
onChange={(e) => setNewDesc(e.target.value)}
placeholder="Brief description"
className="w-full rounded border border-border bg-card px-2 py-1.5 text-xs text-muted-foreground placeholder:text-muted-foreground/60 focus:border-primary focus:outline-none"
/>
<div className="flex gap-1">
<button
type="button"
onClick={addCustomBranch}
disabled={!newName.trim()}
className="rounded bg-primary px-2.5 py-1 text-xs text-white disabled:opacity-50"
>
Add
</button>
<button
type="button"
onClick={() => setShowAddForm(false)}
className="rounded border border-border px-2.5 py-1 text-xs text-muted-foreground"
>
Cancel
</button>
</div>
</div>
) : (
<button
type="button"
onClick={() => setShowAddForm(true)}
className="flex items-center gap-1.5 text-sm text-muted-foreground hover:text-foreground"
>
<Plus className="h-4 w-4" />
Add custom branch
</button>
)}
{/* Error */}
{error && (
<div className="rounded-lg border border-red-400/20 bg-red-400/5 px-3 py-2 text-sm text-red-400">
{error}
</div>
)}
{/* Footer */}
<div className="flex items-center justify-between pt-2">
<span className="text-xs text-muted-foreground">
{selectedBranches.length} branch{selectedBranches.length !== 1 ? 'es' : ''} selected (min 2)
</span>
<button
type="button"
onClick={() => setPhase('detailing')}
disabled={!canProceed}
className={cn(
'rounded-lg bg-gradient-brand px-4 py-2 text-sm font-medium text-white shadow-lg shadow-primary/20',
canProceed ? 'hover:opacity-90' : 'cursor-not-allowed opacity-50'
)}
>
Continue to Detail
</button>
</div>
</div>
)
}

View File

@@ -0,0 +1,163 @@
import { useState } from 'react'
import { useAIFlowBuilderStore } from '@/store/aiFlowBuilderStore'
import { QuotaDisplay } from './QuotaDisplay'
import { cn } from '@/lib/utils'
export function FoundationForm() {
const { metadata, setMetadata, quota, start, isLoading, error } = useAIFlowBuilderStore()
const [tagInput, setTagInput] = useState('')
const canSubmit =
metadata.name.trim().length > 0 &&
metadata.description.trim().length > 0 &&
!isLoading &&
(quota?.allowed !== false)
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault()
if (!canSubmit) return
await start()
}
const addTag = () => {
const tag = tagInput.trim()
if (tag && !metadata.environment_tags.includes(tag)) {
setMetadata({ environment_tags: [...metadata.environment_tags, tag] })
}
setTagInput('')
}
const removeTag = (tag: string) => {
setMetadata({ environment_tags: metadata.environment_tags.filter((t) => t !== tag) })
}
return (
<form onSubmit={handleSubmit} className="space-y-5">
{quota && <QuotaDisplay quota={quota} />}
{/* Flow Type */}
<div>
<label className="mb-2 block font-label text-[0.6875rem] uppercase tracking-wide text-muted-foreground">
Flow Type
</label>
<div className="flex gap-2">
{(['troubleshooting', 'procedural'] as const).map((type) => (
<button
key={type}
type="button"
onClick={() => setMetadata({ flow_type: type })}
className={cn(
'flex-1 rounded-lg border px-3 py-2.5 text-sm font-medium transition-colors',
metadata.flow_type === type
? 'border-primary/30 bg-primary/10 text-foreground'
: 'border-border bg-card text-muted-foreground hover:bg-accent'
)}
>
{type === 'troubleshooting' ? 'Troubleshooting' : 'Procedural'}
</button>
))}
</div>
</div>
{/* Name */}
<div>
<label className="mb-2 block font-label text-[0.6875rem] uppercase tracking-wide text-muted-foreground">
Flow Name
</label>
<input
type="text"
value={metadata.name}
onChange={(e) => setMetadata({ name: e.target.value })}
placeholder="e.g. DNS Resolution Failures"
className="w-full rounded-lg border border-border bg-card px-3 py-2 text-sm text-foreground placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary/20"
maxLength={255}
/>
</div>
{/* Description */}
<div>
<label className="mb-2 block font-label text-[0.6875rem] uppercase tracking-wide text-muted-foreground">
Description
</label>
<textarea
value={metadata.description}
onChange={(e) => setMetadata({ description: e.target.value })}
placeholder="Describe what this flow covers. The more detail you provide, the better the AI suggestions will be."
rows={4}
className="w-full rounded-lg border border-border bg-card px-3 py-2 text-sm text-foreground placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary/20 resize-none"
maxLength={2000}
/>
<p className="mt-1 text-right text-[10px] text-muted-foreground">
{metadata.description.length}/2000
</p>
</div>
{/* Environment Tags */}
<div>
<label className="mb-2 block font-label text-[0.6875rem] uppercase tracking-wide text-muted-foreground">
Environment Tags <span className="normal-case tracking-normal text-muted-foreground/60">(optional)</span>
</label>
<div className="flex gap-2">
<input
type="text"
value={tagInput}
onChange={(e) => setTagInput(e.target.value)}
onKeyDown={(e) => {
if (e.key === 'Enter') {
e.preventDefault()
addTag()
}
}}
placeholder="e.g. Windows Server, Active Directory"
className="flex-1 rounded-lg border border-border bg-card px-3 py-2 text-sm text-foreground placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary/20"
/>
<button
type="button"
onClick={addTag}
className="rounded-lg border border-border px-3 py-2 text-sm text-muted-foreground hover:bg-accent hover:text-foreground"
>
Add
</button>
</div>
{metadata.environment_tags.length > 0 && (
<div className="mt-2 flex flex-wrap gap-1.5">
{metadata.environment_tags.map((tag) => (
<span
key={tag}
className="inline-flex items-center gap-1 rounded-full bg-card border border-border px-2.5 py-0.5 font-label text-xs text-muted-foreground"
>
{tag}
<button
type="button"
onClick={() => removeTag(tag)}
className="ml-0.5 text-muted-foreground/60 hover:text-foreground"
>
&times;
</button>
</span>
))}
</div>
)}
</div>
{/* Error */}
{error && (
<div className="rounded-lg border border-red-400/20 bg-red-400/5 px-3 py-2 text-sm text-red-400">
{error}
</div>
)}
{/* Submit */}
<button
type="submit"
disabled={!canSubmit}
className={cn(
'w-full rounded-lg bg-gradient-brand py-2.5 text-sm font-medium text-white shadow-lg shadow-primary/20',
canSubmit ? 'hover:opacity-90' : 'cursor-not-allowed opacity-50'
)}
>
{isLoading ? 'Creating...' : 'Continue to AI Scaffold'}
</button>
</form>
)
}

View File

@@ -0,0 +1,33 @@
import { useEffect, useState } from 'react'
import { Sparkles } from 'lucide-react'
const MESSAGES = [
'Analyzing your flow requirements...',
'Building decision paths...',
'Generating troubleshooting logic...',
'Crafting resolution steps...',
'Structuring the flow...',
]
export function GeneratingAnimation() {
const [messageIndex, setMessageIndex] = useState(0)
useEffect(() => {
const interval = setInterval(() => {
setMessageIndex((prev) => (prev + 1) % MESSAGES.length)
}, 3000)
return () => clearInterval(interval)
}, [])
return (
<div className="flex flex-col items-center justify-center gap-4 py-12">
<div className="relative">
<div className="h-12 w-12 animate-spin rounded-full border-4 border-border border-t-primary" />
<Sparkles className="absolute left-1/2 top-1/2 h-5 w-5 -translate-x-1/2 -translate-y-1/2 text-primary" />
</div>
<p className="text-sm text-muted-foreground animate-pulse">
{MESSAGES[messageIndex]}
</p>
</div>
)
}

View File

@@ -0,0 +1,48 @@
import { cn } from '@/lib/utils'
import type { AIQuotaStatus } from '@/types'
interface QuotaDisplayProps {
quota: AIQuotaStatus
compact?: boolean
}
export function QuotaDisplay({ quota, compact = false }: QuotaDisplayProps) {
if (!quota.ai_enabled) return null
const monthlyRemaining =
quota.monthly_limit !== null
? Math.max(0, quota.monthly_limit - quota.monthly_used)
: null
const getColor = () => {
if (!quota.allowed) return 'text-red-400'
if (monthlyRemaining !== null && monthlyRemaining <= 1) return 'text-amber-400'
return 'text-green-400'
}
if (compact) {
return (
<span className={cn('text-xs font-label', getColor())}>
{monthlyRemaining !== null
? `${monthlyRemaining}/${quota.monthly_limit} builds`
: 'Unlimited'}
</span>
)
}
return (
<div className="flex items-center gap-2 rounded-lg border border-border bg-accent/50 px-3 py-1.5">
<div className={cn('h-2 w-2 rounded-full', getColor().replace('text-', 'bg-'))} />
<span className="text-xs text-muted-foreground">
{monthlyRemaining !== null ? (
<>
<span className={cn('font-medium', getColor())}>{monthlyRemaining}</span>
{' '}of {quota.monthly_limit} AI builds remaining
</>
) : (
'Unlimited AI builds'
)}
</span>
</div>
)
}

View File

@@ -0,0 +1,85 @@
import { GitBranch, Layers, CheckCircle, ArrowRight, RotateCcw } from 'lucide-react'
import { useAIFlowBuilderStore } from '@/store/aiFlowBuilderStore'
import { cn } from '@/lib/utils'
interface TreePreviewCardProps {
onOpenInEditor: () => void
}
export function TreePreviewCard({ onOpenInEditor }: TreePreviewCardProps) {
const { assembledTree, reset, isLoading } = useAIFlowBuilderStore()
if (!assembledTree) return null
const { summary } = assembledTree
const stats = [
{ label: 'Nodes', value: summary.node_count, icon: Layers },
{ label: 'Decisions', value: summary.decision_count, icon: GitBranch },
{ label: 'Solutions', value: summary.solution_count, icon: CheckCircle },
{ label: 'Depth', value: summary.depth, icon: Layers },
]
return (
<div className="space-y-4">
<div className="text-center">
<div className="mx-auto mb-3 flex h-12 w-12 items-center justify-center rounded-full bg-green-400/10">
<CheckCircle className="h-6 w-6 text-green-400" />
</div>
<h3 className="text-lg font-semibold text-foreground">
Tree Assembled
</h3>
<p className="mt-1 text-sm text-muted-foreground">
&quot;{assembledTree.suggested_name}&quot; is ready to review in the editor.
</p>
</div>
{/* Stats grid */}
<div className="grid grid-cols-4 gap-2">
{stats.map(({ label, value, icon: Icon }) => (
<div
key={label}
className="flex flex-col items-center rounded-lg border border-border bg-accent/30 p-2.5"
>
<Icon className="mb-1 h-4 w-4 text-muted-foreground" />
<span className="text-lg font-semibold text-gradient-brand">{value}</span>
<span className="text-[10px] font-label uppercase tracking-wide text-muted-foreground">
{label}
</span>
</div>
))}
</div>
{/* Description */}
{assembledTree.suggested_description && (
<div className="rounded-lg border border-border bg-accent/20 p-3">
<p className="text-xs text-muted-foreground">{assembledTree.suggested_description}</p>
</div>
)}
{/* Actions */}
<div className="flex gap-2">
<button
type="button"
onClick={onOpenInEditor}
disabled={isLoading}
className={cn(
'flex flex-1 items-center justify-center gap-2 rounded-lg bg-gradient-brand py-2.5 text-sm font-medium text-white shadow-lg shadow-primary/20',
'hover:opacity-90'
)}
>
<ArrowRight className="h-4 w-4" />
Open in Editor
</button>
<button
type="button"
onClick={reset}
className="flex items-center gap-2 rounded-lg border border-border px-4 py-2.5 text-sm text-muted-foreground hover:bg-accent hover:text-foreground"
>
<RotateCcw className="h-4 w-4" />
Start Over
</button>
</div>
</div>
)
}

View File

@@ -0,0 +1,70 @@
import { Check } from 'lucide-react'
import { cn } from '@/lib/utils'
import type { AIWizardPhase } from '@/types'
const STEPS = [
{ key: 'foundation', label: 'Foundation' },
{ key: 'scaffolding', label: 'Scaffold' },
{ key: 'detailing', label: 'Detail' },
{ key: 'reviewing', label: 'Review' },
] as const
const PHASE_ORDER: Record<string, number> = {
foundation: 0,
scaffolding: 1,
generating: 1,
detailing: 2,
reviewing: 3,
completed: 4,
error: -1,
}
interface WizardStepIndicatorProps {
phase: AIWizardPhase
}
export function WizardStepIndicator({ phase }: WizardStepIndicatorProps) {
const currentIndex = PHASE_ORDER[phase] ?? 0
return (
<div className="flex items-center gap-1 px-2">
{STEPS.map((step, i) => {
const isCompleted = currentIndex > i
const isCurrent = currentIndex === i
return (
<div key={step.key} className="flex items-center gap-1">
{i > 0 && (
<div
className={cn(
'h-px w-4 sm:w-6',
isCompleted ? 'bg-primary' : 'bg-border'
)}
/>
)}
<div className="flex items-center gap-1.5">
<div
className={cn(
'flex h-5 w-5 items-center justify-center rounded-full text-[10px] font-medium',
isCompleted && 'bg-primary text-white',
isCurrent && 'bg-primary/20 text-primary ring-1 ring-primary/40',
!isCompleted && !isCurrent && 'bg-accent text-muted-foreground'
)}
>
{isCompleted ? <Check className="h-3 w-3" /> : i + 1}
</div>
<span
className={cn(
'hidden text-xs sm:inline',
isCurrent ? 'font-medium text-foreground' : 'text-muted-foreground'
)}
>
{step.label}
</span>
</div>
</div>
)
})}
</div>
)
}

View File

@@ -1,6 +1,6 @@
import { useEffect, useState } from 'react'
import { useNavigate, Link } from 'react-router-dom'
import { Play, Pencil, Share2, Trash2, GitBranch, Clock, TrendingUp, FolderTree, Plus, ListOrdered, ChevronDown, Wrench } from 'lucide-react'
import { Play, Pencil, Share2, Trash2, GitBranch, Clock, TrendingUp, FolderTree, Plus, ListOrdered, ChevronDown, Wrench, Sparkles } from 'lucide-react'
import { treesApi } from '@/api/trees'
import { sessionsApi } from '@/api/sessions'
import type { TreeListItem } from '@/types'
@@ -12,6 +12,8 @@ import { cn } from '@/lib/utils'
import { useAuthStore } from '@/store/authStore'
import { usePermissions } from '@/hooks/usePermissions'
import { toast } from '@/lib/toast'
import { AIFlowBuilderModal } from '@/components/ai-builder/AIFlowBuilderModal'
import { aiBuilderApi } from '@/api/aiBuilder'
interface TreeWithStats extends TreeListItem {
lastUsed?: string
@@ -32,11 +34,17 @@ export function MyTreesPage() {
const [treeToShare, setTreeToShare] = useState<TreeWithStats | null>(null)
const [showShareModal, setShowShareModal] = useState(false)
const [showCreateMenu, setShowCreateMenu] = useState(false)
const [showAIBuilder, setShowAIBuilder] = useState(false)
const [aiEnabled, setAiEnabled] = useState(false)
useEffect(() => {
loadMyTrees()
}, [user?.id])
useEffect(() => {
aiBuilderApi.getQuota().then((q) => setAiEnabled(q.ai_enabled)).catch(() => {})
}, [])
const loadMyTrees = async () => {
if (!user?.id) return
setIsLoading(true)
@@ -168,6 +176,25 @@ export function MyTreesPage() {
<div className="text-xs text-muted-foreground">Scheduled multi-target tasks</div>
</div>
</Link>
{aiEnabled && (
<>
<div className="my-1 border-t border-border" />
<button
type="button"
onClick={() => {
setShowCreateMenu(false)
setShowAIBuilder(true)
}}
className="flex w-full items-center gap-3 rounded-md px-3 py-2.5 text-sm text-foreground hover:bg-accent"
>
<Sparkles className="h-4 w-4 text-primary" />
<div className="text-left">
<div className="font-medium">Build with AI</div>
<div className="text-xs text-muted-foreground">AI-assisted flow creation</div>
</div>
</button>
</>
)}
</div>
</>
)}
@@ -373,6 +400,12 @@ export function MyTreesPage() {
}}
/>
)}
{/* AI Flow Builder Modal */}
<AIFlowBuilderModal
isOpen={showAIBuilder}
onClose={() => setShowAIBuilder(false)}
/>
</div>
)
}

View File

@@ -0,0 +1,201 @@
import { create } from 'zustand'
import { aiBuilderApi } from '@/api/aiBuilder'
import type { AIQuotaStatus, AIBranch, AIAssembleResponse, AIWizardPhase } from '@/types'
interface AIFlowBuilderState {
// Wizard state
phase: AIWizardPhase
conversationId: string | null
metadata: {
flow_type: 'troubleshooting' | 'procedural'
name: string
description: string
environment_tags: string[]
category_id: string | null
}
// Stage 2
suggestedBranches: AIBranch[]
selectedBranches: AIBranch[]
// Stage 3
currentBranchIndex: number
// Stage 4
assembledTree: AIAssembleResponse | null
// Quota
quota: AIQuotaStatus | null
// UI state
isLoading: boolean
error: string | null
// Actions
loadQuota: () => Promise<void>
setMetadata: (metadata: Partial<AIFlowBuilderState['metadata']>) => void
start: () => Promise<void>
scaffold: () => Promise<void>
selectBranches: (branches: AIBranch[]) => void
generateBranchDetail: (branchName: string) => Promise<void>
assemble: () => Promise<void>
reset: () => void
setPhase: (phase: AIWizardPhase) => void
setError: (error: string | null) => void
}
const initialMetadata = {
flow_type: 'troubleshooting' as const,
name: '',
description: '',
environment_tags: [] as string[],
category_id: null as string | null,
}
export const useAIFlowBuilderStore = create<AIFlowBuilderState>()((set, get) => ({
phase: 'foundation',
conversationId: null,
metadata: { ...initialMetadata },
suggestedBranches: [],
selectedBranches: [],
currentBranchIndex: 0,
assembledTree: null,
quota: null,
isLoading: false,
error: null,
loadQuota: async () => {
try {
const quota = await aiBuilderApi.getQuota()
set({ quota })
} catch {
// Silently fail — quota display is optional
}
},
setMetadata: (metadata) => {
set((state) => ({
metadata: { ...state.metadata, ...metadata },
}))
},
start: async () => {
const { metadata } = get()
set({ isLoading: true, error: null })
try {
const response = await aiBuilderApi.start({
flow_type: metadata.flow_type,
name: metadata.name,
description: metadata.description,
environment_tags: metadata.environment_tags,
category_id: metadata.category_id ?? undefined,
})
set({
conversationId: response.conversation_id,
phase: 'scaffolding',
isLoading: false,
})
} catch (err) {
const message = _extractError(err)
set({ error: message, isLoading: false })
}
},
scaffold: async () => {
const { conversationId } = get()
if (!conversationId) return
set({ isLoading: true, error: null, phase: 'generating' })
try {
const response = await aiBuilderApi.scaffold(conversationId)
const branches: AIBranch[] = response.branches.map((b) => ({
name: b.name,
description: b.description,
}))
set({
suggestedBranches: branches,
selectedBranches: branches,
phase: 'scaffolding',
isLoading: false,
})
} catch (err) {
const message = _extractError(err)
set({ error: message, phase: 'error', isLoading: false })
}
},
selectBranches: (branches) => {
set({ selectedBranches: branches })
},
generateBranchDetail: async (branchName) => {
const { conversationId, selectedBranches } = get()
if (!conversationId) return
set({ isLoading: true, error: null, phase: 'generating' })
try {
const response = await aiBuilderApi.branchDetail(conversationId, branchName)
const updatedBranches = selectedBranches.map((b) =>
b.name === branchName ? { ...b, steps: response.steps } : b
)
set({
selectedBranches: updatedBranches,
phase: 'detailing',
isLoading: false,
})
} catch (err) {
const message = _extractError(err)
set({ error: message, phase: 'error', isLoading: false })
}
},
assemble: async () => {
const { conversationId, selectedBranches } = get()
if (!conversationId) return
set({ isLoading: true, error: null })
try {
const response = await aiBuilderApi.assemble(
conversationId,
selectedBranches.map((b) => ({
name: b.name,
description: b.description,
steps: b.steps,
}))
)
set({
assembledTree: response,
phase: 'reviewing',
isLoading: false,
})
} catch (err) {
const message = _extractError(err)
set({ error: message, phase: 'error', isLoading: false })
}
},
reset: () => {
set({
phase: 'foundation',
conversationId: null,
metadata: { ...initialMetadata },
suggestedBranches: [],
selectedBranches: [],
currentBranchIndex: 0,
assembledTree: null,
isLoading: false,
error: null,
})
},
setPhase: (phase) => set({ phase }),
setError: (error) => set({ error }),
}))
function _extractError(err: unknown): string {
if (err && typeof err === 'object' && 'response' in err) {
const axiosErr = err as { response?: { data?: { detail?: string | { message?: string } } } }
const detail = axiosErr.response?.data?.detail
if (typeof detail === 'string') return detail
if (detail && typeof detail === 'object' && 'message' in detail) return detail.message ?? 'Unknown error'
}
if (err instanceof Error) return err.message
return 'An unexpected error occurred'
}

60
frontend/src/types/ai.ts Normal file
View File

@@ -0,0 +1,60 @@
export interface AIQuotaStatus {
plan: string
monthly_used: number
monthly_limit: number | null
monthly_reset_at: string
daily_used: number
daily_limit: number | null
daily_reset_at: string
allowed: boolean
ai_enabled: boolean
}
export interface AIBranch {
name: string
description: string
steps?: Record<string, unknown>
isCustom?: boolean
}
export interface AITreeSummary {
node_count: number
decision_count: number
action_count: number
solution_count: number
depth: number
}
export interface AIStartResponse {
conversation_id: string
status: string
}
export interface AIScaffoldResponse {
conversation_id: string
branches: Array<{ name: string; description: string }>
status: string
}
export interface AIBranchDetailResponse {
conversation_id: string
branch_name: string
steps: Record<string, unknown>
status: string
}
export interface AIAssembleResponse {
tree_structure: Record<string, unknown>
suggested_name: string
suggested_description: string
summary: AITreeSummary
status: string
}
export type AIWizardPhase =
| 'foundation'
| 'scaffolding'
| 'detailing'
| 'reviewing'
| 'generating'
| 'error'

View File

@@ -34,3 +34,14 @@ export type {
BatchLaunchRequest,
BatchLaunchResponse,
} from './maintenance'
export type {
AIQuotaStatus,
AIBranch,
AITreeSummary,
AIStartResponse,
AIScaffoldResponse,
AIBranchDetailResponse,
AIAssembleResponse,
AIWizardPhase,
} from './ai'