Co-authored-by: Michael Chihlas <michael@resolutionflow.com> Co-committed-by: Michael Chihlas <michael@resolutionflow.com>
357 lines
14 KiB
Python
357 lines
14 KiB
Python
"""Single billing service module. Stripe is the only impl — no provider
|
|
abstraction. Account row is canonical local state; Stripe is canonical
|
|
remote state; the webhook handler bridges the two."""
|
|
import logging
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
import stripe
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.models.account import Account
|
|
from app.models.plan_billing import PlanBilling
|
|
from app.models.stripe_event import StripeEvent
|
|
from app.models.subscription import Subscription
|
|
|
|
|
|
TRIAL_DAYS = 14
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BillingService:
|
|
@staticmethod
|
|
async def invalidate_billing_cache(account_ids) -> None:
|
|
"""No-op stub for future in-process billing cache invalidation.
|
|
|
|
Today there is no `app.state.billing_cache` — `BillingService.get_billing_state`
|
|
always reads fresh from the DB. Call sites that mutate plan/feature data
|
|
invoke this hook so that wiring is in place when an in-process cache is
|
|
added later. Until then, this just logs.
|
|
|
|
TODO: when an in-process billing cache (e.g. `app.state.billing_cache`)
|
|
is introduced, evict entries for the given account_ids here.
|
|
"""
|
|
try:
|
|
count = len(list(account_ids))
|
|
except TypeError:
|
|
count = -1
|
|
logger.debug(
|
|
"BillingService.invalidate_billing_cache called for %d account(s) "
|
|
"(no-op stub — wire to app.state.billing_cache when added)",
|
|
count,
|
|
)
|
|
|
|
@staticmethod
|
|
async def start_trial(db: AsyncSession, account_id) -> Subscription:
|
|
"""Idempotent. Creates a trialing Subscription on Pro for the account if
|
|
one doesn't exist; otherwise returns the existing row."""
|
|
result = await db.execute(
|
|
select(Subscription).where(Subscription.account_id == account_id)
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
if existing is not None:
|
|
return existing
|
|
|
|
sub = Subscription(
|
|
account_id=account_id,
|
|
plan="pro",
|
|
status="trialing",
|
|
current_period_start=datetime.now(timezone.utc),
|
|
current_period_end=datetime.now(timezone.utc) + timedelta(days=TRIAL_DAYS),
|
|
)
|
|
db.add(sub)
|
|
await db.commit()
|
|
await db.refresh(sub)
|
|
return sub
|
|
|
|
@staticmethod
|
|
async def create_checkout_session(
|
|
db: AsyncSession,
|
|
account: Account,
|
|
plan: str,
|
|
seats: int,
|
|
billing_interval: str,
|
|
success_url: str,
|
|
cancel_url: str,
|
|
) -> str:
|
|
"""Create a Stripe Checkout Session for subscription purchase. If the
|
|
account currently has a trialing subscription with time remaining, that
|
|
trial end is preserved on the new Stripe subscription so the user
|
|
isn't charged early."""
|
|
if not settings.stripe_enabled:
|
|
raise RuntimeError("Stripe not configured")
|
|
stripe.api_key = settings.STRIPE_SECRET_KEY
|
|
|
|
plan_billing = (await db.execute(
|
|
select(PlanBilling).where(PlanBilling.plan == plan)
|
|
)).scalar_one_or_none()
|
|
if plan_billing is None:
|
|
raise ValueError(f"Unknown plan: {plan}")
|
|
price_id = (
|
|
plan_billing.stripe_monthly_price_id if billing_interval == "monthly"
|
|
else plan_billing.stripe_annual_price_id
|
|
)
|
|
if price_id is None:
|
|
raise RuntimeError(
|
|
f"Plan '{plan}' has no Stripe price for {billing_interval}"
|
|
)
|
|
|
|
if account.stripe_customer_id is None:
|
|
customer = stripe.Customer.create(
|
|
email=None,
|
|
metadata={"account_id": str(account.id)},
|
|
)
|
|
account.stripe_customer_id = customer.id
|
|
await db.commit()
|
|
|
|
sub = (await db.execute(
|
|
select(Subscription).where(Subscription.account_id == account.id)
|
|
)).scalar_one_or_none()
|
|
subscription_data = {}
|
|
if (
|
|
sub
|
|
and sub.status == "trialing"
|
|
and sub.current_period_end
|
|
and sub.current_period_end > datetime.now(timezone.utc)
|
|
):
|
|
subscription_data["trial_end"] = int(sub.current_period_end.timestamp())
|
|
|
|
session = stripe.checkout.Session.create(
|
|
customer=account.stripe_customer_id,
|
|
line_items=[{"price": price_id, "quantity": seats}],
|
|
mode="subscription",
|
|
subscription_data=subscription_data or None,
|
|
success_url=success_url,
|
|
cancel_url=cancel_url,
|
|
allow_promotion_codes=False,
|
|
)
|
|
return session.url
|
|
|
|
@staticmethod
|
|
async def open_customer_portal(account: Account) -> str:
|
|
"""Create a Stripe-hosted Customer Portal session and return the URL.
|
|
|
|
Raises RuntimeError if Stripe isn't configured (endpoint maps to 503).
|
|
Raises ValueError if the account has no stripe_customer_id yet — the
|
|
user must complete a checkout first (endpoint maps to 400).
|
|
"""
|
|
if not settings.stripe_enabled:
|
|
raise RuntimeError("Stripe not configured")
|
|
if account.stripe_customer_id is None:
|
|
raise ValueError("no_stripe_customer")
|
|
stripe.api_key = settings.STRIPE_SECRET_KEY
|
|
session = stripe.billing_portal.Session.create(
|
|
customer=account.stripe_customer_id,
|
|
return_url=f"{settings.FRONTEND_URL}/account/billing",
|
|
)
|
|
return session.url
|
|
|
|
@staticmethod
|
|
async def get_billing_state(db: AsyncSession, account):
|
|
"""Aggregate Subscription + PlanLimits + PlanBilling + resolved feature
|
|
flags for the account."""
|
|
from app.models.plan_limits import PlanLimits
|
|
from app.models.plan_billing import PlanBilling
|
|
from app.models.feature_flag import (
|
|
FeatureFlag, PlanFeatureDefault, AccountFeatureOverride,
|
|
)
|
|
|
|
sub = (await db.execute(
|
|
select(Subscription).where(Subscription.account_id == account.id)
|
|
)).scalar_one_or_none()
|
|
if sub is None:
|
|
from fastapi import HTTPException
|
|
raise HTTPException(status_code=404, detail="No subscription for account")
|
|
|
|
pl = (await db.execute(
|
|
select(PlanLimits).where(PlanLimits.plan == sub.plan)
|
|
)).scalar_one_or_none()
|
|
pb = (await db.execute(
|
|
select(PlanBilling).where(PlanBilling.plan == sub.plan)
|
|
)).scalar_one_or_none()
|
|
|
|
# Resolved feature flags: plan defaults overridden by account overrides
|
|
defaults = (await db.execute(
|
|
select(PlanFeatureDefault, FeatureFlag)
|
|
.join(FeatureFlag, PlanFeatureDefault.flag_id == FeatureFlag.id)
|
|
.where(PlanFeatureDefault.plan == sub.plan)
|
|
)).all()
|
|
resolved = {flag.flag_key: pfd.enabled for pfd, flag in defaults}
|
|
overrides = (await db.execute(
|
|
select(AccountFeatureOverride, FeatureFlag)
|
|
.join(FeatureFlag, AccountFeatureOverride.flag_id == FeatureFlag.id)
|
|
.where(AccountFeatureOverride.account_id == account.id)
|
|
)).all()
|
|
for ovr, flag in overrides:
|
|
resolved[flag.flag_key] = ovr.enabled
|
|
|
|
return {
|
|
"subscription": {
|
|
"status": sub.status,
|
|
"plan": sub.plan,
|
|
"current_period_start": sub.current_period_start,
|
|
"current_period_end": sub.current_period_end,
|
|
"cancel_at_period_end": sub.cancel_at_period_end,
|
|
"seat_limit": sub.seat_limit,
|
|
"has_pro_entitlement": sub.has_pro_entitlement,
|
|
"is_paid": sub.is_paid,
|
|
},
|
|
"plan_billing": pb,
|
|
"plan_limits": _plan_limits_to_dict(pl) if pl else {},
|
|
"enabled_features": resolved,
|
|
}
|
|
|
|
@staticmethod
|
|
async def apply_subscription_event(
|
|
db: AsyncSession, event_id: str, event_type: str, payload: dict
|
|
) -> bool:
|
|
"""Idempotent. Returns True if the event was applied; False if it had
|
|
already been processed (idempotent ack). The webhook handler returns 200
|
|
either way.
|
|
|
|
Atomic: the StripeEvent idempotency mark and the handler's state
|
|
mutations are committed in a single transaction. If the handler raises
|
|
the entire transaction (idempotency mark + partial mutations) is rolled
|
|
back, so a Stripe retry will re-run the handler. Without this, a
|
|
handler that fails mid-flight would leave the StripeEvent row persisted
|
|
and silently desync subscription state from Stripe.
|
|
"""
|
|
db.add(StripeEvent(
|
|
id=event_id,
|
|
event_type=event_type,
|
|
payload_excerpt=_excerpt(payload),
|
|
))
|
|
try:
|
|
await db.flush()
|
|
except IntegrityError:
|
|
# Duplicate event_id — already processed (or in flight). Ack with False.
|
|
await db.rollback()
|
|
return False
|
|
|
|
try:
|
|
if event_type == "checkout.session.completed":
|
|
await _handle_checkout_completed(db, payload)
|
|
elif event_type == "customer.subscription.updated":
|
|
await _handle_subscription_updated(db, payload)
|
|
elif event_type == "customer.subscription.deleted":
|
|
await _handle_subscription_deleted(db, payload)
|
|
elif event_type == "invoice.payment_failed":
|
|
await _handle_payment_failed(db, payload)
|
|
elif event_type == "invoice.payment_succeeded":
|
|
await _handle_payment_succeeded(db, payload)
|
|
await db.commit()
|
|
except Exception:
|
|
# Roll back the StripeEvent insert + any partial handler mutations
|
|
# so Stripe's retry can re-run cleanly.
|
|
await db.rollback()
|
|
raise
|
|
return True
|
|
|
|
|
|
def _plan_limits_to_dict(pl) -> dict:
|
|
return {c.name: getattr(pl, c.name) for c in pl.__table__.columns}
|
|
|
|
|
|
def _excerpt(payload: dict) -> dict:
|
|
obj = payload.get("data", {}).get("object", {})
|
|
return {
|
|
"object_id": obj.get("id"),
|
|
"customer": obj.get("customer"),
|
|
"subscription": obj.get("subscription"),
|
|
"status": obj.get("status"),
|
|
}
|
|
|
|
|
|
async def _handle_checkout_completed(db: AsyncSession, payload: dict):
|
|
obj = payload["data"]["object"]
|
|
customer_id = obj["customer"]
|
|
subscription_id = obj["subscription"]
|
|
|
|
account = (await db.execute(
|
|
select(Account).where(Account.stripe_customer_id == customer_id)
|
|
)).scalar_one_or_none()
|
|
if account is None:
|
|
return
|
|
|
|
sub = (await db.execute(
|
|
select(Subscription).where(Subscription.account_id == account.id)
|
|
)).scalar_one_or_none()
|
|
if sub is None:
|
|
return
|
|
|
|
stripe.api_key = settings.STRIPE_SECRET_KEY
|
|
stripe_sub = stripe.Subscription.retrieve(subscription_id)
|
|
sub.stripe_subscription_id = subscription_id
|
|
sub.stripe_price_id = stripe_sub["items"]["data"][0]["price"]["id"]
|
|
sub.status = "active"
|
|
sub.current_period_start = datetime.fromtimestamp(stripe_sub["current_period_start"], tz=timezone.utc)
|
|
sub.current_period_end = datetime.fromtimestamp(stripe_sub["current_period_end"], tz=timezone.utc)
|
|
sub.seat_limit = stripe_sub["items"]["data"][0]["quantity"]
|
|
pb = (await db.execute(
|
|
select(PlanBilling).where(
|
|
(PlanBilling.stripe_monthly_price_id == sub.stripe_price_id) |
|
|
(PlanBilling.stripe_annual_price_id == sub.stripe_price_id)
|
|
)
|
|
)).scalar_one_or_none()
|
|
if pb is not None:
|
|
sub.plan = pb.plan
|
|
# No commit — apply_subscription_event commits once for the full event.
|
|
|
|
|
|
async def _handle_subscription_updated(db: AsyncSession, payload: dict):
|
|
obj = payload["data"]["object"]
|
|
sub = (await db.execute(
|
|
select(Subscription).where(Subscription.stripe_subscription_id == obj["id"])
|
|
)).scalar_one_or_none()
|
|
if sub is None:
|
|
return
|
|
sub.status = obj["status"]
|
|
sub.current_period_start = datetime.fromtimestamp(obj["current_period_start"], tz=timezone.utc)
|
|
sub.current_period_end = datetime.fromtimestamp(obj["current_period_end"], tz=timezone.utc)
|
|
sub.cancel_at_period_end = obj.get("cancel_at_period_end", False)
|
|
sub.seat_limit = obj["items"]["data"][0]["quantity"]
|
|
# No commit — apply_subscription_event commits once for the full event.
|
|
|
|
|
|
async def _handle_subscription_deleted(db: AsyncSession, payload: dict):
|
|
obj = payload["data"]["object"]
|
|
sub = (await db.execute(
|
|
select(Subscription).where(Subscription.stripe_subscription_id == obj["id"])
|
|
)).scalar_one_or_none()
|
|
if sub is None:
|
|
return
|
|
sub.status = "canceled"
|
|
# No commit — apply_subscription_event commits once for the full event.
|
|
|
|
|
|
async def _handle_payment_failed(db: AsyncSession, payload: dict):
|
|
obj = payload["data"]["object"]
|
|
subscription_id = obj.get("subscription")
|
|
if not subscription_id:
|
|
return
|
|
sub = (await db.execute(
|
|
select(Subscription).where(Subscription.stripe_subscription_id == subscription_id)
|
|
)).scalar_one_or_none()
|
|
if sub is None:
|
|
return
|
|
sub.status = "past_due"
|
|
# No commit — apply_subscription_event commits once for the full event.
|
|
|
|
|
|
async def _handle_payment_succeeded(db: AsyncSession, payload: dict):
|
|
obj = payload["data"]["object"]
|
|
subscription_id = obj.get("subscription")
|
|
if not subscription_id:
|
|
return
|
|
sub = (await db.execute(
|
|
select(Subscription).where(Subscription.stripe_subscription_id == subscription_id)
|
|
)).scalar_one_or_none()
|
|
if sub is None:
|
|
return
|
|
if sub.status == "past_due":
|
|
sub.status = "active"
|
|
# No commit — apply_subscription_event commits once for the full event.
|