105 lines
3.6 KiB
Python
105 lines
3.6 KiB
Python
"""Single billing service module. Stripe is the only impl — no provider
|
|
abstraction. Account row is canonical local state; Stripe is canonical
|
|
remote state; the webhook handler bridges the two."""
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
import stripe
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.models.account import Account
|
|
from app.models.plan_billing import PlanBilling
|
|
from app.models.subscription import Subscription
|
|
|
|
|
|
TRIAL_DAYS = 14
|
|
|
|
|
|
class BillingService:
|
|
@staticmethod
|
|
async def start_trial(db: AsyncSession, account_id) -> Subscription:
|
|
"""Idempotent. Creates a trialing Subscription on Pro for the account if
|
|
one doesn't exist; otherwise returns the existing row."""
|
|
result = await db.execute(
|
|
select(Subscription).where(Subscription.account_id == account_id)
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
if existing is not None:
|
|
return existing
|
|
|
|
sub = Subscription(
|
|
account_id=account_id,
|
|
plan="pro",
|
|
status="trialing",
|
|
current_period_start=datetime.now(timezone.utc),
|
|
current_period_end=datetime.now(timezone.utc) + timedelta(days=TRIAL_DAYS),
|
|
)
|
|
db.add(sub)
|
|
await db.commit()
|
|
await db.refresh(sub)
|
|
return sub
|
|
|
|
@staticmethod
|
|
async def create_checkout_session(
|
|
db: AsyncSession,
|
|
account: Account,
|
|
plan: str,
|
|
seats: int,
|
|
billing_interval: str,
|
|
success_url: str,
|
|
cancel_url: str,
|
|
) -> str:
|
|
"""Create a Stripe Checkout Session for subscription purchase. If the
|
|
account currently has a trialing subscription with time remaining, that
|
|
trial end is preserved on the new Stripe subscription so the user
|
|
isn't charged early."""
|
|
if not settings.stripe_enabled:
|
|
raise RuntimeError("Stripe not configured")
|
|
stripe.api_key = settings.STRIPE_SECRET_KEY
|
|
|
|
plan_billing = (await db.execute(
|
|
select(PlanBilling).where(PlanBilling.plan == plan)
|
|
)).scalar_one_or_none()
|
|
if plan_billing is None:
|
|
raise ValueError(f"Unknown plan: {plan}")
|
|
price_id = (
|
|
plan_billing.stripe_monthly_price_id if billing_interval == "monthly"
|
|
else plan_billing.stripe_annual_price_id
|
|
)
|
|
if price_id is None:
|
|
raise RuntimeError(
|
|
f"Plan '{plan}' has no Stripe price for {billing_interval}"
|
|
)
|
|
|
|
if account.stripe_customer_id is None:
|
|
customer = stripe.Customer.create(
|
|
email=None,
|
|
metadata={"account_id": str(account.id)},
|
|
)
|
|
account.stripe_customer_id = customer.id
|
|
await db.commit()
|
|
|
|
sub = (await db.execute(
|
|
select(Subscription).where(Subscription.account_id == account.id)
|
|
)).scalar_one_or_none()
|
|
subscription_data = {}
|
|
if (
|
|
sub
|
|
and sub.status == "trialing"
|
|
and sub.current_period_end
|
|
and sub.current_period_end > datetime.now(timezone.utc)
|
|
):
|
|
subscription_data["trial_end"] = int(sub.current_period_end.timestamp())
|
|
|
|
session = stripe.checkout.Session.create(
|
|
customer=account.stripe_customer_id,
|
|
line_items=[{"price": price_id, "quantity": seats}],
|
|
mode="subscription",
|
|
subscription_data=subscription_data or None,
|
|
success_url=success_url,
|
|
cancel_url=cancel_url,
|
|
allow_promotion_codes=False,
|
|
)
|
|
return session.url
|