feat(billing): apply_subscription_event with stripe_events idempotency
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -5,11 +5,13 @@ from datetime import datetime, timezone, timedelta
|
||||
|
||||
import stripe
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.account import Account
|
||||
from app.models.plan_billing import PlanBilling
|
||||
from app.models.stripe_event import StripeEvent
|
||||
from app.models.subscription import Subscription
|
||||
|
||||
|
||||
@@ -102,3 +104,134 @@ class BillingService:
|
||||
allow_promotion_codes=False,
|
||||
)
|
||||
return session.url
|
||||
|
||||
@staticmethod
|
||||
async def apply_subscription_event(
|
||||
db: AsyncSession, event_id: str, event_type: str, payload: dict
|
||||
) -> bool:
|
||||
"""Idempotent. Returns True if the event was applied; False if it had
|
||||
already been processed (idempotent ack). The webhook handler returns 200
|
||||
either way."""
|
||||
try:
|
||||
db.add(StripeEvent(
|
||||
id=event_id,
|
||||
event_type=event_type,
|
||||
payload_excerpt=_excerpt(payload),
|
||||
))
|
||||
await db.commit()
|
||||
except IntegrityError:
|
||||
await db.rollback()
|
||||
return False
|
||||
|
||||
if event_type == "checkout.session.completed":
|
||||
await _handle_checkout_completed(db, payload)
|
||||
elif event_type == "customer.subscription.updated":
|
||||
await _handle_subscription_updated(db, payload)
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
await _handle_subscription_deleted(db, payload)
|
||||
elif event_type == "invoice.payment_failed":
|
||||
await _handle_payment_failed(db, payload)
|
||||
elif event_type == "invoice.payment_succeeded":
|
||||
await _handle_payment_succeeded(db, payload)
|
||||
return True
|
||||
|
||||
|
||||
def _excerpt(payload: dict) -> dict:
|
||||
obj = payload.get("data", {}).get("object", {})
|
||||
return {
|
||||
"object_id": obj.get("id"),
|
||||
"customer": obj.get("customer"),
|
||||
"subscription": obj.get("subscription"),
|
||||
"status": obj.get("status"),
|
||||
}
|
||||
|
||||
|
||||
async def _handle_checkout_completed(db: AsyncSession, payload: dict):
|
||||
obj = payload["data"]["object"]
|
||||
customer_id = obj["customer"]
|
||||
subscription_id = obj["subscription"]
|
||||
|
||||
account = (await db.execute(
|
||||
select(Account).where(Account.stripe_customer_id == customer_id)
|
||||
)).scalar_one_or_none()
|
||||
if account is None:
|
||||
return
|
||||
|
||||
sub = (await db.execute(
|
||||
select(Subscription).where(Subscription.account_id == account.id)
|
||||
)).scalar_one_or_none()
|
||||
if sub is None:
|
||||
return
|
||||
|
||||
stripe.api_key = settings.STRIPE_SECRET_KEY
|
||||
stripe_sub = stripe.Subscription.retrieve(subscription_id)
|
||||
sub.stripe_subscription_id = subscription_id
|
||||
sub.stripe_price_id = stripe_sub["items"]["data"][0]["price"]["id"]
|
||||
sub.status = "active"
|
||||
sub.current_period_start = datetime.fromtimestamp(stripe_sub["current_period_start"], tz=timezone.utc)
|
||||
sub.current_period_end = datetime.fromtimestamp(stripe_sub["current_period_end"], tz=timezone.utc)
|
||||
sub.seat_limit = stripe_sub["items"]["data"][0]["quantity"]
|
||||
pb = (await db.execute(
|
||||
select(PlanBilling).where(
|
||||
(PlanBilling.stripe_monthly_price_id == sub.stripe_price_id) |
|
||||
(PlanBilling.stripe_annual_price_id == sub.stripe_price_id)
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if pb is not None:
|
||||
sub.plan = pb.plan
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _handle_subscription_updated(db: AsyncSession, payload: dict):
|
||||
obj = payload["data"]["object"]
|
||||
sub = (await db.execute(
|
||||
select(Subscription).where(Subscription.stripe_subscription_id == obj["id"])
|
||||
)).scalar_one_or_none()
|
||||
if sub is None:
|
||||
return
|
||||
sub.status = obj["status"]
|
||||
sub.current_period_start = datetime.fromtimestamp(obj["current_period_start"], tz=timezone.utc)
|
||||
sub.current_period_end = datetime.fromtimestamp(obj["current_period_end"], tz=timezone.utc)
|
||||
sub.cancel_at_period_end = obj.get("cancel_at_period_end", False)
|
||||
sub.seat_limit = obj["items"]["data"][0]["quantity"]
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _handle_subscription_deleted(db: AsyncSession, payload: dict):
|
||||
obj = payload["data"]["object"]
|
||||
sub = (await db.execute(
|
||||
select(Subscription).where(Subscription.stripe_subscription_id == obj["id"])
|
||||
)).scalar_one_or_none()
|
||||
if sub is None:
|
||||
return
|
||||
sub.status = "canceled"
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _handle_payment_failed(db: AsyncSession, payload: dict):
|
||||
obj = payload["data"]["object"]
|
||||
subscription_id = obj.get("subscription")
|
||||
if not subscription_id:
|
||||
return
|
||||
sub = (await db.execute(
|
||||
select(Subscription).where(Subscription.stripe_subscription_id == subscription_id)
|
||||
)).scalar_one_or_none()
|
||||
if sub is None:
|
||||
return
|
||||
sub.status = "past_due"
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _handle_payment_succeeded(db: AsyncSession, payload: dict):
|
||||
obj = payload["data"]["object"]
|
||||
subscription_id = obj.get("subscription")
|
||||
if not subscription_id:
|
||||
return
|
||||
sub = (await db.execute(
|
||||
select(Subscription).where(Subscription.stripe_subscription_id == subscription_id)
|
||||
)).scalar_one_or_none()
|
||||
if sub is None:
|
||||
return
|
||||
if sub.status == "past_due":
|
||||
sub.status = "active"
|
||||
await db.commit()
|
||||
|
||||
@@ -57,3 +57,24 @@ async def test_register_creates_trial_subscription(client, test_db):
|
||||
assert sub.plan == "pro"
|
||||
assert sub.status == "trialing"
|
||||
assert sub.current_period_end is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_subscription_event_is_idempotent(test_db):
|
||||
payload = {
|
||||
"data": {"object": {
|
||||
"id": "evt_test_1",
|
||||
"customer": "cus_xxx",
|
||||
"subscription": "sub_xxx",
|
||||
"status": "active",
|
||||
}}
|
||||
}
|
||||
|
||||
applied_first = await BillingService.apply_subscription_event(
|
||||
test_db, "evt_test_1", "customer.subscription.updated", payload
|
||||
)
|
||||
applied_second = await BillingService.apply_subscription_event(
|
||||
test_db, "evt_test_1", "customer.subscription.updated", payload
|
||||
)
|
||||
assert applied_first is True
|
||||
assert applied_second is False # already-processed → ack without re-applying
|
||||
|
||||
Reference in New Issue
Block a user