diff --git a/backend/app/services/billing.py b/backend/app/services/billing.py index 6e3d65b1..c94d1880 100644 --- a/backend/app/services/billing.py +++ b/backend/app/services/billing.py @@ -5,11 +5,13 @@ from datetime import datetime, timezone, timedelta import stripe from sqlalchemy import select +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.models.account import Account from app.models.plan_billing import PlanBilling +from app.models.stripe_event import StripeEvent from app.models.subscription import Subscription @@ -102,3 +104,134 @@ class BillingService: allow_promotion_codes=False, ) return session.url + + @staticmethod + async def apply_subscription_event( + db: AsyncSession, event_id: str, event_type: str, payload: dict + ) -> bool: + """Idempotent. Returns True if the event was applied; False if it had + already been processed (idempotent ack). The webhook handler returns 200 + either way.""" + try: + db.add(StripeEvent( + id=event_id, + event_type=event_type, + payload_excerpt=_excerpt(payload), + )) + await db.commit() + except IntegrityError: + await db.rollback() + return False + + if event_type == "checkout.session.completed": + await _handle_checkout_completed(db, payload) + elif event_type == "customer.subscription.updated": + await _handle_subscription_updated(db, payload) + elif event_type == "customer.subscription.deleted": + await _handle_subscription_deleted(db, payload) + elif event_type == "invoice.payment_failed": + await _handle_payment_failed(db, payload) + elif event_type == "invoice.payment_succeeded": + await _handle_payment_succeeded(db, payload) + return True + + +def _excerpt(payload: dict) -> dict: + obj = payload.get("data", {}).get("object", {}) + return { + "object_id": obj.get("id"), + "customer": obj.get("customer"), + "subscription": obj.get("subscription"), + "status": obj.get("status"), + } + + +async def _handle_checkout_completed(db: AsyncSession, payload: dict): + obj = payload["data"]["object"] + customer_id = obj["customer"] + subscription_id = obj["subscription"] + + account = (await db.execute( + select(Account).where(Account.stripe_customer_id == customer_id) + )).scalar_one_or_none() + if account is None: + return + + sub = (await db.execute( + select(Subscription).where(Subscription.account_id == account.id) + )).scalar_one_or_none() + if sub is None: + return + + stripe.api_key = settings.STRIPE_SECRET_KEY + stripe_sub = stripe.Subscription.retrieve(subscription_id) + sub.stripe_subscription_id = subscription_id + sub.stripe_price_id = stripe_sub["items"]["data"][0]["price"]["id"] + sub.status = "active" + sub.current_period_start = datetime.fromtimestamp(stripe_sub["current_period_start"], tz=timezone.utc) + sub.current_period_end = datetime.fromtimestamp(stripe_sub["current_period_end"], tz=timezone.utc) + sub.seat_limit = stripe_sub["items"]["data"][0]["quantity"] + pb = (await db.execute( + select(PlanBilling).where( + (PlanBilling.stripe_monthly_price_id == sub.stripe_price_id) | + (PlanBilling.stripe_annual_price_id == sub.stripe_price_id) + ) + )).scalar_one_or_none() + if pb is not None: + sub.plan = pb.plan + await db.commit() + + +async def _handle_subscription_updated(db: AsyncSession, payload: dict): + obj = payload["data"]["object"] + sub = (await db.execute( + select(Subscription).where(Subscription.stripe_subscription_id == obj["id"]) + )).scalar_one_or_none() + if sub is None: + return + sub.status = obj["status"] + sub.current_period_start = datetime.fromtimestamp(obj["current_period_start"], tz=timezone.utc) + sub.current_period_end = datetime.fromtimestamp(obj["current_period_end"], tz=timezone.utc) + sub.cancel_at_period_end = obj.get("cancel_at_period_end", False) + sub.seat_limit = obj["items"]["data"][0]["quantity"] + await db.commit() + + +async def _handle_subscription_deleted(db: AsyncSession, payload: dict): + obj = payload["data"]["object"] + sub = (await db.execute( + select(Subscription).where(Subscription.stripe_subscription_id == obj["id"]) + )).scalar_one_or_none() + if sub is None: + return + sub.status = "canceled" + await db.commit() + + +async def _handle_payment_failed(db: AsyncSession, payload: dict): + obj = payload["data"]["object"] + subscription_id = obj.get("subscription") + if not subscription_id: + return + sub = (await db.execute( + select(Subscription).where(Subscription.stripe_subscription_id == subscription_id) + )).scalar_one_or_none() + if sub is None: + return + sub.status = "past_due" + await db.commit() + + +async def _handle_payment_succeeded(db: AsyncSession, payload: dict): + obj = payload["data"]["object"] + subscription_id = obj.get("subscription") + if not subscription_id: + return + sub = (await db.execute( + select(Subscription).where(Subscription.stripe_subscription_id == subscription_id) + )).scalar_one_or_none() + if sub is None: + return + if sub.status == "past_due": + sub.status = "active" + await db.commit() diff --git a/backend/tests/test_billing_service.py b/backend/tests/test_billing_service.py index cb9726d8..7ba62708 100644 --- a/backend/tests/test_billing_service.py +++ b/backend/tests/test_billing_service.py @@ -57,3 +57,24 @@ async def test_register_creates_trial_subscription(client, test_db): assert sub.plan == "pro" assert sub.status == "trialing" assert sub.current_period_end is not None + + +@pytest.mark.asyncio +async def test_apply_subscription_event_is_idempotent(test_db): + payload = { + "data": {"object": { + "id": "evt_test_1", + "customer": "cus_xxx", + "subscription": "sub_xxx", + "status": "active", + }} + } + + applied_first = await BillingService.apply_subscription_event( + test_db, "evt_test_1", "customer.subscription.updated", payload + ) + applied_second = await BillingService.apply_subscription_event( + test_db, "evt_test_1", "customer.subscription.updated", payload + ) + assert applied_first is True + assert applied_second is False # already-processed → ack without re-applying