Implements three-phase AI assistant feature: - Phase 0: RAG infrastructure with pgvector embeddings, Voyage AI integration, tree chunking service, and semantic search over team's flow library - Phase 1: In-session copilot panel during flow navigation with contextual AI help, current step awareness, and suggested related flows - Phase 2: Standalone AI chat page with persistent conversation history, pin/delete, and configurable retention policies (account-level) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
193 lines
6.4 KiB
Python
193 lines
6.4 KiB
Python
"""AI generation quota management.
|
|
|
|
Enforces monthly and daily limits on AI flow builder usage.
|
|
Monthly quota consumed only on successful tree assembly (counts_toward_quota=True).
|
|
Daily limit is an anti-abuse guard consumed on conversation start.
|
|
"""
|
|
import calendar
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.ai_usage import AIUsage
|
|
from app.models.plan_limits import PlanLimits
|
|
from app.models.account_limit_override import AccountLimitOverride
|
|
from app.core.subscriptions import get_account_subscription, get_plan_limits
|
|
|
|
|
|
async def get_user_plan(account_id: Optional[UUID], db: AsyncSession) -> str:
|
|
"""Get the plan tier for an account."""
|
|
if not account_id:
|
|
return "free"
|
|
sub = await get_account_subscription(account_id, db)
|
|
if sub is None:
|
|
return "free"
|
|
return sub.plan if sub.plan else "free"
|
|
|
|
|
|
async def _get_effective_limits(
|
|
account_id: UUID, plan: str, db: AsyncSession
|
|
) -> tuple[Optional[int], Optional[int]]:
|
|
"""Get effective AI limits (monthly, daily), applying account overrides.
|
|
|
|
Returns (monthly_limit, daily_limit). None means unlimited.
|
|
"""
|
|
limits = await get_plan_limits(plan, db)
|
|
monthly = limits.max_ai_builds_per_month if limits else None
|
|
daily = limits.max_ai_builds_per_24h if limits else None
|
|
|
|
# Check for account-level overrides
|
|
result = await db.execute(
|
|
select(AccountLimitOverride).where(
|
|
AccountLimitOverride.account_id == account_id
|
|
)
|
|
)
|
|
override = result.scalar_one_or_none()
|
|
if override:
|
|
if override.override_max_ai_builds_per_month is not None:
|
|
monthly = override.override_max_ai_builds_per_month
|
|
if override.override_max_ai_builds_per_24h is not None:
|
|
daily = override.override_max_ai_builds_per_24h
|
|
|
|
return monthly, daily
|
|
|
|
|
|
def _get_billing_anchor_month_start(anchor: Optional[datetime]) -> datetime:
|
|
"""Calculate the start of the current billing month from the anchor date.
|
|
|
|
If the anchor is day 15, the billing month runs from the 15th of each month.
|
|
Falls back to calendar month if anchor is None.
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
if not anchor:
|
|
return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
|
|
anchor_day = min(anchor.day, 28) # Clamp to avoid month overflow
|
|
this_month_anchor = now.replace(
|
|
day=anchor_day, hour=0, minute=0, second=0, microsecond=0
|
|
)
|
|
|
|
if now >= this_month_anchor:
|
|
return this_month_anchor
|
|
else:
|
|
# We're before the anchor day, so billing month started last month
|
|
if now.month == 1:
|
|
return this_month_anchor.replace(year=now.year - 1, month=12)
|
|
else:
|
|
return this_month_anchor.replace(month=now.month - 1)
|
|
|
|
|
|
async def check_ai_quota(
|
|
user_id: UUID,
|
|
account_id: UUID,
|
|
db: AsyncSession,
|
|
billing_anchor: Optional[datetime] = None,
|
|
is_super_admin: bool = False,
|
|
) -> tuple[bool, dict]:
|
|
"""Check if user can make an AI generation.
|
|
|
|
Returns (allowed, quota_status_dict).
|
|
Monthly counts only rows with counts_toward_quota=True.
|
|
Daily counts only rows with generation_type in ('scaffold', 'branch_detail').
|
|
Super admins bypass all limits.
|
|
"""
|
|
plan = await get_user_plan(account_id, db)
|
|
monthly_limit, daily_limit = await _get_effective_limits(account_id, plan, db)
|
|
|
|
now = datetime.now(timezone.utc)
|
|
month_start = _get_billing_anchor_month_start(billing_anchor)
|
|
day_start = now - timedelta(hours=24)
|
|
|
|
# Monthly: count successful quota-consuming records
|
|
monthly_count = await db.scalar(
|
|
select(func.count(AIUsage.id)).where(
|
|
AIUsage.user_id == user_id,
|
|
AIUsage.counts_toward_quota == True, # noqa: E712
|
|
AIUsage.created_at >= month_start,
|
|
)
|
|
) or 0
|
|
|
|
# Daily: count all AI API calls (scaffold + branch_detail) in last 24h
|
|
daily_count = await db.scalar(
|
|
select(func.count(AIUsage.id)).where(
|
|
AIUsage.user_id == user_id,
|
|
AIUsage.succeeded == True, # noqa: E712
|
|
AIUsage.generation_type.in_(["scaffold", "branch_detail", "chat_message", "chat_generate", "copilot_message", "assistant_message"]),
|
|
AIUsage.created_at >= day_start,
|
|
)
|
|
) or 0
|
|
|
|
allowed = True
|
|
deny_reason = None
|
|
if is_super_admin:
|
|
# Super admins bypass all limits
|
|
monthly_limit = None
|
|
daily_limit = None
|
|
if monthly_limit is not None and monthly_count >= monthly_limit:
|
|
allowed = False
|
|
deny_reason = "monthly"
|
|
if daily_limit is not None and daily_count >= daily_limit:
|
|
allowed = False
|
|
deny_reason = "daily"
|
|
|
|
# Calculate reset timestamps
|
|
next_month = month_start.month % 12 + 1
|
|
next_year = month_start.year + (1 if month_start.month == 12 else 0)
|
|
max_day = calendar.monthrange(next_year, next_month)[1]
|
|
monthly_reset_at = month_start.replace(
|
|
month=next_month,
|
|
year=next_year,
|
|
day=min(month_start.day, max_day),
|
|
)
|
|
daily_reset_at = day_start + timedelta(hours=24)
|
|
|
|
return allowed, {
|
|
"plan": plan,
|
|
"monthly_used": monthly_count,
|
|
"monthly_limit": monthly_limit,
|
|
"monthly_reset_at": monthly_reset_at.isoformat(),
|
|
"daily_used": daily_count,
|
|
"daily_limit": daily_limit,
|
|
"daily_reset_at": daily_reset_at.isoformat(),
|
|
"allowed": allowed,
|
|
"deny_reason": deny_reason,
|
|
}
|
|
|
|
|
|
async def record_ai_usage(
|
|
user_id: UUID,
|
|
account_id: UUID,
|
|
conversation_id: Optional[UUID],
|
|
generation_type: str,
|
|
tier: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
estimated_cost: float,
|
|
succeeded: bool,
|
|
counts_toward_quota: bool,
|
|
error_code: Optional[str],
|
|
extra_data: Optional[dict],
|
|
db: AsyncSession,
|
|
) -> AIUsage:
|
|
"""Record an AI usage entry."""
|
|
usage = AIUsage(
|
|
user_id=user_id,
|
|
account_id=account_id,
|
|
conversation_id=conversation_id,
|
|
generation_type=generation_type,
|
|
tier_at_time=tier,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
estimated_cost_usd=estimated_cost,
|
|
succeeded=succeeded,
|
|
counts_toward_quota=counts_toward_quota,
|
|
error_code=error_code,
|
|
extra_data=extra_data or {},
|
|
)
|
|
db.add(usage)
|
|
await db.flush()
|
|
return usage
|