feat(billing): add /billing/checkout-session via BillingService

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-06 14:51:06 -04:00
parent 9851d56633
commit f683bb5720
6 changed files with 176 additions and 1 deletions

View File

@@ -0,0 +1,36 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_active_user
from app.core.admin_database import get_admin_db
from app.core.config import settings
from app.models.account import Account
from app.models.user import User
from app.schemas.billing import CheckoutSessionCreate, CheckoutSessionResponse
from app.services.billing import BillingService
router = APIRouter(prefix="/billing", tags=["billing"])
@router.post("/checkout-session", response_model=CheckoutSessionResponse)
async def create_checkout_session(
payload: CheckoutSessionCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
) -> CheckoutSessionResponse:
account = (await db.execute(
select(Account).where(Account.id == current_user.account_id)
)).scalar_one()
url = await BillingService.create_checkout_session(
db=db,
account=account,
plan=payload.plan,
seats=payload.seats,
billing_interval=payload.billing_interval,
success_url=f"{settings.FRONTEND_URL}/account/billing?success=1",
cancel_url=f"{settings.FRONTEND_URL}/account/billing/select-plan",
)
return CheckoutSessionResponse(url=url)

View File

@@ -23,6 +23,7 @@ from app.api.endpoints import (
analytics,
assistant_chat,
auth,
billing,
beta_feedback,
beta_signup,
branding,
@@ -81,6 +82,7 @@ api_router = APIRouter()
# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS.
# ---------------------------------------------------------------------------
api_router.include_router(auth.router)
api_router.include_router(billing.router) # Reachable when subscription locked
api_router.include_router(shared.router) # Public share links (no auth)
api_router.include_router(shares.public_router) # Public session share links (optional auth)
api_router.include_router(beta_signup.router)

View File

@@ -94,11 +94,12 @@ class Settings(BaseSettings):
STRIPE_SECRET_KEY: Optional[str] = None
STRIPE_PUBLISHABLE_KEY: Optional[str] = None
STRIPE_WEBHOOK_SECRET: Optional[str] = None
SELF_SERVE_ENABLED: bool = False
@property
def stripe_enabled(self) -> bool:
"""Check if Stripe is configured."""
return self.STRIPE_SECRET_KEY is not None and self.STRIPE_WEBHOOK_SECRET is not None
return bool(self.STRIPE_SECRET_KEY)
# AI Flow Builder
ANTHROPIC_API_KEY: Optional[str] = None

View File

@@ -0,0 +1,12 @@
from typing import Literal
from pydantic import BaseModel
class CheckoutSessionCreate(BaseModel):
plan: Literal["pro", "starter", "team", "enterprise"]
seats: int
billing_interval: Literal["monthly", "annual"] = "monthly"
class CheckoutSessionResponse(BaseModel):
url: str

View File

@@ -2,9 +2,14 @@
abstraction. Account row is canonical local state; Stripe is canonical
remote state; the webhook handler bridges the two."""
from datetime import datetime, timezone, timedelta
import stripe
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.models.account import Account
from app.models.plan_billing import PlanBilling
from app.models.subscription import Subscription
@@ -34,3 +39,66 @@ class BillingService:
await db.commit()
await db.refresh(sub)
return sub
@staticmethod
async def create_checkout_session(
db: AsyncSession,
account: Account,
plan: str,
seats: int,
billing_interval: str,
success_url: str,
cancel_url: str,
) -> str:
"""Create a Stripe Checkout Session for subscription purchase. If the
account currently has a trialing subscription with time remaining, that
trial end is preserved on the new Stripe subscription so the user
isn't charged early."""
if not settings.stripe_enabled:
raise RuntimeError("Stripe not configured")
stripe.api_key = settings.STRIPE_SECRET_KEY
plan_billing = (await db.execute(
select(PlanBilling).where(PlanBilling.plan == plan)
)).scalar_one_or_none()
if plan_billing is None:
raise ValueError(f"Unknown plan: {plan}")
price_id = (
plan_billing.stripe_monthly_price_id if billing_interval == "monthly"
else plan_billing.stripe_annual_price_id
)
if price_id is None:
raise RuntimeError(
f"Plan '{plan}' has no Stripe price for {billing_interval}"
)
if account.stripe_customer_id is None:
customer = stripe.Customer.create(
email=None,
metadata={"account_id": str(account.id)},
)
account.stripe_customer_id = customer.id
await db.commit()
sub = (await db.execute(
select(Subscription).where(Subscription.account_id == account.id)
)).scalar_one_or_none()
subscription_data = {}
if (
sub
and sub.status == "trialing"
and sub.current_period_end
and sub.current_period_end > datetime.now(timezone.utc)
):
subscription_data["trial_end"] = int(sub.current_period_end.timestamp())
session = stripe.checkout.Session.create(
customer=account.stripe_customer_id,
line_items=[{"price": price_id, "quantity": seats}],
mode="subscription",
subscription_data=subscription_data or None,
success_url=success_url,
cancel_url=cancel_url,
allow_promotion_codes=False,
)
return session.url

View File

@@ -0,0 +1,56 @@
import pytest
from unittest.mock import patch, MagicMock
from app.models.plan_billing import PlanBilling
@pytest.mark.asyncio
async def test_checkout_session_creates_stripe_session(
client, test_db, test_user, auth_headers, monkeypatch
):
"""End-to-end: post body → Stripe SDK called → URL returned. Stripe SDK
mocked; Customer + Session calls patched."""
from app.core.config import settings
monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy")
test_db.add(PlanBilling(
plan="pro",
display_name="Pro",
stripe_product_id="prod_test",
stripe_monthly_price_id="price_test_monthly",
))
await test_db.commit()
fake_customer = MagicMock()
fake_customer.id = "cus_test_123"
fake_session = MagicMock()
fake_session.url = "https://checkout.stripe.com/test"
with patch("stripe.Customer.create", return_value=fake_customer) as cust_mock, \
patch("stripe.checkout.Session.create", return_value=fake_session) as sess_mock:
response = await client.post(
"/api/v1/billing/checkout-session",
json={"plan": "pro", "seats": 3, "billing_interval": "monthly"},
headers=auth_headers,
)
assert response.status_code == 200, response.json()
assert response.json()["url"] == "https://checkout.stripe.com/test"
cust_mock.assert_called_once()
sess_mock.assert_called_once()
@pytest.mark.asyncio
async def test_checkout_session_unknown_plan_returns_500(
client, test_db, test_user, auth_headers, monkeypatch
):
"""No PlanBilling row → ValueError surfaces as 500 (the endpoint doesn't
catch business errors)."""
from app.core.config import settings
monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy")
response = await client.post(
"/api/v1/billing/checkout-session",
json={"plan": "pro", "seats": 1, "billing_interval": "monthly"},
headers=auth_headers,
)
assert response.status_code == 500