feat(billing): add /billing/checkout-session via BillingService
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
36
backend/app/api/endpoints/billing.py
Normal file
36
backend/app/api/endpoints/billing.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_active_user
|
||||||
|
from app.core.admin_database import get_admin_db
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.models.account import Account
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.billing import CheckoutSessionCreate, CheckoutSessionResponse
|
||||||
|
from app.services.billing import BillingService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/checkout-session", response_model=CheckoutSessionResponse)
|
||||||
|
async def create_checkout_session(
|
||||||
|
payload: CheckoutSessionCreate,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
|
) -> CheckoutSessionResponse:
|
||||||
|
account = (await db.execute(
|
||||||
|
select(Account).where(Account.id == current_user.account_id)
|
||||||
|
)).scalar_one()
|
||||||
|
url = await BillingService.create_checkout_session(
|
||||||
|
db=db,
|
||||||
|
account=account,
|
||||||
|
plan=payload.plan,
|
||||||
|
seats=payload.seats,
|
||||||
|
billing_interval=payload.billing_interval,
|
||||||
|
success_url=f"{settings.FRONTEND_URL}/account/billing?success=1",
|
||||||
|
cancel_url=f"{settings.FRONTEND_URL}/account/billing/select-plan",
|
||||||
|
)
|
||||||
|
return CheckoutSessionResponse(url=url)
|
||||||
@@ -23,6 +23,7 @@ from app.api.endpoints import (
|
|||||||
analytics,
|
analytics,
|
||||||
assistant_chat,
|
assistant_chat,
|
||||||
auth,
|
auth,
|
||||||
|
billing,
|
||||||
beta_feedback,
|
beta_feedback,
|
||||||
beta_signup,
|
beta_signup,
|
||||||
branding,
|
branding,
|
||||||
@@ -81,6 +82,7 @@ api_router = APIRouter()
|
|||||||
# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS.
|
# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
api_router.include_router(auth.router)
|
api_router.include_router(auth.router)
|
||||||
|
api_router.include_router(billing.router) # Reachable when subscription locked
|
||||||
api_router.include_router(shared.router) # Public share links (no auth)
|
api_router.include_router(shared.router) # Public share links (no auth)
|
||||||
api_router.include_router(shares.public_router) # Public session share links (optional auth)
|
api_router.include_router(shares.public_router) # Public session share links (optional auth)
|
||||||
api_router.include_router(beta_signup.router)
|
api_router.include_router(beta_signup.router)
|
||||||
|
|||||||
@@ -94,11 +94,12 @@ class Settings(BaseSettings):
|
|||||||
STRIPE_SECRET_KEY: Optional[str] = None
|
STRIPE_SECRET_KEY: Optional[str] = None
|
||||||
STRIPE_PUBLISHABLE_KEY: Optional[str] = None
|
STRIPE_PUBLISHABLE_KEY: Optional[str] = None
|
||||||
STRIPE_WEBHOOK_SECRET: Optional[str] = None
|
STRIPE_WEBHOOK_SECRET: Optional[str] = None
|
||||||
|
SELF_SERVE_ENABLED: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stripe_enabled(self) -> bool:
|
def stripe_enabled(self) -> bool:
|
||||||
"""Check if Stripe is configured."""
|
"""Check if Stripe is configured."""
|
||||||
return self.STRIPE_SECRET_KEY is not None and self.STRIPE_WEBHOOK_SECRET is not None
|
return bool(self.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
# AI Flow Builder
|
# AI Flow Builder
|
||||||
ANTHROPIC_API_KEY: Optional[str] = None
|
ANTHROPIC_API_KEY: Optional[str] = None
|
||||||
|
|||||||
12
backend/app/schemas/billing.py
Normal file
12
backend/app/schemas/billing.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from typing import Literal
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class CheckoutSessionCreate(BaseModel):
|
||||||
|
plan: Literal["pro", "starter", "team", "enterprise"]
|
||||||
|
seats: int
|
||||||
|
billing_interval: Literal["monthly", "annual"] = "monthly"
|
||||||
|
|
||||||
|
|
||||||
|
class CheckoutSessionResponse(BaseModel):
|
||||||
|
url: str
|
||||||
@@ -2,9 +2,14 @@
|
|||||||
abstraction. Account row is canonical local state; Stripe is canonical
|
abstraction. Account row is canonical local state; Stripe is canonical
|
||||||
remote state; the webhook handler bridges the two."""
|
remote state; the webhook handler bridges the two."""
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
|
|
||||||
|
import stripe
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.models.account import Account
|
||||||
|
from app.models.plan_billing import PlanBilling
|
||||||
from app.models.subscription import Subscription
|
from app.models.subscription import Subscription
|
||||||
|
|
||||||
|
|
||||||
@@ -34,3 +39,66 @@ class BillingService:
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(sub)
|
await db.refresh(sub)
|
||||||
return sub
|
return sub
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_checkout_session(
|
||||||
|
db: AsyncSession,
|
||||||
|
account: Account,
|
||||||
|
plan: str,
|
||||||
|
seats: int,
|
||||||
|
billing_interval: str,
|
||||||
|
success_url: str,
|
||||||
|
cancel_url: str,
|
||||||
|
) -> str:
|
||||||
|
"""Create a Stripe Checkout Session for subscription purchase. If the
|
||||||
|
account currently has a trialing subscription with time remaining, that
|
||||||
|
trial end is preserved on the new Stripe subscription so the user
|
||||||
|
isn't charged early."""
|
||||||
|
if not settings.stripe_enabled:
|
||||||
|
raise RuntimeError("Stripe not configured")
|
||||||
|
stripe.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
|
||||||
|
plan_billing = (await db.execute(
|
||||||
|
select(PlanBilling).where(PlanBilling.plan == plan)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if plan_billing is None:
|
||||||
|
raise ValueError(f"Unknown plan: {plan}")
|
||||||
|
price_id = (
|
||||||
|
plan_billing.stripe_monthly_price_id if billing_interval == "monthly"
|
||||||
|
else plan_billing.stripe_annual_price_id
|
||||||
|
)
|
||||||
|
if price_id is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Plan '{plan}' has no Stripe price for {billing_interval}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if account.stripe_customer_id is None:
|
||||||
|
customer = stripe.Customer.create(
|
||||||
|
email=None,
|
||||||
|
metadata={"account_id": str(account.id)},
|
||||||
|
)
|
||||||
|
account.stripe_customer_id = customer.id
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
sub = (await db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == account.id)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
subscription_data = {}
|
||||||
|
if (
|
||||||
|
sub
|
||||||
|
and sub.status == "trialing"
|
||||||
|
and sub.current_period_end
|
||||||
|
and sub.current_period_end > datetime.now(timezone.utc)
|
||||||
|
):
|
||||||
|
subscription_data["trial_end"] = int(sub.current_period_end.timestamp())
|
||||||
|
|
||||||
|
session = stripe.checkout.Session.create(
|
||||||
|
customer=account.stripe_customer_id,
|
||||||
|
line_items=[{"price": price_id, "quantity": seats}],
|
||||||
|
mode="subscription",
|
||||||
|
subscription_data=subscription_data or None,
|
||||||
|
success_url=success_url,
|
||||||
|
cancel_url=cancel_url,
|
||||||
|
allow_promotion_codes=False,
|
||||||
|
)
|
||||||
|
return session.url
|
||||||
|
|||||||
56
backend/tests/test_billing_checkout.py
Normal file
56
backend/tests/test_billing_checkout.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from app.models.plan_billing import PlanBilling
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_session_creates_stripe_session(
|
||||||
|
client, test_db, test_user, auth_headers, monkeypatch
|
||||||
|
):
|
||||||
|
"""End-to-end: post body → Stripe SDK called → URL returned. Stripe SDK
|
||||||
|
mocked; Customer + Session calls patched."""
|
||||||
|
from app.core.config import settings
|
||||||
|
monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy")
|
||||||
|
|
||||||
|
test_db.add(PlanBilling(
|
||||||
|
plan="pro",
|
||||||
|
display_name="Pro",
|
||||||
|
stripe_product_id="prod_test",
|
||||||
|
stripe_monthly_price_id="price_test_monthly",
|
||||||
|
))
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
fake_customer = MagicMock()
|
||||||
|
fake_customer.id = "cus_test_123"
|
||||||
|
fake_session = MagicMock()
|
||||||
|
fake_session.url = "https://checkout.stripe.com/test"
|
||||||
|
|
||||||
|
with patch("stripe.Customer.create", return_value=fake_customer) as cust_mock, \
|
||||||
|
patch("stripe.checkout.Session.create", return_value=fake_session) as sess_mock:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/checkout-session",
|
||||||
|
json={"plan": "pro", "seats": 3, "billing_interval": "monthly"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, response.json()
|
||||||
|
assert response.json()["url"] == "https://checkout.stripe.com/test"
|
||||||
|
cust_mock.assert_called_once()
|
||||||
|
sess_mock.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_session_unknown_plan_returns_500(
|
||||||
|
client, test_db, test_user, auth_headers, monkeypatch
|
||||||
|
):
|
||||||
|
"""No PlanBilling row → ValueError surfaces as 500 (the endpoint doesn't
|
||||||
|
catch business errors)."""
|
||||||
|
from app.core.config import settings
|
||||||
|
monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy")
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/checkout-session",
|
||||||
|
json={"plan": "pro", "seats": 1, "billing_interval": "monthly"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 500
|
||||||
Reference in New Issue
Block a user