From 44432413c2e641d0599f68a7ab22b8faf1c9b61c Mon Sep 17 00:00:00 2001 From: chihlasm Date: Fri, 20 Feb 2026 08:07:08 -0500 Subject: [PATCH 01/25] feat: AI-assisted flow builder with 4-stage wizard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the complete AI flow builder feature using a guided 4-stage wizard (Foundation → Scaffold → Branch Detail → Review & Assemble). AI assists at bounded points using Claude Haiku for cost-efficient structured JSON generation (~$0.01-0.03/flow). Backend: new models (ai_conversations, ai_usage), Alembic migration, quota enforcement with billing anchor, Anthropic API integration with prompt caching, tree validation, conversation CRUD with 24h TTL, APScheduler cleanup job, 5 API endpoints, Pydantic schemas. Frontend: TypeScript types, API client, Zustand store for wizard state, 7 components (modal, step indicator, foundation form, branch selector, branch detail view, tree preview, quota display), MyTreesPage integration with "Build with AI" button (hidden when AI not configured). Tests: 14 validator unit tests + 11 endpoint integration tests with mocked Anthropic (zero real API spend). All 25 tests passing. Co-Authored-By: Claude Opus 4.6 --- .../a1b2c3d4e5f6_add_ai_flow_builder.py | 216 +++++++++ backend/app/api/endpoints/ai_builder.py | 427 ++++++++++++++++++ backend/app/api/router.py | 2 + backend/app/core/ai_conversation_store.py | 87 ++++ backend/app/core/ai_quota_service.py | 181 ++++++++ backend/app/core/ai_tree_generator_service.py | 293 ++++++++++++ backend/app/core/ai_tree_validator.py | 199 ++++++++ backend/app/core/config.py | 12 + backend/app/core/scheduler.py | 26 +- backend/app/main.py | 11 +- backend/app/models/__init__.py | 4 + backend/app/models/account_limit_override.py | 2 + backend/app/models/ai_conversation.py | 67 +++ backend/app/models/ai_usage.py | 69 +++ backend/app/models/plan_limits.py | 4 + backend/app/models/user.py | 5 + backend/app/schemas/__init__.py | 9 + backend/app/schemas/ai_builder.py | 116 +++++ backend/requirements.txt | 3 + backend/tests/test_ai_endpoints.py | 360 +++++++++++++++ backend/tests/test_ai_tree_validator.py | 183 ++++++++ frontend/src/api/aiBuilder.ts | 61 +++ frontend/src/api/index.ts | 1 + .../ai-builder/AIFlowBuilderModal.tsx | 135 ++++++ .../ai-builder/BranchDetailView.tsx | 208 +++++++++ .../components/ai-builder/BranchSelector.tsx | 280 ++++++++++++ .../components/ai-builder/FoundationForm.tsx | 163 +++++++ .../ai-builder/GeneratingAnimation.tsx | 33 ++ .../components/ai-builder/QuotaDisplay.tsx | 48 ++ .../components/ai-builder/TreePreviewCard.tsx | 85 ++++ .../ai-builder/WizardStepIndicator.tsx | 70 +++ frontend/src/pages/MyTreesPage.tsx | 35 +- frontend/src/store/aiFlowBuilderStore.ts | 201 +++++++++ frontend/src/types/ai.ts | 60 +++ frontend/src/types/index.ts | 11 + 35 files changed, 3662 insertions(+), 5 deletions(-) create mode 100644 backend/alembic/versions/a1b2c3d4e5f6_add_ai_flow_builder.py create mode 100644 backend/app/api/endpoints/ai_builder.py create mode 100644 backend/app/core/ai_conversation_store.py create mode 100644 backend/app/core/ai_quota_service.py create mode 100644 backend/app/core/ai_tree_generator_service.py create mode 100644 backend/app/core/ai_tree_validator.py create mode 100644 backend/app/models/ai_conversation.py create mode 100644 backend/app/models/ai_usage.py create mode 100644 backend/app/schemas/ai_builder.py create mode 100644 backend/tests/test_ai_endpoints.py create mode 100644 backend/tests/test_ai_tree_validator.py create mode 100644 frontend/src/api/aiBuilder.ts create mode 100644 frontend/src/components/ai-builder/AIFlowBuilderModal.tsx create mode 100644 frontend/src/components/ai-builder/BranchDetailView.tsx create mode 100644 frontend/src/components/ai-builder/BranchSelector.tsx create mode 100644 frontend/src/components/ai-builder/FoundationForm.tsx create mode 100644 frontend/src/components/ai-builder/GeneratingAnimation.tsx create mode 100644 frontend/src/components/ai-builder/QuotaDisplay.tsx create mode 100644 frontend/src/components/ai-builder/TreePreviewCard.tsx create mode 100644 frontend/src/components/ai-builder/WizardStepIndicator.tsx create mode 100644 frontend/src/store/aiFlowBuilderStore.ts create mode 100644 frontend/src/types/ai.ts diff --git a/backend/alembic/versions/a1b2c3d4e5f6_add_ai_flow_builder.py b/backend/alembic/versions/a1b2c3d4e5f6_add_ai_flow_builder.py new file mode 100644 index 00000000..cc6286dc --- /dev/null +++ b/backend/alembic/versions/a1b2c3d4e5f6_add_ai_flow_builder.py @@ -0,0 +1,216 @@ +"""add ai flow builder tables and columns + +Revision ID: a1b2c3d4e5f6 +Revises: e65b9f8fd458 +Create Date: 2026-02-20 12:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, None] = "e65b9f8fd458" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── ai_conversations table ── + op.create_table( + "ai_conversations", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "account_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("status", sa.String(20), nullable=False, server_default="foundation"), + sa.Column("messages", postgresql.JSONB(), nullable=False, server_default="[]"), + sa.Column( + "wizard_state", postgresql.JSONB(), nullable=False, server_default="{}" + ), + sa.Column("generated_tree", postgresql.JSONB(), nullable=True), + sa.Column("question_rounds", sa.Integer(), nullable=False, server_default="0"), + sa.Column( + "expires_at", sa.DateTime(timezone=True), nullable=False + ), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + ) + op.create_index( + "ix_ai_conversations_user_id", "ai_conversations", ["user_id"] + ) + op.create_index( + "ix_ai_conversations_account_id", "ai_conversations", ["account_id"] + ) + op.create_index( + "ix_ai_conversations_user_created", + "ai_conversations", + ["user_id", sa.text("created_at DESC")], + ) + op.create_index( + "ix_ai_conversations_expires_at", "ai_conversations", ["expires_at"] + ) + + # ── ai_usage table ── + op.create_table( + "ai_usage", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "account_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "conversation_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("ai_conversations.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("generation_type", sa.String(20), nullable=False), + sa.Column("tier_at_time", sa.String(20), nullable=False), + sa.Column("input_tokens", sa.Integer(), nullable=False, server_default="0"), + sa.Column("output_tokens", sa.Integer(), nullable=False, server_default="0"), + sa.Column( + "estimated_cost_usd", + sa.Numeric(10, 6), + nullable=False, + server_default="0", + ), + sa.Column("succeeded", sa.Boolean(), nullable=False, server_default="true"), + sa.Column( + "counts_toward_quota", + sa.Boolean(), + nullable=False, + server_default="false", + ), + sa.Column("error_code", sa.String(100), nullable=True), + sa.Column("metadata", postgresql.JSONB(), nullable=False, server_default="{}"), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + ) + op.create_index("ix_ai_usage_user_id", "ai_usage", ["user_id"]) + op.create_index("ix_ai_usage_account_id", "ai_usage", ["account_id"]) + op.create_index("ix_ai_usage_created_at", "ai_usage", ["created_at"]) + op.create_index( + "ix_ai_usage_user_created", + "ai_usage", + ["user_id", sa.text("created_at DESC")], + ) + op.create_index( + "ix_ai_usage_user_type_created", + "ai_usage", + ["user_id", "generation_type", sa.text("created_at DESC")], + ) + # Prevents double quota decrement from race conditions + op.execute( + """ + CREATE UNIQUE INDEX ix_ai_usage_unique_quota + ON ai_usage (conversation_id) + WHERE counts_toward_quota = true; + """ + ) + + # ── Schema modifications to existing tables ── + + # users: add ai_billing_cycle_anchor_at + op.add_column( + "users", + sa.Column("ai_billing_cycle_anchor_at", sa.DateTime(timezone=True), nullable=True), + ) + # Backfill: use created_at as the billing anchor + op.execute( + "UPDATE users SET ai_billing_cycle_anchor_at = created_at WHERE ai_billing_cycle_anchor_at IS NULL" + ) + + # plan_limits: add AI limit columns + op.add_column( + "plan_limits", + sa.Column("max_ai_builds_per_month", sa.Integer(), nullable=True), + ) + op.add_column( + "plan_limits", + sa.Column("max_ai_builds_per_24h", sa.Integer(), nullable=True), + ) + + # account_limit_overrides: add AI override columns + op.add_column( + "account_limit_overrides", + sa.Column("override_max_ai_builds_per_month", sa.Integer(), nullable=True), + ) + op.add_column( + "account_limit_overrides", + sa.Column("override_max_ai_builds_per_24h", sa.Integer(), nullable=True), + ) + + # Seed plan_limits with AI quota values + op.execute( + """ + UPDATE plan_limits SET max_ai_builds_per_month = 2, max_ai_builds_per_24h = 1 + WHERE plan = 'free'; + """ + ) + op.execute( + """ + UPDATE plan_limits SET max_ai_builds_per_month = 50, max_ai_builds_per_24h = 10 + WHERE plan = 'pro'; + """ + ) + op.execute( + """ + UPDATE plan_limits SET max_ai_builds_per_month = 200, max_ai_builds_per_24h = 20 + WHERE plan = 'team'; + """ + ) + # Enterprise: NULL means unlimited (no update needed as default is NULL) + + +def downgrade() -> None: + # Drop AI override columns from account_limit_overrides + op.drop_column("account_limit_overrides", "override_max_ai_builds_per_24h") + op.drop_column("account_limit_overrides", "override_max_ai_builds_per_month") + + # Drop AI limit columns from plan_limits + op.drop_column("plan_limits", "max_ai_builds_per_24h") + op.drop_column("plan_limits", "max_ai_builds_per_month") + + # Drop ai_billing_cycle_anchor_at from users + op.drop_column("users", "ai_billing_cycle_anchor_at") + + # Drop ai_usage table (indexes drop automatically) + op.drop_table("ai_usage") + + # Drop ai_conversations table (indexes drop automatically) + op.drop_table("ai_conversations") diff --git a/backend/app/api/endpoints/ai_builder.py b/backend/app/api/endpoints/ai_builder.py new file mode 100644 index 00000000..099cf539 --- /dev/null +++ b/backend/app/api/endpoints/ai_builder.py @@ -0,0 +1,427 @@ +"""AI Flow Builder wizard endpoints. + +4-stage wizard: + POST /ai/start — Stage 1: create conversation with metadata + POST /ai/scaffold — Stage 2: AI suggests branches + POST /ai/branch-detail — Stage 3: AI generates detail for one branch + POST /ai/assemble — Stage 4: assemble branches into tree (no AI) + GET /ai/quota — quota status +""" +import logging +from typing import Annotated + +import anthropic +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin +from app.core.config import settings +from app.core.ai_conversation_store import ( + create_conversation, + get_conversation, + update_conversation, +) +from app.core.ai_quota_service import check_ai_quota, record_ai_usage, get_user_plan +from app.core.ai_tree_generator_service import ( + scaffold_branches, + generate_branch_detail, + assemble_tree, +) +from app.models.user import User +from app.schemas.ai_builder import ( + AIStartRequest, + AIStartResponse, + AIScaffoldRequest, + AIScaffoldResponse, + AIBranchDetailRequest, + AIBranchDetailResponse, + AIAssembleRequest, + AIAssembleResponse, + AIQuotaStatusResponse, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ai", tags=["ai-builder"]) + + +def _require_ai_enabled() -> None: + """Raise 503 if AI is not configured.""" + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI flow builder is not configured. Set ANTHROPIC_API_KEY.", + ) + + +@router.get("/quota", response_model=AIQuotaStatusResponse) +async def get_quota( + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Get current user's AI quota status.""" + if not settings.ai_enabled: + return AIQuotaStatusResponse( + plan="free", + monthly_used=0, + monthly_limit=None, + monthly_reset_at="", + daily_used=0, + daily_limit=None, + daily_reset_at="", + allowed=False, + ai_enabled=False, + ) + + _, quota_status = await check_ai_quota( + user_id=current_user.id, + account_id=current_user.account_id, + db=db, + billing_anchor=current_user.ai_billing_cycle_anchor_at, + ) + return AIQuotaStatusResponse( + **quota_status, + ai_enabled=True, + ) + + +@router.post("/start", response_model=AIStartResponse, status_code=201) +async def start_conversation( + data: AIStartRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Stage 1: Create a new AI wizard conversation with foundation metadata.""" + _require_ai_enabled() + + # Check daily quota (anti-abuse) + allowed, quota_status = await check_ai_quota( + user_id=current_user.id, + account_id=current_user.account_id, + db=db, + billing_anchor=current_user.ai_billing_cycle_anchor_at, + ) + if not allowed: + reset_key = ( + "daily_reset_at" + if quota_status.get("deny_reason") == "daily" + else "monthly_reset_at" + ) + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail={ + "message": f"AI build limit exceeded ({quota_status['deny_reason']})", + "reset_at": quota_status.get(reset_key), + "quota": quota_status, + }, + ) + + wizard_state = { + "flow_type": data.flow_type, + "name": data.name, + "description": data.description, + "environment_tags": data.environment_tags, + "category_id": str(data.category_id) if data.category_id else None, + } + + conversation = await create_conversation( + user_id=current_user.id, + account_id=current_user.account_id, + wizard_state=wizard_state, + db=db, + ) + await db.commit() + + return AIStartResponse( + conversation_id=conversation.id, + status=conversation.status, + ) + + +@router.post("/scaffold", response_model=AIScaffoldResponse) +async def scaffold( + data: AIScaffoldRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Stage 2: AI suggests top-level branches.""" + _require_ai_enabled() + + conversation = await get_conversation( + data.conversation_id, current_user.id, db + ) + + # Check per-flow call limit + if conversation.question_rounds >= settings.AI_MAX_CALLS_PER_FLOW: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Maximum AI calls per flow exceeded", + ) + + plan = await get_user_plan(current_user.account_id, db) + + try: + branches, input_tokens, output_tokens, cost = await scaffold_branches( + conversation.wizard_state, + ) + except anthropic.APIError as e: + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="scaffold", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"error": str(e)}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="AI provider error. Please try again.", + ) + except ValueError as e: + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="scaffold", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code="invalid_output", + extra_data={"error": str(e)}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"AI returned invalid output: {e}", + ) + + # Record successful usage + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="scaffold", + tier=plan, + input_tokens=input_tokens, + output_tokens=output_tokens, + estimated_cost=cost, + succeeded=True, + counts_toward_quota=False, + error_code=None, + extra_data=None, + db=db, + ) + + # Update conversation state + wizard_state = dict(conversation.wizard_state) + wizard_state["branches"] = branches + await update_conversation( + conversation.id, + current_user.id, + { + "status": "scaffolding", + "wizard_state": wizard_state, + "question_rounds": conversation.question_rounds + 1, + }, + db, + ) + await db.commit() + + return AIScaffoldResponse( + conversation_id=conversation.id, + branches=branches, + status="scaffolding", + ) + + +@router.post("/branch-detail", response_model=AIBranchDetailResponse) +async def branch_detail( + data: AIBranchDetailRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Stage 3: AI generates detailed nodes for one branch.""" + _require_ai_enabled() + + conversation = await get_conversation( + data.conversation_id, current_user.id, db + ) + + if conversation.question_rounds >= settings.AI_MAX_CALLS_PER_FLOW: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Maximum AI calls per flow exceeded", + ) + + wizard_state = conversation.wizard_state + existing_branches = [ + b.get("name", "") for b in wizard_state.get("branches", []) + ] + + plan = await get_user_plan(current_user.account_id, db) + + try: + branch_tree, input_tokens, output_tokens, cost = ( + await generate_branch_detail( + wizard_state, + data.branch_name, + existing_branches, + ) + ) + except anthropic.APIError as e: + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="branch_detail", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"error": str(e), "branch_name": data.branch_name}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="AI provider error. Please try again.", + ) + except ValueError as e: + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="branch_detail", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code="invalid_output", + extra_data={"error": str(e), "branch_name": data.branch_name}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"AI returned invalid output: {e}", + ) + + # Record successful usage + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="branch_detail", + tier=plan, + input_tokens=input_tokens, + output_tokens=output_tokens, + estimated_cost=cost, + succeeded=True, + counts_toward_quota=False, + error_code=None, + extra_data={"branch_name": data.branch_name}, + db=db, + ) + + # Update conversation + await update_conversation( + conversation.id, + current_user.id, + { + "status": "detailing", + "question_rounds": conversation.question_rounds + 1, + }, + db, + ) + await db.commit() + + return AIBranchDetailResponse( + conversation_id=conversation.id, + branch_name=data.branch_name, + steps=branch_tree, + status="detailing", + ) + + +@router.post("/assemble", response_model=AIAssembleResponse) +async def assemble( + data: AIAssembleRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Stage 4: Assemble selected branches into a complete tree (no AI calls).""" + conversation = await get_conversation( + data.conversation_id, current_user.id, db + ) + + wizard_state = conversation.wizard_state + branches_for_assembly = [b.model_dump() for b in data.selected_branches] + + try: + tree_structure, name, description, stats = assemble_tree( + wizard_state, branches_for_assembly + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=str(e), + ) + + # Record quota-consuming usage on successful assembly + plan = await get_user_plan(current_user.account_id, db) + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=conversation.id, + generation_type="tree", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=True, + counts_toward_quota=True, + error_code=None, + extra_data={"stats": stats}, + db=db, + ) + + # Update conversation with assembled tree + await update_conversation( + conversation.id, + current_user.id, + { + "status": "completed", + "generated_tree": tree_structure, + }, + db, + ) + await db.commit() + + return AIAssembleResponse( + tree_structure=tree_structure, + suggested_name=name, + suggested_description=description, + summary=stats, + status="completed", + ) diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 3aac1d7f..2c79e039 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -5,6 +5,7 @@ from app.api.endpoints import ratings, analytics from app.api.endpoints import target_lists from app.api.endpoints import maintenance_schedules from app.api.endpoints import feedback +from app.api.endpoints import ai_builder api_router = APIRouter() @@ -34,3 +35,4 @@ api_router.include_router(analytics.router) api_router.include_router(target_lists.router) api_router.include_router(maintenance_schedules.router) api_router.include_router(feedback.router) +api_router.include_router(ai_builder.router) diff --git a/backend/app/core/ai_conversation_store.py b/backend/app/core/ai_conversation_store.py new file mode 100644 index 00000000..2a5f0744 --- /dev/null +++ b/backend/app/core/ai_conversation_store.py @@ -0,0 +1,87 @@ +"""DB-backed CRUD for AI wizard conversation state. + +Conversations have a 24-hour TTL. Every access validates ownership and expiry. +""" +import uuid +from datetime import datetime, timezone, timedelta +from typing import Any, Optional +from uuid import UUID + +from fastapi import HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import settings +from app.models.ai_conversation import AIConversation + + +async def create_conversation( + user_id: UUID, + account_id: UUID, + wizard_state: dict[str, Any], + db: AsyncSession, +) -> AIConversation: + """Create a new AI wizard conversation.""" + conversation = AIConversation( + user_id=user_id, + account_id=account_id, + status="foundation", + wizard_state=wizard_state, + messages=[], + expires_at=datetime.now(timezone.utc) + + timedelta(hours=settings.AI_CONVERSATION_TTL_HOURS), + ) + db.add(conversation) + await db.flush() + return conversation + + +async def get_conversation( + conversation_id: UUID, + user_id: UUID, + db: AsyncSession, +) -> AIConversation: + """Get a conversation, validating ownership and expiry. + + Raises HTTPException 410 if expired, 404 if not found or wrong owner. + """ + result = await db.execute( + select(AIConversation).where(AIConversation.id == conversation_id) + ) + conversation = result.scalar_one_or_none() + + if not conversation or conversation.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Conversation not found", + ) + + if conversation.expires_at < datetime.now(timezone.utc): + conversation.status = "expired" + await db.flush() + raise HTTPException( + status_code=status.HTTP_410_GONE, + detail="Conversation expired. Please start a new AI build.", + ) + + return conversation + + +async def update_conversation( + conversation_id: UUID, + user_id: UUID, + updates: dict[str, Any], + db: AsyncSession, +) -> AIConversation: + """Update a conversation's fields. + + Validates ownership and expiry before updating. + """ + conversation = await get_conversation(conversation_id, user_id, db) + + for key, value in updates.items(): + if hasattr(conversation, key): + setattr(conversation, key, value) + + await db.flush() + return conversation diff --git a/backend/app/core/ai_quota_service.py b/backend/app/core/ai_quota_service.py new file mode 100644 index 00000000..1f89ec8b --- /dev/null +++ b/backend/app/core/ai_quota_service.py @@ -0,0 +1,181 @@ +"""AI generation quota management. + +Enforces monthly and daily limits on AI flow builder usage. +Monthly quota consumed only on successful tree assembly (counts_toward_quota=True). +Daily limit is an anti-abuse guard consumed on conversation start. +""" +from datetime import datetime, timezone, timedelta +from typing import Optional +from uuid import UUID + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.ai_usage import AIUsage +from app.models.plan_limits import PlanLimits +from app.models.account_limit_override import AccountLimitOverride +from app.core.subscriptions import get_account_subscription, get_plan_limits + + +async def get_user_plan(account_id: Optional[UUID], db: AsyncSession) -> str: + """Get the plan tier for an account.""" + if not account_id: + return "free" + sub = await get_account_subscription(account_id, db) + if sub is None: + return "free" + return sub.plan if sub.plan else "free" + + +async def _get_effective_limits( + account_id: UUID, plan: str, db: AsyncSession +) -> tuple[Optional[int], Optional[int]]: + """Get effective AI limits (monthly, daily), applying account overrides. + + Returns (monthly_limit, daily_limit). None means unlimited. + """ + limits = await get_plan_limits(plan, db) + monthly = limits.max_ai_builds_per_month if limits else None + daily = limits.max_ai_builds_per_24h if limits else None + + # Check for account-level overrides + result = await db.execute( + select(AccountLimitOverride).where( + AccountLimitOverride.account_id == account_id + ) + ) + override = result.scalar_one_or_none() + if override: + if override.override_max_ai_builds_per_month is not None: + monthly = override.override_max_ai_builds_per_month + if override.override_max_ai_builds_per_24h is not None: + daily = override.override_max_ai_builds_per_24h + + return monthly, daily + + +def _get_billing_anchor_month_start(anchor: Optional[datetime]) -> datetime: + """Calculate the start of the current billing month from the anchor date. + + If the anchor is day 15, the billing month runs from the 15th of each month. + Falls back to calendar month if anchor is None. + """ + now = datetime.now(timezone.utc) + if not anchor: + return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + anchor_day = min(anchor.day, 28) # Clamp to avoid month overflow + this_month_anchor = now.replace( + day=anchor_day, hour=0, minute=0, second=0, microsecond=0 + ) + + if now >= this_month_anchor: + return this_month_anchor + else: + # We're before the anchor day, so billing month started last month + if now.month == 1: + return this_month_anchor.replace(year=now.year - 1, month=12) + else: + return this_month_anchor.replace(month=now.month - 1) + + +async def check_ai_quota( + user_id: UUID, + account_id: UUID, + db: AsyncSession, + billing_anchor: Optional[datetime] = None, +) -> tuple[bool, dict]: + """Check if user can make an AI generation. + + Returns (allowed, quota_status_dict). + Monthly counts only rows with counts_toward_quota=True. + Daily counts only rows with generation_type in ('scaffold', 'branch_detail'). + """ + plan = await get_user_plan(account_id, db) + monthly_limit, daily_limit = await _get_effective_limits(account_id, plan, db) + + now = datetime.now(timezone.utc) + month_start = _get_billing_anchor_month_start(billing_anchor) + day_start = now - timedelta(hours=24) + + # Monthly: count successful quota-consuming records + monthly_count = await db.scalar( + select(func.count(AIUsage.id)).where( + AIUsage.user_id == user_id, + AIUsage.counts_toward_quota == True, # noqa: E712 + AIUsage.created_at >= month_start, + ) + ) or 0 + + # Daily: count all AI API calls (scaffold + branch_detail) in last 24h + daily_count = await db.scalar( + select(func.count(AIUsage.id)).where( + AIUsage.user_id == user_id, + AIUsage.succeeded == True, # noqa: E712 + AIUsage.generation_type.in_(["scaffold", "branch_detail"]), + AIUsage.created_at >= day_start, + ) + ) or 0 + + allowed = True + deny_reason = None + if monthly_limit is not None and monthly_count >= monthly_limit: + allowed = False + deny_reason = "monthly" + if daily_limit is not None and daily_count >= daily_limit: + allowed = False + deny_reason = "daily" + + # Calculate reset timestamps + monthly_reset_at = month_start.replace( + month=month_start.month % 12 + 1, + year=month_start.year + (1 if month_start.month == 12 else 0), + ) + daily_reset_at = day_start + timedelta(hours=24) + + return allowed, { + "plan": plan, + "monthly_used": monthly_count, + "monthly_limit": monthly_limit, + "monthly_reset_at": monthly_reset_at.isoformat(), + "daily_used": daily_count, + "daily_limit": daily_limit, + "daily_reset_at": daily_reset_at.isoformat(), + "allowed": allowed, + "deny_reason": deny_reason, + } + + +async def record_ai_usage( + user_id: UUID, + account_id: UUID, + conversation_id: Optional[UUID], + generation_type: str, + tier: str, + input_tokens: int, + output_tokens: int, + estimated_cost: float, + succeeded: bool, + counts_toward_quota: bool, + error_code: Optional[str], + extra_data: Optional[dict], + db: AsyncSession, +) -> AIUsage: + """Record an AI usage entry.""" + usage = AIUsage( + user_id=user_id, + account_id=account_id, + conversation_id=conversation_id, + generation_type=generation_type, + tier_at_time=tier, + input_tokens=input_tokens, + output_tokens=output_tokens, + estimated_cost_usd=estimated_cost, + succeeded=succeeded, + counts_toward_quota=counts_toward_quota, + error_code=error_code, + extra_data=extra_data or {}, + ) + db.add(usage) + await db.flush() + return usage diff --git a/backend/app/core/ai_tree_generator_service.py b/backend/app/core/ai_tree_generator_service.py new file mode 100644 index 00000000..805c84f7 --- /dev/null +++ b/backend/app/core/ai_tree_generator_service.py @@ -0,0 +1,293 @@ +"""AI-powered tree generation service using Anthropic Claude API. + +Implements the 4-stage wizard flow: + Stage 2 (scaffold): AI suggests 4-7 top-level branches + Stage 3 (branch_detail): AI generates detailed nodes per branch + Stage 4 (assemble): Pure assembly logic — zero AI calls + +System prompts are static constants to enable Anthropic prompt caching. +""" +import json +import logging +import uuid +from typing import Any + +import anthropic + +from app.core.config import settings +from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats + +logger = logging.getLogger(__name__) + +# ── Cost estimation (Haiku 4.5 pricing) ── +COST_PER_INPUT_TOKEN = 1.0 / 1_000_000 # $1.00 per 1M input tokens +COST_PER_OUTPUT_TOKEN = 5.0 / 1_000_000 # $5.00 per 1M output tokens + + +# ── System Prompts ── + +SCAFFOLD_SYSTEM_PROMPT = """You are ResolutionFlow AI, assisting MSP engineers to build troubleshooting and procedural flows for IT service management. + +Context: Your audience is technical MSP staff experienced with Windows Server, Active Directory, networking, and common MSP tooling (ConnectWise, Datto, SonicWall, etc.). + +Task: Given a flow type, category, name, description, and environment tags, suggest 4-7 top-level branches for the flow. + +For TROUBLESHOOTING flows: +- Branches should be symptom-based categories (e.g., "Authentication Failures", "Connectivity Issues", "Performance Degradation") +- Each branch represents a common way the problem manifests +- Order from most common to least common + +For PROCEDURE flows: +- Branches should be phase-based stages (e.g., "Prerequisites", "Configuration", "Verification", "Documentation") +- Each branch represents a major step in the process +- Order in logical execution sequence + +Rules: +- Suggest 4-7 branches +- Be specific to the technology/service described — avoid generic branches +- Branch names should be concise (2-5 words) +- Each branch needs a brief description (1 sentence) +- Return ONLY valid JSON, no markdown, no explanation + +Output format: +{"branches": [{"name": "Branch Name", "description": "Brief description of what this covers"}]}""" + + +BRANCH_DETAIL_SYSTEM_PROMPT = """You are ResolutionFlow AI generating step-by-step detail for one branch of a troubleshooting or procedural flow for MSP engineers. + +Context: Your audience is technical MSP staff experienced with Windows Server, Active Directory, networking, and common MSP tooling. + +You must return ONLY valid JSON — no markdown, no code fences, no explanation. + +Required node schema: + +Decision nodes (branching diagnostic questions): +{"id": "unique-slug", "type": "decision", "question": "The diagnostic question", "help_text": "Optional context or command hint", "options": [{"id": "opt-id", "label": "Answer choice", "next_node_id": "child-node-id"}], "children": []} + +Action nodes (investigation or remediation steps): +{"id": "unique-slug", "type": "action", "title": "Short title", "description": "Detailed instructions", "commands": ["PowerShell or CMD commands"], "expected_outcome": "What success looks like", "children": []} + +Solution nodes (leaf nodes — the resolution): +{"id": "unique-slug", "type": "solution", "title": "Resolution title", "description": "Full resolution description", "resolution_steps": ["Step 1", "Step 2"]} + +Rules: +1. Generate 3-10 nodes for this branch +2. Start with a decision node if troubleshooting, action node if procedure +3. Every branch path MUST end in a solution node — no dead ends +4. Include realistic MSP commands (PowerShell preferred for Windows) +5. Use unique node IDs prefixed with the branch context (e.g., "dns-check-service") +6. Every option's next_node_id must match an existing child node's id +7. All option labels must be meaningful and specific +8. Decision nodes must have at least 2 options +9. Return a single root node with its children nested inside + +Few-shot example (abbreviated): +{"id": "dns-root", "type": "decision", "question": "Can the client resolve any DNS names?", "help_text": "Run: nslookup google.com", "options": [{"id": "dns-opt-none", "label": "No DNS resolution at all", "next_node_id": "dns-check-service"}, {"id": "dns-opt-partial", "label": "Some names resolve, others don't", "next_node_id": "dns-check-specific"}], "children": [{"id": "dns-check-service", "type": "action", "title": "Check DNS Service", "description": "Verify the DNS Client service is running", "commands": ["Get-Service -Name Dnscache"], "expected_outcome": "Service should be Running", "children": [{"id": "dns-resolved", "type": "solution", "title": "DNS Service Restored", "description": "DNS client service was stopped", "resolution_steps": ["Restart DNS Client service", "Flush DNS cache: ipconfig /flushdns", "Test resolution"]}]}, {"id": "dns-check-specific", "type": "solution", "title": "Selective DNS Failure", "description": "Specific records missing or stale", "resolution_steps": ["Check DNS server configuration", "Verify zone records", "Clear DNS cache"]}]}""" + + +CORRECTIVE_PROMPT_TEMPLATE = """Your previous JSON was invalid for ResolutionFlow's tree schema. + +Validation errors: +{error_list} + +Return a corrected full JSON object only. No markdown, no prose, no code fences. +Fix ALL listed errors while maintaining the same troubleshooting/procedural logic.""" + + +def _get_client() -> anthropic.AsyncAnthropic: + """Get configured async Anthropic client.""" + if not settings.ANTHROPIC_API_KEY: + raise RuntimeError("ANTHROPIC_API_KEY not configured") + return anthropic.AsyncAnthropic( + api_key=settings.ANTHROPIC_API_KEY, + timeout=settings.AI_REQUEST_TIMEOUT_SECONDS, + ) + + +def _estimate_cost(input_tokens: int, output_tokens: int) -> float: + """Estimate USD cost from token counts.""" + return (input_tokens * COST_PER_INPUT_TOKEN) + ( + output_tokens * COST_PER_OUTPUT_TOKEN + ) + + +async def scaffold_branches( + wizard_state: dict[str, Any], +) -> tuple[list[dict[str, str]], int, int, float]: + """Stage 2: AI suggests top-level branches. + + Returns (branches, input_tokens, output_tokens, estimated_cost). + Raises ValueError on invalid response. + """ + client = _get_client() + + flow_type = wizard_state.get("flow_type", "troubleshooting") + name = wizard_state.get("name", "") + description = wizard_state.get("description", "") + tags = wizard_state.get("environment_tags", []) + + user_message = ( + f"Flow type: {flow_type}\n" + f"Name: {name}\n" + f"Description: {description}\n" + ) + if tags: + user_message += f"Environment: {', '.join(tags)}\n" + + response = await client.messages.create( + model=settings.AI_MODEL, + max_tokens=1024, + system=SCAFFOLD_SYSTEM_PROMPT, + messages=[{"role": "user", "content": user_message}], + ) + + raw_text = response.content[0].text + input_tokens = response.usage.input_tokens + output_tokens = response.usage.output_tokens + cost = _estimate_cost(input_tokens, output_tokens) + + try: + data = json.loads(raw_text) + except json.JSONDecodeError as e: + raise ValueError(f"AI returned invalid JSON: {e}") + + branches = data.get("branches", []) + if not isinstance(branches, list) or len(branches) < 2: + raise ValueError("AI returned fewer than 2 branches") + + return branches, input_tokens, output_tokens, cost + + +async def generate_branch_detail( + wizard_state: dict[str, Any], + branch_name: str, + existing_branches: list[str], +) -> tuple[dict[str, Any], int, int, float]: + """Stage 3: AI generates detailed nodes for one branch. + + Returns (branch_tree, input_tokens, output_tokens, estimated_cost). + On validation failure, retries once with corrective prompt. + Raises ValueError if both attempts fail. + """ + client = _get_client() + + flow_type = wizard_state.get("flow_type", "troubleshooting") + name = wizard_state.get("name", "") + description = wizard_state.get("description", "") + + user_message = ( + f"Flow: {name} ({flow_type})\n" + f"Description: {description}\n" + f"Branch to detail: {branch_name}\n" + ) + if existing_branches: + other = [b for b in existing_branches if b != branch_name] + if other: + user_message += f"Other branches (avoid overlap): {', '.join(other)}\n" + + messages = [{"role": "user", "content": user_message}] + total_input = 0 + total_output = 0 + + for attempt in range(2): + response = await client.messages.create( + model=settings.AI_MODEL, + max_tokens=4096, + system=BRANCH_DETAIL_SYSTEM_PROMPT, + messages=messages, + ) + + raw_text = response.content[0].text + total_input += response.usage.input_tokens + total_output += response.usage.output_tokens + + try: + branch_tree = json.loads(raw_text) + except json.JSONDecodeError as e: + if attempt == 0: + messages.append({"role": "assistant", "content": raw_text}) + messages.append({ + "role": "user", + "content": CORRECTIVE_PROMPT_TEMPLATE.format( + error_list=f"JSON parse error: {e}" + ), + }) + continue + raise ValueError(f"AI returned invalid JSON after retry: {e}") + + # Validate the branch structure + errors = validate_generated_tree(branch_tree) + if not errors: + cost = _estimate_cost(total_input, total_output) + return branch_tree, total_input, total_output, cost + + if attempt == 0: + messages.append({"role": "assistant", "content": raw_text}) + messages.append({ + "role": "user", + "content": CORRECTIVE_PROMPT_TEMPLATE.format( + error_list="\n".join(f"- {e}" for e in errors) + ), + }) + continue + + raise ValueError( + f"AI tree validation failed after retry: {'; '.join(errors)}" + ) + + # Should not reach here + raise ValueError("Branch detail generation failed") + + +def assemble_tree( + wizard_state: dict[str, Any], + branches: list[dict[str, Any]], +) -> tuple[dict[str, Any], str, str, dict[str, int]]: + """Stage 4: Assemble branches into a complete tree. + + Zero AI calls — pure assembly logic. + Returns (tree_structure, suggested_name, suggested_description, summary_stats). + """ + flow_type = wizard_state.get("flow_type", "troubleshooting") + name = wizard_state.get("name", "Untitled Flow") + description = wizard_state.get("description", "") + + # Build root decision node pointing to each branch + options = [] + children = [] + for i, branch in enumerate(branches): + branch_name = branch.get("name", f"Branch {i + 1}") + branch_tree = branch.get("steps") + + if not branch_tree or not isinstance(branch_tree, dict): + # Skip branches without detail + continue + + branch_id = branch_tree.get("id", f"branch_{i}") + options.append({ + "id": f"opt_{i + 1}", + "label": branch_name, + "next_node_id": branch_id, + }) + children.append(branch_tree) + + if len(options) < 2: + raise ValueError("Need at least 2 branches with detail to assemble a tree") + + # Determine root question based on flow type + if flow_type == "troubleshooting": + root_question = f"What issue is the user experiencing with {name}?" + else: + root_question = f"Which phase of {name} are you working on?" + + tree_structure = { + "id": "root", + "type": "decision", + "question": root_question, + "options": options, + "children": children, + } + + stats = count_tree_stats(tree_structure) + + return tree_structure, name, description, stats diff --git a/backend/app/core/ai_tree_validator.py b/backend/app/core/ai_tree_validator.py new file mode 100644 index 00000000..b58ef28d --- /dev/null +++ b/backend/app/core/ai_tree_validator.py @@ -0,0 +1,199 @@ +"""Validation for AI-generated tree structures. + +Ensures generated trees conform to ResolutionFlow's node schema +before they are saved to the database. +""" +from typing import Any + + +VALID_NODE_TYPES = {"decision", "action", "solution"} + +# Required fields per node type +REQUIRED_FIELDS = { + "decision": {"id", "type", "question", "options", "children"}, + "action": {"id", "type", "title", "description"}, + "solution": {"id", "type", "title", "description"}, +} + + +class TreeValidationError(Exception): + """Raised when a generated tree fails validation.""" + + def __init__(self, errors: list[str]): + self.errors = errors + super().__init__(f"Tree validation failed: {'; '.join(errors)}") + + +def validate_generated_tree(tree: dict[str, Any]) -> list[str]: + """Validate an AI-generated tree structure. + + Returns a list of error strings. Empty list means valid. + """ + errors: list[str] = [] + + if not isinstance(tree, dict): + return ["Tree must be a JSON object"] + + # Root must be a decision node + if tree.get("type") != "decision": + errors.append("Root node must be type 'decision'") + + # Collect all node IDs and validate structure + all_ids: set[str] = set() + all_referenced_ids: set[str] = set() + node_count = 0 + solution_count = 0 + + def _validate_node(node: dict[str, Any], path: str) -> None: + nonlocal node_count, solution_count + + if not isinstance(node, dict): + errors.append(f"Node at {path} is not an object") + return + + node_count += 1 + node_type = node.get("type") + node_id = node.get("id") + + # Check node ID + if not node_id: + errors.append(f"Node at {path} missing 'id'") + elif node_id in all_ids: + errors.append(f"Duplicate node ID: '{node_id}'") + else: + all_ids.add(node_id) + + # Check node type + if node_type not in VALID_NODE_TYPES: + errors.append( + f"Node '{node_id or path}' has invalid type '{node_type}'. " + f"Must be one of: {', '.join(sorted(VALID_NODE_TYPES))}" + ) + return + + # Check required fields + required = REQUIRED_FIELDS[node_type] + missing = required - set(node.keys()) + if missing: + errors.append( + f"Node '{node_id}' (type={node_type}) missing fields: {', '.join(sorted(missing))}" + ) + + # Type-specific validation + if node_type == "decision": + options = node.get("options", []) + if not isinstance(options, list) or len(options) < 2: + errors.append( + f"Decision node '{node_id}' must have at least 2 options" + ) + else: + children = node.get("children", []) + child_ids = {c.get("id") for c in children if isinstance(c, dict)} + option_ids: set[str] = set() + + for opt in options: + if not isinstance(opt, dict): + errors.append(f"Option in node '{node_id}' is not an object") + continue + opt_id = opt.get("id") + if opt_id and opt_id in option_ids: + errors.append( + f"Duplicate option ID '{opt_id}' in node '{node_id}'" + ) + if opt_id: + option_ids.add(opt_id) + + next_id = opt.get("next_node_id") + if next_id: + all_referenced_ids.add(next_id) + if child_ids and next_id not in child_ids: + errors.append( + f"Option '{opt.get('label', '?')}' in node '{node_id}' " + f"references non-existent child '{next_id}'" + ) + + elif node_type == "action": + next_id = node.get("next_node_id") + if next_id: + all_referenced_ids.add(next_id) + + elif node_type == "solution": + solution_count += 1 + + # Recurse into children + for i, child in enumerate(node.get("children", [])): + _validate_node(child, f"{path}.children[{i}]") + + _validate_node(tree, "root") + + # Global checks + if node_count < 5: + errors.append( + f"Tree has only {node_count} nodes. Minimum 5 required for a useful tree." + ) + if node_count > 50: + errors.append( + f"Tree has {node_count} nodes. Maximum 50 allowed." + ) + if solution_count < 2: + errors.append( + f"Tree has only {solution_count} solution nodes. " + "Need at least 2 to cover different resolution paths." + ) + + # Check that all leaf (non-solution) nodes have children or are solutions + _check_branch_termination(tree, errors) + + return errors + + +def _check_branch_termination(node: dict[str, Any], errors: list[str]) -> None: + """Verify every branch eventually reaches a solution node.""" + if not isinstance(node, dict): + return + + node_type = node.get("type") + node_id = node.get("id", "?") + children = node.get("children", []) + + if node_type == "solution": + return # Solution is a valid terminus + + if not children and node_type != "solution": + errors.append( + f"Node '{node_id}' (type={node_type}) is a dead end — " + "it has no children and is not a solution node" + ) + return + + for child in children: + _check_branch_termination(child, errors) + + +def count_tree_stats(tree: dict[str, Any]) -> dict[str, int]: + """Count node types and calculate depth of a tree.""" + stats = { + "node_count": 0, + "decision_count": 0, + "action_count": 0, + "solution_count": 0, + "depth": 0, + } + + def _count(node: dict[str, Any], depth: int) -> None: + if not isinstance(node, dict): + return + stats["node_count"] += 1 + node_type = node.get("type", "") + if node_type == "decision": + stats["decision_count"] += 1 + elif node_type == "action": + stats["action_count"] += 1 + elif node_type == "solution": + stats["solution_count"] += 1 + stats["depth"] = max(stats["depth"], depth) + for child in node.get("children", []): + _count(child, depth + 1) + + _count(tree, 1) + return stats diff --git a/backend/app/core/config.py b/backend/app/core/config.py index eb073d08..1f7ce77c 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -72,6 +72,18 @@ class Settings(BaseSettings): """Check if Stripe is configured.""" return self.STRIPE_SECRET_KEY is not None and self.STRIPE_WEBHOOK_SECRET is not None + # AI Flow Builder + ANTHROPIC_API_KEY: Optional[str] = None + AI_MODEL: str = "claude-haiku-4-5" + AI_CONVERSATION_TTL_HOURS: int = 24 + AI_MAX_CALLS_PER_FLOW: int = 10 + AI_REQUEST_TIMEOUT_SECONDS: int = 45 + + @property + def ai_enabled(self) -> bool: + """Check if AI Flow Builder is configured.""" + return self.ANTHROPIC_API_KEY is not None + # Deployment – auto-seed test data on PR environments SEED_ON_DEPLOY: bool = False diff --git a/backend/app/core/scheduler.py b/backend/app/core/scheduler.py index 30c3c1f2..3d5f5ff6 100644 --- a/backend/app/core/scheduler.py +++ b/backend/app/core/scheduler.py @@ -1,4 +1,4 @@ -"""APScheduler integration for maintenance flow auto-session creation.""" +"""APScheduler integration for maintenance flow auto-session creation and AI cleanup.""" import logging import uuid from datetime import datetime, timezone @@ -7,8 +7,9 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.base import SchedulerNotRunningError from apscheduler.jobstores.base import JobLookupError from apscheduler.triggers.cron import CronTrigger +from apscheduler.triggers.interval import IntervalTrigger import pytz -from sqlalchemy import select +from sqlalchemy import select, delete from sqlalchemy.ext.asyncio import AsyncSession logger = logging.getLogger(__name__) @@ -114,6 +115,27 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None: await db.rollback() +async def _cleanup_expired_ai_conversations() -> None: + """Delete expired AI wizard conversations.""" + import app.models # noqa: F401 + from app.core.database import async_session_maker + from app.models.ai_conversation import AIConversation + + async with async_session_maker() as db: + try: + result = await db.execute( + delete(AIConversation).where( + AIConversation.expires_at < datetime.now(timezone.utc) + ) + ) + if result.rowcount > 0: + logger.info(f"Cleaned up {result.rowcount} expired AI conversation(s)") + await db.commit() + except Exception: + logger.exception("Error cleaning up expired AI conversations") + await db.rollback() + + async def load_all_schedules(db: AsyncSession) -> None: """Load all active schedules into APScheduler on startup.""" # Import all models to ensure SQLAlchemy mapper relationships resolve diff --git a/backend/app/main.py b/backend/app/main.py index 01ccb65f..78462f76 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -13,7 +13,7 @@ from app.core.logging_config import setup_logging from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware from app.core.rate_limit import limiter from app.api.router import api_router -from app.core.scheduler import scheduler, load_all_schedules +from app.core.scheduler import scheduler, load_all_schedules, _cleanup_expired_ai_conversations # Initialize logging configuration setup_logging() @@ -103,10 +103,17 @@ async def lifespan(app: FastAPI): # Note: In production, use Alembic migrations instead of init_db # await init_db() - # Start maintenance schedule runner + # Start maintenance schedule runner + AI conversation cleanup scheduler.start() async with async_session_maker() as db: await load_all_schedules(db) + scheduler.add_job( + _cleanup_expired_ai_conversations, + trigger="interval", + hours=1, + id="cleanup_ai_conversations", + replace_existing=True, + ) # Auto-seed trees in background on PR environments seed_task = None diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 5bd2f3e6..3731740b 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -26,6 +26,8 @@ from .user_pinned_tree import UserPinnedTree from .target_list import TargetList from .maintenance_schedule import MaintenanceSchedule from .feedback import Feedback +from .ai_conversation import AIConversation +from .ai_usage import AIUsage __all__ = [ "User", @@ -63,4 +65,6 @@ __all__ = [ "TargetList", "MaintenanceSchedule", "Feedback", + "AIConversation", + "AIUsage", ] diff --git a/backend/app/models/account_limit_override.py b/backend/app/models/account_limit_override.py index 62d241ac..322c15df 100644 --- a/backend/app/models/account_limit_override.py +++ b/backend/app/models/account_limit_override.py @@ -24,6 +24,8 @@ class AccountLimitOverride(Base): override_max_trees: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) override_max_sessions_per_month: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) override_max_users: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + override_max_ai_builds_per_month: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + override_max_ai_builds_per_24h: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) note: Mapped[Optional[str]] = mapped_column(Text, nullable=True) created_by_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), diff --git a/backend/app/models/ai_conversation.py b/backend/app/models/ai_conversation.py new file mode 100644 index 00000000..00fa2faa --- /dev/null +++ b/backend/app/models/ai_conversation.py @@ -0,0 +1,67 @@ +"""AI Flow Builder conversation tracking. + +Stores wizard session state across the 4-stage flow builder process. +Conversations expire after 24 hours and are cleaned up by the scheduler. +""" +import uuid +from datetime import datetime, timezone +from typing import Optional, Any + +from sqlalchemy import String, DateTime, ForeignKey, Integer, Text +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import UUID, JSONB + +from app.core.database import Base + + +class AIConversation(Base): + __tablename__ = "ai_conversations" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + user_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + status: Mapped[str] = mapped_column( + String(20), + nullable=False, + default="foundation", + comment="foundation | scaffolding | detailing | reviewing | completed | expired", + ) + # Conversation history across all wizard stages + messages: Mapped[list[dict[str, Any]]] = mapped_column( + JSONB, nullable=False, default=list + ) + # Wizard state: Stage 1 metadata, Stage 2 branches, Stage 3 detail + wizard_state: Mapped[dict[str, Any]] = mapped_column( + JSONB, nullable=False, default=dict + ) + # Assembled tree from Stage 4 (null until assembly) + generated_tree: Mapped[Optional[dict[str, Any]]] = mapped_column( + JSONB, nullable=True + ) + # Tracks AI call count for per-flow limits + question_rounds: Mapped[int] = mapped_column( + Integer, nullable=False, default=0 + ) + expires_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) diff --git a/backend/app/models/ai_usage.py b/backend/app/models/ai_usage.py new file mode 100644 index 00000000..2cb8abd1 --- /dev/null +++ b/backend/app/models/ai_usage.py @@ -0,0 +1,69 @@ +"""AI usage tracking for quota enforcement and cost visibility. + +Every AI API call is recorded here. Only rows with counts_toward_quota=True +and succeeded=True are counted against the user's monthly quota. +""" +import uuid +from datetime import datetime, timezone +from typing import Optional, Any + +from sqlalchemy import String, DateTime, ForeignKey, Integer, Boolean, Numeric +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import UUID, JSONB + +from app.core.database import Base + + +class AIUsage(Base): + __tablename__ = "ai_usage" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + user_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + conversation_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("ai_conversations.id", ondelete="SET NULL"), + nullable=True, + ) + generation_type: Mapped[str] = mapped_column( + String(20), + nullable=False, + comment="scaffold | branch_detail | branch_suggest", + ) + tier_at_time: Mapped[str] = mapped_column( + String(20), + nullable=False, + comment="free | pro | team | enterprise", + ) + input_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + output_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + estimated_cost_usd: Mapped[float] = mapped_column( + Numeric(10, 6), nullable=False, default=0 + ) + succeeded: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + counts_toward_quota: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False + ) + error_code: Mapped[Optional[str]] = mapped_column( + String(100), nullable=True + ) + extra_data: Mapped[dict[str, Any]] = mapped_column( + "metadata", JSONB, nullable=False, default=dict + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + index=True, + ) diff --git a/backend/app/models/plan_limits.py b/backend/app/models/plan_limits.py index 1a6b0511..65bd0c3a 100644 --- a/backend/app/models/plan_limits.py +++ b/backend/app/models/plan_limits.py @@ -14,3 +14,7 @@ class PlanLimits(Base): custom_branding: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) priority_support: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) export_formats: Mapped[list] = mapped_column(JSONB, nullable=False, default=lambda: ["markdown", "text"]) + + # AI Flow Builder limits + max_ai_builds_per_month: Mapped[int | None] = mapped_column(Integer, nullable=True) + max_ai_builds_per_24h: Mapped[int | None] = mapped_column(Integer, nullable=True) diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 439482ef..275ed5ca 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -67,6 +67,11 @@ class User(Base): ) last_login: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + # AI billing cycle anchor (for quota reset calculation) + ai_billing_cycle_anchor_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), nullable=True + ) + # Soft delete deleted_at: Mapped[Optional[datetime]] = mapped_column( DateTime(timezone=True), diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py index e5fe8241..1dd46a2e 100644 --- a/backend/app/schemas/__init__.py +++ b/backend/app/schemas/__init__.py @@ -5,6 +5,11 @@ from .session import SessionCreate, SessionUpdate, SessionResponse, SessionExpor from .category import CategoryCreate, CategoryUpdate, CategoryResponse, CategoryListResponse from .tag import TagCreate, TagResponse, TagListResponse, TagAssignment from .folder import FolderCreate, FolderUpdate, FolderResponse, FolderListResponse, FolderReorderRequest, FolderTreeRequest +from .ai_builder import ( + AIStartRequest, AIScaffoldRequest, AIBranchDetailRequest, AIAssembleRequest, + AIStartResponse, AIScaffoldResponse, AIBranchDetailResponse, AIAssembleResponse, + AIQuotaStatusResponse, +) __all__ = [ # User @@ -21,4 +26,8 @@ __all__ = [ "TagCreate", "TagResponse", "TagListResponse", "TagAssignment", # Folder "FolderCreate", "FolderUpdate", "FolderResponse", "FolderListResponse", "FolderReorderRequest", "FolderTreeRequest", + # AI Builder + "AIStartRequest", "AIScaffoldRequest", "AIBranchDetailRequest", "AIAssembleRequest", + "AIStartResponse", "AIScaffoldResponse", "AIBranchDetailResponse", "AIAssembleResponse", + "AIQuotaStatusResponse", ] diff --git a/backend/app/schemas/ai_builder.py b/backend/app/schemas/ai_builder.py new file mode 100644 index 00000000..67e4e87d --- /dev/null +++ b/backend/app/schemas/ai_builder.py @@ -0,0 +1,116 @@ +"""Pydantic schemas for the AI Flow Builder wizard.""" +from typing import Any, Literal, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + + +# ── Requests ── + + +class AIStartRequest(BaseModel): + """Stage 1: Foundation — engineer provides flow metadata.""" + + flow_type: Literal["troubleshooting", "procedural"] = Field( + ..., description="Type of flow to generate" + ) + category_id: Optional[UUID] = None + name: str = Field(..., min_length=1, max_length=255) + description: str = Field("", max_length=2000) + environment_tags: list[str] = Field(default_factory=list) + + +class AIScaffoldRequest(BaseModel): + """Stage 2: Request AI-generated branch suggestions.""" + + conversation_id: UUID + + +class AIBranchDetailRequest(BaseModel): + """Stage 3: Request AI-generated detail for one branch.""" + + conversation_id: UUID + branch_name: str = Field(..., min_length=1, max_length=255) + + +class AIBranchUpdate(BaseModel): + """A branch with optional user edits for assembly.""" + + name: str + description: str = "" + steps: Optional[dict[str, Any]] = None + + +class AIAssembleRequest(BaseModel): + """Stage 4: Assemble selected branches into a complete tree.""" + + conversation_id: UUID + selected_branches: list[AIBranchUpdate] = Field(..., min_length=2) + + +# ── Responses ── + + +class AIStartResponse(BaseModel): + """Response after creating a conversation.""" + + conversation_id: UUID + status: str + + +class AIBranchSuggestion(BaseModel): + """A single branch suggestion from the AI.""" + + name: str + description: str + + +class AIScaffoldResponse(BaseModel): + """Response with AI-suggested branches.""" + + conversation_id: UUID + branches: list[AIBranchSuggestion] + status: str + + +class AIBranchDetailResponse(BaseModel): + """Response with AI-generated detail for one branch.""" + + conversation_id: UUID + branch_name: str + steps: dict[str, Any] + status: str + + +class AITreeSummary(BaseModel): + """Summary statistics for an assembled tree.""" + + node_count: int + decision_count: int + action_count: int + solution_count: int + depth: int + + +class AIAssembleResponse(BaseModel): + """Response with the fully assembled tree.""" + + tree_structure: dict[str, Any] + suggested_name: str + suggested_description: str + summary: AITreeSummary + status: str + + +class AIQuotaStatusResponse(BaseModel): + """Current user's AI quota status.""" + + plan: str + monthly_used: int + monthly_limit: Optional[int] + monthly_reset_at: str + daily_used: int + daily_limit: Optional[int] + daily_reset_at: str + allowed: bool + ai_enabled: bool diff --git a/backend/requirements.txt b/backend/requirements.txt index 47d6eaf8..7c8b5493 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -31,6 +31,9 @@ resend==2.21.0 # HTTP client (seed scripts, internal API calls) httpx>=0.27.0 +# AI Flow Builder +anthropic>=0.40.0 + # Utilities python-dotenv==1.0.1 croniter>=2.0.0 diff --git a/backend/tests/test_ai_endpoints.py b/backend/tests/test_ai_endpoints.py new file mode 100644 index 00000000..3eb7b596 --- /dev/null +++ b/backend/tests/test_ai_endpoints.py @@ -0,0 +1,360 @@ +"""Integration tests for AI Flow Builder endpoints. + +All Anthropic API calls are mocked — zero real API spend. +""" +import json +from unittest.mock import AsyncMock, patch, MagicMock + +import pytest + +from app.core.config import settings + + +# ── Sample AI responses ── + +SCAFFOLD_RESPONSE_JSON = json.dumps({ + "branches": [ + {"name": "Service Not Running", "description": "The target service is stopped or crashed."}, + {"name": "Authentication Failures", "description": "Users cannot authenticate against the service."}, + {"name": "Network Connectivity", "description": "Network-level issues preventing access."}, + {"name": "Configuration Errors", "description": "Misconfiguration of the service or its dependencies."}, + ] +}) + +BRANCH_DETAIL_JSON = json.dumps({ + "id": "svc-root", + "type": "decision", + "question": "Is the service running?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "svc-check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "svc-restart"}, + ], + "children": [ + { + "id": "svc-check-logs", + "type": "action", + "title": "Check Event Logs", + "description": "Check Windows Event Viewer for errors.", + "commands": ["Get-EventLog -LogName Application -Newest 20"], + "children": [ + { + "id": "svc-logs-resolved", + "type": "solution", + "title": "Issue Found in Logs", + "description": "Error identified and resolved.", + "resolution_steps": ["Fix the error", "Restart service"], + } + ], + }, + { + "id": "svc-restart", + "type": "action", + "title": "Restart Service", + "description": "Attempt to restart the service.", + "commands": ["Restart-Service -Name 'TestService'"], + "children": [ + { + "id": "svc-restart-ok", + "type": "solution", + "title": "Service Restored", + "description": "Service is running after restart.", + "resolution_steps": ["Verify connectivity", "Document in ticket"], + } + ], + }, + ], +}) + + +def _mock_anthropic_response(text: str, input_tokens: int = 100, output_tokens: int = 200): + """Create a mock Anthropic API response.""" + response = MagicMock() + response.content = [MagicMock(text=text)] + response.usage = MagicMock(input_tokens=input_tokens, output_tokens=output_tokens) + return response + + +@pytest.fixture +def enable_ai(): + """Temporarily enable AI by setting a fake API key.""" + original = settings.ANTHROPIC_API_KEY + settings.ANTHROPIC_API_KEY = "test-key-fake" + yield + settings.ANTHROPIC_API_KEY = original + + +@pytest.fixture +def disable_ai(): + """Ensure AI is disabled.""" + original = settings.ANTHROPIC_API_KEY + settings.ANTHROPIC_API_KEY = None + yield + settings.ANTHROPIC_API_KEY = original + + +# ── Quota endpoint ── + + +@pytest.mark.asyncio +async def test_quota_returns_disabled_when_no_key(client, auth_headers, disable_ai): + """GET /ai/quota returns ai_enabled=false when no API key.""" + response = await client.get("/api/v1/ai/quota", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["ai_enabled"] is False + assert data["allowed"] is False + + +@pytest.mark.asyncio +async def test_quota_returns_enabled_with_key(client, auth_headers, enable_ai): + """GET /ai/quota returns ai_enabled=true with API key configured.""" + response = await client.get("/api/v1/ai/quota", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["ai_enabled"] is True + assert data["allowed"] is True + + +# ── Start endpoint ── + + +@pytest.mark.asyncio +async def test_start_requires_auth(client, enable_ai): + """POST /ai/start requires authentication.""" + response = await client.post("/api/v1/ai/start", json={ + "flow_type": "troubleshooting", + "name": "Test Flow", + "description": "Test", + }) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_start_returns_503_when_disabled(client, auth_headers, disable_ai): + """POST /ai/start returns 503 when AI is not configured.""" + response = await client.post( + "/api/v1/ai/start", + json={ + "flow_type": "troubleshooting", + "name": "Test Flow", + "description": "Test description", + }, + headers=auth_headers, + ) + assert response.status_code == 503 + + +@pytest.mark.asyncio +async def test_start_creates_conversation(client, auth_headers, enable_ai): + """POST /ai/start creates a conversation and returns conversation_id.""" + response = await client.post( + "/api/v1/ai/start", + json={ + "flow_type": "troubleshooting", + "name": "DNS Issues", + "description": "Troubleshooting DNS resolution failures", + "environment_tags": ["Windows Server", "Active Directory"], + }, + headers=auth_headers, + ) + assert response.status_code == 201 + data = response.json() + assert "conversation_id" in data + assert data["status"] == "foundation" + + +@pytest.mark.asyncio +async def test_start_validates_input(client, auth_headers, enable_ai): + """POST /ai/start rejects invalid input.""" + response = await client.post( + "/api/v1/ai/start", + json={ + "flow_type": "troubleshooting", + "name": "", # Empty name + "description": "Test", + }, + headers=auth_headers, + ) + assert response.status_code == 422 + + +# ── Scaffold endpoint ── + + +@pytest.mark.asyncio +async def test_scaffold_success(client, auth_headers, enable_ai): + """POST /ai/scaffold returns AI-generated branches.""" + # Create conversation first + start_resp = await client.post( + "/api/v1/ai/start", + json={ + "flow_type": "troubleshooting", + "name": "DNS Issues", + "description": "DNS resolution failures", + }, + headers=auth_headers, + ) + conversation_id = start_resp.json()["conversation_id"] + + # Mock Anthropic + mock_response = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service._get_client") as mock_client: + mock_client.return_value.messages.create = AsyncMock(return_value=mock_response) + + response = await client.post( + "/api/v1/ai/scaffold", + json={"conversation_id": conversation_id}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "scaffolding" + assert len(data["branches"]) == 4 + assert data["branches"][0]["name"] == "Service Not Running" + + +@pytest.mark.asyncio +async def test_scaffold_invalid_conversation(client, auth_headers, enable_ai): + """POST /ai/scaffold returns 404 for nonexistent conversation.""" + response = await client.post( + "/api/v1/ai/scaffold", + json={"conversation_id": "00000000-0000-0000-0000-000000000000"}, + headers=auth_headers, + ) + assert response.status_code == 404 + + +# ── Branch detail endpoint ── + + +@pytest.mark.asyncio +async def test_branch_detail_success(client, auth_headers, enable_ai): + """POST /ai/branch-detail returns AI-generated branch nodes.""" + # Create and scaffold first + start_resp = await client.post( + "/api/v1/ai/start", + json={ + "flow_type": "troubleshooting", + "name": "Service Issues", + "description": "Service troubleshooting", + }, + headers=auth_headers, + ) + conversation_id = start_resp.json()["conversation_id"] + + scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service._get_client") as mock_client: + mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock) + await client.post( + "/api/v1/ai/scaffold", + json={"conversation_id": conversation_id}, + headers=auth_headers, + ) + + # Now generate branch detail + detail_mock = _mock_anthropic_response(BRANCH_DETAIL_JSON) + with patch("app.core.ai_tree_generator_service._get_client") as mock_client: + mock_client.return_value.messages.create = AsyncMock(return_value=detail_mock) + + response = await client.post( + "/api/v1/ai/branch-detail", + json={ + "conversation_id": conversation_id, + "branch_name": "Service Not Running", + }, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["branch_name"] == "Service Not Running" + assert data["steps"]["id"] == "svc-root" + assert data["steps"]["type"] == "decision" + + +# ── Assemble endpoint ── + + +@pytest.mark.asyncio +async def test_assemble_success(client, auth_headers, enable_ai): + """POST /ai/assemble returns assembled tree from branches with detail.""" + # Create conversation + start_resp = await client.post( + "/api/v1/ai/start", + json={ + "flow_type": "troubleshooting", + "name": "Service Issues", + "description": "Service troubleshooting", + }, + headers=auth_headers, + ) + conversation_id = start_resp.json()["conversation_id"] + + # Scaffold + scaffold_mock = _mock_anthropic_response(SCAFFOLD_RESPONSE_JSON) + with patch("app.core.ai_tree_generator_service._get_client") as mock_client: + mock_client.return_value.messages.create = AsyncMock(return_value=scaffold_mock) + await client.post( + "/api/v1/ai/scaffold", + json={"conversation_id": conversation_id}, + headers=auth_headers, + ) + + # Assemble with branch detail included + branch_tree = json.loads(BRANCH_DETAIL_JSON) + response = await client.post( + "/api/v1/ai/assemble", + json={ + "conversation_id": conversation_id, + "selected_branches": [ + { + "name": "Service Not Running", + "description": "The target service is stopped.", + "steps": branch_tree, + }, + { + "name": "Authentication Failures", + "description": "Users cannot authenticate.", + "steps": branch_tree, + }, + ], + }, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed" + assert data["suggested_name"] == "Service Issues" + assert "tree_structure" in data + assert data["tree_structure"]["type"] == "decision" + assert data["summary"]["node_count"] > 0 + assert data["summary"]["solution_count"] >= 2 + + +@pytest.mark.asyncio +async def test_assemble_requires_min_2_branches(client, auth_headers, enable_ai): + """POST /ai/assemble rejects fewer than 2 branches.""" + start_resp = await client.post( + "/api/v1/ai/start", + json={ + "flow_type": "troubleshooting", + "name": "Test", + "description": "Test", + }, + headers=auth_headers, + ) + conversation_id = start_resp.json()["conversation_id"] + + response = await client.post( + "/api/v1/ai/assemble", + json={ + "conversation_id": conversation_id, + "selected_branches": [ + {"name": "Only Branch", "description": "Just one"}, + ], + }, + headers=auth_headers, + ) + assert response.status_code == 422 diff --git a/backend/tests/test_ai_tree_validator.py b/backend/tests/test_ai_tree_validator.py new file mode 100644 index 00000000..6e95545d --- /dev/null +++ b/backend/tests/test_ai_tree_validator.py @@ -0,0 +1,183 @@ +"""Tests for AI-generated tree structure validation.""" +import pytest + +from app.core.ai_tree_validator import validate_generated_tree, count_tree_stats + + +def _make_valid_tree(): + """Helper: minimal valid tree for testing.""" + return { + "id": "root", + "type": "decision", + "question": "Is the service running?", + "options": [ + {"id": "opt-yes", "label": "Yes", "next_node_id": "check-logs"}, + {"id": "opt-no", "label": "No", "next_node_id": "restart-service"}, + ], + "children": [ + { + "id": "check-logs", + "type": "decision", + "question": "Are there errors in the logs?", + "options": [ + {"id": "opt-errors", "label": "Yes", "next_node_id": "fix-errors"}, + {"id": "opt-clean", "label": "No", "next_node_id": "escalate"}, + ], + "children": [ + { + "id": "fix-errors", + "type": "solution", + "title": "Fix Errors", + "description": "Apply the fix for the errors found.", + }, + { + "id": "escalate", + "type": "solution", + "title": "Escalate", + "description": "No errors found; escalate to Tier 2.", + }, + ], + }, + { + "id": "restart-service", + "type": "action", + "title": "Restart the Service", + "description": "Restart the service and verify.", + "commands": ["Restart-Service -Name 'TestService'"], + "children": [ + { + "id": "service-resolved", + "type": "solution", + "title": "Service Restored", + "description": "Service is running after restart.", + }, + ], + }, + ], + } + + +class TestValidTree: + def test_valid_tree_passes(self): + errors = validate_generated_tree(_make_valid_tree()) + assert errors == [] + + def test_not_a_dict(self): + errors = validate_generated_tree("not a dict") + assert any("must be a JSON object" in e for e in errors) + + def test_root_not_decision(self): + tree = _make_valid_tree() + tree["type"] = "action" + tree["title"] = "Fake" + errors = validate_generated_tree(tree) + assert any("Root node must be type 'decision'" in e for e in errors) + + +class TestNodeValidation: + def test_missing_id(self): + tree = _make_valid_tree() + del tree["children"][0]["id"] + errors = validate_generated_tree(tree) + assert any("missing 'id'" in e for e in errors) + + def test_duplicate_ids(self): + tree = _make_valid_tree() + tree["children"][1]["id"] = "check-logs" # same as sibling + errors = validate_generated_tree(tree) + assert any("Duplicate node ID" in e for e in errors) + + def test_invalid_node_type(self): + tree = _make_valid_tree() + tree["children"][0]["type"] = "unknown" + errors = validate_generated_tree(tree) + assert any("invalid type" in e for e in errors) + + def test_decision_missing_options(self): + tree = _make_valid_tree() + del tree["children"][0]["options"] + errors = validate_generated_tree(tree) + assert any("missing fields" in e for e in errors) + + def test_decision_less_than_2_options(self): + tree = _make_valid_tree() + tree["children"][0]["options"] = [ + {"id": "opt-1", "label": "Only", "next_node_id": "fix-errors"} + ] + errors = validate_generated_tree(tree) + assert any("at least 2 options" in e for e in errors) + + +class TestReferenceIntegrity: + def test_option_references_nonexistent_child(self): + tree = _make_valid_tree() + tree["options"][0]["next_node_id"] = "nonexistent" + errors = validate_generated_tree(tree) + assert any("non-existent child" in e for e in errors) + + def test_duplicate_option_ids(self): + tree = _make_valid_tree() + tree["options"][0]["id"] = "same" + tree["options"][1]["id"] = "same" + errors = validate_generated_tree(tree) + assert any("Duplicate option ID" in e for e in errors) + + +class TestGlobalChecks: + def test_too_few_nodes(self): + tree = { + "id": "root", + "type": "decision", + "question": "Test?", + "options": [ + {"id": "o1", "label": "A", "next_node_id": "s1"}, + {"id": "o2", "label": "B", "next_node_id": "s2"}, + ], + "children": [ + {"id": "s1", "type": "solution", "title": "S1", "description": "D1"}, + {"id": "s2", "type": "solution", "title": "S2", "description": "D2"}, + ], + } + errors = validate_generated_tree(tree) + assert any("Minimum 5 required" in e for e in errors) + + def test_too_few_solutions(self): + tree = _make_valid_tree() + # Remove all solutions except one — replace children of check-logs + tree["children"][0]["children"] = [ + { + "id": "only-solution", + "type": "solution", + "title": "Only", + "description": "Only solution", + } + ] + tree["children"][0]["options"] = [ + {"id": "o1", "label": "A", "next_node_id": "only-solution"}, + {"id": "o2", "label": "B", "next_node_id": "only-solution"}, + ] + # Now restart-service branch has 1 solution, check-logs has 1 = total 2 + # Remove one more to get to 1 + tree["children"][1]["children"] = [] + errors = validate_generated_tree(tree) + assert any("solution" in e.lower() for e in errors) + + +class TestDeadEndDetection: + def test_dead_end_action_node(self): + tree = _make_valid_tree() + # Remove restart-service's children — becomes dead end + tree["children"][1]["children"] = [] + errors = validate_generated_tree(tree) + assert any("dead end" in e for e in errors) + + +class TestCountTreeStats: + def test_stats_correct(self): + tree = _make_valid_tree() + stats = count_tree_stats(tree) + assert stats["node_count"] == 6 + assert stats["decision_count"] == 2 + assert stats["action_count"] == 1 + assert stats["solution_count"] == 3 + assert stats["depth"] >= 3 diff --git a/frontend/src/api/aiBuilder.ts b/frontend/src/api/aiBuilder.ts new file mode 100644 index 00000000..521b7a99 --- /dev/null +++ b/frontend/src/api/aiBuilder.ts @@ -0,0 +1,61 @@ +import { apiClient } from './client' +import type { + AIQuotaStatus, + AIStartResponse, + AIScaffoldResponse, + AIBranchDetailResponse, + AIAssembleResponse, +} from '@/types' + +export const aiBuilderApi = { + getQuota: async (): Promise => { + const { data } = await apiClient.get('/ai/quota') + return data + }, + + start: async (params: { + flow_type: 'troubleshooting' | 'procedural' + name: string + description: string + environment_tags?: string[] + category_id?: string + }): Promise => { + const { data } = await apiClient.post('/ai/start', params) + return data + }, + + scaffold: async (conversationId: string): Promise => { + const { data } = await apiClient.post('/ai/scaffold', { + conversation_id: conversationId, + }) + return data + }, + + branchDetail: async ( + conversationId: string, + branchName: string + ): Promise => { + const { data } = await apiClient.post('/ai/branch-detail', { + conversation_id: conversationId, + branch_name: branchName, + }) + return data + }, + + assemble: async ( + conversationId: string, + selectedBranches: Array<{ + name: string + description: string + steps?: Record + }> + ): Promise => { + const { data } = await apiClient.post('/ai/assemble', { + conversation_id: conversationId, + selected_branches: selectedBranches, + }) + return data + }, +} + +export default aiBuilderApi diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index caf7c6df..fba65d24 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -16,3 +16,4 @@ export { default as analyticsApi } from './analytics' export { targetListsApi } from './targetLists' export { maintenanceSchedulesApi, batchLaunchApi } from './maintenanceSchedules' export { default as feedbackApi } from './feedback' +export { default as aiBuilderApi } from './aiBuilder' diff --git a/frontend/src/components/ai-builder/AIFlowBuilderModal.tsx b/frontend/src/components/ai-builder/AIFlowBuilderModal.tsx new file mode 100644 index 00000000..5d1b0494 --- /dev/null +++ b/frontend/src/components/ai-builder/AIFlowBuilderModal.tsx @@ -0,0 +1,135 @@ +import { useEffect } from 'react' +import { useNavigate } from 'react-router-dom' +import { Modal } from '@/components/common/Modal' +import { useAIFlowBuilderStore } from '@/store/aiFlowBuilderStore' +import { treesApi } from '@/api/trees' +import { toast } from '@/lib/toast' +import { WizardStepIndicator } from './WizardStepIndicator' +import { FoundationForm } from './FoundationForm' +import { BranchSelector } from './BranchSelector' +import { BranchDetailView } from './BranchDetailView' +import { TreePreviewCard } from './TreePreviewCard' +import { GeneratingAnimation } from './GeneratingAnimation' + +interface AIFlowBuilderModalProps { + isOpen: boolean + onClose: () => void +} + +export function AIFlowBuilderModal({ isOpen, onClose }: AIFlowBuilderModalProps) { + const navigate = useNavigate() + const { + phase, + metadata, + assembledTree, + loadQuota, + scaffold, + reset, + } = useAIFlowBuilderStore() + + // Load quota when modal opens + useEffect(() => { + if (isOpen) { + loadQuota() + } + }, [isOpen, loadQuota]) + + // Auto-trigger scaffold after conversation starts + useEffect(() => { + if (phase === 'scaffolding' && !useAIFlowBuilderStore.getState().suggestedBranches.length) { + scaffold() + } + }, [phase, scaffold]) + + const handleClose = () => { + reset() + onClose() + } + + const handleOpenInEditor = async () => { + if (!assembledTree) return + try { + const tree = await treesApi.create({ + name: assembledTree.suggested_name, + description: assembledTree.suggested_description, + tree_structure: assembledTree.tree_structure, + tree_type: metadata.flow_type, + }) + handleClose() + const editorPath = + metadata.flow_type === 'procedural' + ? `/flows/${tree.id}/edit` + : `/trees/${tree.id}/edit` + navigate(editorPath) + } catch { + toast.error('Failed to create flow. Please try again.') + } + } + + const getTitle = () => { + switch (phase) { + case 'foundation': + return 'Build with AI' + case 'scaffolding': + case 'generating': + return 'AI Scaffold' + case 'detailing': + return 'Branch Detail' + case 'reviewing': + return 'Review & Assemble' + case 'error': + return 'AI Flow Builder' + default: + return 'Build with AI' + } + } + + return ( + + } + > + {phase === 'foundation' && } + {phase === 'scaffolding' && } + {phase === 'generating' && } + {phase === 'detailing' && } + {phase === 'reviewing' && ( + + )} + {phase === 'error' && } + + ) +} + +function ErrorView() { + const { error, reset, setPhase } = useAIFlowBuilderStore() + + return ( +
+
+ {error || 'An unexpected error occurred.'} +
+
+ + +
+
+ ) +} diff --git a/frontend/src/components/ai-builder/BranchDetailView.tsx b/frontend/src/components/ai-builder/BranchDetailView.tsx new file mode 100644 index 00000000..8f4eb31d --- /dev/null +++ b/frontend/src/components/ai-builder/BranchDetailView.tsx @@ -0,0 +1,208 @@ +import { useState } from 'react' +import { Check, RefreshCw, SkipForward, ChevronRight, ChevronLeft } from 'lucide-react' +import { useAIFlowBuilderStore } from '@/store/aiFlowBuilderStore' +import { GeneratingAnimation } from './GeneratingAnimation' +import { cn } from '@/lib/utils' + +export function BranchDetailView() { + const { + selectedBranches, + generateBranchDetail, + assemble, + isLoading, + error, + phase, + setError, + } = useAIFlowBuilderStore() + + const [viewingIndex, setViewingIndex] = useState(0) + const currentBranch = selectedBranches[viewingIndex] + + const allBranchesHaveDetail = selectedBranches.every((b) => b.steps) + const branchesWithDetail = selectedBranches.filter((b) => b.steps).length + + const handleGenerate = async (branchName: string) => { + setError(null) + await generateBranchDetail(branchName) + } + + const handleAssemble = async () => { + await assemble() + } + + if (phase === 'generating' && isLoading) { + return + } + + return ( +
+ {/* Branch tabs */} +
+ {selectedBranches.map((branch, i) => ( + + ))} +
+ + {/* Current branch detail */} + {currentBranch && ( +
+
+
+

{currentBranch.name}

+

{currentBranch.description}

+
+
+ + {currentBranch.steps ? ( +
+ {/* Mini tree preview */} +
+ +
+ +
+ +
+
+ ) : ( +
+

+ Generate AI detail for this branch +

+
+ + +
+
+ )} +
+ )} + + {/* Error */} + {error && ( +
+ {error} +
+ )} + + {/* Navigation */} +
+
+ + +
+ +
+ + {branchesWithDetail}/{selectedBranches.length} detailed + + +
+
+
+ ) +} + +/** Recursive mini-preview of a node tree */ +function NodePreview({ node, depth }: { node: Record; depth: number }) { + const type = node.type as string + const label = + type === 'decision' + ? (node.question as string) + : (node.title as string) || 'Untitled' + const children = (node.children as Record[]) || [] + + const typeColors: Record = { + decision: 'bg-blue-400', + action: 'bg-amber-400', + solution: 'bg-green-400', + } + + return ( +
+
+
+ {label} + {type} +
+ {children.map((child, i) => ( + + ))} +
+ ) +} diff --git a/frontend/src/components/ai-builder/BranchSelector.tsx b/frontend/src/components/ai-builder/BranchSelector.tsx new file mode 100644 index 00000000..458f5b80 --- /dev/null +++ b/frontend/src/components/ai-builder/BranchSelector.tsx @@ -0,0 +1,280 @@ +import { useState } from 'react' +import { GripVertical, Plus, X, Pencil, Check } from 'lucide-react' +import { useAIFlowBuilderStore } from '@/store/aiFlowBuilderStore' +import { cn } from '@/lib/utils' +import type { AIBranch } from '@/types' + +export function BranchSelector() { + const { + suggestedBranches, + selectedBranches, + selectBranches, + setPhase, + error, + } = useAIFlowBuilderStore() + + const [editingIndex, setEditingIndex] = useState(null) + const [editName, setEditName] = useState('') + const [editDesc, setEditDesc] = useState('') + const [showAddForm, setShowAddForm] = useState(false) + const [newName, setNewName] = useState('') + const [newDesc, setNewDesc] = useState('') + + const toggleBranch = (branch: AIBranch) => { + const isSelected = selectedBranches.some((b) => b.name === branch.name) + if (isSelected) { + selectBranches(selectedBranches.filter((b) => b.name !== branch.name)) + } else { + selectBranches([...selectedBranches, branch]) + } + } + + const startEditing = (index: number) => { + const branch = selectedBranches[index] + setEditingIndex(index) + setEditName(branch.name) + setEditDesc(branch.description) + } + + const saveEdit = () => { + if (editingIndex === null || !editName.trim()) return + const updated = [...selectedBranches] + updated[editingIndex] = { + ...updated[editingIndex], + name: editName.trim(), + description: editDesc.trim(), + } + selectBranches(updated) + setEditingIndex(null) + } + + const addCustomBranch = () => { + if (!newName.trim()) return + const branch: AIBranch = { + name: newName.trim(), + description: newDesc.trim(), + isCustom: true, + } + selectBranches([...selectedBranches, branch]) + setNewName('') + setNewDesc('') + setShowAddForm(false) + } + + const moveBranch = (fromIndex: number, direction: 'up' | 'down') => { + const toIndex = direction === 'up' ? fromIndex - 1 : fromIndex + 1 + if (toIndex < 0 || toIndex >= selectedBranches.length) return + const updated = [...selectedBranches] + ;[updated[fromIndex], updated[toIndex]] = [updated[toIndex], updated[fromIndex]] + selectBranches(updated) + } + + const canProceed = selectedBranches.length >= 2 + + return ( +
+
+

+ AI suggested {suggestedBranches.length} branches. Select, reorder, rename, or add your own. +

+
+ + {/* Branch list */} +
+ {suggestedBranches.map((branch) => { + const isSelected = selectedBranches.some((b) => b.name === branch.name) + const selectedIndex = selectedBranches.findIndex((b) => b.name === branch.name) + + return ( +
toggleBranch(branch)} + > +
+ {isSelected && } +
+
+ {editingIndex !== null && selectedIndex === editingIndex ? ( +
e.stopPropagation()}> + setEditName(e.target.value)} + className="w-full rounded border border-border bg-card px-2 py-1 text-sm text-foreground focus:border-primary focus:outline-none" + autoFocus + /> + setEditDesc(e.target.value)} + className="w-full rounded border border-border bg-card px-2 py-1 text-xs text-muted-foreground focus:border-primary focus:outline-none" + /> +
+ + +
+
+ ) : ( + <> +
{branch.name}
+
{branch.description}
+ + )} +
+ {isSelected && editingIndex !== selectedIndex && ( +
e.stopPropagation()}> + + +
+ )} +
+ ) + })} + + {/* Custom branches (not in suggested) */} + {selectedBranches + .filter((b) => b.isCustom) + .map((branch, i) => { + return ( +
+
+ +
+
+
{branch.name}
+
{branch.description}
+ + Custom + +
+ +
+ ) + })} +
+ + {/* Add custom branch */} + {showAddForm ? ( +
+ setNewName(e.target.value)} + placeholder="Branch name" + className="w-full rounded border border-border bg-card px-2 py-1.5 text-sm text-foreground placeholder:text-muted-foreground focus:border-primary focus:outline-none" + autoFocus + /> + setNewDesc(e.target.value)} + placeholder="Brief description" + className="w-full rounded border border-border bg-card px-2 py-1.5 text-xs text-muted-foreground placeholder:text-muted-foreground/60 focus:border-primary focus:outline-none" + /> +
+ + +
+
+ ) : ( + + )} + + {/* Error */} + {error && ( +
+ {error} +
+ )} + + {/* Footer */} +
+ + {selectedBranches.length} branch{selectedBranches.length !== 1 ? 'es' : ''} selected (min 2) + + +
+
+ ) +} diff --git a/frontend/src/components/ai-builder/FoundationForm.tsx b/frontend/src/components/ai-builder/FoundationForm.tsx new file mode 100644 index 00000000..ef07dff4 --- /dev/null +++ b/frontend/src/components/ai-builder/FoundationForm.tsx @@ -0,0 +1,163 @@ +import { useState } from 'react' +import { useAIFlowBuilderStore } from '@/store/aiFlowBuilderStore' +import { QuotaDisplay } from './QuotaDisplay' +import { cn } from '@/lib/utils' + +export function FoundationForm() { + const { metadata, setMetadata, quota, start, isLoading, error } = useAIFlowBuilderStore() + const [tagInput, setTagInput] = useState('') + + const canSubmit = + metadata.name.trim().length > 0 && + metadata.description.trim().length > 0 && + !isLoading && + (quota?.allowed !== false) + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault() + if (!canSubmit) return + await start() + } + + const addTag = () => { + const tag = tagInput.trim() + if (tag && !metadata.environment_tags.includes(tag)) { + setMetadata({ environment_tags: [...metadata.environment_tags, tag] }) + } + setTagInput('') + } + + const removeTag = (tag: string) => { + setMetadata({ environment_tags: metadata.environment_tags.filter((t) => t !== tag) }) + } + + return ( +
+ {quota && } + + {/* Flow Type */} +
+ +
+ {(['troubleshooting', 'procedural'] as const).map((type) => ( + + ))} +
+
+ + {/* Name */} +
+ + setMetadata({ name: e.target.value })} + placeholder="e.g. DNS Resolution Failures" + className="w-full rounded-lg border border-border bg-card px-3 py-2 text-sm text-foreground placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary/20" + maxLength={255} + /> +
+ + {/* Description */} +
+ +