"""AI generation quota management. Enforces monthly and daily limits on AI flow builder usage. Monthly quota consumed only on successful tree assembly (counts_toward_quota=True). Daily limit is an anti-abuse guard consumed on conversation start. """ import calendar from datetime import datetime, timezone, timedelta from typing import Optional from uuid import UUID from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.ai_usage import AIUsage from app.models.plan_limits import PlanLimits from app.models.account_limit_override import AccountLimitOverride from app.core.subscriptions import get_account_subscription, get_plan_limits async def get_user_plan(account_id: Optional[UUID], db: AsyncSession) -> str: """Get the plan tier for an account.""" if not account_id: return "free" sub = await get_account_subscription(account_id, db) if sub is None: return "free" return sub.plan if sub.plan else "free" async def _get_effective_limits( account_id: UUID, plan: str, db: AsyncSession ) -> tuple[Optional[int], Optional[int]]: """Get effective AI limits (monthly, daily), applying account overrides. Returns (monthly_limit, daily_limit). None means unlimited. """ limits = await get_plan_limits(plan, db) monthly = limits.max_ai_builds_per_month if limits else None daily = limits.max_ai_builds_per_24h if limits else None # Check for account-level overrides result = await db.execute( select(AccountLimitOverride).where( AccountLimitOverride.account_id == account_id ) ) override = result.scalar_one_or_none() if override: if override.override_max_ai_builds_per_month is not None: monthly = override.override_max_ai_builds_per_month if override.override_max_ai_builds_per_24h is not None: daily = override.override_max_ai_builds_per_24h return monthly, daily def _get_billing_anchor_month_start(anchor: Optional[datetime]) -> datetime: """Calculate the start of the current billing month from the anchor date. If the anchor is day 15, the billing month runs from the 15th of each month. Falls back to calendar month if anchor is None. """ now = datetime.now(timezone.utc) if not anchor: return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) anchor_day = min(anchor.day, 28) # Clamp to avoid month overflow this_month_anchor = now.replace( day=anchor_day, hour=0, minute=0, second=0, microsecond=0 ) if now >= this_month_anchor: return this_month_anchor else: # We're before the anchor day, so billing month started last month if now.month == 1: return this_month_anchor.replace(year=now.year - 1, month=12) else: return this_month_anchor.replace(month=now.month - 1) async def check_ai_quota( user_id: UUID, account_id: UUID, db: AsyncSession, billing_anchor: Optional[datetime] = None, is_super_admin: bool = False, ) -> tuple[bool, dict]: """Check if user can make an AI generation. Returns (allowed, quota_status_dict). Monthly counts only rows with counts_toward_quota=True. Daily counts only rows with generation_type in ('scaffold', 'branch_detail'). Super admins bypass all limits. """ plan = await get_user_plan(account_id, db) monthly_limit, daily_limit = await _get_effective_limits(account_id, plan, db) now = datetime.now(timezone.utc) month_start = _get_billing_anchor_month_start(billing_anchor) day_start = now - timedelta(hours=24) # Monthly: count successful quota-consuming records monthly_count = await db.scalar( select(func.count(AIUsage.id)).where( AIUsage.user_id == user_id, AIUsage.counts_toward_quota == True, # noqa: E712 AIUsage.created_at >= month_start, ) ) or 0 # Daily: count all AI API calls (scaffold + branch_detail) in last 24h daily_count = await db.scalar( select(func.count(AIUsage.id)).where( AIUsage.user_id == user_id, AIUsage.succeeded == True, # noqa: E712 AIUsage.generation_type.in_(["scaffold", "branch_detail", "chat_message", "chat_generate", "copilot_message", "assistant_message"]), AIUsage.created_at >= day_start, ) ) or 0 allowed = True deny_reason = None if is_super_admin: # Super admins bypass all limits monthly_limit = None daily_limit = None if monthly_limit is not None and monthly_count >= monthly_limit: allowed = False deny_reason = "monthly" if daily_limit is not None and daily_count >= daily_limit: allowed = False deny_reason = "daily" # Calculate reset timestamps next_month = month_start.month % 12 + 1 next_year = month_start.year + (1 if month_start.month == 12 else 0) max_day = calendar.monthrange(next_year, next_month)[1] monthly_reset_at = month_start.replace( month=next_month, year=next_year, day=min(month_start.day, max_day), ) daily_reset_at = day_start + timedelta(hours=24) return allowed, { "plan": plan, "monthly_used": monthly_count, "monthly_limit": monthly_limit, "monthly_reset_at": monthly_reset_at.isoformat(), "daily_used": daily_count, "daily_limit": daily_limit, "daily_reset_at": daily_reset_at.isoformat(), "allowed": allowed, "deny_reason": deny_reason, } async def record_ai_usage( user_id: UUID, account_id: UUID, conversation_id: Optional[UUID], generation_type: str, tier: str, input_tokens: int, output_tokens: int, estimated_cost: float, succeeded: bool, counts_toward_quota: bool, error_code: Optional[str], extra_data: Optional[dict], db: AsyncSession, ) -> AIUsage: """Record an AI usage entry.""" usage = AIUsage( user_id=user_id, account_id=account_id, conversation_id=conversation_id, generation_type=generation_type, tier_at_time=tier, input_tokens=input_tokens, output_tokens=output_tokens, estimated_cost_usd=estimated_cost, succeeded=succeeded, counts_toward_quota=counts_toward_quota, error_code=error_code, extra_data=extra_data or {}, ) db.add(usage) await db.flush() return usage