feat: AI-assisted flow builder with 4-stage wizard
Implements the complete AI flow builder feature using a guided 4-stage wizard (Foundation → Scaffold → Branch Detail → Review & Assemble). AI assists at bounded points using Claude Haiku for cost-efficient structured JSON generation (~$0.01-0.03/flow). Backend: new models (ai_conversations, ai_usage), Alembic migration, quota enforcement with billing anchor, Anthropic API integration with prompt caching, tree validation, conversation CRUD with 24h TTL, APScheduler cleanup job, 5 API endpoints, Pydantic schemas. Frontend: TypeScript types, API client, Zustand store for wizard state, 7 components (modal, step indicator, foundation form, branch selector, branch detail view, tree preview, quota display), MyTreesPage integration with "Build with AI" button (hidden when AI not configured). Tests: 14 validator unit tests + 11 endpoint integration tests with mocked Anthropic (zero real API spend). All 25 tests passing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
181
backend/app/core/ai_quota_service.py
Normal file
181
backend/app/core/ai_quota_service.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""AI generation quota management.
|
||||
|
||||
Enforces monthly and daily limits on AI flow builder usage.
|
||||
Monthly quota consumed only on successful tree assembly (counts_toward_quota=True).
|
||||
Daily limit is an anti-abuse guard consumed on conversation start.
|
||||
"""
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.ai_usage import AIUsage
|
||||
from app.models.plan_limits import PlanLimits
|
||||
from app.models.account_limit_override import AccountLimitOverride
|
||||
from app.core.subscriptions import get_account_subscription, get_plan_limits
|
||||
|
||||
|
||||
async def get_user_plan(account_id: Optional[UUID], db: AsyncSession) -> str:
|
||||
"""Get the plan tier for an account."""
|
||||
if not account_id:
|
||||
return "free"
|
||||
sub = await get_account_subscription(account_id, db)
|
||||
if sub is None:
|
||||
return "free"
|
||||
return sub.plan if sub.plan else "free"
|
||||
|
||||
|
||||
async def _get_effective_limits(
|
||||
account_id: UUID, plan: str, db: AsyncSession
|
||||
) -> tuple[Optional[int], Optional[int]]:
|
||||
"""Get effective AI limits (monthly, daily), applying account overrides.
|
||||
|
||||
Returns (monthly_limit, daily_limit). None means unlimited.
|
||||
"""
|
||||
limits = await get_plan_limits(plan, db)
|
||||
monthly = limits.max_ai_builds_per_month if limits else None
|
||||
daily = limits.max_ai_builds_per_24h if limits else None
|
||||
|
||||
# Check for account-level overrides
|
||||
result = await db.execute(
|
||||
select(AccountLimitOverride).where(
|
||||
AccountLimitOverride.account_id == account_id
|
||||
)
|
||||
)
|
||||
override = result.scalar_one_or_none()
|
||||
if override:
|
||||
if override.override_max_ai_builds_per_month is not None:
|
||||
monthly = override.override_max_ai_builds_per_month
|
||||
if override.override_max_ai_builds_per_24h is not None:
|
||||
daily = override.override_max_ai_builds_per_24h
|
||||
|
||||
return monthly, daily
|
||||
|
||||
|
||||
def _get_billing_anchor_month_start(anchor: Optional[datetime]) -> datetime:
|
||||
"""Calculate the start of the current billing month from the anchor date.
|
||||
|
||||
If the anchor is day 15, the billing month runs from the 15th of each month.
|
||||
Falls back to calendar month if anchor is None.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
if not anchor:
|
||||
return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
anchor_day = min(anchor.day, 28) # Clamp to avoid month overflow
|
||||
this_month_anchor = now.replace(
|
||||
day=anchor_day, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
if now >= this_month_anchor:
|
||||
return this_month_anchor
|
||||
else:
|
||||
# We're before the anchor day, so billing month started last month
|
||||
if now.month == 1:
|
||||
return this_month_anchor.replace(year=now.year - 1, month=12)
|
||||
else:
|
||||
return this_month_anchor.replace(month=now.month - 1)
|
||||
|
||||
|
||||
async def check_ai_quota(
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
db: AsyncSession,
|
||||
billing_anchor: Optional[datetime] = None,
|
||||
) -> tuple[bool, dict]:
|
||||
"""Check if user can make an AI generation.
|
||||
|
||||
Returns (allowed, quota_status_dict).
|
||||
Monthly counts only rows with counts_toward_quota=True.
|
||||
Daily counts only rows with generation_type in ('scaffold', 'branch_detail').
|
||||
"""
|
||||
plan = await get_user_plan(account_id, db)
|
||||
monthly_limit, daily_limit = await _get_effective_limits(account_id, plan, db)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
month_start = _get_billing_anchor_month_start(billing_anchor)
|
||||
day_start = now - timedelta(hours=24)
|
||||
|
||||
# Monthly: count successful quota-consuming records
|
||||
monthly_count = await db.scalar(
|
||||
select(func.count(AIUsage.id)).where(
|
||||
AIUsage.user_id == user_id,
|
||||
AIUsage.counts_toward_quota == True, # noqa: E712
|
||||
AIUsage.created_at >= month_start,
|
||||
)
|
||||
) or 0
|
||||
|
||||
# Daily: count all AI API calls (scaffold + branch_detail) in last 24h
|
||||
daily_count = await db.scalar(
|
||||
select(func.count(AIUsage.id)).where(
|
||||
AIUsage.user_id == user_id,
|
||||
AIUsage.succeeded == True, # noqa: E712
|
||||
AIUsage.generation_type.in_(["scaffold", "branch_detail"]),
|
||||
AIUsage.created_at >= day_start,
|
||||
)
|
||||
) or 0
|
||||
|
||||
allowed = True
|
||||
deny_reason = None
|
||||
if monthly_limit is not None and monthly_count >= monthly_limit:
|
||||
allowed = False
|
||||
deny_reason = "monthly"
|
||||
if daily_limit is not None and daily_count >= daily_limit:
|
||||
allowed = False
|
||||
deny_reason = "daily"
|
||||
|
||||
# Calculate reset timestamps
|
||||
monthly_reset_at = month_start.replace(
|
||||
month=month_start.month % 12 + 1,
|
||||
year=month_start.year + (1 if month_start.month == 12 else 0),
|
||||
)
|
||||
daily_reset_at = day_start + timedelta(hours=24)
|
||||
|
||||
return allowed, {
|
||||
"plan": plan,
|
||||
"monthly_used": monthly_count,
|
||||
"monthly_limit": monthly_limit,
|
||||
"monthly_reset_at": monthly_reset_at.isoformat(),
|
||||
"daily_used": daily_count,
|
||||
"daily_limit": daily_limit,
|
||||
"daily_reset_at": daily_reset_at.isoformat(),
|
||||
"allowed": allowed,
|
||||
"deny_reason": deny_reason,
|
||||
}
|
||||
|
||||
|
||||
async def record_ai_usage(
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
conversation_id: Optional[UUID],
|
||||
generation_type: str,
|
||||
tier: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
estimated_cost: float,
|
||||
succeeded: bool,
|
||||
counts_toward_quota: bool,
|
||||
error_code: Optional[str],
|
||||
extra_data: Optional[dict],
|
||||
db: AsyncSession,
|
||||
) -> AIUsage:
|
||||
"""Record an AI usage entry."""
|
||||
usage = AIUsage(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
conversation_id=conversation_id,
|
||||
generation_type=generation_type,
|
||||
tier_at_time=tier,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
estimated_cost_usd=estimated_cost,
|
||||
succeeded=succeeded,
|
||||
counts_toward_quota=counts_toward_quota,
|
||||
error_code=error_code,
|
||||
extra_data=extra_data or {},
|
||||
)
|
||||
db.add(usage)
|
||||
await db.flush()
|
||||
return usage
|
||||
Reference in New Issue
Block a user