diff --git a/backend/app/services/billing.py b/backend/app/services/billing.py index 1ae0a999..444867a5 100644 --- a/backend/app/services/billing.py +++ b/backend/app/services/billing.py @@ -210,28 +210,44 @@ class BillingService: ) -> bool: """Idempotent. Returns True if the event was applied; False if it had already been processed (idempotent ack). The webhook handler returns 200 - either way.""" + either way. + + Atomic: the StripeEvent idempotency mark and the handler's state + mutations are committed in a single transaction. If the handler raises + the entire transaction (idempotency mark + partial mutations) is rolled + back, so a Stripe retry will re-run the handler. Without this, a + handler that fails mid-flight would leave the StripeEvent row persisted + and silently desync subscription state from Stripe. + """ + db.add(StripeEvent( + id=event_id, + event_type=event_type, + payload_excerpt=_excerpt(payload), + )) try: - db.add(StripeEvent( - id=event_id, - event_type=event_type, - payload_excerpt=_excerpt(payload), - )) - await db.commit() + await db.flush() except IntegrityError: + # Duplicate event_id — already processed (or in flight). Ack with False. await db.rollback() return False - if event_type == "checkout.session.completed": - await _handle_checkout_completed(db, payload) - elif event_type == "customer.subscription.updated": - await _handle_subscription_updated(db, payload) - elif event_type == "customer.subscription.deleted": - await _handle_subscription_deleted(db, payload) - elif event_type == "invoice.payment_failed": - await _handle_payment_failed(db, payload) - elif event_type == "invoice.payment_succeeded": - await _handle_payment_succeeded(db, payload) + try: + if event_type == "checkout.session.completed": + await _handle_checkout_completed(db, payload) + elif event_type == "customer.subscription.updated": + await _handle_subscription_updated(db, payload) + elif event_type == "customer.subscription.deleted": + await _handle_subscription_deleted(db, payload) + elif event_type == "invoice.payment_failed": + await _handle_payment_failed(db, payload) + elif event_type == "invoice.payment_succeeded": + await _handle_payment_succeeded(db, payload) + await db.commit() + except Exception: + # Roll back the StripeEvent insert + any partial handler mutations + # so Stripe's retry can re-run cleanly. + await db.rollback() + raise return True @@ -282,7 +298,7 @@ async def _handle_checkout_completed(db: AsyncSession, payload: dict): )).scalar_one_or_none() if pb is not None: sub.plan = pb.plan - await db.commit() + # No commit — apply_subscription_event commits once for the full event. async def _handle_subscription_updated(db: AsyncSession, payload: dict): @@ -297,7 +313,7 @@ async def _handle_subscription_updated(db: AsyncSession, payload: dict): sub.current_period_end = datetime.fromtimestamp(obj["current_period_end"], tz=timezone.utc) sub.cancel_at_period_end = obj.get("cancel_at_period_end", False) sub.seat_limit = obj["items"]["data"][0]["quantity"] - await db.commit() + # No commit — apply_subscription_event commits once for the full event. async def _handle_subscription_deleted(db: AsyncSession, payload: dict): @@ -308,7 +324,7 @@ async def _handle_subscription_deleted(db: AsyncSession, payload: dict): if sub is None: return sub.status = "canceled" - await db.commit() + # No commit — apply_subscription_event commits once for the full event. async def _handle_payment_failed(db: AsyncSession, payload: dict): @@ -322,7 +338,7 @@ async def _handle_payment_failed(db: AsyncSession, payload: dict): if sub is None: return sub.status = "past_due" - await db.commit() + # No commit — apply_subscription_event commits once for the full event. async def _handle_payment_succeeded(db: AsyncSession, payload: dict): @@ -337,4 +353,4 @@ async def _handle_payment_succeeded(db: AsyncSession, payload: dict): return if sub.status == "past_due": sub.status = "active" - await db.commit() + # No commit — apply_subscription_event commits once for the full event. diff --git a/backend/tests/test_stripe_webhook_handler.py b/backend/tests/test_stripe_webhook_handler.py index 0430b9f3..e14b9925 100644 --- a/backend/tests/test_stripe_webhook_handler.py +++ b/backend/tests/test_stripe_webhook_handler.py @@ -142,3 +142,178 @@ async def test_webhook_idempotency( assert r2.status_code == 200 assert r1.json()["applied"] is True assert r2.json()["applied"] is False + + +# ---------------------------------------------------------------------------- +# Atomic-idempotency regression tests +# ---------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_apply_event_handler_failure_does_not_persist_idempotency_mark( + test_db, test_user, +): + """If the handler raises, the StripeEvent row must NOT be persisted — + otherwise Stripe's retry will be silently dropped as a duplicate and the + subscription state will desync from Stripe.""" + from app.services.billing import BillingService + from app.models.stripe_event import StripeEvent + + event_id = "evt_handler_fail_1" + payload = {"data": {"object": { + "id": "sub_doesnotmatter", + "status": "active", + "current_period_start": 1714521600, + "current_period_end": 1717113600, + "items": {"data": [{"quantity": 1}]}, + "cancel_at_period_end": False, + }}} + + boom = RuntimeError("simulated handler failure") + with patch( + "app.services.billing._handle_subscription_updated", + side_effect=boom, + ): + with pytest.raises(RuntimeError, match="simulated handler failure"): + await BillingService.apply_subscription_event( + test_db, + event_id=event_id, + event_type="customer.subscription.updated", + payload=payload, + ) + + # The StripeEvent row must not exist — handler raised, the entire + # transaction (idempotency mark + partial mutations) was rolled back. + row = (await test_db.execute( + select(StripeEvent).where(StripeEvent.id == event_id) + )).scalar_one_or_none() + assert row is None, ( + "StripeEvent row was persisted despite handler failure — " + "Stripe retry will be silently dropped" + ) + + +@pytest.mark.asyncio +async def test_apply_event_retry_after_failure_succeeds( + test_db, test_user, +): + """A failed first attempt followed by a successful retry must apply state. + This is the core Stripe webhook retry contract.""" + from app.services.billing import BillingService + from app.models.stripe_event import StripeEvent + + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id)) + test_db.add(Subscription( + account_id=account_id, plan="pro", status="trialing", + stripe_subscription_id="sub_retry", + )) + await test_db.commit() + + event_id = "evt_retry_1" + payload = {"data": {"object": { + "id": "sub_retry", + "status": "active", + "current_period_start": 1714521600, + "current_period_end": 1717113600, + "items": {"data": [{"quantity": 3}]}, + "cancel_at_period_end": False, + }}} + + # First attempt — handler raises mid-flight. + with patch( + "app.services.billing._handle_subscription_updated", + side_effect=RuntimeError("transient blip"), + ): + with pytest.raises(RuntimeError): + await BillingService.apply_subscription_event( + test_db, + event_id=event_id, + event_type="customer.subscription.updated", + payload=payload, + ) + + # No idempotency mark, sub still trialing. + row = (await test_db.execute( + select(StripeEvent).where(StripeEvent.id == event_id) + )).scalar_one_or_none() + assert row is None + sub = (await test_db.execute( + select(Subscription).where(Subscription.account_id == account_id) + )).scalar_one() + assert sub.status == "trialing" + + # Second attempt — same event_id, handler succeeds. + applied = await BillingService.apply_subscription_event( + test_db, + event_id=event_id, + event_type="customer.subscription.updated", + payload=payload, + ) + assert applied is True + + # Idempotency mark now persisted, sub state reconciled. + row = (await test_db.execute( + select(StripeEvent).where(StripeEvent.id == event_id) + )).scalar_one() + assert row.id == event_id + await test_db.refresh(sub) + assert sub.status == "active" + assert sub.seat_limit == 3 + + +@pytest.mark.asyncio +async def test_apply_event_duplicate_event_id_skips( + test_db, test_user, +): + """Two successful invocations with the same event_id must not double-apply. + Second call returns False; mutations are not repeated.""" + from app.services.billing import BillingService + + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id)) + test_db.add(Subscription( + account_id=account_id, plan="pro", status="trialing", + stripe_subscription_id="sub_dup", + )) + await test_db.commit() + + event_id = "evt_dedupe_1" + payload = {"data": {"object": { + "id": "sub_dup", + "status": "active", + "current_period_start": 1714521600, + "current_period_end": 1717113600, + "items": {"data": [{"quantity": 7}]}, + "cancel_at_period_end": False, + }}} + + applied1 = await BillingService.apply_subscription_event( + test_db, + event_id=event_id, + event_type="customer.subscription.updated", + payload=payload, + ) + assert applied1 is True + + sub = (await test_db.execute( + select(Subscription).where(Subscription.account_id == account_id) + )).scalar_one() + assert sub.status == "active" + assert sub.seat_limit == 7 + + # Mutate locally so we can prove the second call doesn't re-run the handler. + sub.seat_limit = 99 + await test_db.commit() + + applied2 = await BillingService.apply_subscription_event( + test_db, + event_id=event_id, + event_type="customer.subscription.updated", + payload=payload, + ) + assert applied2 is False + + await test_db.refresh(sub) + # Handler did NOT run again — our local mutation is preserved. + assert sub.seat_limit == 99