Files
resolutionflow/backend/app/core/ai_quota_service.py
Michael Chihlas 1aa60dada2 feat: add AI assistant with in-session copilot and standalone chat with RAG
Implements three-phase AI assistant feature:
- Phase 0: RAG infrastructure with pgvector embeddings, Voyage AI integration,
  tree chunking service, and semantic search over team's flow library
- Phase 1: In-session copilot panel during flow navigation with contextual
  AI help, current step awareness, and suggested related flows
- Phase 2: Standalone AI chat page with persistent conversation history,
  pin/delete, and configurable retention policies (account-level)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-04 01:36:36 -05:00

193 lines
6.4 KiB
Python

"""AI generation quota management.
Enforces monthly and daily limits on AI flow builder usage.
Monthly quota consumed only on successful tree assembly (counts_toward_quota=True).
Daily limit is an anti-abuse guard consumed on conversation start.
"""
import calendar
from datetime import datetime, timezone, timedelta
from typing import Optional
from uuid import UUID
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ai_usage import AIUsage
from app.models.plan_limits import PlanLimits
from app.models.account_limit_override import AccountLimitOverride
from app.core.subscriptions import get_account_subscription, get_plan_limits
async def get_user_plan(account_id: Optional[UUID], db: AsyncSession) -> str:
"""Get the plan tier for an account."""
if not account_id:
return "free"
sub = await get_account_subscription(account_id, db)
if sub is None:
return "free"
return sub.plan if sub.plan else "free"
async def _get_effective_limits(
account_id: UUID, plan: str, db: AsyncSession
) -> tuple[Optional[int], Optional[int]]:
"""Get effective AI limits (monthly, daily), applying account overrides.
Returns (monthly_limit, daily_limit). None means unlimited.
"""
limits = await get_plan_limits(plan, db)
monthly = limits.max_ai_builds_per_month if limits else None
daily = limits.max_ai_builds_per_24h if limits else None
# Check for account-level overrides
result = await db.execute(
select(AccountLimitOverride).where(
AccountLimitOverride.account_id == account_id
)
)
override = result.scalar_one_or_none()
if override:
if override.override_max_ai_builds_per_month is not None:
monthly = override.override_max_ai_builds_per_month
if override.override_max_ai_builds_per_24h is not None:
daily = override.override_max_ai_builds_per_24h
return monthly, daily
def _get_billing_anchor_month_start(anchor: Optional[datetime]) -> datetime:
"""Calculate the start of the current billing month from the anchor date.
If the anchor is day 15, the billing month runs from the 15th of each month.
Falls back to calendar month if anchor is None.
"""
now = datetime.now(timezone.utc)
if not anchor:
return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
anchor_day = min(anchor.day, 28) # Clamp to avoid month overflow
this_month_anchor = now.replace(
day=anchor_day, hour=0, minute=0, second=0, microsecond=0
)
if now >= this_month_anchor:
return this_month_anchor
else:
# We're before the anchor day, so billing month started last month
if now.month == 1:
return this_month_anchor.replace(year=now.year - 1, month=12)
else:
return this_month_anchor.replace(month=now.month - 1)
async def check_ai_quota(
user_id: UUID,
account_id: UUID,
db: AsyncSession,
billing_anchor: Optional[datetime] = None,
is_super_admin: bool = False,
) -> tuple[bool, dict]:
"""Check if user can make an AI generation.
Returns (allowed, quota_status_dict).
Monthly counts only rows with counts_toward_quota=True.
Daily counts only rows with generation_type in ('scaffold', 'branch_detail').
Super admins bypass all limits.
"""
plan = await get_user_plan(account_id, db)
monthly_limit, daily_limit = await _get_effective_limits(account_id, plan, db)
now = datetime.now(timezone.utc)
month_start = _get_billing_anchor_month_start(billing_anchor)
day_start = now - timedelta(hours=24)
# Monthly: count successful quota-consuming records
monthly_count = await db.scalar(
select(func.count(AIUsage.id)).where(
AIUsage.user_id == user_id,
AIUsage.counts_toward_quota == True, # noqa: E712
AIUsage.created_at >= month_start,
)
) or 0
# Daily: count all AI API calls (scaffold + branch_detail) in last 24h
daily_count = await db.scalar(
select(func.count(AIUsage.id)).where(
AIUsage.user_id == user_id,
AIUsage.succeeded == True, # noqa: E712
AIUsage.generation_type.in_(["scaffold", "branch_detail", "chat_message", "chat_generate", "copilot_message", "assistant_message"]),
AIUsage.created_at >= day_start,
)
) or 0
allowed = True
deny_reason = None
if is_super_admin:
# Super admins bypass all limits
monthly_limit = None
daily_limit = None
if monthly_limit is not None and monthly_count >= monthly_limit:
allowed = False
deny_reason = "monthly"
if daily_limit is not None and daily_count >= daily_limit:
allowed = False
deny_reason = "daily"
# Calculate reset timestamps
next_month = month_start.month % 12 + 1
next_year = month_start.year + (1 if month_start.month == 12 else 0)
max_day = calendar.monthrange(next_year, next_month)[1]
monthly_reset_at = month_start.replace(
month=next_month,
year=next_year,
day=min(month_start.day, max_day),
)
daily_reset_at = day_start + timedelta(hours=24)
return allowed, {
"plan": plan,
"monthly_used": monthly_count,
"monthly_limit": monthly_limit,
"monthly_reset_at": monthly_reset_at.isoformat(),
"daily_used": daily_count,
"daily_limit": daily_limit,
"daily_reset_at": daily_reset_at.isoformat(),
"allowed": allowed,
"deny_reason": deny_reason,
}
async def record_ai_usage(
user_id: UUID,
account_id: UUID,
conversation_id: Optional[UUID],
generation_type: str,
tier: str,
input_tokens: int,
output_tokens: int,
estimated_cost: float,
succeeded: bool,
counts_toward_quota: bool,
error_code: Optional[str],
extra_data: Optional[dict],
db: AsyncSession,
) -> AIUsage:
"""Record an AI usage entry."""
usage = AIUsage(
user_id=user_id,
account_id=account_id,
conversation_id=conversation_id,
generation_type=generation_type,
tier_at_time=tier,
input_tokens=input_tokens,
output_tokens=output_tokens,
estimated_cost_usd=estimated_cost,
succeeded=succeeded,
counts_toward_quota=counts_toward_quota,
error_code=error_code,
extra_data=extra_data or {},
)
db.add(usage)
await db.flush()
return usage