"""Single billing service module. Stripe is the only impl — no provider abstraction. Account row is canonical local state; Stripe is canonical remote state; the webhook handler bridges the two.""" from datetime import datetime, timezone, timedelta import stripe from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.models.account import Account from app.models.plan_billing import PlanBilling from app.models.stripe_event import StripeEvent from app.models.subscription import Subscription TRIAL_DAYS = 14 class BillingService: @staticmethod async def start_trial(db: AsyncSession, account_id) -> Subscription: """Idempotent. Creates a trialing Subscription on Pro for the account if one doesn't exist; otherwise returns the existing row.""" result = await db.execute( select(Subscription).where(Subscription.account_id == account_id) ) existing = result.scalar_one_or_none() if existing is not None: return existing sub = Subscription( account_id=account_id, plan="pro", status="trialing", current_period_start=datetime.now(timezone.utc), current_period_end=datetime.now(timezone.utc) + timedelta(days=TRIAL_DAYS), ) db.add(sub) await db.commit() await db.refresh(sub) return sub @staticmethod async def create_checkout_session( db: AsyncSession, account: Account, plan: str, seats: int, billing_interval: str, success_url: str, cancel_url: str, ) -> str: """Create a Stripe Checkout Session for subscription purchase. If the account currently has a trialing subscription with time remaining, that trial end is preserved on the new Stripe subscription so the user isn't charged early.""" if not settings.stripe_enabled: raise RuntimeError("Stripe not configured") stripe.api_key = settings.STRIPE_SECRET_KEY plan_billing = (await db.execute( select(PlanBilling).where(PlanBilling.plan == plan) )).scalar_one_or_none() if plan_billing is None: raise ValueError(f"Unknown plan: {plan}") price_id = ( plan_billing.stripe_monthly_price_id if billing_interval == "monthly" else plan_billing.stripe_annual_price_id ) if price_id is None: raise RuntimeError( f"Plan '{plan}' has no Stripe price for {billing_interval}" ) if account.stripe_customer_id is None: customer = stripe.Customer.create( email=None, metadata={"account_id": str(account.id)}, ) account.stripe_customer_id = customer.id await db.commit() sub = (await db.execute( select(Subscription).where(Subscription.account_id == account.id) )).scalar_one_or_none() subscription_data = {} if ( sub and sub.status == "trialing" and sub.current_period_end and sub.current_period_end > datetime.now(timezone.utc) ): subscription_data["trial_end"] = int(sub.current_period_end.timestamp()) session = stripe.checkout.Session.create( customer=account.stripe_customer_id, line_items=[{"price": price_id, "quantity": seats}], mode="subscription", subscription_data=subscription_data or None, success_url=success_url, cancel_url=cancel_url, allow_promotion_codes=False, ) return session.url @staticmethod async def get_billing_state(db: AsyncSession, account): """Aggregate Subscription + PlanLimits + PlanBilling + resolved feature flags for the account.""" from app.models.plan_limits import PlanLimits from app.models.plan_billing import PlanBilling from app.models.feature_flag import ( FeatureFlag, PlanFeatureDefault, AccountFeatureOverride, ) sub = (await db.execute( select(Subscription).where(Subscription.account_id == account.id) )).scalar_one_or_none() if sub is None: from fastapi import HTTPException raise HTTPException(status_code=404, detail="No subscription for account") pl = (await db.execute( select(PlanLimits).where(PlanLimits.plan == sub.plan) )).scalar_one_or_none() pb = (await db.execute( select(PlanBilling).where(PlanBilling.plan == sub.plan) )).scalar_one_or_none() # Resolved feature flags: plan defaults overridden by account overrides defaults = (await db.execute( select(PlanFeatureDefault, FeatureFlag) .join(FeatureFlag, PlanFeatureDefault.flag_id == FeatureFlag.id) .where(PlanFeatureDefault.plan == sub.plan) )).all() resolved = {flag.flag_key: pfd.enabled for pfd, flag in defaults} overrides = (await db.execute( select(AccountFeatureOverride, FeatureFlag) .join(FeatureFlag, AccountFeatureOverride.flag_id == FeatureFlag.id) .where(AccountFeatureOverride.account_id == account.id) )).all() for ovr, flag in overrides: resolved[flag.flag_key] = ovr.enabled return { "subscription": { "status": sub.status, "plan": sub.plan, "current_period_start": sub.current_period_start, "current_period_end": sub.current_period_end, "cancel_at_period_end": sub.cancel_at_period_end, "seat_limit": sub.seat_limit, "has_pro_entitlement": sub.has_pro_entitlement, "is_paid": sub.is_paid, }, "plan_billing": pb, "plan_limits": _plan_limits_to_dict(pl) if pl else {}, "enabled_features": resolved, } @staticmethod async def apply_subscription_event( db: AsyncSession, event_id: str, event_type: str, payload: dict ) -> bool: """Idempotent. Returns True if the event was applied; False if it had already been processed (idempotent ack). The webhook handler returns 200 either way.""" try: db.add(StripeEvent( id=event_id, event_type=event_type, payload_excerpt=_excerpt(payload), )) await db.commit() except IntegrityError: await db.rollback() return False if event_type == "checkout.session.completed": await _handle_checkout_completed(db, payload) elif event_type == "customer.subscription.updated": await _handle_subscription_updated(db, payload) elif event_type == "customer.subscription.deleted": await _handle_subscription_deleted(db, payload) elif event_type == "invoice.payment_failed": await _handle_payment_failed(db, payload) elif event_type == "invoice.payment_succeeded": await _handle_payment_succeeded(db, payload) return True def _plan_limits_to_dict(pl) -> dict: return {c.name: getattr(pl, c.name) for c in pl.__table__.columns} def _excerpt(payload: dict) -> dict: obj = payload.get("data", {}).get("object", {}) return { "object_id": obj.get("id"), "customer": obj.get("customer"), "subscription": obj.get("subscription"), "status": obj.get("status"), } async def _handle_checkout_completed(db: AsyncSession, payload: dict): obj = payload["data"]["object"] customer_id = obj["customer"] subscription_id = obj["subscription"] account = (await db.execute( select(Account).where(Account.stripe_customer_id == customer_id) )).scalar_one_or_none() if account is None: return sub = (await db.execute( select(Subscription).where(Subscription.account_id == account.id) )).scalar_one_or_none() if sub is None: return stripe.api_key = settings.STRIPE_SECRET_KEY stripe_sub = stripe.Subscription.retrieve(subscription_id) sub.stripe_subscription_id = subscription_id sub.stripe_price_id = stripe_sub["items"]["data"][0]["price"]["id"] sub.status = "active" sub.current_period_start = datetime.fromtimestamp(stripe_sub["current_period_start"], tz=timezone.utc) sub.current_period_end = datetime.fromtimestamp(stripe_sub["current_period_end"], tz=timezone.utc) sub.seat_limit = stripe_sub["items"]["data"][0]["quantity"] pb = (await db.execute( select(PlanBilling).where( (PlanBilling.stripe_monthly_price_id == sub.stripe_price_id) | (PlanBilling.stripe_annual_price_id == sub.stripe_price_id) ) )).scalar_one_or_none() if pb is not None: sub.plan = pb.plan await db.commit() async def _handle_subscription_updated(db: AsyncSession, payload: dict): obj = payload["data"]["object"] sub = (await db.execute( select(Subscription).where(Subscription.stripe_subscription_id == obj["id"]) )).scalar_one_or_none() if sub is None: return sub.status = obj["status"] sub.current_period_start = datetime.fromtimestamp(obj["current_period_start"], tz=timezone.utc) sub.current_period_end = datetime.fromtimestamp(obj["current_period_end"], tz=timezone.utc) sub.cancel_at_period_end = obj.get("cancel_at_period_end", False) sub.seat_limit = obj["items"]["data"][0]["quantity"] await db.commit() async def _handle_subscription_deleted(db: AsyncSession, payload: dict): obj = payload["data"]["object"] sub = (await db.execute( select(Subscription).where(Subscription.stripe_subscription_id == obj["id"]) )).scalar_one_or_none() if sub is None: return sub.status = "canceled" await db.commit() async def _handle_payment_failed(db: AsyncSession, payload: dict): obj = payload["data"]["object"] subscription_id = obj.get("subscription") if not subscription_id: return sub = (await db.execute( select(Subscription).where(Subscription.stripe_subscription_id == subscription_id) )).scalar_one_or_none() if sub is None: return sub.status = "past_due" await db.commit() async def _handle_payment_succeeded(db: AsyncSession, payload: dict): obj = payload["data"]["object"] subscription_id = obj.get("subscription") if not subscription_id: return sub = (await db.execute( select(Subscription).where(Subscription.stripe_subscription_id == subscription_id) )).scalar_one_or_none() if sub is None: return if sub.status == "past_due": sub.status = "active" await db.commit()