"""Subscription limit checks and plan helpers.""" from typing import Optional from uuid import UUID from datetime import datetime, timezone from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.subscription import Subscription from app.models.plan_limits import PlanLimits from app.models.tree import Tree from app.models.session import Session async def get_account_subscription(account_id: UUID, db: AsyncSession) -> Optional[Subscription]: result = await db.execute( select(Subscription).where(Subscription.account_id == account_id) ) return result.scalar_one_or_none() async def get_plan_limits(plan: str, db: AsyncSession) -> Optional[PlanLimits]: result = await db.execute( select(PlanLimits).where(PlanLimits.plan == plan) ) return result.scalar_one_or_none() async def get_user_plan_limits(user_account_id: UUID, db: AsyncSession) -> Optional[PlanLimits]: sub = await get_account_subscription(user_account_id, db) if sub is None: return await get_plan_limits("free", db) return await get_plan_limits(sub.plan, db) async def check_tree_limit(account_id: UUID, db: AsyncSession) -> tuple[bool, Optional[int], int]: """Check if account can create a new tree. Returns: (can_create, limit, current_count) """ sub = await get_account_subscription(account_id, db) if sub is None: return False, 0, 0 limits = await get_plan_limits(sub.plan, db) if limits is None or limits.max_trees is None: return True, None, 0 # unlimited current_count = await db.scalar( select(func.count(Tree.id)).where( Tree.account_id == account_id, Tree.deleted_at.is_(None), ) ) current_count = current_count or 0 return current_count < limits.max_trees, limits.max_trees, current_count async def check_session_limit(account_id: UUID, db: AsyncSession) -> tuple[bool, Optional[int], int]: """Check if account can create a new session this month. Returns: (can_create, limit, current_count) """ sub = await get_account_subscription(account_id, db) if sub is None: return False, 0, 0 limits = await get_plan_limits(sub.plan, db) if limits is None or limits.max_sessions_per_month is None: return True, None, 0 # unlimited # Count sessions this calendar month for all users in this account now = datetime.now(timezone.utc) month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) from app.models.user import User current_count = await db.scalar( select(func.count(Session.id)).where( Session.user_id.in_( select(User.id).where(User.account_id == account_id) ), Session.started_at >= month_start, ) ) current_count = current_count or 0 return current_count < limits.max_sessions_per_month, limits.max_sessions_per_month, current_count async def get_account_usage(account_id: UUID, db: AsyncSession) -> dict: """Get current usage stats for an account.""" tree_count = await db.scalar( select(func.count(Tree.id)).where( Tree.account_id == account_id, Tree.deleted_at.is_(None), ) ) or 0 now = datetime.now(timezone.utc) month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) from app.models.user import User session_count = await db.scalar( select(func.count(Session.id)).where( Session.user_id.in_( select(User.id).where(User.account_id == account_id) ), Session.started_at >= month_start, ) ) or 0 return {"tree_count": tree_count, "session_count_this_month": session_count}