feat(billing): add BillingService.start_trial; wire into /auth/register
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -195,26 +195,30 @@ async def register(
|
|||||||
# Now set account owner and create subscription
|
# Now set account owner and create subscription
|
||||||
new_account.owner_id = new_user.id
|
new_account.owner_id = new_user.id
|
||||||
|
|
||||||
# Apply plan/trial from invite code if present
|
|
||||||
sub_plan = "free"
|
|
||||||
sub_status = "active"
|
|
||||||
period_start = None
|
|
||||||
period_end = None
|
|
||||||
if invite_code_record and invite_code_record.assigned_plan:
|
if invite_code_record and invite_code_record.assigned_plan:
|
||||||
|
# Plan/trial driven by platform invite code (existing pilot flow)
|
||||||
sub_plan = invite_code_record.assigned_plan
|
sub_plan = invite_code_record.assigned_plan
|
||||||
|
sub_status = "active"
|
||||||
|
period_start = None
|
||||||
|
period_end = None
|
||||||
if invite_code_record.trial_duration_days:
|
if invite_code_record.trial_duration_days:
|
||||||
sub_status = "trialing"
|
sub_status = "trialing"
|
||||||
period_start = datetime.now(timezone.utc)
|
period_start = datetime.now(timezone.utc)
|
||||||
period_end = period_start + timedelta(days=invite_code_record.trial_duration_days)
|
period_end = period_start + timedelta(days=invite_code_record.trial_duration_days)
|
||||||
|
db.add(Subscription(
|
||||||
new_subscription = Subscription(
|
account_id=new_account.id,
|
||||||
account_id=new_account.id,
|
plan=sub_plan,
|
||||||
plan=sub_plan,
|
status=sub_status,
|
||||||
status=sub_status,
|
current_period_start=period_start,
|
||||||
current_period_start=period_start,
|
current_period_end=period_end,
|
||||||
current_period_end=period_end,
|
))
|
||||||
)
|
else:
|
||||||
db.add(new_subscription)
|
# New self-serve shop — start the standard Pro trial.
|
||||||
|
# start_trial commits internally; flush our pending User/Account changes
|
||||||
|
# first so the FK is satisfied.
|
||||||
|
await db.flush()
|
||||||
|
from app.services.billing import BillingService
|
||||||
|
await BillingService.start_trial(db, new_account.id)
|
||||||
|
|
||||||
# Mark platform invite code as used
|
# Mark platform invite code as used
|
||||||
if invite_code_record:
|
if invite_code_record:
|
||||||
|
|||||||
36
backend/app/services/billing.py
Normal file
36
backend/app/services/billing.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Single billing service module. Stripe is the only impl — no provider
|
||||||
|
abstraction. Account row is canonical local state; Stripe is canonical
|
||||||
|
remote state; the webhook handler bridges the two."""
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
|
||||||
|
|
||||||
|
TRIAL_DAYS = 14
|
||||||
|
|
||||||
|
|
||||||
|
class BillingService:
|
||||||
|
@staticmethod
|
||||||
|
async def start_trial(db: AsyncSession, account_id) -> Subscription:
|
||||||
|
"""Idempotent. Creates a trialing Subscription on Pro for the account if
|
||||||
|
one doesn't exist; otherwise returns the existing row."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == account_id)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing is not None:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
sub = Subscription(
|
||||||
|
account_id=account_id,
|
||||||
|
plan="pro",
|
||||||
|
status="trialing",
|
||||||
|
current_period_start=datetime.now(timezone.utc),
|
||||||
|
current_period_end=datetime.now(timezone.utc) + timedelta(days=TRIAL_DAYS),
|
||||||
|
)
|
||||||
|
db.add(sub)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(sub)
|
||||||
|
return sub
|
||||||
59
backend/tests/test_billing_service.py
Normal file
59
backend/tests/test_billing_service.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from sqlalchemy import select, delete
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
from app.services.billing import BillingService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_trial_creates_trialing_pro_subscription(test_db):
|
||||||
|
"""Direct service test — bypasses register, creates account inline."""
|
||||||
|
from app.models.account import Account
|
||||||
|
account = Account(name="DirectTest", display_code="DIRECT01")
|
||||||
|
test_db.add(account)
|
||||||
|
await test_db.flush()
|
||||||
|
|
||||||
|
sub = await BillingService.start_trial(test_db, account.id)
|
||||||
|
assert sub.plan == "pro"
|
||||||
|
assert sub.status == "trialing"
|
||||||
|
assert sub.current_period_end is not None
|
||||||
|
assert sub.current_period_end > datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_trial_is_idempotent(test_db):
|
||||||
|
from app.models.account import Account
|
||||||
|
account = Account(name="Idempo", display_code="IDEMPO01")
|
||||||
|
test_db.add(account)
|
||||||
|
await test_db.flush()
|
||||||
|
|
||||||
|
sub1 = await BillingService.start_trial(test_db, account.id)
|
||||||
|
sub2 = await BillingService.start_trial(test_db, account.id)
|
||||||
|
assert sub1.id == sub2.id
|
||||||
|
|
||||||
|
rows = (await test_db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == account.id)
|
||||||
|
)).scalars().all()
|
||||||
|
assert len(rows) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_creates_trial_subscription(client, test_db):
|
||||||
|
"""Registering a brand-new shop (no invite code) yields a Pro/trialing sub."""
|
||||||
|
response = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "newshop@example.com",
|
||||||
|
"password": "Verystrong1Pwd",
|
||||||
|
"name": "New Shop",
|
||||||
|
})
|
||||||
|
assert response.status_code in (200, 201), response.json()
|
||||||
|
|
||||||
|
body = response.json()
|
||||||
|
account_id = uuid.UUID(body["account_id"])
|
||||||
|
|
||||||
|
sub = (await test_db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == account_id)
|
||||||
|
)).scalar_one()
|
||||||
|
assert sub.plan == "pro"
|
||||||
|
assert sub.status == "trialing"
|
||||||
|
assert sub.current_period_end is not None
|
||||||
Reference in New Issue
Block a user