From 9851d566337ae76668f6f3cb3b977998146b8a6c Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Wed, 6 May 2026 14:48:30 -0400 Subject: [PATCH] feat(billing): add BillingService.start_trial; wire into /auth/register Co-Authored-By: Claude Opus 4.7 --- backend/app/api/endpoints/auth.py | 32 ++++++++------- backend/app/services/billing.py | 36 ++++++++++++++++ backend/tests/test_billing_service.py | 59 +++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 14 deletions(-) create mode 100644 backend/app/services/billing.py create mode 100644 backend/tests/test_billing_service.py diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index 2634a6ef..9484480a 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -195,26 +195,30 @@ async def register( # Now set account owner and create subscription new_account.owner_id = new_user.id - # Apply plan/trial from invite code if present - sub_plan = "free" - sub_status = "active" - period_start = None - period_end = None if invite_code_record and invite_code_record.assigned_plan: + # Plan/trial driven by platform invite code (existing pilot flow) sub_plan = invite_code_record.assigned_plan + sub_status = "active" + period_start = None + period_end = None if invite_code_record.trial_duration_days: sub_status = "trialing" period_start = datetime.now(timezone.utc) period_end = period_start + timedelta(days=invite_code_record.trial_duration_days) - - new_subscription = Subscription( - account_id=new_account.id, - plan=sub_plan, - status=sub_status, - current_period_start=period_start, - current_period_end=period_end, - ) - db.add(new_subscription) + db.add(Subscription( + account_id=new_account.id, + plan=sub_plan, + status=sub_status, + current_period_start=period_start, + current_period_end=period_end, + )) + else: + # New self-serve shop — start the standard Pro trial. + # start_trial commits internally; flush our pending User/Account changes + # first so the FK is satisfied. + await db.flush() + from app.services.billing import BillingService + await BillingService.start_trial(db, new_account.id) # Mark platform invite code as used if invite_code_record: diff --git a/backend/app/services/billing.py b/backend/app/services/billing.py new file mode 100644 index 00000000..e1b08782 --- /dev/null +++ b/backend/app/services/billing.py @@ -0,0 +1,36 @@ +"""Single billing service module. Stripe is the only impl — no provider +abstraction. Account row is canonical local state; Stripe is canonical +remote state; the webhook handler bridges the two.""" +from datetime import datetime, timezone, timedelta +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.subscription import Subscription + + +TRIAL_DAYS = 14 + + +class BillingService: + @staticmethod + async def start_trial(db: AsyncSession, account_id) -> Subscription: + """Idempotent. Creates a trialing Subscription on Pro for the account if + one doesn't exist; otherwise returns the existing row.""" + result = await db.execute( + select(Subscription).where(Subscription.account_id == account_id) + ) + existing = result.scalar_one_or_none() + if existing is not None: + return existing + + sub = Subscription( + account_id=account_id, + plan="pro", + status="trialing", + current_period_start=datetime.now(timezone.utc), + current_period_end=datetime.now(timezone.utc) + timedelta(days=TRIAL_DAYS), + ) + db.add(sub) + await db.commit() + await db.refresh(sub) + return sub diff --git a/backend/tests/test_billing_service.py b/backend/tests/test_billing_service.py new file mode 100644 index 00000000..cb9726d8 --- /dev/null +++ b/backend/tests/test_billing_service.py @@ -0,0 +1,59 @@ +import uuid +import pytest +from datetime import datetime, timezone +from sqlalchemy import select, delete +from app.models.subscription import Subscription +from app.services.billing import BillingService + + +@pytest.mark.asyncio +async def test_start_trial_creates_trialing_pro_subscription(test_db): + """Direct service test — bypasses register, creates account inline.""" + from app.models.account import Account + account = Account(name="DirectTest", display_code="DIRECT01") + test_db.add(account) + await test_db.flush() + + sub = await BillingService.start_trial(test_db, account.id) + assert sub.plan == "pro" + assert sub.status == "trialing" + assert sub.current_period_end is not None + assert sub.current_period_end > datetime.now(timezone.utc) + + +@pytest.mark.asyncio +async def test_start_trial_is_idempotent(test_db): + from app.models.account import Account + account = Account(name="Idempo", display_code="IDEMPO01") + test_db.add(account) + await test_db.flush() + + sub1 = await BillingService.start_trial(test_db, account.id) + sub2 = await BillingService.start_trial(test_db, account.id) + assert sub1.id == sub2.id + + rows = (await test_db.execute( + select(Subscription).where(Subscription.account_id == account.id) + )).scalars().all() + assert len(rows) == 1 + + +@pytest.mark.asyncio +async def test_register_creates_trial_subscription(client, test_db): + """Registering a brand-new shop (no invite code) yields a Pro/trialing sub.""" + response = await client.post("/api/v1/auth/register", json={ + "email": "newshop@example.com", + "password": "Verystrong1Pwd", + "name": "New Shop", + }) + assert response.status_code in (200, 201), response.json() + + body = response.json() + account_id = uuid.UUID(body["account_id"]) + + sub = (await test_db.execute( + select(Subscription).where(Subscription.account_id == account_id) + )).scalar_one() + assert sub.plan == "pro" + assert sub.status == "trialing" + assert sub.current_period_end is not None