feat(billing): apply_subscription_event with stripe_events idempotency
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -5,11 +5,13 @@ from datetime import datetime, timezone, timedelta
|
|||||||
|
|
||||||
import stripe
|
import stripe
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.models.account import Account
|
from app.models.account import Account
|
||||||
from app.models.plan_billing import PlanBilling
|
from app.models.plan_billing import PlanBilling
|
||||||
|
from app.models.stripe_event import StripeEvent
|
||||||
from app.models.subscription import Subscription
|
from app.models.subscription import Subscription
|
||||||
|
|
||||||
|
|
||||||
@@ -102,3 +104,134 @@ class BillingService:
|
|||||||
allow_promotion_codes=False,
|
allow_promotion_codes=False,
|
||||||
)
|
)
|
||||||
return session.url
|
return session.url
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def apply_subscription_event(
|
||||||
|
db: AsyncSession, event_id: str, event_type: str, payload: dict
|
||||||
|
) -> bool:
|
||||||
|
"""Idempotent. Returns True if the event was applied; False if it had
|
||||||
|
already been processed (idempotent ack). The webhook handler returns 200
|
||||||
|
either way."""
|
||||||
|
try:
|
||||||
|
db.add(StripeEvent(
|
||||||
|
id=event_id,
|
||||||
|
event_type=event_type,
|
||||||
|
payload_excerpt=_excerpt(payload),
|
||||||
|
))
|
||||||
|
await db.commit()
|
||||||
|
except IntegrityError:
|
||||||
|
await db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
if event_type == "checkout.session.completed":
|
||||||
|
await _handle_checkout_completed(db, payload)
|
||||||
|
elif event_type == "customer.subscription.updated":
|
||||||
|
await _handle_subscription_updated(db, payload)
|
||||||
|
elif event_type == "customer.subscription.deleted":
|
||||||
|
await _handle_subscription_deleted(db, payload)
|
||||||
|
elif event_type == "invoice.payment_failed":
|
||||||
|
await _handle_payment_failed(db, payload)
|
||||||
|
elif event_type == "invoice.payment_succeeded":
|
||||||
|
await _handle_payment_succeeded(db, payload)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _excerpt(payload: dict) -> dict:
|
||||||
|
obj = payload.get("data", {}).get("object", {})
|
||||||
|
return {
|
||||||
|
"object_id": obj.get("id"),
|
||||||
|
"customer": obj.get("customer"),
|
||||||
|
"subscription": obj.get("subscription"),
|
||||||
|
"status": obj.get("status"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_checkout_completed(db: AsyncSession, payload: dict):
|
||||||
|
obj = payload["data"]["object"]
|
||||||
|
customer_id = obj["customer"]
|
||||||
|
subscription_id = obj["subscription"]
|
||||||
|
|
||||||
|
account = (await db.execute(
|
||||||
|
select(Account).where(Account.stripe_customer_id == customer_id)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if account is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
sub = (await db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == account.id)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
stripe.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
stripe_sub = stripe.Subscription.retrieve(subscription_id)
|
||||||
|
sub.stripe_subscription_id = subscription_id
|
||||||
|
sub.stripe_price_id = stripe_sub["items"]["data"][0]["price"]["id"]
|
||||||
|
sub.status = "active"
|
||||||
|
sub.current_period_start = datetime.fromtimestamp(stripe_sub["current_period_start"], tz=timezone.utc)
|
||||||
|
sub.current_period_end = datetime.fromtimestamp(stripe_sub["current_period_end"], tz=timezone.utc)
|
||||||
|
sub.seat_limit = stripe_sub["items"]["data"][0]["quantity"]
|
||||||
|
pb = (await db.execute(
|
||||||
|
select(PlanBilling).where(
|
||||||
|
(PlanBilling.stripe_monthly_price_id == sub.stripe_price_id) |
|
||||||
|
(PlanBilling.stripe_annual_price_id == sub.stripe_price_id)
|
||||||
|
)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if pb is not None:
|
||||||
|
sub.plan = pb.plan
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_subscription_updated(db: AsyncSession, payload: dict):
|
||||||
|
obj = payload["data"]["object"]
|
||||||
|
sub = (await db.execute(
|
||||||
|
select(Subscription).where(Subscription.stripe_subscription_id == obj["id"])
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
sub.status = obj["status"]
|
||||||
|
sub.current_period_start = datetime.fromtimestamp(obj["current_period_start"], tz=timezone.utc)
|
||||||
|
sub.current_period_end = datetime.fromtimestamp(obj["current_period_end"], tz=timezone.utc)
|
||||||
|
sub.cancel_at_period_end = obj.get("cancel_at_period_end", False)
|
||||||
|
sub.seat_limit = obj["items"]["data"][0]["quantity"]
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_subscription_deleted(db: AsyncSession, payload: dict):
|
||||||
|
obj = payload["data"]["object"]
|
||||||
|
sub = (await db.execute(
|
||||||
|
select(Subscription).where(Subscription.stripe_subscription_id == obj["id"])
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
sub.status = "canceled"
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_payment_failed(db: AsyncSession, payload: dict):
|
||||||
|
obj = payload["data"]["object"]
|
||||||
|
subscription_id = obj.get("subscription")
|
||||||
|
if not subscription_id:
|
||||||
|
return
|
||||||
|
sub = (await db.execute(
|
||||||
|
select(Subscription).where(Subscription.stripe_subscription_id == subscription_id)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
sub.status = "past_due"
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_payment_succeeded(db: AsyncSession, payload: dict):
|
||||||
|
obj = payload["data"]["object"]
|
||||||
|
subscription_id = obj.get("subscription")
|
||||||
|
if not subscription_id:
|
||||||
|
return
|
||||||
|
sub = (await db.execute(
|
||||||
|
select(Subscription).where(Subscription.stripe_subscription_id == subscription_id)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
if sub.status == "past_due":
|
||||||
|
sub.status = "active"
|
||||||
|
await db.commit()
|
||||||
|
|||||||
@@ -57,3 +57,24 @@ async def test_register_creates_trial_subscription(client, test_db):
|
|||||||
assert sub.plan == "pro"
|
assert sub.plan == "pro"
|
||||||
assert sub.status == "trialing"
|
assert sub.status == "trialing"
|
||||||
assert sub.current_period_end is not None
|
assert sub.current_period_end is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_apply_subscription_event_is_idempotent(test_db):
|
||||||
|
payload = {
|
||||||
|
"data": {"object": {
|
||||||
|
"id": "evt_test_1",
|
||||||
|
"customer": "cus_xxx",
|
||||||
|
"subscription": "sub_xxx",
|
||||||
|
"status": "active",
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
applied_first = await BillingService.apply_subscription_event(
|
||||||
|
test_db, "evt_test_1", "customer.subscription.updated", payload
|
||||||
|
)
|
||||||
|
applied_second = await BillingService.apply_subscription_event(
|
||||||
|
test_db, "evt_test_1", "customer.subscription.updated", payload
|
||||||
|
)
|
||||||
|
assert applied_first is True
|
||||||
|
assert applied_second is False # already-processed → ack without re-applying
|
||||||
|
|||||||
Reference in New Issue
Block a user