"""Single billing service module. Stripe is the only impl — no provider abstraction. Account row is canonical local state; Stripe is canonical remote state; the webhook handler bridges the two.""" from datetime import datetime, timezone, timedelta import stripe from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.models.account import Account from app.models.plan_billing import PlanBilling from app.models.subscription import Subscription TRIAL_DAYS = 14 class BillingService: @staticmethod async def start_trial(db: AsyncSession, account_id) -> Subscription: """Idempotent. Creates a trialing Subscription on Pro for the account if one doesn't exist; otherwise returns the existing row.""" result = await db.execute( select(Subscription).where(Subscription.account_id == account_id) ) existing = result.scalar_one_or_none() if existing is not None: return existing sub = Subscription( account_id=account_id, plan="pro", status="trialing", current_period_start=datetime.now(timezone.utc), current_period_end=datetime.now(timezone.utc) + timedelta(days=TRIAL_DAYS), ) db.add(sub) await db.commit() await db.refresh(sub) return sub @staticmethod async def create_checkout_session( db: AsyncSession, account: Account, plan: str, seats: int, billing_interval: str, success_url: str, cancel_url: str, ) -> str: """Create a Stripe Checkout Session for subscription purchase. If the account currently has a trialing subscription with time remaining, that trial end is preserved on the new Stripe subscription so the user isn't charged early.""" if not settings.stripe_enabled: raise RuntimeError("Stripe not configured") stripe.api_key = settings.STRIPE_SECRET_KEY plan_billing = (await db.execute( select(PlanBilling).where(PlanBilling.plan == plan) )).scalar_one_or_none() if plan_billing is None: raise ValueError(f"Unknown plan: {plan}") price_id = ( plan_billing.stripe_monthly_price_id if billing_interval == "monthly" else plan_billing.stripe_annual_price_id ) if price_id is None: raise RuntimeError( f"Plan '{plan}' has no Stripe price for {billing_interval}" ) if account.stripe_customer_id is None: customer = stripe.Customer.create( email=None, metadata={"account_id": str(account.id)}, ) account.stripe_customer_id = customer.id await db.commit() sub = (await db.execute( select(Subscription).where(Subscription.account_id == account.id) )).scalar_one_or_none() subscription_data = {} if ( sub and sub.status == "trialing" and sub.current_period_end and sub.current_period_end > datetime.now(timezone.utc) ): subscription_data["trial_end"] = int(sub.current_period_end.timestamp()) session = stripe.checkout.Session.create( customer=account.stripe_customer_id, line_items=[{"price": price_id, "quantity": seats}], mode="subscription", subscription_data=subscription_data or None, success_url=success_url, cancel_url=cancel_url, allow_promotion_codes=False, ) return session.url