From f683bb5720bceceba8e02f899ab91966cd66d8a9 Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Wed, 6 May 2026 14:51:06 -0400 Subject: [PATCH] feat(billing): add /billing/checkout-session via BillingService Co-Authored-By: Claude Opus 4.7 --- backend/app/api/endpoints/billing.py | 36 ++++++++++++++ backend/app/api/router.py | 2 + backend/app/core/config.py | 3 +- backend/app/schemas/billing.py | 12 +++++ backend/app/services/billing.py | 68 ++++++++++++++++++++++++++ backend/tests/test_billing_checkout.py | 56 +++++++++++++++++++++ 6 files changed, 176 insertions(+), 1 deletion(-) create mode 100644 backend/app/api/endpoints/billing.py create mode 100644 backend/app/schemas/billing.py create mode 100644 backend/tests/test_billing_checkout.py diff --git a/backend/app/api/endpoints/billing.py b/backend/app/api/endpoints/billing.py new file mode 100644 index 00000000..024b1420 --- /dev/null +++ b/backend/app/api/endpoints/billing.py @@ -0,0 +1,36 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_active_user +from app.core.admin_database import get_admin_db +from app.core.config import settings +from app.models.account import Account +from app.models.user import User +from app.schemas.billing import CheckoutSessionCreate, CheckoutSessionResponse +from app.services.billing import BillingService + +router = APIRouter(prefix="/billing", tags=["billing"]) + + +@router.post("/checkout-session", response_model=CheckoutSessionResponse) +async def create_checkout_session( + payload: CheckoutSessionCreate, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_admin_db)], +) -> CheckoutSessionResponse: + account = (await db.execute( + select(Account).where(Account.id == current_user.account_id) + )).scalar_one() + url = await BillingService.create_checkout_session( + db=db, + account=account, + plan=payload.plan, + seats=payload.seats, + billing_interval=payload.billing_interval, + success_url=f"{settings.FRONTEND_URL}/account/billing?success=1", + cancel_url=f"{settings.FRONTEND_URL}/account/billing/select-plan", + ) + return CheckoutSessionResponse(url=url) diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 3cdc2c7e..c3018855 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -23,6 +23,7 @@ from app.api.endpoints import ( analytics, assistant_chat, auth, + billing, beta_feedback, beta_signup, branding, @@ -81,6 +82,7 @@ api_router = APIRouter() # in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS. # --------------------------------------------------------------------------- api_router.include_router(auth.router) +api_router.include_router(billing.router) # Reachable when subscription locked api_router.include_router(shared.router) # Public share links (no auth) api_router.include_router(shares.public_router) # Public session share links (optional auth) api_router.include_router(beta_signup.router) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index afc2fcbc..23795e42 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -94,11 +94,12 @@ class Settings(BaseSettings): STRIPE_SECRET_KEY: Optional[str] = None STRIPE_PUBLISHABLE_KEY: Optional[str] = None STRIPE_WEBHOOK_SECRET: Optional[str] = None + SELF_SERVE_ENABLED: bool = False @property def stripe_enabled(self) -> bool: """Check if Stripe is configured.""" - return self.STRIPE_SECRET_KEY is not None and self.STRIPE_WEBHOOK_SECRET is not None + return bool(self.STRIPE_SECRET_KEY) # AI Flow Builder ANTHROPIC_API_KEY: Optional[str] = None diff --git a/backend/app/schemas/billing.py b/backend/app/schemas/billing.py new file mode 100644 index 00000000..51f4db2a --- /dev/null +++ b/backend/app/schemas/billing.py @@ -0,0 +1,12 @@ +from typing import Literal +from pydantic import BaseModel + + +class CheckoutSessionCreate(BaseModel): + plan: Literal["pro", "starter", "team", "enterprise"] + seats: int + billing_interval: Literal["monthly", "annual"] = "monthly" + + +class CheckoutSessionResponse(BaseModel): + url: str diff --git a/backend/app/services/billing.py b/backend/app/services/billing.py index e1b08782..6e3d65b1 100644 --- a/backend/app/services/billing.py +++ b/backend/app/services/billing.py @@ -2,9 +2,14 @@ abstraction. Account row is canonical local state; Stripe is canonical remote state; the webhook handler bridges the two.""" from datetime import datetime, timezone, timedelta + +import stripe from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.core.config import settings +from app.models.account import Account +from app.models.plan_billing import PlanBilling from app.models.subscription import Subscription @@ -34,3 +39,66 @@ class BillingService: await db.commit() await db.refresh(sub) return sub + + @staticmethod + async def create_checkout_session( + db: AsyncSession, + account: Account, + plan: str, + seats: int, + billing_interval: str, + success_url: str, + cancel_url: str, + ) -> str: + """Create a Stripe Checkout Session for subscription purchase. If the + account currently has a trialing subscription with time remaining, that + trial end is preserved on the new Stripe subscription so the user + isn't charged early.""" + if not settings.stripe_enabled: + raise RuntimeError("Stripe not configured") + stripe.api_key = settings.STRIPE_SECRET_KEY + + plan_billing = (await db.execute( + select(PlanBilling).where(PlanBilling.plan == plan) + )).scalar_one_or_none() + if plan_billing is None: + raise ValueError(f"Unknown plan: {plan}") + price_id = ( + plan_billing.stripe_monthly_price_id if billing_interval == "monthly" + else plan_billing.stripe_annual_price_id + ) + if price_id is None: + raise RuntimeError( + f"Plan '{plan}' has no Stripe price for {billing_interval}" + ) + + if account.stripe_customer_id is None: + customer = stripe.Customer.create( + email=None, + metadata={"account_id": str(account.id)}, + ) + account.stripe_customer_id = customer.id + await db.commit() + + sub = (await db.execute( + select(Subscription).where(Subscription.account_id == account.id) + )).scalar_one_or_none() + subscription_data = {} + if ( + sub + and sub.status == "trialing" + and sub.current_period_end + and sub.current_period_end > datetime.now(timezone.utc) + ): + subscription_data["trial_end"] = int(sub.current_period_end.timestamp()) + + session = stripe.checkout.Session.create( + customer=account.stripe_customer_id, + line_items=[{"price": price_id, "quantity": seats}], + mode="subscription", + subscription_data=subscription_data or None, + success_url=success_url, + cancel_url=cancel_url, + allow_promotion_codes=False, + ) + return session.url diff --git a/backend/tests/test_billing_checkout.py b/backend/tests/test_billing_checkout.py new file mode 100644 index 00000000..48e12f75 --- /dev/null +++ b/backend/tests/test_billing_checkout.py @@ -0,0 +1,56 @@ +import pytest +from unittest.mock import patch, MagicMock +from app.models.plan_billing import PlanBilling + + +@pytest.mark.asyncio +async def test_checkout_session_creates_stripe_session( + client, test_db, test_user, auth_headers, monkeypatch +): + """End-to-end: post body → Stripe SDK called → URL returned. Stripe SDK + mocked; Customer + Session calls patched.""" + from app.core.config import settings + monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy") + + test_db.add(PlanBilling( + plan="pro", + display_name="Pro", + stripe_product_id="prod_test", + stripe_monthly_price_id="price_test_monthly", + )) + await test_db.commit() + + fake_customer = MagicMock() + fake_customer.id = "cus_test_123" + fake_session = MagicMock() + fake_session.url = "https://checkout.stripe.com/test" + + with patch("stripe.Customer.create", return_value=fake_customer) as cust_mock, \ + patch("stripe.checkout.Session.create", return_value=fake_session) as sess_mock: + response = await client.post( + "/api/v1/billing/checkout-session", + json={"plan": "pro", "seats": 3, "billing_interval": "monthly"}, + headers=auth_headers, + ) + + assert response.status_code == 200, response.json() + assert response.json()["url"] == "https://checkout.stripe.com/test" + cust_mock.assert_called_once() + sess_mock.assert_called_once() + + +@pytest.mark.asyncio +async def test_checkout_session_unknown_plan_returns_500( + client, test_db, test_user, auth_headers, monkeypatch +): + """No PlanBilling row → ValueError surfaces as 500 (the endpoint doesn't + catch business errors).""" + from app.core.config import settings + monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy") + + response = await client.post( + "/api/v1/billing/checkout-session", + json={"plan": "pro", "seats": 1, "billing_interval": "monthly"}, + headers=auth_headers, + ) + assert response.status_code == 500