"""Single billing service module. Stripe is the only impl — no provider abstraction. Account row is canonical local state; Stripe is canonical remote state; the webhook handler bridges the two.""" import logging from datetime import datetime, timezone, timedelta import stripe from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.models.account import Account from app.models.plan_billing import PlanBilling from app.models.stripe_event import StripeEvent from app.models.subscription import Subscription TRIAL_DAYS = 14 logger = logging.getLogger(__name__) class BillingService: @staticmethod async def invalidate_billing_cache(account_ids) -> None: """No-op stub for future in-process billing cache invalidation. Today there is no `app.state.billing_cache` — `BillingService.get_billing_state` always reads fresh from the DB. Call sites that mutate plan/feature data invoke this hook so that wiring is in place when an in-process cache is added later. Until then, this just logs. TODO: when an in-process billing cache (e.g. `app.state.billing_cache`) is introduced, evict entries for the given account_ids here. """ try: count = len(list(account_ids)) except TypeError: count = -1 logger.debug( "BillingService.invalidate_billing_cache called for %d account(s) " "(no-op stub — wire to app.state.billing_cache when added)", count, ) @staticmethod async def start_trial(db: AsyncSession, account_id) -> Subscription: """Idempotent. Creates a trialing Subscription on Pro for the account if one doesn't exist; otherwise returns the existing row.""" result = await db.execute( select(Subscription).where(Subscription.account_id == account_id) ) existing = result.scalar_one_or_none() if existing is not None: return existing sub = Subscription( account_id=account_id, plan="pro", status="trialing", current_period_start=datetime.now(timezone.utc), current_period_end=datetime.now(timezone.utc) + timedelta(days=TRIAL_DAYS), ) db.add(sub) await db.commit() await db.refresh(sub) return sub @staticmethod async def create_checkout_session( db: AsyncSession, account: Account, plan: str, seats: int, billing_interval: str, success_url: str, cancel_url: str, ) -> str: """Create a Stripe Checkout Session for subscription purchase. If the account currently has a trialing subscription with time remaining, that trial end is preserved on the new Stripe subscription so the user isn't charged early.""" if not settings.stripe_enabled: raise RuntimeError("Stripe not configured") stripe.api_key = settings.STRIPE_SECRET_KEY plan_billing = (await db.execute( select(PlanBilling).where(PlanBilling.plan == plan) )).scalar_one_or_none() if plan_billing is None: raise ValueError(f"Unknown plan: {plan}") price_id = ( plan_billing.stripe_monthly_price_id if billing_interval == "monthly" else plan_billing.stripe_annual_price_id ) if price_id is None: raise RuntimeError( f"Plan '{plan}' has no Stripe price for {billing_interval}" ) if account.stripe_customer_id is None: customer = stripe.Customer.create( email=None, metadata={"account_id": str(account.id)}, ) account.stripe_customer_id = customer.id await db.commit() sub = (await db.execute( select(Subscription).where(Subscription.account_id == account.id) )).scalar_one_or_none() subscription_data = {} if ( sub and sub.status == "trialing" and sub.current_period_end and sub.current_period_end > datetime.now(timezone.utc) ): subscription_data["trial_end"] = int(sub.current_period_end.timestamp()) session = stripe.checkout.Session.create( customer=account.stripe_customer_id, line_items=[{"price": price_id, "quantity": seats}], mode="subscription", subscription_data=subscription_data or None, success_url=success_url, cancel_url=cancel_url, allow_promotion_codes=False, ) return session.url @staticmethod async def open_customer_portal(account: Account) -> str: """Create a Stripe-hosted Customer Portal session and return the URL. Raises RuntimeError if Stripe isn't configured (endpoint maps to 503). Raises ValueError if the account has no stripe_customer_id yet — the user must complete a checkout first (endpoint maps to 400). """ if not settings.stripe_enabled: raise RuntimeError("Stripe not configured") if account.stripe_customer_id is None: raise ValueError("no_stripe_customer") stripe.api_key = settings.STRIPE_SECRET_KEY session = stripe.billing_portal.Session.create( customer=account.stripe_customer_id, return_url=f"{settings.FRONTEND_URL}/account/billing", ) return session.url @staticmethod async def get_billing_state(db: AsyncSession, account): """Aggregate Subscription + PlanLimits + PlanBilling + resolved feature flags for the account.""" from app.models.plan_limits import PlanLimits from app.models.plan_billing import PlanBilling from app.models.feature_flag import ( FeatureFlag, PlanFeatureDefault, AccountFeatureOverride, ) sub = (await db.execute( select(Subscription).where(Subscription.account_id == account.id) )).scalar_one_or_none() if sub is None: from fastapi import HTTPException raise HTTPException(status_code=404, detail="No subscription for account") pl = (await db.execute( select(PlanLimits).where(PlanLimits.plan == sub.plan) )).scalar_one_or_none() pb = (await db.execute( select(PlanBilling).where(PlanBilling.plan == sub.plan) )).scalar_one_or_none() # Resolved feature flags: plan defaults overridden by account overrides defaults = (await db.execute( select(PlanFeatureDefault, FeatureFlag) .join(FeatureFlag, PlanFeatureDefault.flag_id == FeatureFlag.id) .where(PlanFeatureDefault.plan == sub.plan) )).all() resolved = {flag.flag_key: pfd.enabled for pfd, flag in defaults} overrides = (await db.execute( select(AccountFeatureOverride, FeatureFlag) .join(FeatureFlag, AccountFeatureOverride.flag_id == FeatureFlag.id) .where(AccountFeatureOverride.account_id == account.id) )).all() for ovr, flag in overrides: resolved[flag.flag_key] = ovr.enabled return { "subscription": { "status": sub.status, "plan": sub.plan, "current_period_start": sub.current_period_start, "current_period_end": sub.current_period_end, "cancel_at_period_end": sub.cancel_at_period_end, "seat_limit": sub.seat_limit, "has_pro_entitlement": sub.has_pro_entitlement, "is_paid": sub.is_paid, }, "plan_billing": pb, "plan_limits": _plan_limits_to_dict(pl) if pl else {}, "enabled_features": resolved, } @staticmethod async def apply_subscription_event( db: AsyncSession, event_id: str, event_type: str, payload: dict ) -> bool: """Idempotent. Returns True if the event was applied; False if it had already been processed (idempotent ack). The webhook handler returns 200 either way. Atomic: the StripeEvent idempotency mark and the handler's state mutations are committed in a single transaction. If the handler raises the entire transaction (idempotency mark + partial mutations) is rolled back, so a Stripe retry will re-run the handler. Without this, a handler that fails mid-flight would leave the StripeEvent row persisted and silently desync subscription state from Stripe. """ db.add(StripeEvent( id=event_id, event_type=event_type, payload_excerpt=_excerpt(payload), )) try: await db.flush() except IntegrityError: # Duplicate event_id — already processed (or in flight). Ack with False. await db.rollback() return False try: if event_type == "checkout.session.completed": await _handle_checkout_completed(db, payload) elif event_type == "customer.subscription.updated": await _handle_subscription_updated(db, payload) elif event_type == "customer.subscription.deleted": await _handle_subscription_deleted(db, payload) elif event_type == "invoice.payment_failed": await _handle_payment_failed(db, payload) elif event_type == "invoice.payment_succeeded": await _handle_payment_succeeded(db, payload) await db.commit() except Exception: # Roll back the StripeEvent insert + any partial handler mutations # so Stripe's retry can re-run cleanly. await db.rollback() raise return True def _plan_limits_to_dict(pl) -> dict: return {c.name: getattr(pl, c.name) for c in pl.__table__.columns} def _excerpt(payload: dict) -> dict: obj = payload.get("data", {}).get("object", {}) return { "object_id": obj.get("id"), "customer": obj.get("customer"), "subscription": obj.get("subscription"), "status": obj.get("status"), } async def _handle_checkout_completed(db: AsyncSession, payload: dict): obj = payload["data"]["object"] customer_id = obj["customer"] subscription_id = obj["subscription"] account = (await db.execute( select(Account).where(Account.stripe_customer_id == customer_id) )).scalar_one_or_none() if account is None: return sub = (await db.execute( select(Subscription).where(Subscription.account_id == account.id) )).scalar_one_or_none() if sub is None: return stripe.api_key = settings.STRIPE_SECRET_KEY stripe_sub = stripe.Subscription.retrieve(subscription_id) sub.stripe_subscription_id = subscription_id sub.stripe_price_id = stripe_sub["items"]["data"][0]["price"]["id"] sub.status = "active" sub.current_period_start = datetime.fromtimestamp(stripe_sub["current_period_start"], tz=timezone.utc) sub.current_period_end = datetime.fromtimestamp(stripe_sub["current_period_end"], tz=timezone.utc) sub.seat_limit = stripe_sub["items"]["data"][0]["quantity"] pb = (await db.execute( select(PlanBilling).where( (PlanBilling.stripe_monthly_price_id == sub.stripe_price_id) | (PlanBilling.stripe_annual_price_id == sub.stripe_price_id) ) )).scalar_one_or_none() if pb is not None: sub.plan = pb.plan # No commit — apply_subscription_event commits once for the full event. async def _handle_subscription_updated(db: AsyncSession, payload: dict): obj = payload["data"]["object"] sub = (await db.execute( select(Subscription).where(Subscription.stripe_subscription_id == obj["id"]) )).scalar_one_or_none() if sub is None: return sub.status = obj["status"] sub.current_period_start = datetime.fromtimestamp(obj["current_period_start"], tz=timezone.utc) sub.current_period_end = datetime.fromtimestamp(obj["current_period_end"], tz=timezone.utc) sub.cancel_at_period_end = obj.get("cancel_at_period_end", False) sub.seat_limit = obj["items"]["data"][0]["quantity"] # No commit — apply_subscription_event commits once for the full event. async def _handle_subscription_deleted(db: AsyncSession, payload: dict): obj = payload["data"]["object"] sub = (await db.execute( select(Subscription).where(Subscription.stripe_subscription_id == obj["id"]) )).scalar_one_or_none() if sub is None: return sub.status = "canceled" # No commit — apply_subscription_event commits once for the full event. async def _handle_payment_failed(db: AsyncSession, payload: dict): obj = payload["data"]["object"] subscription_id = obj.get("subscription") if not subscription_id: return sub = (await db.execute( select(Subscription).where(Subscription.stripe_subscription_id == subscription_id) )).scalar_one_or_none() if sub is None: return sub.status = "past_due" # No commit — apply_subscription_event commits once for the full event. async def _handle_payment_succeeded(db: AsyncSession, payload: dict): obj = payload["data"]["object"] subscription_id = obj.get("subscription") if not subscription_id: return sub = (await db.execute( select(Subscription).where(Subscription.stripe_subscription_id == subscription_id) )).scalar_one_or_none() if sub is None: return if sub.status == "past_due": sub.status = "active" # No commit — apply_subscription_event commits once for the full event.