feat: self-serve signup Phase 2 (frontend cutover) (#162)
Some checks failed
CI / e2e (push) Has been cancelled
CI / frontend (push) Has been cancelled
CI / backend (push) Has been cancelled
Mirror to GitHub / mirror (push) Has been cancelled

Co-authored-by: Michael Chihlas <michael@resolutionflow.com>
Co-committed-by: Michael Chihlas <michael@resolutionflow.com>
This commit was merged in pull request #162.
This commit is contained in:
2026-05-07 18:42:20 +00:00
committed by chihlasm
parent f918b766b0
commit f1be3abcc5
123 changed files with 11563 additions and 559 deletions

View File

@@ -21,4 +21,12 @@ ANTHROPIC_API_KEY=
VOYAGE_API_KEY=
# ConnectWise PSA Integration
CW_CLIENT_ID=<CONNECTWISE CLIENT ID>
CW_CLIENT_ID=<CONNECTWISE CLIENT ID>
# Stripe
# Test keys from Stripe Dashboard → Developers → API keys (with Test mode toggled on).
# Webhook secret for local dev: from `stripe listen --forward-to localhost:8000/api/v1/webhooks/stripe`.
# When unset, app/core/config.py:stripe_enabled returns False and Stripe code paths short-circuit.
STRIPE_SECRET_KEY=sk_test_
STRIPE_PUBLISHABLE_KEY=pk_test_
STRIPE_WEBHOOK_SECRET=whsec_

View File

@@ -235,6 +235,7 @@ _SUBSCRIPTION_GUARD_ALLOWLIST = {
"/api/v1/billing/portal-session",
"/api/v1/users/me",
"/api/v1/users/me/onboarding-step",
"/api/v1/users/me/onboarding-dismiss-rest",
}
@@ -298,6 +299,8 @@ _EMAIL_VERIFICATION_ALLOWLIST = {
"/api/v1/auth/email/verify",
"/api/v1/auth/password/change",
"/api/v1/users/me",
"/api/v1/users/me/onboarding-step",
"/api/v1/users/me/onboarding-dismiss-rest",
"/api/v1/billing/state",
"/api/v1/billing/checkout-session",
"/api/v1/billing/portal-session",

View File

@@ -0,0 +1,54 @@
"""Public endpoint for resolving an account invite code into display info.
Mounted as a public route (no tenant context, no auth) — used by the
/accept-invite page on the frontend so an invitee can see what account they
are about to join before they sign up. Uses the BYPASSRLS admin session
factory because account_invites is account-scoped under Phase 4 RLS but the
caller has no tenant identity yet.
"""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.admin_database import get_admin_db
from app.models.account_invite import AccountInvite
from app.schemas.oauth import InviteLookupResponse
router = APIRouter(prefix="/accounts", tags=["account-invite-lookup"])
@router.get("/invites/{code}/lookup", response_model=InviteLookupResponse)
async def lookup_invite(
code: str,
db: Annotated[AsyncSession, Depends(get_admin_db)],
) -> InviteLookupResponse:
"""Return minimal display data for a valid (unused, unexpired, not revoked)
invite. Returns 404 with `invite_invalid_or_expired_or_revoked` for any
invalid state — the AcceptInvitePage shows a single "ask the inviter to
resend" message regardless of which condition failed (anti-enumeration)."""
result = await db.execute(
select(AccountInvite)
.where(AccountInvite.code == code)
.options(
joinedload(AccountInvite.account),
joinedload(AccountInvite.invited_by),
)
)
invite = result.scalar_one_or_none()
if invite is None or not invite.is_valid:
raise HTTPException(
status_code=404,
detail={"error": "invite_invalid_or_expired_or_revoked"},
)
return InviteLookupResponse(
account_name=invite.account.name,
inviter_name=invite.invited_by.name,
invited_email=invite.email,
role=invite.role,
)

View File

@@ -8,34 +8,101 @@ from app.core.database import get_db
from app.core.audit import log_audit
from app.models.user import User
from app.models.plan_limits import PlanLimits
from app.models.plan_billing import PlanBilling
from app.models.account import Account
from app.models.account_limit_override import AccountLimitOverride
from app.models.subscription import Subscription
from app.schemas.admin import (
PlanLimitResponse, PlanLimitUpdate,
PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse,
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
)
from app.api.deps import require_admin
from app.services.billing import BillingService
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
@router.get("/plan-limits", response_model=list[PlanLimitResponse])
# Fields on PlanLimitUpdate that map to plan_billing (not plan_limits).
_PLAN_BILLING_FIELDS = (
"display_name",
"description",
"monthly_price_cents",
"annual_price_cents",
"stripe_product_id",
"stripe_monthly_price_id",
"stripe_annual_price_id",
"is_public",
"is_archived",
"sort_order",
)
# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model.
# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit
# null for any of them would otherwise trigger a NOT NULL violation at commit.
_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({
"display_name",
"is_public",
"is_archived",
"sort_order",
})
def _merge_plan_with_billing(
plan: PlanLimits, billing: PlanBilling | None
) -> PlanLimitWithBillingResponse:
"""Build a merged response. Billing fields are None when no plan_billing row
exists for the plan."""
payload = {
"plan": plan.plan,
"max_trees": plan.max_trees,
"max_sessions_per_month": plan.max_sessions_per_month,
"max_users": plan.max_users,
"custom_branding": plan.custom_branding,
"priority_support": plan.priority_support,
"export_formats": plan.export_formats or [],
}
if billing is not None:
payload.update({
"display_name": billing.display_name,
"description": billing.description,
"monthly_price_cents": billing.monthly_price_cents,
"annual_price_cents": billing.annual_price_cents,
"stripe_product_id": billing.stripe_product_id,
"stripe_monthly_price_id": billing.stripe_monthly_price_id,
"stripe_annual_price_id": billing.stripe_annual_price_id,
"is_public": billing.is_public,
"is_archived": billing.is_archived,
"sort_order": billing.sort_order,
})
return PlanLimitWithBillingResponse(**payload)
@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse])
async def list_plan_limits(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""List all plan limit configurations."""
result = await db.execute(select(PlanLimits))
return result.scalars().all()
"""List all plan limit configurations, merged with plan_billing fields
where present. Plans without a plan_billing row return None for the
billing fields."""
rows = (await db.execute(
select(PlanLimits, PlanBilling)
.outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan)
)).all()
return [_merge_plan_with_billing(pl, pb) for pl, pb in rows]
@router.put("/plan-limits", response_model=PlanLimitResponse)
@router.put("/plan-limits", response_model=PlanLimitWithBillingResponse)
async def update_plan_limits(
data: PlanLimitUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Update a plan's limits."""
"""Update a plan's limits and (if any plan_billing field is included)
upsert the matching plan_billing row in the same transaction. After
commit, invalidates the in-process billing cache for accounts on this
plan (currently a no-op — see BillingService.invalidate_billing_cache).
"""
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
plan = result.scalar_one_or_none()
if not plan:
@@ -48,10 +115,50 @@ async def update_plan_limits(
plan.priority_support = data.priority_support
plan.export_formats = data.export_formats
await log_audit(db, current_user.id, "plan_limits.update", "plan_limits", details={"plan": data.plan})
# Did the request include any plan_billing field? (Pydantic gives us
# `model_fields_set` to distinguish "user passed null" from "field omitted".)
billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS)
billing: PlanBilling | None = None
if billing_fields_set:
billing = (await db.execute(
select(PlanBilling).where(PlanBilling.plan == data.plan)
)).scalar_one_or_none()
if billing is None:
# Create. display_name is required on the model — derive from the
# plan name when the caller didn't supply one (e.g. "pro" → "Pro").
display_name = data.display_name or data.plan.capitalize()
billing = PlanBilling(plan=data.plan, display_name=display_name)
db.add(billing)
# Apply only the fields the caller actually included. Allows partial
# updates without clobbering existing values.
for field in billing_fields_set:
value = getattr(data, field)
if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS:
# Don't NULL out a NOT NULL column on update.
continue
setattr(billing, field, value)
await log_audit(
db, current_user.id, "plan_limits.update", "plan_limits",
details={"plan": data.plan, "updated_billing": bool(billing_fields_set)},
)
await db.commit()
await db.refresh(plan)
return plan
if billing is not None:
await db.refresh(billing)
# Invalidate any in-process billing cache for accounts on this plan.
# TODO: invalidate app.state.billing_cache when added.
account_ids = [
row[0] for row in (await db.execute(
select(Subscription.account_id).where(Subscription.plan == data.plan)
)).all()
]
await BillingService.invalidate_billing_cache(account_ids)
return _merge_plan_with_billing(plan, billing)
@router.get("/account-overrides", response_model=list[AccountOverrideResponse])

View File

@@ -47,8 +47,16 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["authentication"])
async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
"""Decode a refresh token JWT and store its hash in the database."""
async def store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
"""Decode a refresh token JWT and store its hash in the database.
Module-public so OAuth callback endpoints (and any future token-issuing
surface) can register the JTI in the ``refresh_tokens`` table the same
way ``/auth/login`` does. Without this the first ``/auth/refresh`` call
will reject the token as "revoked" because no row exists.
Caller is responsible for committing the session.
"""
payload = decode_token(refresh_token_str)
if payload and payload.get("jti"):
token_record = RefreshToken(
@@ -136,7 +144,15 @@ async def register(
# Validate platform invite code (skip if account invite was provided)
invite_code_record = None
if not account_invite_record:
if settings.REQUIRE_INVITE_CODE and not user_data.invite_code:
# When SELF_SERVE_ENABLED is on, the platform invite gate is bypassed
# entirely — public self-serve signup is the whole point. The
# invite_code field stays in the schema for backward compatibility
# and so paid/trial-bearing codes still apply when supplied.
if (
settings.REQUIRE_INVITE_CODE
and not settings.SELF_SERVE_ENABLED
and not user_data.invite_code
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code is required"
@@ -312,7 +328,7 @@ async def login(
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await _store_refresh_token(db, refresh_token_str, user.id)
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return Token(
@@ -347,7 +363,7 @@ async def login_json(
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await _store_refresh_token(db, refresh_token_str, user.id)
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return Token(
@@ -405,7 +421,7 @@ async def refresh_token(
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store new refresh token
await _store_refresh_token(db, new_refresh_token_str, user.id)
await store_refresh_token(db, new_refresh_token_str, user.id)
await db.commit()
return Token(

View File

@@ -1,31 +1,44 @@
"""Public beta signup endpoint — no auth required."""
"""Legacy beta signup endpoint — redirects to /register?from=beta.
Phase 2 (self-serve signup) makes the public register flow the canonical
front door. The old `/api/v1/beta-signup` POST endpoint is kept mounted to
preserve any external links that still hit it, but now responds with a
307 Temporary Redirect to `/register?from=beta` so the user lands in the
real signup flow. The `?from=beta` marker lets the frontend tag the
signup origin for analytics.
Note: there is no `beta_signup` database table — the original endpoint
only fired a notification email. There is therefore no waitlist to email
and no migration to run when retiring the endpoint.
"""
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, EmailStr
from app.core.email import EmailService
from fastapi import APIRouter
from fastapi.responses import RedirectResponse
from app.core.config import settings
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/beta-signup", tags=["beta"])
class BetaSignupRequest(BaseModel):
email: EmailStr
# Local-dev fallback when FRONTEND_URL isn't configured. The redirect must
# be absolute — a relative URL would resolve against the API origin
# (api.resolutionflow.com), which has no /register page.
_DEFAULT_FRONTEND_URL = "http://localhost:5173"
class BetaSignupResponse(BaseModel):
success: bool
message: str
@router.post("", include_in_schema=False)
async def beta_signup_redirect() -> RedirectResponse:
"""Redirect legacy beta-signup POST to the public register page.
@router.post("", response_model=BetaSignupResponse)
async def beta_signup(data: BetaSignupRequest):
"""Collect beta interest — sends notification to beta@resolutionflow.com."""
sent = await EmailService.send_beta_signup_notification(data.email)
if not sent:
logger.warning("Beta signup recorded (email delivery skipped): %s", data.email)
return BetaSignupResponse(
success=True,
message="Thanks! We'll be in touch with beta access details.",
Returns 307 so any client following the redirect preserves the HTTP
method; the frontend treats `/register?from=beta` as the canonical
entry point and reads the `from` query param for analytics.
"""
frontend_url = settings.FRONTEND_URL or _DEFAULT_FRONTEND_URL
return RedirectResponse(
url=f"{frontend_url}/register?from=beta",
status_code=307,
)

View File

@@ -1,6 +1,6 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -10,6 +10,7 @@ from app.core.config import settings
from app.models.account import Account
from app.models.user import User
from app.schemas.billing import (
BillingPortalSessionResponse,
BillingStateResponse,
CheckoutSessionCreate,
CheckoutSessionResponse,
@@ -50,3 +51,26 @@ async def get_billing_state(
)).scalar_one()
state = await BillingService.get_billing_state(db, account)
return BillingStateResponse(**state)
@router.get("/portal-session", response_model=BillingPortalSessionResponse)
async def get_billing_portal_session(
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
) -> BillingPortalSessionResponse:
"""Return a Stripe-hosted Customer Portal URL for the account so the user
can update card / cancel. Allowlisted from the subscription + email-verify
guards (a canceled or unverified-past-grace user must still be able to
update billing)."""
if not settings.stripe_enabled:
raise HTTPException(status_code=503, detail={"error": "stripe_not_configured"})
account = (await db.execute(
select(Account).where(Account.id == current_user.account_id)
)).scalar_one()
try:
url = await BillingService.open_customer_portal(account)
except ValueError:
raise HTTPException(status_code=400, detail={"error": "no_stripe_customer"})
return BillingPortalSessionResponse(url=url)

View File

@@ -0,0 +1,40 @@
"""Public runtime configuration endpoint.
GET /api/v1/config/public
Returns the small set of runtime flags the frontend needs at app load
to decide whether to render the self-serve signup flow and which OAuth
buttons to show. No authentication required.
The response model lives in `app.schemas.config` so it can be reused by
frontend codegen and other call sites if needed.
"""
from __future__ import annotations
from fastapi import APIRouter
from app.core.config import settings
from app.schemas.config import PublicConfigResponse
router = APIRouter(prefix="/config", tags=["config"])
@router.get("/public", response_model=PublicConfigResponse)
async def get_public_config() -> PublicConfigResponse:
"""Return public-safe runtime config.
`oauth_providers` reflects which OAuth client IDs are configured server
side; the frontend uses it to render only buttons that will actually
succeed. `self_serve_enabled` is the master switch for the new public
self-serve signup flow.
"""
providers: list[str] = []
if settings.GOOGLE_CLIENT_ID:
providers.append("google")
if settings.MS_CLIENT_ID:
providers.append("microsoft")
return PublicConfigResponse(
self_serve_enabled=settings.SELF_SERVE_ENABLED,
oauth_providers=providers,
)

View File

@@ -7,10 +7,12 @@ from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.endpoints.auth import store_refresh_token
from app.core.admin_database import get_admin_db
from app.core.config import settings
from app.core.security import create_access_token, create_refresh_token
from app.models.account import Account
from app.models.account_invite import AccountInvite
from app.models.oauth_identity import OAuthIdentity
from app.models.user import User
from app.schemas.oauth import OAuthCallbackPayload, OAuthCallbackResponse
@@ -31,9 +33,21 @@ def _generate_display_code(length: int = 8) -> str:
async def _sign_in_or_register(
db: AsyncSession, provider: str, profile: OAuthProfile
db: AsyncSession,
provider: str,
profile: OAuthProfile,
*,
account_invite_code: str | None = None,
invited_email: str | None = None,
) -> tuple[User, bool]:
"""Returns (user, is_new_user). Idempotent on (provider, provider_subject)."""
"""Returns (user, is_new_user). Idempotent on (provider, provider_subject).
When ``account_invite_code`` is supplied (from the /accept-invite flow),
a brand-new user is created inside the invited account instead of getting
a personal account + Pro trial. Mismatch between the OAuth profile email
and ``invited_email`` raises ``invite_email_mismatch`` per the spec
contract that mirrors the email+password register path.
"""
identity = (
await db.execute(
select(OAuthIdentity).where(
@@ -53,28 +67,96 @@ async def _sign_in_or_register(
await db.execute(select(User).where(User.email == profile.email))
).scalar_one_or_none()
is_new_user = user is None
# If the user arrived via an invite link but already has a ResolutionFlow
# account (e.g., previously signed up with email+password), silently
# linking the OAuth identity to that existing account would bypass the
# invite — they'd stay in their personal account and the invite would
# never be consumed. Fail loud instead so they can sign in and accept the
# invite from the dashboard. The "invited user wants to transfer accounts"
# case is a v2 concern.
if account_invite_code and not is_new_user:
raise HTTPException(
status_code=400,
detail={
"error": "email_already_registered_use_login",
"message": (
"An account already exists for this email. Please sign in "
"instead, then accept the invite from your dashboard."
),
},
)
invite_record: AccountInvite | None = None
if is_new_user and account_invite_code:
# SELECT FOR UPDATE so two concurrent OAuth callbacks can't both
# consume the same invite code.
invite_record = (
await db.execute(
select(AccountInvite)
.where(AccountInvite.code == account_invite_code)
.with_for_update()
)
).scalar_one_or_none()
if invite_record is None or not invite_record.is_valid:
raise HTTPException(
status_code=400,
detail={"error": "invite_invalid_or_expired_or_revoked"},
)
# Verify the OAuth profile email matches what was invited. We compare
# against the invite row directly (source of truth), but also accept
# the client-supplied invited_email as a defensive equality check.
if invite_record.email.lower() != profile.email.lower():
raise HTTPException(
status_code=400,
detail={"error": "invite_email_mismatch"},
)
if invited_email and invited_email.lower() != invite_record.email.lower():
raise HTTPException(
status_code=400,
detail={"error": "invite_email_mismatch"},
)
if is_new_user:
account = Account(
name=f"{profile.name}'s Account",
display_code=_generate_display_code(),
)
db.add(account)
await db.flush()
user = User(
email=profile.email,
name=profile.name,
password_hash=None,
account_id=account.id,
account_role="owner",
role="engineer",
email_verified_at=datetime.now(timezone.utc),
)
db.add(user)
await db.flush()
account.owner_id = user.id
await db.flush()
# start_trial commits internally; flushed account/user above.
await BillingService.start_trial(db, account.id)
if invite_record is not None:
# Join the invited account directly — no personal account, no
# trial creation.
user = User(
email=profile.email,
name=profile.name,
password_hash=None,
account_id=invite_record.account_id,
account_role=invite_record.role,
role="engineer",
email_verified_at=datetime.now(timezone.utc),
)
db.add(user)
await db.flush()
invite_record.accepted_by_id = user.id
invite_record.used_at = datetime.now(timezone.utc)
await db.flush()
else:
account = Account(
name=f"{profile.name}'s Account",
display_code=_generate_display_code(),
)
db.add(account)
await db.flush()
user = User(
email=profile.email,
name=profile.name,
password_hash=None,
account_id=account.id,
account_role="owner",
role="engineer",
email_verified_at=datetime.now(timezone.utc),
)
db.add(user)
await db.flush()
account.owner_id = user.id
await db.flush()
# start_trial commits internally; flushed account/user above.
await BillingService.start_trial(db, account.id)
db.add(
OAuthIdentity(
@@ -98,10 +180,23 @@ async def google_callback(
raise HTTPException(status_code=503, detail="Google sign-in not configured")
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/google/callback"
profile = await google_exchange_code(payload.code, redirect_uri)
user, is_new = await _sign_in_or_register(db, "google", profile)
user, is_new = await _sign_in_or_register(
db,
"google",
profile,
account_invite_code=payload.account_invite_code,
invited_email=payload.invited_email,
)
refresh_token_str = create_refresh_token({"sub": str(user.id)})
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
# reject this token as "revoked" (the rotation logic requires a row to
# mark as used). _sign_in_or_register already committed; this needs a
# second commit.
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return OAuthCallbackResponse(
access_token=create_access_token({"sub": str(user.id)}),
refresh_token=create_refresh_token({"sub": str(user.id)}),
refresh_token=refresh_token_str,
is_new_user=is_new,
)
@@ -115,9 +210,22 @@ async def microsoft_callback(
raise HTTPException(status_code=503, detail="Microsoft sign-in not configured")
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/microsoft/callback"
profile = await microsoft_exchange_code(payload.code, redirect_uri)
user, is_new = await _sign_in_or_register(db, "microsoft", profile)
user, is_new = await _sign_in_or_register(
db,
"microsoft",
profile,
account_invite_code=payload.account_invite_code,
invited_email=payload.invited_email,
)
refresh_token_str = create_refresh_token({"sub": str(user.id)})
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
# reject this token as "revoked" (the rotation logic requires a row to
# mark as used). _sign_in_or_register already committed; this needs a
# second commit.
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return OAuthCallbackResponse(
access_token=create_access_token({"sub": str(user.id)}),
refresh_token=create_refresh_token({"sub": str(user.id)}),
refresh_token=refresh_token_str,
is_new_user=is_new,
)

View File

@@ -2,19 +2,24 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_active_user
from app.core.database import get_db
from app.core.admin_database import get_admin_db
from app.models.account import Account
from app.models.assistant_chat import AssistantChat
from app.models.psa_connection import PsaConnection
from app.models.session import Session
from app.models.tree import Tree
from app.models.user import User
from app.schemas.onboarding import OnboardingStatus
from app.schemas.onboarding import (
OnboardingStatus,
OnboardingStepRequest,
OnboardingStepResponse,
)
router = APIRouter(prefix="/users", tags=["onboarding"])
@@ -85,6 +90,10 @@ async def get_onboarding_status(
)
connected_psa = (psa_q.scalar() or 0) > 0
# New (Phase 2 — Task 41)
email_verified = current_user.email_verified_at is not None
shop_setup_done = (current_user.onboarding_step_completed or 0) >= 1
return OnboardingStatus(
created_flow=created_flow,
ran_session=ran_session,
@@ -94,6 +103,8 @@ async def get_onboarding_status(
connected_psa=connected_psa,
is_team_user=is_team_user,
dismissed=current_user.onboarding_dismissed,
email_verified=email_verified,
shop_setup_done=shop_setup_done,
)
@@ -109,3 +120,98 @@ async def dismiss_onboarding(
# Return updated status (reuse the GET logic)
return await get_onboarding_status(db=db, current_user=current_user)
# ---------------------------------------------------------------------------
# Welcome wizard endpoints (Phase 2)
#
# These persist Step 1/2/3 progress for the post-signup welcome wizard.
# Mounted on /users/me/* (the parent router prefix is /users) so the wizard
# can run before email verification and during trial.
# ---------------------------------------------------------------------------
@router.patch("/me/onboarding-step", response_model=OnboardingStepResponse)
async def patch_onboarding_step(
body: OnboardingStepRequest,
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> OnboardingStepResponse:
"""Persist welcome-wizard progress for the current user.
Contract:
- step=1 + complete writes accounts.name, accounts.team_size_bucket,
users.role_at_signup, then sets users.onboarding_step_completed=1.
- step=2 + complete writes accounts.primary_psa, then sets
users.onboarding_step_completed=2.
- step=3 + complete just sets users.onboarding_step_completed=3
(invites are POSTed separately).
- action="skip" ignores `data` entirely and only advances the step.
- The new step must be >= current onboarding_step_completed (None=>0);
otherwise 400. Idempotent re-PATCH of the same step succeeds.
"""
current_step = current_user.onboarding_step_completed or 0
if body.step < current_step:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "step_cannot_decrease",
"current_step": current_step,
"requested_step": body.step,
},
)
if body.action == "complete" and body.data is not None and body.step in (1, 2):
# Load the user's account for field writes. Step 3 has no data writes.
account_result = await db.execute(
select(Account).where(Account.id == current_user.account_id)
)
account = account_result.scalar_one_or_none()
if account is None:
# Should never happen — user is required to have an account_id.
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="account_not_found",
)
if body.step == 1:
data = body.data
if data.company_name is not None:
account.name = data.company_name
if data.team_size_bucket is not None:
account.team_size_bucket = data.team_size_bucket
if data.role_at_signup is not None:
current_user.role_at_signup = data.role_at_signup
elif body.step == 2:
data = body.data
if data.primary_psa is not None:
account.primary_psa = data.primary_psa
current_user.onboarding_step_completed = body.step
await db.commit()
await db.refresh(current_user)
return OnboardingStepResponse(
onboarding_step_completed=current_user.onboarding_step_completed,
onboarding_dismissed=current_user.onboarding_dismissed,
)
@router.post("/me/onboarding-dismiss-rest", response_model=OnboardingStepResponse)
async def dismiss_onboarding_rest(
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> OnboardingStepResponse:
"""Set users.onboarding_dismissed=TRUE — backs the wizard's "Skip the rest" button.
Returns the same shape as the step PATCH so the frontend can update its
local store from a single response.
"""
current_user.onboarding_dismissed = True
await db.commit()
await db.refresh(current_user)
return OnboardingStepResponse(
onboarding_step_completed=current_user.onboarding_step_completed,
onboarding_dismissed=current_user.onboarding_dismissed,
)

View File

@@ -0,0 +1,58 @@
"""Public plans endpoint — no auth required.
GET /api/v1/plans/public
Returns the public-safe view of `plan_billing` joined with
`plan_limits.max_users` (exposed as `max_seats`), filtered to
`is_public=True AND is_archived=False`, ordered by sort_order ASC, plan ASC.
Distinct from `/admin/plan-limits` (admin-only, returns ALL plans including
archived/internal). This endpoint exists to power the marketing /pricing page
without exposing the rest of the admin-only billing surface.
"""
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_database import get_admin_db
from app.models.plan_billing import PlanBilling
from app.models.plan_limits import PlanLimits
from app.schemas.billing import PublicPlanResponse
router = APIRouter(prefix="/plans", tags=["plans"])
@router.get("/public", response_model=list[PublicPlanResponse])
async def list_public_plans(
db: Annotated[AsyncSession, Depends(get_admin_db)],
) -> list[PublicPlanResponse]:
"""List public, non-archived plans for the marketing /pricing page.
Public — no auth. Uses `get_admin_db` because this is a cross-tenant read
of the global plan catalog (same pattern as `/config/public`).
"""
stmt = (
select(PlanBilling, PlanLimits.max_users)
.outerjoin(PlanLimits, PlanBilling.plan == PlanLimits.plan)
.where(PlanBilling.is_public.is_(True))
.where(PlanBilling.is_archived.is_(False))
.order_by(PlanBilling.sort_order.asc(), PlanBilling.plan.asc())
)
rows = (await db.execute(stmt)).all()
return [
PublicPlanResponse(
plan=billing.plan,
display_name=billing.display_name,
description=billing.description,
monthly_price_cents=billing.monthly_price_cents,
annual_price_cents=billing.annual_price_cents,
max_seats=max_users,
sort_order=billing.sort_order,
is_public=billing.is_public,
)
for billing, max_users in rows
]

View File

@@ -0,0 +1,114 @@
"""Public Talk-to-Sales endpoint — no auth required.
POST /api/v1/sales-leads
- Inserts a sales_leads row.
- Fires (best-effort) a notification email to settings.SALES_LEAD_RECIPIENT_EMAIL.
- Emits a server-side PostHog event (best-effort).
- Rate-limited per IP (5/hour).
"""
from __future__ import annotations
import asyncio
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_database import get_admin_db
from app.core.config import settings
from app.core.email import EmailService
from app.core.rate_limit import limiter
from app.models.sales_lead import SalesLead
from app.schemas.sales_lead import SalesLeadCreate, SalesLeadCreateResponse
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/sales-leads", tags=["sales"])
async def _send_notification_email(lead: SalesLead) -> None:
"""Fire-and-forget wrapper. EmailService methods never raise, but we
still wrap in a try/except to defend against future regressions."""
try:
await EmailService.send_sales_lead_notification(
to_email=settings.SALES_LEAD_RECIPIENT_EMAIL,
lead=lead,
)
except Exception:
logger.warning(
"Sales lead notification email failed for lead %s",
lead.id,
exc_info=True,
)
def _capture_posthog_event(lead: SalesLead) -> None:
"""Emit `talk_to_sales_form_submitted` server-side. Best-effort.
Backend PostHog SDK isn't initialized in the project today; this function
is the single instrumentation point so wiring it up later is a one-line
change. The call is wrapped so any future failure can never fail the
request.
"""
try:
# Lazy import — keeps the dependency optional. When the backend
# PostHog client is wired in (likely as `app.core.analytics.posthog`),
# swap the import path here and the event will fire automatically.
try:
from app.core.analytics import posthog # type: ignore[attr-defined]
except ImportError:
logger.debug(
"PostHog server-side capture skipped — client not configured"
)
return
distinct_id = lead.posthog_distinct_id or f"sales_lead:{lead.id}"
posthog.capture(
distinct_id=distinct_id,
event="talk_to_sales_form_submitted",
properties={
"source": lead.source,
"company": lead.company,
"team_size": lead.team_size,
},
)
except Exception:
logger.warning(
"PostHog capture failed for sales lead %s",
lead.id,
exc_info=True,
)
@router.post("", response_model=SalesLeadCreateResponse, status_code=201)
@limiter.limit("5/hour")
async def create_sales_lead(
request: Request,
data: SalesLeadCreate,
db: Annotated[AsyncSession, Depends(get_admin_db)],
) -> SalesLeadCreateResponse:
"""Public Talk-to-Sales submission.
Creates a sales_leads row, fires (best-effort) a notification email and a
server-side PostHog event. Rate-limited per IP at 5/hour.
"""
lead = SalesLead(
email=str(data.email).lower(),
name=data.name,
company=data.company,
team_size=data.team_size,
message=data.message,
source=data.source,
posthog_distinct_id=data.posthog_distinct_id,
)
db.add(lead)
await db.commit()
await db.refresh(lead)
# Fire-and-forget: email + analytics. Failures must not fail the request.
asyncio.create_task(_send_notification_email(lead))
_capture_posthog_event(lead)
return SalesLeadCreateResponse(id=lead.id, status="received")

View File

@@ -26,8 +26,10 @@ from app.api.endpoints import (
billing,
beta_feedback,
beta_signup,
sales_leads,
branding,
categories,
config as config_endpoints,
copilot,
device_types,
draft_templates,
@@ -43,6 +45,7 @@ from app.api.endpoints import (
notifications,
oauth as oauth_endpoints,
onboarding,
plans_public,
public_templates,
ratings,
scripts,
@@ -68,6 +71,7 @@ from app.api.endpoints import (
uploads,
webhooks,
accounts,
account_invite_lookup,
)
api_router = APIRouter()
@@ -88,9 +92,13 @@ api_router.include_router(billing.router) # Reachable when subscription lock
api_router.include_router(shared.router) # Public share links (no auth)
api_router.include_router(shares.public_router) # Public session share links (optional auth)
api_router.include_router(beta_signup.router)
api_router.include_router(sales_leads.router) # Talk-to-Sales (no auth, rate-limited)
api_router.include_router(webhooks.router) # Stripe webhook receiver
api_router.include_router(public_templates.router) # Public gallery (no auth, rate-limited)
api_router.include_router(survey.router) # Public survey flow (no auth, rate-limited)
api_router.include_router(config_endpoints.router) # Public runtime feature flags
api_router.include_router(account_invite_lookup.router) # Public invite-code lookup for /accept-invite
api_router.include_router(plans_public.router) # Public plan catalog for /pricing page
# ---------------------------------------------------------------------------
# Admin endpoints — super_admin only

View File

@@ -84,6 +84,7 @@ class Settings(BaseSettings):
RESEND_API_KEY: Optional[str] = None
FROM_EMAIL: str = "ResolutionFlow <invites@resolutionflow.com>"
FEEDBACK_EMAIL: Optional[str] = None
SALES_LEAD_RECIPIENT_EMAIL: str = "sales@resolutionflow.com"
@property
def email_enabled(self) -> bool:

View File

@@ -1,6 +1,11 @@
import logging
from typing import TYPE_CHECKING
from app.core.config import settings
if TYPE_CHECKING:
from app.models.sales_lead import SalesLead
logger = logging.getLogger(__name__)
@@ -484,6 +489,99 @@ class EmailService:
logger.exception("Failed to send beta signup notification for %s", signup_email)
return False
@staticmethod
async def send_sales_lead_notification(
to_email: str,
lead: "SalesLead",
) -> bool:
"""Notify the sales recipient about a new Talk-to-Sales submission.
Fire-and-forget. Returns False (and logs) on any failure; never raises.
"""
if not settings.email_enabled:
logger.warning(
"Sales lead email not sent — RESEND_API_KEY not configured (lead %s)",
lead.id,
)
return False
try:
import resend
import html as html_mod
from datetime import datetime, timezone
resend.api_key = settings.RESEND_API_KEY
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
safe_email = html_mod.escape(lead.email)
safe_name = html_mod.escape(lead.name)
safe_company = html_mod.escape(lead.company)
safe_team_size = html_mod.escape(lead.team_size or "")
safe_source = html_mod.escape(lead.source)
safe_message = html_mod.escape(lead.message or "(no message)")
subject = f"[ResolutionFlow Sales] New lead — {safe_company} ({safe_email})"
email_html = f"""<!DOCTYPE html>
<html><head><meta charset="utf-8"><meta name="viewport" content="width=device-width"></head>
<body style="margin:0;padding:0;background:#101114;font-family:'Inter',Helvetica,Arial,sans-serif;">
<table width="100%" cellpadding="0" cellspacing="0" style="background:#101114;padding:40px 0;">
<tr><td align="center">
<table width="560" cellpadding="0" cellspacing="0" style="background:#14161a;border:1px solid rgba(255,255,255,0.06);border-radius:16px;">
<tr><td style="padding:40px 40px 24px;text-align:center;">
<h1 style="margin:0;color:#f8fafc;font-size:24px;font-weight:600;">Resolution<span style="color:#06b6d4;">Flow</span></h1>
<p style="margin:8px 0 0;color:#5a6170;font-size:14px;">New Sales Lead</p>
</td></tr>
<tr><td style="padding:0 40px 16px;">
<p style="margin:0;color:#8891a0;font-size:16px;line-height:1.6;">
Source: <strong style="color:#f8fafc;">{safe_source}</strong>
</p>
</td></tr>
<tr><td style="padding:0 40px 16px;">
<table width="100%" cellpadding="0" cellspacing="0" style="background:rgba(0,0,0,0.3);border:1px solid rgba(255,255,255,0.06);border-radius:12px;">
<tr><td style="padding:16px;">
<p style="margin:0 0 4px;color:#5a6170;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Name</p>
<p style="margin:0 0 12px;color:#f8fafc;font-size:16px;font-weight:600;">{safe_name}</p>
<p style="margin:0 0 4px;color:#5a6170;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Email</p>
<p style="margin:0 0 12px;color:#22d3ee;font-size:16px;font-weight:600;">{safe_email}</p>
<p style="margin:0 0 4px;color:#5a6170;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Company</p>
<p style="margin:0 0 12px;color:#f8fafc;font-size:16px;font-weight:600;">{safe_company}</p>
<p style="margin:0 0 4px;color:#5a6170;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Team Size</p>
<p style="margin:0;color:#f8fafc;font-size:16px;font-weight:600;">{safe_team_size}</p>
</td></tr>
</table>
</td></tr>
<tr><td style="padding:0 40px 16px;">
<p style="margin:0 0 4px;color:#5a6170;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Message</p>
<p style="margin:0;color:#8891a0;font-size:14px;line-height:1.6;white-space:pre-wrap;">{safe_message}</p>
</td></tr>
<tr><td style="padding:0 40px 32px;">
<p style="margin:0;color:#5a6170;font-size:12px;text-align:center;">
Submitted at {date_str} · Lead ID: {lead.id}
</p>
</td></tr>
</table>
</td></tr>
</table>
</body></html>"""
resend.Emails.send({
"from": settings.FROM_EMAIL,
"to": [to_email],
"reply_to": lead.email,
"subject": subject,
"html": email_html,
})
logger.info("Sales lead notification sent for %s (lead %s)", lead.email, lead.id)
return True
except Exception:
logger.exception(
"Failed to send sales lead notification for %s (lead %s)",
lead.email,
lead.id,
)
return False
@staticmethod
async def send_notification_email(
to_email: str,

View File

@@ -172,6 +172,21 @@ class PlanLimitResponse(BaseModel):
from_attributes = True
class PlanLimitWithBillingResponse(PlanLimitResponse):
"""PlanLimits + plan_billing fields merged. Billing fields are None when no
plan_billing row exists for the plan yet."""
display_name: Optional[str] = None
description: Optional[str] = None
monthly_price_cents: Optional[int] = None
annual_price_cents: Optional[int] = None
stripe_product_id: Optional[str] = None
stripe_monthly_price_id: Optional[str] = None
stripe_annual_price_id: Optional[str] = None
is_public: Optional[bool] = None
is_archived: Optional[bool] = None
sort_order: Optional[int] = None
class PlanLimitUpdate(BaseModel):
plan: str
max_trees: Optional[int] = None
@@ -180,6 +195,19 @@ class PlanLimitUpdate(BaseModel):
custom_branding: bool = False
priority_support: bool = False
export_formats: list = Field(default_factory=lambda: ["markdown", "text"])
# plan_billing fields — all optional, partial-update semantics. If any are
# set in the body, the admin endpoint upserts the plan_billing row in the
# same transaction.
display_name: Optional[str] = None
description: Optional[str] = None
monthly_price_cents: Optional[int] = None
annual_price_cents: Optional[int] = None
stripe_product_id: Optional[str] = None
stripe_monthly_price_id: Optional[str] = None
stripe_annual_price_id: Optional[str] = None
is_public: Optional[bool] = None
is_archived: Optional[bool] = None
sort_order: Optional[int] = None
class AccountOverrideCreate(BaseModel):

View File

@@ -13,6 +13,10 @@ class CheckoutSessionResponse(BaseModel):
url: str
class BillingPortalSessionResponse(BaseModel):
url: str
class SubscriptionState(BaseModel):
status: str
plan: str
@@ -38,3 +42,23 @@ class BillingStateResponse(BaseModel):
plan_billing: Optional[PlanBillingState]
plan_limits: Dict[str, Any]
enabled_features: Dict[str, bool]
class PublicPlanResponse(BaseModel):
"""Public-safe view of a billable plan, used by the marketing /pricing page.
Sourced from `plan_billing` joined with `plan_limits.max_users` (exposed
here as `max_seats`). Always filtered server-side to is_public=True and
is_archived=False, so `is_public` is a constant True for any row returned
here — included for clarity and forward compatibility.
"""
plan: str
display_name: str
description: Optional[str] = None
monthly_price_cents: Optional[int] = None
annual_price_cents: Optional[int] = None
max_seats: Optional[int] = None
sort_order: int
is_public: bool = True
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,18 @@
"""Pydantic schemas for public runtime configuration."""
from __future__ import annotations
from typing import List
from pydantic import BaseModel
class PublicConfigResponse(BaseModel):
"""Runtime feature flags + OAuth provider list exposed to anonymous clients.
Read once by the frontend at app load to decide whether to render the
self-serve signup flow and which OAuth buttons to show.
"""
self_serve_enabled: bool
oauth_providers: List[str]

View File

@@ -4,6 +4,11 @@ from pydantic import BaseModel
class OAuthCallbackPayload(BaseModel):
code: str
state: str | None = None
# When the OAuth flow originated from /accept-invite, the frontend round-trips
# the invite code + invited email so the backend can link the new user to the
# invited account instead of creating a personal one.
account_invite_code: str | None = None
invited_email: str | None = None
class OAuthCallbackResponse(BaseModel):
@@ -11,3 +16,17 @@ class OAuthCallbackResponse(BaseModel):
refresh_token: str
token_type: str = "bearer"
is_new_user: bool
class InviteLookupResponse(BaseModel):
"""Public response surface for GET /accounts/invites/{code}/lookup.
Returns the minimum context needed for the AcceptInvitePage:
account name (so we can title the card), inviter name (for the resend
fallback message), invited email (locked into the form), and role.
"""
account_name: str
inviter_name: str
invited_email: str
role: str

View File

@@ -1,12 +1,55 @@
from pydantic import BaseModel
from typing import Literal, Optional
from pydantic import BaseModel, Field
class OnboardingStatus(BaseModel):
created_flow: bool
ran_session: bool
exported_session: bool
# Kept for backward-compat during deploy; new code paths should not branch on this.
tried_ai_assistant: bool
invited_teammate: bool
connected_psa: bool
is_team_user: bool
dismissed: bool
# New (Phase 2 — Task 41) — drive the unified next-step card + checklist.
email_verified: bool
shop_setup_done: bool
# --- Welcome wizard (Phase 2) ----------------------------------------------
TeamSizeBucket = Literal["1-2", "3-5", "6-10", "11-25", "26+"]
RoleAtSignup = Literal["owner", "lead_tech", "tech", "other"]
PrimaryPsa = Literal["connectwise", "autotask", "halopsa", "none"]
WizardStep = Literal[1, 2, 3]
WizardAction = Literal["complete", "skip"]
class OnboardingStepData(BaseModel):
"""Optional payload carried with `action="complete"` for steps 1 and 2.
Step 1 fields: company_name, team_size_bucket, role_at_signup
Step 2 fields: primary_psa
Step 3 has no data (invitations posted separately).
"""
# Step 1
company_name: Optional[str] = Field(default=None, max_length=255)
team_size_bucket: Optional[TeamSizeBucket] = None
role_at_signup: Optional[RoleAtSignup] = None
# Step 2
primary_psa: Optional[PrimaryPsa] = None
class OnboardingStepRequest(BaseModel):
step: WizardStep
action: WizardAction
data: Optional[OnboardingStepData] = None
class OnboardingStepResponse(BaseModel):
onboarding_step_completed: Optional[int]
onboarding_dismissed: bool

View File

@@ -0,0 +1,27 @@
"""Pydantic schemas for Talk-to-Sales submissions."""
from typing import Literal, Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, EmailStr, Field
SalesLeadSource = Literal["pricing_page", "register_footer", "landing_page"]
class SalesLeadCreate(BaseModel):
"""Public Talk-to-Sales form submission."""
model_config = ConfigDict(str_strip_whitespace=True)
email: EmailStr
name: str = Field(..., min_length=1, max_length=255)
company: str = Field(..., min_length=1, max_length=255)
team_size: Optional[str] = Field(default=None, max_length=20)
message: Optional[str] = Field(default=None, max_length=5000)
source: SalesLeadSource
posthog_distinct_id: Optional[str] = Field(default=None, max_length=255)
class SalesLeadCreateResponse(BaseModel):
id: UUID
status: Literal["received"] = "received"

View File

@@ -58,6 +58,8 @@ class UserResponse(UserBase):
timezone: str = "UTC"
avatar_url: Optional[str] = None
email_verified_at: Optional[datetime] = None
onboarding_step_completed: Optional[int] = None
onboarding_dismissed: bool = False
class Config:
from_attributes = True

View File

@@ -1,6 +1,7 @@
"""Single billing service module. Stripe is the only impl — no provider
abstraction. Account row is canonical local state; Stripe is canonical
remote state; the webhook handler bridges the two."""
import logging
from datetime import datetime, timezone, timedelta
import stripe
@@ -17,8 +18,32 @@ from app.models.subscription import Subscription
TRIAL_DAYS = 14
logger = logging.getLogger(__name__)
class BillingService:
@staticmethod
async def invalidate_billing_cache(account_ids) -> None:
"""No-op stub for future in-process billing cache invalidation.
Today there is no `app.state.billing_cache` — `BillingService.get_billing_state`
always reads fresh from the DB. Call sites that mutate plan/feature data
invoke this hook so that wiring is in place when an in-process cache is
added later. Until then, this just logs.
TODO: when an in-process billing cache (e.g. `app.state.billing_cache`)
is introduced, evict entries for the given account_ids here.
"""
try:
count = len(list(account_ids))
except TypeError:
count = -1
logger.debug(
"BillingService.invalidate_billing_cache called for %d account(s) "
"(no-op stub — wire to app.state.billing_cache when added)",
count,
)
@staticmethod
async def start_trial(db: AsyncSession, account_id) -> Subscription:
"""Idempotent. Creates a trialing Subscription on Pro for the account if
@@ -105,6 +130,25 @@ class BillingService:
)
return session.url
@staticmethod
async def open_customer_portal(account: Account) -> str:
"""Create a Stripe-hosted Customer Portal session and return the URL.
Raises RuntimeError if Stripe isn't configured (endpoint maps to 503).
Raises ValueError if the account has no stripe_customer_id yet — the
user must complete a checkout first (endpoint maps to 400).
"""
if not settings.stripe_enabled:
raise RuntimeError("Stripe not configured")
if account.stripe_customer_id is None:
raise ValueError("no_stripe_customer")
stripe.api_key = settings.STRIPE_SECRET_KEY
session = stripe.billing_portal.Session.create(
customer=account.stripe_customer_id,
return_url=f"{settings.FRONTEND_URL}/account/billing",
)
return session.url
@staticmethod
async def get_billing_state(db: AsyncSession, account):
"""Aggregate Subscription + PlanLimits + PlanBilling + resolved feature
@@ -166,28 +210,44 @@ class BillingService:
) -> bool:
"""Idempotent. Returns True if the event was applied; False if it had
already been processed (idempotent ack). The webhook handler returns 200
either way."""
either way.
Atomic: the StripeEvent idempotency mark and the handler's state
mutations are committed in a single transaction. If the handler raises
the entire transaction (idempotency mark + partial mutations) is rolled
back, so a Stripe retry will re-run the handler. Without this, a
handler that fails mid-flight would leave the StripeEvent row persisted
and silently desync subscription state from Stripe.
"""
db.add(StripeEvent(
id=event_id,
event_type=event_type,
payload_excerpt=_excerpt(payload),
))
try:
db.add(StripeEvent(
id=event_id,
event_type=event_type,
payload_excerpt=_excerpt(payload),
))
await db.commit()
await db.flush()
except IntegrityError:
# Duplicate event_id — already processed (or in flight). Ack with False.
await db.rollback()
return False
if event_type == "checkout.session.completed":
await _handle_checkout_completed(db, payload)
elif event_type == "customer.subscription.updated":
await _handle_subscription_updated(db, payload)
elif event_type == "customer.subscription.deleted":
await _handle_subscription_deleted(db, payload)
elif event_type == "invoice.payment_failed":
await _handle_payment_failed(db, payload)
elif event_type == "invoice.payment_succeeded":
await _handle_payment_succeeded(db, payload)
try:
if event_type == "checkout.session.completed":
await _handle_checkout_completed(db, payload)
elif event_type == "customer.subscription.updated":
await _handle_subscription_updated(db, payload)
elif event_type == "customer.subscription.deleted":
await _handle_subscription_deleted(db, payload)
elif event_type == "invoice.payment_failed":
await _handle_payment_failed(db, payload)
elif event_type == "invoice.payment_succeeded":
await _handle_payment_succeeded(db, payload)
await db.commit()
except Exception:
# Roll back the StripeEvent insert + any partial handler mutations
# so Stripe's retry can re-run cleanly.
await db.rollback()
raise
return True
@@ -238,7 +298,7 @@ async def _handle_checkout_completed(db: AsyncSession, payload: dict):
)).scalar_one_or_none()
if pb is not None:
sub.plan = pb.plan
await db.commit()
# No commit — apply_subscription_event commits once for the full event.
async def _handle_subscription_updated(db: AsyncSession, payload: dict):
@@ -253,7 +313,7 @@ async def _handle_subscription_updated(db: AsyncSession, payload: dict):
sub.current_period_end = datetime.fromtimestamp(obj["current_period_end"], tz=timezone.utc)
sub.cancel_at_period_end = obj.get("cancel_at_period_end", False)
sub.seat_limit = obj["items"]["data"][0]["quantity"]
await db.commit()
# No commit — apply_subscription_event commits once for the full event.
async def _handle_subscription_deleted(db: AsyncSession, payload: dict):
@@ -264,7 +324,7 @@ async def _handle_subscription_deleted(db: AsyncSession, payload: dict):
if sub is None:
return
sub.status = "canceled"
await db.commit()
# No commit — apply_subscription_event commits once for the full event.
async def _handle_payment_failed(db: AsyncSession, payload: dict):
@@ -278,7 +338,7 @@ async def _handle_payment_failed(db: AsyncSession, payload: dict):
if sub is None:
return
sub.status = "past_due"
await db.commit()
# No commit — apply_subscription_event commits once for the full event.
async def _handle_payment_succeeded(db: AsyncSession, payload: dict):
@@ -293,4 +353,4 @@ async def _handle_payment_succeeded(db: AsyncSession, payload: dict):
return
if sub.status == "past_due":
sub.status = "active"
await db.commit()
# No commit — apply_subscription_event commits once for the full event.

View File

@@ -0,0 +1,290 @@
"""Tests for the public GET /accounts/invites/{code}/lookup endpoint
(consumed by the /accept-invite page on the frontend)."""
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import select
from app.models.account_invite import AccountInvite
@pytest.mark.asyncio
async def test_invite_lookup_returns_account_info_for_valid_code(
client, test_db, test_user, auth_headers
):
"""A freshly-created, unused, unexpired invite resolves to the inviter's
account name + the inviter's display name + the invited email + role."""
with patch(
"app.core.email.EmailService.send_account_invite_email",
new_callable=AsyncMock,
return_value=True,
):
create_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "lookup@example.com", "role": "engineer"},
headers=auth_headers,
)
assert create_resp.status_code == 201, create_resp.json()
code = create_resp.json()["code"]
response = await client.get(f"/api/v1/accounts/invites/{code}/lookup")
assert response.status_code == 200, response.json()
body = response.json()
assert body["invited_email"] == "lookup@example.com"
assert body["role"] == "engineer"
assert body["inviter_name"] == test_user["user_data"]["name"]
# account_name is whatever the test_user fixture seeded for the account.
assert isinstance(body["account_name"], str) and body["account_name"]
@pytest.mark.asyncio
async def test_invite_lookup_returns_404_for_invalid_or_expired_code(
client, test_db, test_user
):
"""Three failure modes (unknown code, expired, revoked, used) all collapse
to the same 404 + invite_invalid_or_expired_or_revoked error code."""
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
account_id = uuid.UUID(test_user["user_data"]["account_id"])
# 1) Unknown code
unknown = await client.get("/api/v1/accounts/invites/DOESNOTEXIST/lookup")
assert unknown.status_code == 404
assert unknown.json()["detail"]["error"] == "invite_invalid_or_expired_or_revoked"
# 2) Expired
expired_invite = AccountInvite(
account_id=account_id,
invited_by_id=invited_by_id,
email="expired@example.com",
code="EXPIREDLOOKUP01",
role="engineer",
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
)
test_db.add(expired_invite)
await test_db.commit()
expired = await client.get("/api/v1/accounts/invites/EXPIREDLOOKUP01/lookup")
assert expired.status_code == 404
assert expired.json()["detail"]["error"] == "invite_invalid_or_expired_or_revoked"
# 3) Revoked
revoked_invite = AccountInvite(
account_id=account_id,
invited_by_id=invited_by_id,
email="revoked@example.com",
code="REVOKEDLOOKUP01",
role="engineer",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
revoked_at=datetime.now(timezone.utc),
)
test_db.add(revoked_invite)
await test_db.commit()
revoked = await client.get("/api/v1/accounts/invites/REVOKEDLOOKUP01/lookup")
assert revoked.status_code == 404
assert revoked.json()["detail"]["error"] == "invite_invalid_or_expired_or_revoked"
# 4) Already used
used_invite = AccountInvite(
account_id=account_id,
invited_by_id=invited_by_id,
email="used@example.com",
code="USEDLOOKUP01",
role="engineer",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
accepted_by_id=invited_by_id,
used_at=datetime.now(timezone.utc),
)
test_db.add(used_invite)
await test_db.commit()
used = await client.get("/api/v1/accounts/invites/USEDLOOKUP01/lookup")
assert used.status_code == 404
assert used.json()["detail"]["error"] == "invite_invalid_or_expired_or_revoked"
# Sanity: rows survived (no destructive side effects).
persisted = (
await test_db.execute(
select(AccountInvite).where(
AccountInvite.code.in_(
["EXPIREDLOOKUP01", "REVOKEDLOOKUP01", "USEDLOOKUP01"]
)
)
)
).scalars().all()
assert len(persisted) == 3
@pytest.mark.asyncio
async def test_oauth_callback_links_invite_when_account_invite_code_supplied(
client, test_db, test_user, auth_headers, monkeypatch
):
"""Brand-new OAuth user with account_invite_code joins the invited account
instead of getting a personal one. Invite is marked used."""
from app.core.config import settings
from app.models.user import User
from app.services.oauth_providers import OAuthProfile
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
with patch(
"app.core.email.EmailService.send_account_invite_email",
new_callable=AsyncMock,
return_value=True,
):
create_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "oauth-invite@example.com", "role": "engineer"},
headers=auth_headers,
)
code = create_resp.json()["code"]
inviter_account_id = uuid.UUID(test_user["user_data"]["account_id"])
profile = OAuthProfile(
provider_subject="google_invite_subject_1",
email="oauth-invite@example.com",
name="OAuth Invitee",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
response = await client.post(
"/api/v1/auth/google/callback",
json={
"code": "auth_code_xyz",
"account_invite_code": code,
"invited_email": "oauth-invite@example.com",
},
)
assert response.status_code == 200, response.json()
assert response.json()["is_new_user"] is True
user = (
await test_db.execute(
select(User).where(User.email == "oauth-invite@example.com")
)
).scalar_one()
assert user.account_id == inviter_account_id
assert user.account_role == "engineer"
invite = (
await test_db.execute(
select(AccountInvite).where(AccountInvite.code == code)
)
).scalar_one()
assert invite.used_at is not None
assert invite.accepted_by_id == user.id
@pytest.mark.asyncio
async def test_oauth_callback_existing_email_with_invite_returns_400(
client, test_db, test_user, auth_headers, monkeypatch
):
"""If a user already exists with the invited email (e.g., previously
registered via password), arriving via /accept-invite OAuth must NOT
silently link the OAuth identity to their existing account and skip the
invite. Surface email_already_registered_use_login so the user signs in
and accepts the invite from the dashboard instead."""
from app.core.config import settings
from app.services.oauth_providers import OAuthProfile
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
# 1) Pre-existing user with a password (separate from the inviter).
existing_email = "already-here@example.com"
register_resp = await client.post(
"/api/v1/auth/register",
json={
"email": existing_email,
"password": "PreviousPassword123!",
"name": "Already Here",
},
)
assert register_resp.status_code in (200, 201), register_resp.json()
# 2) Inviter creates an invite for that exact email.
with patch(
"app.core.email.EmailService.send_account_invite_email",
new_callable=AsyncMock,
return_value=True,
):
create_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": existing_email, "role": "engineer"},
headers=auth_headers,
)
assert create_resp.status_code == 201, create_resp.json()
code = create_resp.json()["code"]
# 3) The existing user does Google OAuth and the callback receives the
# invite code. Backend must reject — not link silently.
profile = OAuthProfile(
provider_subject="google_existing_subject_1",
email=existing_email,
name="Already Here",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
response = await client.post(
"/api/v1/auth/google/callback",
json={
"code": "auth_code_xyz",
"account_invite_code": code,
"invited_email": existing_email,
},
)
assert response.status_code == 400, response.json()
assert (
response.json()["detail"]["error"] == "email_already_registered_use_login"
)
# 4) Sanity: the invite was NOT consumed.
invite = (
await test_db.execute(
select(AccountInvite).where(AccountInvite.code == code)
)
).scalar_one()
assert invite.used_at is None
assert invite.accepted_by_id is None
@pytest.mark.asyncio
async def test_oauth_callback_invite_email_mismatch_returns_400(
client, test_db, test_user, auth_headers, monkeypatch
):
"""If the OAuth profile's email differs from the invite's email, the
backend rejects the link with invite_email_mismatch (mirrors register)."""
from app.core.config import settings
from app.services.oauth_providers import OAuthProfile
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
with patch(
"app.core.email.EmailService.send_account_invite_email",
new_callable=AsyncMock,
return_value=True,
):
create_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "expected@example.com", "role": "engineer"},
headers=auth_headers,
)
code = create_resp.json()["code"]
profile = OAuthProfile(
provider_subject="google_invite_subject_2",
email="different@example.com",
name="Wrong Email",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
response = await client.post(
"/api/v1/auth/google/callback",
json={
"code": "auth_code_xyz",
"account_invite_code": code,
"invited_email": "expected@example.com",
},
)
assert response.status_code == 400, response.json()
assert response.json()["detail"]["error"] == "invite_email_mismatch"

View File

@@ -1,7 +1,12 @@
"""Integration tests for admin plan limits and account override endpoints."""
from unittest.mock import AsyncMock, patch
import pytest
from httpx import AsyncClient
from sqlalchemy import select
from app.models.plan_billing import PlanBilling
class TestAdminPlanLimits:
@@ -56,3 +61,204 @@ class TestAdminPlanLimits:
"""Non-admin gets 403."""
response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_admin_plan_limits_get_includes_plan_billing_fields_when_present(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""GET /admin/plan-limits returns plan_billing fields when a row exists,
and None for plans that don't have one yet."""
# Seed a plan_billing row for "pro".
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "pro")
)).scalar_one_or_none()
if existing is None:
test_db.add(PlanBilling(
plan="pro",
display_name="Pro",
description="For working teams",
monthly_price_cents=4900,
annual_price_cents=49000,
stripe_product_id="prod_seed",
stripe_monthly_price_id="price_seed_m",
stripe_annual_price_id="price_seed_a",
is_public=True,
is_archived=False,
sort_order=10,
))
await test_db.commit()
response = await client.get(
"/api/v1/admin/plan-limits", headers=admin_auth_headers
)
assert response.status_code == 200
plans_by_name = {p["plan"]: p for p in response.json()}
assert "pro" in plans_by_name
pro = plans_by_name["pro"]
assert pro["display_name"] == "Pro"
assert pro["monthly_price_cents"] == 4900
assert pro["stripe_monthly_price_id"] == "price_seed_m"
assert pro["is_public"] is True
assert pro["is_archived"] is False
assert pro["sort_order"] == 10
# A plan without a plan_billing row should still return, with None
# billing fields.
if "free" in plans_by_name:
free = plans_by_name["free"]
# free has no plan_billing row in the seed → fields are None.
no_billing_row = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "free")
)).scalar_one_or_none() is None
if no_billing_row:
assert free["display_name"] is None
assert free["monthly_price_cents"] is None
assert free["stripe_product_id"] is None
@pytest.mark.asyncio
async def test_admin_plan_limits_put_creates_plan_billing_row(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""PUT /admin/plan-limits upserts a plan_billing row when billing
fields are included in the body."""
# Ensure no plan_billing row exists for "team" yet.
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
if existing is not None:
await test_db.delete(existing)
await test_db.commit()
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "team",
"max_trees": None,
"max_sessions_per_month": None,
"max_users": None,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text", "pdf"],
"display_name": "Team",
"description": "For growing shops",
"monthly_price_cents": 9900,
"annual_price_cents": 99000,
"stripe_product_id": "prod_team_test",
"stripe_monthly_price_id": "price_team_m",
"stripe_annual_price_id": "price_team_a",
"is_public": True,
"is_archived": False,
"sort_order": 20,
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
body = response.json()
assert body["display_name"] == "Team"
assert body["monthly_price_cents"] == 9900
assert body["stripe_product_id"] == "prod_team_test"
assert body["sort_order"] == 20
# Confirm the row was actually persisted.
await test_db.commit() # ensure session sees other-session writes
pb = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
assert pb is not None
assert pb.display_name == "Team"
assert pb.monthly_price_cents == 9900
assert pb.stripe_monthly_price_id == "price_team_m"
assert pb.is_public is True
@pytest.mark.asyncio
async def test_admin_plan_limits_put_does_not_null_out_required_fields(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""PUT /admin/plan-limits must not NULL out NOT NULL columns on the
plan_billing row when the caller passes explicit nulls. The set of
guarded fields is {display_name, is_public, is_archived, sort_order}.
"""
# Seed a plan_billing row for "team" with non-default values for every
# NOT NULL field so we can detect any clobbering.
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
if existing is not None:
await test_db.delete(existing)
await test_db.commit()
seeded = PlanBilling(
plan="team",
display_name="Team Seeded",
is_public=False,
is_archived=True,
sort_order=5,
)
test_db.add(seeded)
await test_db.commit()
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "team",
"max_trees": None,
"max_sessions_per_month": None,
"max_users": None,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text"],
# Explicit nulls for every NOT NULL plan_billing field.
"display_name": None,
"is_public": None,
"is_archived": None,
"sort_order": None,
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
# Confirm the seeded NOT NULL values were preserved.
await test_db.commit() # ensure session sees writes from the request
pb = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
assert pb is not None
assert pb.display_name == "Team Seeded"
assert pb.is_public is False
assert pb.is_archived is True
assert pb.sort_order == 5
@pytest.mark.asyncio
async def test_admin_plan_limits_put_invalidates_billing_cache(
self, client: AsyncClient, admin_auth_headers: dict
):
"""PUT /admin/plan-limits calls BillingService.invalidate_billing_cache
with the account_ids on the affected plan."""
# Patch the staticmethod on the class. The endpoint imports
# BillingService at module load, so patch the symbol on the class
# itself — both the import and the dotted reference resolve to it.
with patch(
"app.api.endpoints.admin_plan_limits.BillingService.invalidate_billing_cache",
new_callable=AsyncMock,
) as spy:
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "pro",
"max_trees": 25,
"max_sessions_per_month": 500,
"max_users": 10,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text"],
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
spy.assert_awaited_once()
(account_ids_arg,) = spy.await_args.args
# admin fixture seeds an active Pro Subscription, so we expect at
# least one account_id in the invalidation list.
assert isinstance(account_ids_arg, list)
assert len(account_ids_arg) >= 1

View File

@@ -0,0 +1,43 @@
"""Integration tests for the legacy /beta-signup redirect.
Phase 2 retires the public beta-signup form in favor of the regular
register flow. The endpoint stays mounted but answers with a 307 to
the absolute frontend `/register?from=beta` URL so any external links
keep working. There is no `beta_signup` table to migrate — the old
endpoint only fired an email notification — so this test only covers
the redirect contract.
"""
import pytest
from app.core.config import settings
@pytest.mark.asyncio
async def test_beta_signup_redirects_to_register(client, monkeypatch):
"""POST /beta-signup returns 307 to the absolute frontend register URL."""
monkeypatch.setattr(settings, "FRONTEND_URL", "https://example.com")
response = await client.post(
"/api/v1/beta-signup",
json={"email": "anyone@example.com"},
)
assert response.status_code == 307, response.text
assert (
response.headers["location"]
== "https://example.com/register?from=beta"
)
@pytest.mark.asyncio
async def test_beta_signup_redirect_ignores_body(client, monkeypatch):
"""Redirect fires regardless of payload — no validation on the legacy route."""
monkeypatch.setattr(settings, "FRONTEND_URL", "https://example.com")
response = await client.post("/api/v1/beta-signup", json={})
assert response.status_code == 307
assert (
response.headers["location"]
== "https://example.com/register?from=beta"
)

View File

@@ -0,0 +1,83 @@
import uuid
import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy import select
from app.models.account import Account
@pytest.mark.asyncio
async def test_billing_portal_returns_url_for_account_with_stripe_customer(
client, test_db, test_user, auth_headers, monkeypatch
):
"""Happy path: account has a stripe_customer_id and Stripe is configured →
GET /billing/portal-session returns the portal URL."""
from app.core.config import settings
monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy")
monkeypatch.setattr(settings, "FRONTEND_URL", "https://app.example.com")
account_id = uuid.UUID(test_user["user_data"]["account_id"])
account = (await test_db.execute(
select(Account).where(Account.id == account_id)
)).scalar_one()
account.stripe_customer_id = "cus_test_456"
await test_db.commit()
fake_session = MagicMock()
fake_session.url = "https://billing.stripe.com/p/session/test_abc"
with patch(
"stripe.billing_portal.Session.create",
return_value=fake_session,
) as portal_mock:
response = await client.get(
"/api/v1/billing/portal-session",
headers=auth_headers,
)
assert response.status_code == 200, response.json()
assert response.json() == {"url": "https://billing.stripe.com/p/session/test_abc"}
portal_mock.assert_called_once()
call_kwargs = portal_mock.call_args.kwargs
assert call_kwargs["customer"] == "cus_test_456"
assert call_kwargs["return_url"] == "https://app.example.com/account/billing"
@pytest.mark.asyncio
async def test_billing_portal_returns_503_when_stripe_not_configured(
client, test_db, test_user, auth_headers, monkeypatch
):
"""STRIPE_SECRET_KEY unset → settings.stripe_enabled is False → 503."""
from app.core.config import settings
monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", None)
response = await client.get(
"/api/v1/billing/portal-session",
headers=auth_headers,
)
assert response.status_code == 503
assert response.json()["detail"]["error"] == "stripe_not_configured"
@pytest.mark.asyncio
async def test_billing_portal_returns_400_when_account_has_no_stripe_customer(
client, test_db, test_user, auth_headers, monkeypatch
):
"""Account with no stripe_customer_id (never completed checkout) → 400
with `no_stripe_customer` error."""
from app.core.config import settings
monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy")
# test_user fixture seeds an account with no stripe_customer_id by default.
account_id = uuid.UUID(test_user["user_data"]["account_id"])
account = (await test_db.execute(
select(Account).where(Account.id == account_id)
)).scalar_one()
assert account.stripe_customer_id is None
response = await client.get(
"/api/v1/billing/portal-session",
headers=auth_headers,
)
assert response.status_code == 400
assert response.json()["detail"]["error"] == "no_stripe_customer"

View File

@@ -0,0 +1,100 @@
"""Integration tests for the public runtime config endpoint.
Covers GET /api/v1/config/public and the SELF_SERVE_ENABLED interaction
with the existing /auth/register invite-code gate.
"""
from __future__ import annotations
import pytest
from httpx import AsyncClient
from app.core.config import settings
class TestConfigPublic:
"""GET /api/v1/config/public — anonymous, no auth."""
@pytest.mark.asyncio
async def test_get_config_public_returns_self_serve_flag(
self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
):
"""Endpoint reflects the current SELF_SERVE_ENABLED setting and the
configured OAuth providers, with no auth required."""
# Default-off: SELF_SERVE_ENABLED is False unless explicitly set.
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
monkeypatch.setattr(settings, "MS_CLIENT_ID", None)
response = await client.get("/api/v1/config/public")
assert response.status_code == 200
body = response.json()
assert body == {"self_serve_enabled": False, "oauth_providers": []}
# Flip it on, with both OAuth providers configured.
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", True)
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "google-test-id")
monkeypatch.setattr(settings, "MS_CLIENT_ID", "ms-test-id")
response = await client.get("/api/v1/config/public")
assert response.status_code == 200
body = response.json()
assert body["self_serve_enabled"] is True
assert body["oauth_providers"] == ["google", "microsoft"]
# Only Microsoft configured.
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
monkeypatch.setattr(settings, "MS_CLIENT_ID", "ms-test-id")
response = await client.get("/api/v1/config/public")
assert response.status_code == 200
assert response.json()["oauth_providers"] == ["microsoft"]
class TestRegisterInviteCodeGate:
"""Regression + new-behavior tests for /auth/register vs SELF_SERVE_ENABLED."""
@pytest.mark.asyncio
async def test_register_invite_code_required_when_self_serve_disabled(
self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
):
"""Pre-self-serve behavior: REQUIRE_INVITE_CODE=True without an
invite code (and no account-invite) must still 400."""
monkeypatch.setattr(settings, "REQUIRE_INVITE_CODE", True)
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
response = await client.post(
"/api/v1/auth/register",
json={
"email": "no-invite@example.com",
"password": "SecurePass123!",
"name": "No Invite",
},
)
assert response.status_code == 400
assert "invite code is required" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_register_invite_code_optional_when_self_serve_enabled(
self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
):
"""Self-serve on: registration succeeds with no invite code even
when REQUIRE_INVITE_CODE is True. The user, personal account, and
a Pro trial subscription are all created."""
monkeypatch.setattr(settings, "REQUIRE_INVITE_CODE", True)
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", True)
response = await client.post(
"/api/v1/auth/register",
json={
"email": "self-serve@example.com",
"password": "SecurePass123!",
"name": "Self Serve",
},
)
assert response.status_code == 201, response.text
body = response.json()
assert body["email"] == "self-serve@example.com"
assert body["account_role"] == "owner"
assert "account_id" in body

View File

@@ -2,8 +2,10 @@ import uuid
import pytest
from unittest.mock import patch
from sqlalchemy import select
from app.core.security import decode_token, hash_token
from app.models.user import User
from app.models.oauth_identity import OAuthIdentity
from app.models.refresh_token import RefreshToken
from app.models.subscription import Subscription
from app.services.oauth_providers import OAuthProfile
@@ -118,3 +120,77 @@ async def test_microsoft_callback_creates_user(client, test_db, monkeypatch):
select(OAuthIdentity).where(OAuthIdentity.user_id == user.id)
)).scalar_one()
assert identity.provider == "microsoft"
@pytest.mark.asyncio
async def test_oauth_google_callback_stores_refresh_token_jti(
client, test_db, monkeypatch
):
"""A successful Google OAuth callback must persist the refresh-token JTI
in the refresh_tokens table — otherwise /auth/refresh rejects it."""
from app.core.config import settings
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
profile = OAuthProfile(
provider_subject="google_subject_jti_test",
email="jtitest@example.com",
name="JTI Test",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
response = await client.post(
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
)
assert response.status_code == 200, response.json()
body = response.json()
refresh_token_str = body["refresh_token"]
payload = decode_token(refresh_token_str)
assert payload is not None
jti = payload["jti"]
token_hash = hash_token(jti)
user = (await test_db.execute(
select(User).where(User.email == "jtitest@example.com")
)).scalar_one()
stored = (await test_db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)).scalar_one_or_none()
assert stored is not None, "OAuth callback did not persist refresh-token JTI"
assert stored.user_id == user.id
assert stored.revoked_at is None
@pytest.mark.asyncio
async def test_oauth_refresh_works_after_oauth_signup(
client, test_db, monkeypatch
):
"""End-to-end: OAuth callback issues a refresh token; calling /auth/refresh
with that token must succeed (not be rejected as revoked)."""
from app.core.config import settings
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
profile = OAuthProfile(
provider_subject="google_subject_refresh_test",
email="refresh-after-oauth@example.com",
name="Refresh After OAuth",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
callback_resp = await client.post(
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
)
assert callback_resp.status_code == 200, callback_resp.json()
refresh_token_str = callback_resp.json()["refresh_token"]
refresh_resp = await client.post(
"/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {refresh_token_str}"},
)
assert refresh_resp.status_code == 200, refresh_resp.json()
refreshed = refresh_resp.json()
assert refreshed["access_token"]
assert refreshed["refresh_token"]
# Token rotation: new refresh token differs from the original.
assert refreshed["refresh_token"] != refresh_token_str

View File

@@ -1,6 +1,11 @@
"""Tests for onboarding status endpoints."""
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from app.models.user import User
@pytest.mark.asyncio
@@ -21,6 +26,42 @@ async def test_onboarding_status_fresh_user(client, auth_headers):
assert data["connected_psa"] is False
assert data["is_team_user"] is False
assert data["dismissed"] is False
# Phase 2 fields default to false on a fresh, unverified user with no wizard progress.
assert data["email_verified"] is False
assert data["shop_setup_done"] is False
@pytest.mark.asyncio
async def test_onboarding_status_includes_email_verified_and_shop_setup_done(
client, auth_headers, test_user, test_db
):
"""email_verified flips when email_verified_at is set; shop_setup_done flips at step >= 1."""
# Sanity-check baseline.
response = await client.get(
"/api/v1/users/onboarding-status",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["email_verified"] is False
assert data["shop_setup_done"] is False
# Mutate the underlying user, then re-fetch.
user_email = test_user["email"]
result = await test_db.execute(select(User).where(User.email == user_email))
user = result.scalar_one()
user.email_verified_at = datetime.now(tz=timezone.utc)
user.onboarding_step_completed = 1
await test_db.commit()
response = await client.get(
"/api/v1/users/onboarding-status",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["email_verified"] is True
assert data["shop_setup_done"] is True
@pytest.mark.asyncio

View File

@@ -0,0 +1,149 @@
"""Tests for welcome-wizard onboarding-step endpoints (Phase 2)."""
import pytest
from sqlalchemy import select
from app.models.account import Account
from app.models.user import User
@pytest.mark.asyncio
async def test_onboarding_step1_complete_writes_account_name_and_team_size_and_role(
client, auth_headers, test_db, test_user
):
"""Step 1 + complete writes account.name + team_size_bucket + user.role_at_signup
and advances onboarding_step_completed to 1."""
response = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={
"step": 1,
"action": "complete",
"data": {
"company_name": "Acme MSP",
"team_size_bucket": "3-5",
"role_at_signup": "owner",
},
},
)
assert response.status_code == 200, response.text
data = response.json()
assert data["onboarding_step_completed"] == 1
assert data["onboarding_dismissed"] is False
# Verify persisted writes
account_id = test_user["user_data"]["account_id"]
user_email = test_user["email"]
acct = (
await test_db.execute(select(Account).where(Account.id == account_id))
).scalar_one()
assert acct.name == "Acme MSP"
assert acct.team_size_bucket == "3-5"
user = (
await test_db.execute(select(User).where(User.email == user_email))
).scalar_one()
assert user.role_at_signup == "owner"
assert user.onboarding_step_completed == 1
@pytest.mark.asyncio
async def test_onboarding_step2_skip_advances_without_psa(
client, auth_headers, test_db, test_user
):
"""Step 2 + skip ignores data entirely and only advances the step counter
(no primary_psa write)."""
# Capture original account.primary_psa so we can assert it's untouched.
account_id = test_user["user_data"]["account_id"]
acct_before = (
await test_db.execute(select(Account).where(Account.id == account_id))
).scalar_one()
psa_before = acct_before.primary_psa # likely None
# Advance step 1 first so step 2 is allowed.
r1 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={"step": 1, "action": "skip"},
)
assert r1.status_code == 200, r1.text
# Skip step 2 — even if data is present it must be ignored.
r2 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={
"step": 2,
"action": "skip",
"data": {"primary_psa": "connectwise"},
},
)
assert r2.status_code == 200, r2.text
assert r2.json()["onboarding_step_completed"] == 2
# Re-fetch account: primary_psa must NOT have been written.
test_db.expire_all()
acct_after = (
await test_db.execute(select(Account).where(Account.id == account_id))
).scalar_one()
assert acct_after.primary_psa == psa_before
@pytest.mark.asyncio
async def test_onboarding_step_cannot_decrease(client, auth_headers):
"""A step=2 PATCH followed by step=1 must return 400."""
# Advance to step 2.
r1 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={"step": 1, "action": "skip"},
)
assert r1.status_code == 200, r1.text
r2 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={"step": 2, "action": "skip"},
)
assert r2.status_code == 200, r2.text
assert r2.json()["onboarding_step_completed"] == 2
# Try to go back to step 1 — must fail.
r3 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={"step": 1, "action": "skip"},
)
assert r3.status_code == 400, r3.text
# Idempotent re-PATCH of same step succeeds.
r4 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={"step": 2, "action": "skip"},
)
assert r4.status_code == 200, r4.text
assert r4.json()["onboarding_step_completed"] == 2
@pytest.mark.asyncio
async def test_onboarding_dismiss_rest_sets_flag(
client, auth_headers, test_db, test_user
):
"""POST /users/me/onboarding-dismiss-rest sets users.onboarding_dismissed=TRUE."""
response = await client.post(
"/api/v1/users/me/onboarding-dismiss-rest",
headers=auth_headers,
)
assert response.status_code == 200, response.text
data = response.json()
assert data["onboarding_dismissed"] is True
# step counter is whatever it was (None for a fresh user).
assert "onboarding_step_completed" in data
# Verify persisted.
user_email = test_user["email"]
user = (
await test_db.execute(select(User).where(User.email == user_email))
).scalar_one()
assert user.onboarding_dismissed is True

View File

@@ -0,0 +1,132 @@
"""Integration tests for the public plans endpoint.
Covers GET /api/v1/plans/public — the marketing /pricing page data source.
"""
from __future__ import annotations
import pytest
from httpx import AsyncClient
from sqlalchemy import delete
from app.models.plan_billing import PlanBilling
from app.models.plan_limits import PlanLimits
async def _seed_plan_limits(test_db, plan: str, max_users: int | None) -> None:
"""Ensure a plan_limits row exists for the given plan name."""
existing = await test_db.get(PlanLimits, plan)
if existing is None:
test_db.add(
PlanLimits(
plan=plan,
max_trees=None,
max_sessions_per_month=None,
max_users=max_users,
custom_branding=False,
priority_support=False,
export_formats=["markdown", "text"],
)
)
await test_db.commit()
class TestGetPlansPublic:
"""GET /api/v1/plans/public — anonymous, no auth."""
@pytest.mark.asyncio
async def test_get_plans_public_returns_only_is_public_rows(
self, client: AsyncClient, test_db
):
"""Rows with is_public=False or is_archived=True must NOT appear."""
# Wipe any existing billing rows so this test owns the fixture state.
await test_db.execute(delete(PlanBilling))
await test_db.commit()
await _seed_plan_limits(test_db, "starter", 3)
await _seed_plan_limits(test_db, "pro", 10)
await _seed_plan_limits(test_db, "internal", None)
await _seed_plan_limits(test_db, "legacy", 5)
test_db.add_all(
[
PlanBilling(
plan="starter",
display_name="Starter",
monthly_price_cents=1900,
is_public=True,
is_archived=False,
sort_order=10,
),
PlanBilling(
plan="pro",
display_name="Pro",
monthly_price_cents=4900,
is_public=True,
is_archived=False,
sort_order=20,
),
PlanBilling(
plan="internal",
display_name="Internal",
is_public=False, # hidden
is_archived=False,
sort_order=30,
),
PlanBilling(
plan="legacy",
display_name="Legacy",
is_public=True,
is_archived=True, # archived
sort_order=40,
),
]
)
await test_db.commit()
response = await client.get("/api/v1/plans/public")
assert response.status_code == 200
plans = response.json()
plan_names = {p["plan"] for p in plans}
assert "starter" in plan_names
assert "pro" in plan_names
assert "internal" not in plan_names
assert "legacy" not in plan_names
# Schema sanity check
starter = next(p for p in plans if p["plan"] == "starter")
assert starter["display_name"] == "Starter"
assert starter["monthly_price_cents"] == 1900
assert starter["max_seats"] == 3
assert starter["is_public"] is True
@pytest.mark.asyncio
async def test_get_plans_public_orders_by_sort_order_then_plan(
self, client: AsyncClient, test_db
):
"""Result must be ordered by sort_order ASC, then plan name ASC."""
await test_db.execute(delete(PlanBilling))
await test_db.commit()
# plan_limits rows for FK satisfaction
for name in ("alpha", "bravo", "charlie", "delta"):
await _seed_plan_limits(test_db, name, None)
# Two with sort_order=10 (charlie should come before delta by plan ASC),
# one with sort_order=5 (alpha first overall),
# one with sort_order=20 (bravo last).
test_db.add_all(
[
PlanBilling(plan="charlie", display_name="C", sort_order=10, is_public=True, is_archived=False),
PlanBilling(plan="delta", display_name="D", sort_order=10, is_public=True, is_archived=False),
PlanBilling(plan="alpha", display_name="A", sort_order=5, is_public=True, is_archived=False),
PlanBilling(plan="bravo", display_name="B", sort_order=20, is_public=True, is_archived=False),
]
)
await test_db.commit()
response = await client.get("/api/v1/plans/public")
assert response.status_code == 200
ordered = [p["plan"] for p in response.json()]
assert ordered == ["alpha", "charlie", "delta", "bravo"]

View File

@@ -0,0 +1,134 @@
"""Integration tests for the public Talk-to-Sales endpoint.
POST /api/v1/sales-leads — no auth, rate-limited 5/hour per IP.
"""
from unittest.mock import AsyncMock, patch
import pytest
import sqlalchemy as sa
@pytest.mark.asyncio
async def test_sales_lead_creates_row_and_sends_notification_email(client, test_db):
"""Happy path: row inserted, notification email fired, 201 returned."""
payload = {
"email": "buyer@acme.example",
"name": "Pat Buyer",
"company": "Acme MSP",
"team_size": "11-50",
"message": "We're evaluating ResolutionFlow for our NOC team.",
"source": "pricing_page",
"posthog_distinct_id": "ph_distinct_123",
}
with patch(
"app.api.endpoints.sales_leads.EmailService.send_sales_lead_notification",
new=AsyncMock(return_value=True),
) as mock_email:
response = await client.post("/api/v1/sales-leads", json=payload)
assert response.status_code == 201, response.text
body = response.json()
assert body["status"] == "received"
assert "id" in body
# Notification email was attempted (asyncio.create_task — give it a tick).
import asyncio
await asyncio.sleep(0)
await asyncio.sleep(0)
assert mock_email.await_count == 1
kwargs = mock_email.await_args.kwargs
assert kwargs["to_email"] # default placeholder until cutover
assert kwargs["lead"].email == "buyer@acme.example"
assert kwargs["lead"].source == "pricing_page"
# Row was inserted with normalized email + all fields preserved.
result = await test_db.execute(
sa.text("SELECT email, name, company, team_size, message, source, posthog_distinct_id, status FROM sales_leads")
)
rows = result.all()
assert len(rows) == 1
row = rows[0]
assert row.email == "buyer@acme.example"
assert row.name == "Pat Buyer"
assert row.company == "Acme MSP"
assert row.team_size == "11-50"
assert row.message == "We're evaluating ResolutionFlow for our NOC team."
assert row.source == "pricing_page"
assert row.posthog_distinct_id == "ph_distinct_123"
assert row.status == "new"
@pytest.mark.asyncio
async def test_sales_lead_email_failure_does_not_fail_request(client, test_db):
"""If the email send raises, the API still returns 201 and the row persists."""
payload = {
"email": "buyer2@acme.example",
"name": "Sam Lead",
"company": "Acme MSP",
"source": "register_footer",
}
with patch(
"app.api.endpoints.sales_leads.EmailService.send_sales_lead_notification",
new=AsyncMock(side_effect=RuntimeError("resend exploded")),
):
response = await client.post("/api/v1/sales-leads", json=payload)
assert response.status_code == 201, response.text
# Row must still be persisted even though email failed.
import asyncio
await asyncio.sleep(0)
result = await test_db.execute(
sa.text("SELECT count(*) FROM sales_leads WHERE email = 'buyer2@acme.example'")
)
assert result.scalar() == 1
@pytest.mark.asyncio
async def test_sales_lead_rate_limited_after_5_per_hour(client):
"""The 6th submission within an hour from the same IP returns 429.
The default `limiter` is disabled in tests (DEBUG=true). We re-enable it
for this test, then reset its state on teardown so other tests aren't
affected.
"""
from app.core.rate_limit import limiter
was_enabled = limiter.enabled
limiter.enabled = True
try:
limiter.reset()
with patch(
"app.api.endpoints.sales_leads.EmailService.send_sales_lead_notification",
new=AsyncMock(return_value=True),
):
for i in range(5):
payload = {
"email": f"lead{i}@acme.example",
"name": f"Lead {i}",
"company": "Acme MSP",
"source": "landing_page",
}
resp = await client.post("/api/v1/sales-leads", json=payload)
assert resp.status_code == 201, f"submission {i}: {resp.text}"
# 6th should be rate-limited.
resp = await client.post(
"/api/v1/sales-leads",
json={
"email": "lead6@acme.example",
"name": "Lead 6",
"company": "Acme MSP",
"source": "landing_page",
},
)
assert resp.status_code == 429, resp.text
finally:
limiter.reset()
limiter.enabled = was_enabled

View File

@@ -142,3 +142,178 @@ async def test_webhook_idempotency(
assert r2.status_code == 200
assert r1.json()["applied"] is True
assert r2.json()["applied"] is False
# ----------------------------------------------------------------------------
# Atomic-idempotency regression tests
# ----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_apply_event_handler_failure_does_not_persist_idempotency_mark(
test_db, test_user,
):
"""If the handler raises, the StripeEvent row must NOT be persisted —
otherwise Stripe's retry will be silently dropped as a duplicate and the
subscription state will desync from Stripe."""
from app.services.billing import BillingService
from app.models.stripe_event import StripeEvent
event_id = "evt_handler_fail_1"
payload = {"data": {"object": {
"id": "sub_doesnotmatter",
"status": "active",
"current_period_start": 1714521600,
"current_period_end": 1717113600,
"items": {"data": [{"quantity": 1}]},
"cancel_at_period_end": False,
}}}
boom = RuntimeError("simulated handler failure")
with patch(
"app.services.billing._handle_subscription_updated",
side_effect=boom,
):
with pytest.raises(RuntimeError, match="simulated handler failure"):
await BillingService.apply_subscription_event(
test_db,
event_id=event_id,
event_type="customer.subscription.updated",
payload=payload,
)
# The StripeEvent row must not exist — handler raised, the entire
# transaction (idempotency mark + partial mutations) was rolled back.
row = (await test_db.execute(
select(StripeEvent).where(StripeEvent.id == event_id)
)).scalar_one_or_none()
assert row is None, (
"StripeEvent row was persisted despite handler failure — "
"Stripe retry will be silently dropped"
)
@pytest.mark.asyncio
async def test_apply_event_retry_after_failure_succeeds(
test_db, test_user,
):
"""A failed first attempt followed by a successful retry must apply state.
This is the core Stripe webhook retry contract."""
from app.services.billing import BillingService
from app.models.stripe_event import StripeEvent
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
test_db.add(Subscription(
account_id=account_id, plan="pro", status="trialing",
stripe_subscription_id="sub_retry",
))
await test_db.commit()
event_id = "evt_retry_1"
payload = {"data": {"object": {
"id": "sub_retry",
"status": "active",
"current_period_start": 1714521600,
"current_period_end": 1717113600,
"items": {"data": [{"quantity": 3}]},
"cancel_at_period_end": False,
}}}
# First attempt — handler raises mid-flight.
with patch(
"app.services.billing._handle_subscription_updated",
side_effect=RuntimeError("transient blip"),
):
with pytest.raises(RuntimeError):
await BillingService.apply_subscription_event(
test_db,
event_id=event_id,
event_type="customer.subscription.updated",
payload=payload,
)
# No idempotency mark, sub still trialing.
row = (await test_db.execute(
select(StripeEvent).where(StripeEvent.id == event_id)
)).scalar_one_or_none()
assert row is None
sub = (await test_db.execute(
select(Subscription).where(Subscription.account_id == account_id)
)).scalar_one()
assert sub.status == "trialing"
# Second attempt — same event_id, handler succeeds.
applied = await BillingService.apply_subscription_event(
test_db,
event_id=event_id,
event_type="customer.subscription.updated",
payload=payload,
)
assert applied is True
# Idempotency mark now persisted, sub state reconciled.
row = (await test_db.execute(
select(StripeEvent).where(StripeEvent.id == event_id)
)).scalar_one()
assert row.id == event_id
await test_db.refresh(sub)
assert sub.status == "active"
assert sub.seat_limit == 3
@pytest.mark.asyncio
async def test_apply_event_duplicate_event_id_skips(
test_db, test_user,
):
"""Two successful invocations with the same event_id must not double-apply.
Second call returns False; mutations are not repeated."""
from app.services.billing import BillingService
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
test_db.add(Subscription(
account_id=account_id, plan="pro", status="trialing",
stripe_subscription_id="sub_dup",
))
await test_db.commit()
event_id = "evt_dedupe_1"
payload = {"data": {"object": {
"id": "sub_dup",
"status": "active",
"current_period_start": 1714521600,
"current_period_end": 1717113600,
"items": {"data": [{"quantity": 7}]},
"cancel_at_period_end": False,
}}}
applied1 = await BillingService.apply_subscription_event(
test_db,
event_id=event_id,
event_type="customer.subscription.updated",
payload=payload,
)
assert applied1 is True
sub = (await test_db.execute(
select(Subscription).where(Subscription.account_id == account_id)
)).scalar_one()
assert sub.status == "active"
assert sub.seat_limit == 7
# Mutate locally so we can prove the second call doesn't re-run the handler.
sub.seat_limit = 99
await test_db.commit()
applied2 = await BillingService.apply_subscription_event(
test_db,
event_id=event_id,
event_type="customer.subscription.updated",
payload=payload,
)
assert applied2 is False
await test_db.refresh(sub)
# Handler did NOT run again — our local mutation is preserved.
assert sub.seat_limit == 99