Files
resolutionflow/backend/app/services/billing.py
Michael Chihlas 2f8ec3775e feat(billing): add BillingService.open_customer_portal + GET endpoint
Authed users can now request a Stripe-hosted Customer Portal URL for card
updates and cancellation via GET /api/v1/billing/portal-session. The path is
already in both _SUBSCRIPTION_GUARD_ALLOWLIST and _EMAIL_VERIFICATION_ALLOWLIST
so canceled or unverified-past-grace users can still update billing.

- Returns 503 with {"error": "stripe_not_configured"} when STRIPE_SECRET_KEY unset.
- Returns 400 with {"error": "no_stripe_customer"} when account has no
  stripe_customer_id (must complete checkout first).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-06 20:00:08 -04:00

316 lines
12 KiB
Python

"""Single billing service module. Stripe is the only impl — no provider
abstraction. Account row is canonical local state; Stripe is canonical
remote state; the webhook handler bridges the two."""
from datetime import datetime, timezone, timedelta
import stripe
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.models.account import Account
from app.models.plan_billing import PlanBilling
from app.models.stripe_event import StripeEvent
from app.models.subscription import Subscription
TRIAL_DAYS = 14
class BillingService:
@staticmethod
async def start_trial(db: AsyncSession, account_id) -> Subscription:
"""Idempotent. Creates a trialing Subscription on Pro for the account if
one doesn't exist; otherwise returns the existing row."""
result = await db.execute(
select(Subscription).where(Subscription.account_id == account_id)
)
existing = result.scalar_one_or_none()
if existing is not None:
return existing
sub = Subscription(
account_id=account_id,
plan="pro",
status="trialing",
current_period_start=datetime.now(timezone.utc),
current_period_end=datetime.now(timezone.utc) + timedelta(days=TRIAL_DAYS),
)
db.add(sub)
await db.commit()
await db.refresh(sub)
return sub
@staticmethod
async def create_checkout_session(
db: AsyncSession,
account: Account,
plan: str,
seats: int,
billing_interval: str,
success_url: str,
cancel_url: str,
) -> str:
"""Create a Stripe Checkout Session for subscription purchase. If the
account currently has a trialing subscription with time remaining, that
trial end is preserved on the new Stripe subscription so the user
isn't charged early."""
if not settings.stripe_enabled:
raise RuntimeError("Stripe not configured")
stripe.api_key = settings.STRIPE_SECRET_KEY
plan_billing = (await db.execute(
select(PlanBilling).where(PlanBilling.plan == plan)
)).scalar_one_or_none()
if plan_billing is None:
raise ValueError(f"Unknown plan: {plan}")
price_id = (
plan_billing.stripe_monthly_price_id if billing_interval == "monthly"
else plan_billing.stripe_annual_price_id
)
if price_id is None:
raise RuntimeError(
f"Plan '{plan}' has no Stripe price for {billing_interval}"
)
if account.stripe_customer_id is None:
customer = stripe.Customer.create(
email=None,
metadata={"account_id": str(account.id)},
)
account.stripe_customer_id = customer.id
await db.commit()
sub = (await db.execute(
select(Subscription).where(Subscription.account_id == account.id)
)).scalar_one_or_none()
subscription_data = {}
if (
sub
and sub.status == "trialing"
and sub.current_period_end
and sub.current_period_end > datetime.now(timezone.utc)
):
subscription_data["trial_end"] = int(sub.current_period_end.timestamp())
session = stripe.checkout.Session.create(
customer=account.stripe_customer_id,
line_items=[{"price": price_id, "quantity": seats}],
mode="subscription",
subscription_data=subscription_data or None,
success_url=success_url,
cancel_url=cancel_url,
allow_promotion_codes=False,
)
return session.url
@staticmethod
async def open_customer_portal(account: Account) -> str:
"""Create a Stripe-hosted Customer Portal session and return the URL.
Raises RuntimeError if Stripe isn't configured (endpoint maps to 503).
Raises ValueError if the account has no stripe_customer_id yet — the
user must complete a checkout first (endpoint maps to 400).
"""
if not settings.stripe_enabled:
raise RuntimeError("Stripe not configured")
if account.stripe_customer_id is None:
raise ValueError("no_stripe_customer")
stripe.api_key = settings.STRIPE_SECRET_KEY
session = stripe.billing_portal.Session.create(
customer=account.stripe_customer_id,
return_url=f"{settings.FRONTEND_URL}/account/billing",
)
return session.url
@staticmethod
async def get_billing_state(db: AsyncSession, account):
"""Aggregate Subscription + PlanLimits + PlanBilling + resolved feature
flags for the account."""
from app.models.plan_limits import PlanLimits
from app.models.plan_billing import PlanBilling
from app.models.feature_flag import (
FeatureFlag, PlanFeatureDefault, AccountFeatureOverride,
)
sub = (await db.execute(
select(Subscription).where(Subscription.account_id == account.id)
)).scalar_one_or_none()
if sub is None:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail="No subscription for account")
pl = (await db.execute(
select(PlanLimits).where(PlanLimits.plan == sub.plan)
)).scalar_one_or_none()
pb = (await db.execute(
select(PlanBilling).where(PlanBilling.plan == sub.plan)
)).scalar_one_or_none()
# Resolved feature flags: plan defaults overridden by account overrides
defaults = (await db.execute(
select(PlanFeatureDefault, FeatureFlag)
.join(FeatureFlag, PlanFeatureDefault.flag_id == FeatureFlag.id)
.where(PlanFeatureDefault.plan == sub.plan)
)).all()
resolved = {flag.flag_key: pfd.enabled for pfd, flag in defaults}
overrides = (await db.execute(
select(AccountFeatureOverride, FeatureFlag)
.join(FeatureFlag, AccountFeatureOverride.flag_id == FeatureFlag.id)
.where(AccountFeatureOverride.account_id == account.id)
)).all()
for ovr, flag in overrides:
resolved[flag.flag_key] = ovr.enabled
return {
"subscription": {
"status": sub.status,
"plan": sub.plan,
"current_period_start": sub.current_period_start,
"current_period_end": sub.current_period_end,
"cancel_at_period_end": sub.cancel_at_period_end,
"seat_limit": sub.seat_limit,
"has_pro_entitlement": sub.has_pro_entitlement,
"is_paid": sub.is_paid,
},
"plan_billing": pb,
"plan_limits": _plan_limits_to_dict(pl) if pl else {},
"enabled_features": resolved,
}
@staticmethod
async def apply_subscription_event(
db: AsyncSession, event_id: str, event_type: str, payload: dict
) -> bool:
"""Idempotent. Returns True if the event was applied; False if it had
already been processed (idempotent ack). The webhook handler returns 200
either way."""
try:
db.add(StripeEvent(
id=event_id,
event_type=event_type,
payload_excerpt=_excerpt(payload),
))
await db.commit()
except IntegrityError:
await db.rollback()
return False
if event_type == "checkout.session.completed":
await _handle_checkout_completed(db, payload)
elif event_type == "customer.subscription.updated":
await _handle_subscription_updated(db, payload)
elif event_type == "customer.subscription.deleted":
await _handle_subscription_deleted(db, payload)
elif event_type == "invoice.payment_failed":
await _handle_payment_failed(db, payload)
elif event_type == "invoice.payment_succeeded":
await _handle_payment_succeeded(db, payload)
return True
def _plan_limits_to_dict(pl) -> dict:
return {c.name: getattr(pl, c.name) for c in pl.__table__.columns}
def _excerpt(payload: dict) -> dict:
obj = payload.get("data", {}).get("object", {})
return {
"object_id": obj.get("id"),
"customer": obj.get("customer"),
"subscription": obj.get("subscription"),
"status": obj.get("status"),
}
async def _handle_checkout_completed(db: AsyncSession, payload: dict):
obj = payload["data"]["object"]
customer_id = obj["customer"]
subscription_id = obj["subscription"]
account = (await db.execute(
select(Account).where(Account.stripe_customer_id == customer_id)
)).scalar_one_or_none()
if account is None:
return
sub = (await db.execute(
select(Subscription).where(Subscription.account_id == account.id)
)).scalar_one_or_none()
if sub is None:
return
stripe.api_key = settings.STRIPE_SECRET_KEY
stripe_sub = stripe.Subscription.retrieve(subscription_id)
sub.stripe_subscription_id = subscription_id
sub.stripe_price_id = stripe_sub["items"]["data"][0]["price"]["id"]
sub.status = "active"
sub.current_period_start = datetime.fromtimestamp(stripe_sub["current_period_start"], tz=timezone.utc)
sub.current_period_end = datetime.fromtimestamp(stripe_sub["current_period_end"], tz=timezone.utc)
sub.seat_limit = stripe_sub["items"]["data"][0]["quantity"]
pb = (await db.execute(
select(PlanBilling).where(
(PlanBilling.stripe_monthly_price_id == sub.stripe_price_id) |
(PlanBilling.stripe_annual_price_id == sub.stripe_price_id)
)
)).scalar_one_or_none()
if pb is not None:
sub.plan = pb.plan
await db.commit()
async def _handle_subscription_updated(db: AsyncSession, payload: dict):
obj = payload["data"]["object"]
sub = (await db.execute(
select(Subscription).where(Subscription.stripe_subscription_id == obj["id"])
)).scalar_one_or_none()
if sub is None:
return
sub.status = obj["status"]
sub.current_period_start = datetime.fromtimestamp(obj["current_period_start"], tz=timezone.utc)
sub.current_period_end = datetime.fromtimestamp(obj["current_period_end"], tz=timezone.utc)
sub.cancel_at_period_end = obj.get("cancel_at_period_end", False)
sub.seat_limit = obj["items"]["data"][0]["quantity"]
await db.commit()
async def _handle_subscription_deleted(db: AsyncSession, payload: dict):
obj = payload["data"]["object"]
sub = (await db.execute(
select(Subscription).where(Subscription.stripe_subscription_id == obj["id"])
)).scalar_one_or_none()
if sub is None:
return
sub.status = "canceled"
await db.commit()
async def _handle_payment_failed(db: AsyncSession, payload: dict):
obj = payload["data"]["object"]
subscription_id = obj.get("subscription")
if not subscription_id:
return
sub = (await db.execute(
select(Subscription).where(Subscription.stripe_subscription_id == subscription_id)
)).scalar_one_or_none()
if sub is None:
return
sub.status = "past_due"
await db.commit()
async def _handle_payment_succeeded(db: AsyncSession, payload: dict):
obj = payload["data"]["object"]
subscription_id = obj.get("subscription")
if not subscription_id:
return
sub = (await db.execute(
select(Subscription).where(Subscription.stripe_subscription_id == subscription_id)
)).scalar_one_or_none()
if sub is None:
return
if sub.status == "past_due":
sub.status = "active"
await db.commit()