feat: self-serve signup Phase 2 (frontend cutover) (#162)
Some checks failed
CI / e2e (push) Has been cancelled
CI / frontend (push) Has been cancelled
CI / backend (push) Has been cancelled
Mirror to GitHub / mirror (push) Has been cancelled

Co-authored-by: Michael Chihlas <michael@resolutionflow.com>
Co-committed-by: Michael Chihlas <michael@resolutionflow.com>
This commit was merged in pull request #162.
This commit is contained in:
2026-05-07 18:42:20 +00:00
committed by chihlasm
parent f918b766b0
commit f1be3abcc5
123 changed files with 11563 additions and 559 deletions

View File

@@ -235,6 +235,7 @@ _SUBSCRIPTION_GUARD_ALLOWLIST = {
"/api/v1/billing/portal-session",
"/api/v1/users/me",
"/api/v1/users/me/onboarding-step",
"/api/v1/users/me/onboarding-dismiss-rest",
}
@@ -298,6 +299,8 @@ _EMAIL_VERIFICATION_ALLOWLIST = {
"/api/v1/auth/email/verify",
"/api/v1/auth/password/change",
"/api/v1/users/me",
"/api/v1/users/me/onboarding-step",
"/api/v1/users/me/onboarding-dismiss-rest",
"/api/v1/billing/state",
"/api/v1/billing/checkout-session",
"/api/v1/billing/portal-session",

View File

@@ -0,0 +1,54 @@
"""Public endpoint for resolving an account invite code into display info.
Mounted as a public route (no tenant context, no auth) — used by the
/accept-invite page on the frontend so an invitee can see what account they
are about to join before they sign up. Uses the BYPASSRLS admin session
factory because account_invites is account-scoped under Phase 4 RLS but the
caller has no tenant identity yet.
"""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.admin_database import get_admin_db
from app.models.account_invite import AccountInvite
from app.schemas.oauth import InviteLookupResponse
router = APIRouter(prefix="/accounts", tags=["account-invite-lookup"])
@router.get("/invites/{code}/lookup", response_model=InviteLookupResponse)
async def lookup_invite(
code: str,
db: Annotated[AsyncSession, Depends(get_admin_db)],
) -> InviteLookupResponse:
"""Return minimal display data for a valid (unused, unexpired, not revoked)
invite. Returns 404 with `invite_invalid_or_expired_or_revoked` for any
invalid state — the AcceptInvitePage shows a single "ask the inviter to
resend" message regardless of which condition failed (anti-enumeration)."""
result = await db.execute(
select(AccountInvite)
.where(AccountInvite.code == code)
.options(
joinedload(AccountInvite.account),
joinedload(AccountInvite.invited_by),
)
)
invite = result.scalar_one_or_none()
if invite is None or not invite.is_valid:
raise HTTPException(
status_code=404,
detail={"error": "invite_invalid_or_expired_or_revoked"},
)
return InviteLookupResponse(
account_name=invite.account.name,
inviter_name=invite.invited_by.name,
invited_email=invite.email,
role=invite.role,
)

View File

@@ -8,34 +8,101 @@ from app.core.database import get_db
from app.core.audit import log_audit
from app.models.user import User
from app.models.plan_limits import PlanLimits
from app.models.plan_billing import PlanBilling
from app.models.account import Account
from app.models.account_limit_override import AccountLimitOverride
from app.models.subscription import Subscription
from app.schemas.admin import (
PlanLimitResponse, PlanLimitUpdate,
PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse,
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
)
from app.api.deps import require_admin
from app.services.billing import BillingService
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
@router.get("/plan-limits", response_model=list[PlanLimitResponse])
# Fields on PlanLimitUpdate that map to plan_billing (not plan_limits).
_PLAN_BILLING_FIELDS = (
"display_name",
"description",
"monthly_price_cents",
"annual_price_cents",
"stripe_product_id",
"stripe_monthly_price_id",
"stripe_annual_price_id",
"is_public",
"is_archived",
"sort_order",
)
# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model.
# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit
# null for any of them would otherwise trigger a NOT NULL violation at commit.
_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({
"display_name",
"is_public",
"is_archived",
"sort_order",
})
def _merge_plan_with_billing(
plan: PlanLimits, billing: PlanBilling | None
) -> PlanLimitWithBillingResponse:
"""Build a merged response. Billing fields are None when no plan_billing row
exists for the plan."""
payload = {
"plan": plan.plan,
"max_trees": plan.max_trees,
"max_sessions_per_month": plan.max_sessions_per_month,
"max_users": plan.max_users,
"custom_branding": plan.custom_branding,
"priority_support": plan.priority_support,
"export_formats": plan.export_formats or [],
}
if billing is not None:
payload.update({
"display_name": billing.display_name,
"description": billing.description,
"monthly_price_cents": billing.monthly_price_cents,
"annual_price_cents": billing.annual_price_cents,
"stripe_product_id": billing.stripe_product_id,
"stripe_monthly_price_id": billing.stripe_monthly_price_id,
"stripe_annual_price_id": billing.stripe_annual_price_id,
"is_public": billing.is_public,
"is_archived": billing.is_archived,
"sort_order": billing.sort_order,
})
return PlanLimitWithBillingResponse(**payload)
@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse])
async def list_plan_limits(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""List all plan limit configurations."""
result = await db.execute(select(PlanLimits))
return result.scalars().all()
"""List all plan limit configurations, merged with plan_billing fields
where present. Plans without a plan_billing row return None for the
billing fields."""
rows = (await db.execute(
select(PlanLimits, PlanBilling)
.outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan)
)).all()
return [_merge_plan_with_billing(pl, pb) for pl, pb in rows]
@router.put("/plan-limits", response_model=PlanLimitResponse)
@router.put("/plan-limits", response_model=PlanLimitWithBillingResponse)
async def update_plan_limits(
data: PlanLimitUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Update a plan's limits."""
"""Update a plan's limits and (if any plan_billing field is included)
upsert the matching plan_billing row in the same transaction. After
commit, invalidates the in-process billing cache for accounts on this
plan (currently a no-op — see BillingService.invalidate_billing_cache).
"""
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
plan = result.scalar_one_or_none()
if not plan:
@@ -48,10 +115,50 @@ async def update_plan_limits(
plan.priority_support = data.priority_support
plan.export_formats = data.export_formats
await log_audit(db, current_user.id, "plan_limits.update", "plan_limits", details={"plan": data.plan})
# Did the request include any plan_billing field? (Pydantic gives us
# `model_fields_set` to distinguish "user passed null" from "field omitted".)
billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS)
billing: PlanBilling | None = None
if billing_fields_set:
billing = (await db.execute(
select(PlanBilling).where(PlanBilling.plan == data.plan)
)).scalar_one_or_none()
if billing is None:
# Create. display_name is required on the model — derive from the
# plan name when the caller didn't supply one (e.g. "pro" → "Pro").
display_name = data.display_name or data.plan.capitalize()
billing = PlanBilling(plan=data.plan, display_name=display_name)
db.add(billing)
# Apply only the fields the caller actually included. Allows partial
# updates without clobbering existing values.
for field in billing_fields_set:
value = getattr(data, field)
if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS:
# Don't NULL out a NOT NULL column on update.
continue
setattr(billing, field, value)
await log_audit(
db, current_user.id, "plan_limits.update", "plan_limits",
details={"plan": data.plan, "updated_billing": bool(billing_fields_set)},
)
await db.commit()
await db.refresh(plan)
return plan
if billing is not None:
await db.refresh(billing)
# Invalidate any in-process billing cache for accounts on this plan.
# TODO: invalidate app.state.billing_cache when added.
account_ids = [
row[0] for row in (await db.execute(
select(Subscription.account_id).where(Subscription.plan == data.plan)
)).all()
]
await BillingService.invalidate_billing_cache(account_ids)
return _merge_plan_with_billing(plan, billing)
@router.get("/account-overrides", response_model=list[AccountOverrideResponse])

View File

@@ -47,8 +47,16 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["authentication"])
async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
"""Decode a refresh token JWT and store its hash in the database."""
async def store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
"""Decode a refresh token JWT and store its hash in the database.
Module-public so OAuth callback endpoints (and any future token-issuing
surface) can register the JTI in the ``refresh_tokens`` table the same
way ``/auth/login`` does. Without this the first ``/auth/refresh`` call
will reject the token as "revoked" because no row exists.
Caller is responsible for committing the session.
"""
payload = decode_token(refresh_token_str)
if payload and payload.get("jti"):
token_record = RefreshToken(
@@ -136,7 +144,15 @@ async def register(
# Validate platform invite code (skip if account invite was provided)
invite_code_record = None
if not account_invite_record:
if settings.REQUIRE_INVITE_CODE and not user_data.invite_code:
# When SELF_SERVE_ENABLED is on, the platform invite gate is bypassed
# entirely — public self-serve signup is the whole point. The
# invite_code field stays in the schema for backward compatibility
# and so paid/trial-bearing codes still apply when supplied.
if (
settings.REQUIRE_INVITE_CODE
and not settings.SELF_SERVE_ENABLED
and not user_data.invite_code
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code is required"
@@ -312,7 +328,7 @@ async def login(
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await _store_refresh_token(db, refresh_token_str, user.id)
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return Token(
@@ -347,7 +363,7 @@ async def login_json(
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await _store_refresh_token(db, refresh_token_str, user.id)
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return Token(
@@ -405,7 +421,7 @@ async def refresh_token(
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store new refresh token
await _store_refresh_token(db, new_refresh_token_str, user.id)
await store_refresh_token(db, new_refresh_token_str, user.id)
await db.commit()
return Token(

View File

@@ -1,31 +1,44 @@
"""Public beta signup endpoint — no auth required."""
"""Legacy beta signup endpoint — redirects to /register?from=beta.
Phase 2 (self-serve signup) makes the public register flow the canonical
front door. The old `/api/v1/beta-signup` POST endpoint is kept mounted to
preserve any external links that still hit it, but now responds with a
307 Temporary Redirect to `/register?from=beta` so the user lands in the
real signup flow. The `?from=beta` marker lets the frontend tag the
signup origin for analytics.
Note: there is no `beta_signup` database table — the original endpoint
only fired a notification email. There is therefore no waitlist to email
and no migration to run when retiring the endpoint.
"""
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, EmailStr
from app.core.email import EmailService
from fastapi import APIRouter
from fastapi.responses import RedirectResponse
from app.core.config import settings
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/beta-signup", tags=["beta"])
class BetaSignupRequest(BaseModel):
email: EmailStr
# Local-dev fallback when FRONTEND_URL isn't configured. The redirect must
# be absolute — a relative URL would resolve against the API origin
# (api.resolutionflow.com), which has no /register page.
_DEFAULT_FRONTEND_URL = "http://localhost:5173"
class BetaSignupResponse(BaseModel):
success: bool
message: str
@router.post("", include_in_schema=False)
async def beta_signup_redirect() -> RedirectResponse:
"""Redirect legacy beta-signup POST to the public register page.
@router.post("", response_model=BetaSignupResponse)
async def beta_signup(data: BetaSignupRequest):
"""Collect beta interest — sends notification to beta@resolutionflow.com."""
sent = await EmailService.send_beta_signup_notification(data.email)
if not sent:
logger.warning("Beta signup recorded (email delivery skipped): %s", data.email)
return BetaSignupResponse(
success=True,
message="Thanks! We'll be in touch with beta access details.",
Returns 307 so any client following the redirect preserves the HTTP
method; the frontend treats `/register?from=beta` as the canonical
entry point and reads the `from` query param for analytics.
"""
frontend_url = settings.FRONTEND_URL or _DEFAULT_FRONTEND_URL
return RedirectResponse(
url=f"{frontend_url}/register?from=beta",
status_code=307,
)

View File

@@ -1,6 +1,6 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -10,6 +10,7 @@ from app.core.config import settings
from app.models.account import Account
from app.models.user import User
from app.schemas.billing import (
BillingPortalSessionResponse,
BillingStateResponse,
CheckoutSessionCreate,
CheckoutSessionResponse,
@@ -50,3 +51,26 @@ async def get_billing_state(
)).scalar_one()
state = await BillingService.get_billing_state(db, account)
return BillingStateResponse(**state)
@router.get("/portal-session", response_model=BillingPortalSessionResponse)
async def get_billing_portal_session(
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
) -> BillingPortalSessionResponse:
"""Return a Stripe-hosted Customer Portal URL for the account so the user
can update card / cancel. Allowlisted from the subscription + email-verify
guards (a canceled or unverified-past-grace user must still be able to
update billing)."""
if not settings.stripe_enabled:
raise HTTPException(status_code=503, detail={"error": "stripe_not_configured"})
account = (await db.execute(
select(Account).where(Account.id == current_user.account_id)
)).scalar_one()
try:
url = await BillingService.open_customer_portal(account)
except ValueError:
raise HTTPException(status_code=400, detail={"error": "no_stripe_customer"})
return BillingPortalSessionResponse(url=url)

View File

@@ -0,0 +1,40 @@
"""Public runtime configuration endpoint.
GET /api/v1/config/public
Returns the small set of runtime flags the frontend needs at app load
to decide whether to render the self-serve signup flow and which OAuth
buttons to show. No authentication required.
The response model lives in `app.schemas.config` so it can be reused by
frontend codegen and other call sites if needed.
"""
from __future__ import annotations
from fastapi import APIRouter
from app.core.config import settings
from app.schemas.config import PublicConfigResponse
router = APIRouter(prefix="/config", tags=["config"])
@router.get("/public", response_model=PublicConfigResponse)
async def get_public_config() -> PublicConfigResponse:
"""Return public-safe runtime config.
`oauth_providers` reflects which OAuth client IDs are configured server
side; the frontend uses it to render only buttons that will actually
succeed. `self_serve_enabled` is the master switch for the new public
self-serve signup flow.
"""
providers: list[str] = []
if settings.GOOGLE_CLIENT_ID:
providers.append("google")
if settings.MS_CLIENT_ID:
providers.append("microsoft")
return PublicConfigResponse(
self_serve_enabled=settings.SELF_SERVE_ENABLED,
oauth_providers=providers,
)

View File

@@ -7,10 +7,12 @@ from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.endpoints.auth import store_refresh_token
from app.core.admin_database import get_admin_db
from app.core.config import settings
from app.core.security import create_access_token, create_refresh_token
from app.models.account import Account
from app.models.account_invite import AccountInvite
from app.models.oauth_identity import OAuthIdentity
from app.models.user import User
from app.schemas.oauth import OAuthCallbackPayload, OAuthCallbackResponse
@@ -31,9 +33,21 @@ def _generate_display_code(length: int = 8) -> str:
async def _sign_in_or_register(
db: AsyncSession, provider: str, profile: OAuthProfile
db: AsyncSession,
provider: str,
profile: OAuthProfile,
*,
account_invite_code: str | None = None,
invited_email: str | None = None,
) -> tuple[User, bool]:
"""Returns (user, is_new_user). Idempotent on (provider, provider_subject)."""
"""Returns (user, is_new_user). Idempotent on (provider, provider_subject).
When ``account_invite_code`` is supplied (from the /accept-invite flow),
a brand-new user is created inside the invited account instead of getting
a personal account + Pro trial. Mismatch between the OAuth profile email
and ``invited_email`` raises ``invite_email_mismatch`` per the spec
contract that mirrors the email+password register path.
"""
identity = (
await db.execute(
select(OAuthIdentity).where(
@@ -53,28 +67,96 @@ async def _sign_in_or_register(
await db.execute(select(User).where(User.email == profile.email))
).scalar_one_or_none()
is_new_user = user is None
# If the user arrived via an invite link but already has a ResolutionFlow
# account (e.g., previously signed up with email+password), silently
# linking the OAuth identity to that existing account would bypass the
# invite — they'd stay in their personal account and the invite would
# never be consumed. Fail loud instead so they can sign in and accept the
# invite from the dashboard. The "invited user wants to transfer accounts"
# case is a v2 concern.
if account_invite_code and not is_new_user:
raise HTTPException(
status_code=400,
detail={
"error": "email_already_registered_use_login",
"message": (
"An account already exists for this email. Please sign in "
"instead, then accept the invite from your dashboard."
),
},
)
invite_record: AccountInvite | None = None
if is_new_user and account_invite_code:
# SELECT FOR UPDATE so two concurrent OAuth callbacks can't both
# consume the same invite code.
invite_record = (
await db.execute(
select(AccountInvite)
.where(AccountInvite.code == account_invite_code)
.with_for_update()
)
).scalar_one_or_none()
if invite_record is None or not invite_record.is_valid:
raise HTTPException(
status_code=400,
detail={"error": "invite_invalid_or_expired_or_revoked"},
)
# Verify the OAuth profile email matches what was invited. We compare
# against the invite row directly (source of truth), but also accept
# the client-supplied invited_email as a defensive equality check.
if invite_record.email.lower() != profile.email.lower():
raise HTTPException(
status_code=400,
detail={"error": "invite_email_mismatch"},
)
if invited_email and invited_email.lower() != invite_record.email.lower():
raise HTTPException(
status_code=400,
detail={"error": "invite_email_mismatch"},
)
if is_new_user:
account = Account(
name=f"{profile.name}'s Account",
display_code=_generate_display_code(),
)
db.add(account)
await db.flush()
user = User(
email=profile.email,
name=profile.name,
password_hash=None,
account_id=account.id,
account_role="owner",
role="engineer",
email_verified_at=datetime.now(timezone.utc),
)
db.add(user)
await db.flush()
account.owner_id = user.id
await db.flush()
# start_trial commits internally; flushed account/user above.
await BillingService.start_trial(db, account.id)
if invite_record is not None:
# Join the invited account directly — no personal account, no
# trial creation.
user = User(
email=profile.email,
name=profile.name,
password_hash=None,
account_id=invite_record.account_id,
account_role=invite_record.role,
role="engineer",
email_verified_at=datetime.now(timezone.utc),
)
db.add(user)
await db.flush()
invite_record.accepted_by_id = user.id
invite_record.used_at = datetime.now(timezone.utc)
await db.flush()
else:
account = Account(
name=f"{profile.name}'s Account",
display_code=_generate_display_code(),
)
db.add(account)
await db.flush()
user = User(
email=profile.email,
name=profile.name,
password_hash=None,
account_id=account.id,
account_role="owner",
role="engineer",
email_verified_at=datetime.now(timezone.utc),
)
db.add(user)
await db.flush()
account.owner_id = user.id
await db.flush()
# start_trial commits internally; flushed account/user above.
await BillingService.start_trial(db, account.id)
db.add(
OAuthIdentity(
@@ -98,10 +180,23 @@ async def google_callback(
raise HTTPException(status_code=503, detail="Google sign-in not configured")
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/google/callback"
profile = await google_exchange_code(payload.code, redirect_uri)
user, is_new = await _sign_in_or_register(db, "google", profile)
user, is_new = await _sign_in_or_register(
db,
"google",
profile,
account_invite_code=payload.account_invite_code,
invited_email=payload.invited_email,
)
refresh_token_str = create_refresh_token({"sub": str(user.id)})
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
# reject this token as "revoked" (the rotation logic requires a row to
# mark as used). _sign_in_or_register already committed; this needs a
# second commit.
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return OAuthCallbackResponse(
access_token=create_access_token({"sub": str(user.id)}),
refresh_token=create_refresh_token({"sub": str(user.id)}),
refresh_token=refresh_token_str,
is_new_user=is_new,
)
@@ -115,9 +210,22 @@ async def microsoft_callback(
raise HTTPException(status_code=503, detail="Microsoft sign-in not configured")
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/microsoft/callback"
profile = await microsoft_exchange_code(payload.code, redirect_uri)
user, is_new = await _sign_in_or_register(db, "microsoft", profile)
user, is_new = await _sign_in_or_register(
db,
"microsoft",
profile,
account_invite_code=payload.account_invite_code,
invited_email=payload.invited_email,
)
refresh_token_str = create_refresh_token({"sub": str(user.id)})
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
# reject this token as "revoked" (the rotation logic requires a row to
# mark as used). _sign_in_or_register already committed; this needs a
# second commit.
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return OAuthCallbackResponse(
access_token=create_access_token({"sub": str(user.id)}),
refresh_token=create_refresh_token({"sub": str(user.id)}),
refresh_token=refresh_token_str,
is_new_user=is_new,
)

View File

@@ -2,19 +2,24 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_active_user
from app.core.database import get_db
from app.core.admin_database import get_admin_db
from app.models.account import Account
from app.models.assistant_chat import AssistantChat
from app.models.psa_connection import PsaConnection
from app.models.session import Session
from app.models.tree import Tree
from app.models.user import User
from app.schemas.onboarding import OnboardingStatus
from app.schemas.onboarding import (
OnboardingStatus,
OnboardingStepRequest,
OnboardingStepResponse,
)
router = APIRouter(prefix="/users", tags=["onboarding"])
@@ -85,6 +90,10 @@ async def get_onboarding_status(
)
connected_psa = (psa_q.scalar() or 0) > 0
# New (Phase 2 — Task 41)
email_verified = current_user.email_verified_at is not None
shop_setup_done = (current_user.onboarding_step_completed or 0) >= 1
return OnboardingStatus(
created_flow=created_flow,
ran_session=ran_session,
@@ -94,6 +103,8 @@ async def get_onboarding_status(
connected_psa=connected_psa,
is_team_user=is_team_user,
dismissed=current_user.onboarding_dismissed,
email_verified=email_verified,
shop_setup_done=shop_setup_done,
)
@@ -109,3 +120,98 @@ async def dismiss_onboarding(
# Return updated status (reuse the GET logic)
return await get_onboarding_status(db=db, current_user=current_user)
# ---------------------------------------------------------------------------
# Welcome wizard endpoints (Phase 2)
#
# These persist Step 1/2/3 progress for the post-signup welcome wizard.
# Mounted on /users/me/* (the parent router prefix is /users) so the wizard
# can run before email verification and during trial.
# ---------------------------------------------------------------------------
@router.patch("/me/onboarding-step", response_model=OnboardingStepResponse)
async def patch_onboarding_step(
body: OnboardingStepRequest,
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> OnboardingStepResponse:
"""Persist welcome-wizard progress for the current user.
Contract:
- step=1 + complete writes accounts.name, accounts.team_size_bucket,
users.role_at_signup, then sets users.onboarding_step_completed=1.
- step=2 + complete writes accounts.primary_psa, then sets
users.onboarding_step_completed=2.
- step=3 + complete just sets users.onboarding_step_completed=3
(invites are POSTed separately).
- action="skip" ignores `data` entirely and only advances the step.
- The new step must be >= current onboarding_step_completed (None=>0);
otherwise 400. Idempotent re-PATCH of the same step succeeds.
"""
current_step = current_user.onboarding_step_completed or 0
if body.step < current_step:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "step_cannot_decrease",
"current_step": current_step,
"requested_step": body.step,
},
)
if body.action == "complete" and body.data is not None and body.step in (1, 2):
# Load the user's account for field writes. Step 3 has no data writes.
account_result = await db.execute(
select(Account).where(Account.id == current_user.account_id)
)
account = account_result.scalar_one_or_none()
if account is None:
# Should never happen — user is required to have an account_id.
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="account_not_found",
)
if body.step == 1:
data = body.data
if data.company_name is not None:
account.name = data.company_name
if data.team_size_bucket is not None:
account.team_size_bucket = data.team_size_bucket
if data.role_at_signup is not None:
current_user.role_at_signup = data.role_at_signup
elif body.step == 2:
data = body.data
if data.primary_psa is not None:
account.primary_psa = data.primary_psa
current_user.onboarding_step_completed = body.step
await db.commit()
await db.refresh(current_user)
return OnboardingStepResponse(
onboarding_step_completed=current_user.onboarding_step_completed,
onboarding_dismissed=current_user.onboarding_dismissed,
)
@router.post("/me/onboarding-dismiss-rest", response_model=OnboardingStepResponse)
async def dismiss_onboarding_rest(
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> OnboardingStepResponse:
"""Set users.onboarding_dismissed=TRUE — backs the wizard's "Skip the rest" button.
Returns the same shape as the step PATCH so the frontend can update its
local store from a single response.
"""
current_user.onboarding_dismissed = True
await db.commit()
await db.refresh(current_user)
return OnboardingStepResponse(
onboarding_step_completed=current_user.onboarding_step_completed,
onboarding_dismissed=current_user.onboarding_dismissed,
)

View File

@@ -0,0 +1,58 @@
"""Public plans endpoint — no auth required.
GET /api/v1/plans/public
Returns the public-safe view of `plan_billing` joined with
`plan_limits.max_users` (exposed as `max_seats`), filtered to
`is_public=True AND is_archived=False`, ordered by sort_order ASC, plan ASC.
Distinct from `/admin/plan-limits` (admin-only, returns ALL plans including
archived/internal). This endpoint exists to power the marketing /pricing page
without exposing the rest of the admin-only billing surface.
"""
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_database import get_admin_db
from app.models.plan_billing import PlanBilling
from app.models.plan_limits import PlanLimits
from app.schemas.billing import PublicPlanResponse
router = APIRouter(prefix="/plans", tags=["plans"])
@router.get("/public", response_model=list[PublicPlanResponse])
async def list_public_plans(
db: Annotated[AsyncSession, Depends(get_admin_db)],
) -> list[PublicPlanResponse]:
"""List public, non-archived plans for the marketing /pricing page.
Public — no auth. Uses `get_admin_db` because this is a cross-tenant read
of the global plan catalog (same pattern as `/config/public`).
"""
stmt = (
select(PlanBilling, PlanLimits.max_users)
.outerjoin(PlanLimits, PlanBilling.plan == PlanLimits.plan)
.where(PlanBilling.is_public.is_(True))
.where(PlanBilling.is_archived.is_(False))
.order_by(PlanBilling.sort_order.asc(), PlanBilling.plan.asc())
)
rows = (await db.execute(stmt)).all()
return [
PublicPlanResponse(
plan=billing.plan,
display_name=billing.display_name,
description=billing.description,
monthly_price_cents=billing.monthly_price_cents,
annual_price_cents=billing.annual_price_cents,
max_seats=max_users,
sort_order=billing.sort_order,
is_public=billing.is_public,
)
for billing, max_users in rows
]

View File

@@ -0,0 +1,114 @@
"""Public Talk-to-Sales endpoint — no auth required.
POST /api/v1/sales-leads
- Inserts a sales_leads row.
- Fires (best-effort) a notification email to settings.SALES_LEAD_RECIPIENT_EMAIL.
- Emits a server-side PostHog event (best-effort).
- Rate-limited per IP (5/hour).
"""
from __future__ import annotations
import asyncio
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_database import get_admin_db
from app.core.config import settings
from app.core.email import EmailService
from app.core.rate_limit import limiter
from app.models.sales_lead import SalesLead
from app.schemas.sales_lead import SalesLeadCreate, SalesLeadCreateResponse
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/sales-leads", tags=["sales"])
async def _send_notification_email(lead: SalesLead) -> None:
"""Fire-and-forget wrapper. EmailService methods never raise, but we
still wrap in a try/except to defend against future regressions."""
try:
await EmailService.send_sales_lead_notification(
to_email=settings.SALES_LEAD_RECIPIENT_EMAIL,
lead=lead,
)
except Exception:
logger.warning(
"Sales lead notification email failed for lead %s",
lead.id,
exc_info=True,
)
def _capture_posthog_event(lead: SalesLead) -> None:
"""Emit `talk_to_sales_form_submitted` server-side. Best-effort.
Backend PostHog SDK isn't initialized in the project today; this function
is the single instrumentation point so wiring it up later is a one-line
change. The call is wrapped so any future failure can never fail the
request.
"""
try:
# Lazy import — keeps the dependency optional. When the backend
# PostHog client is wired in (likely as `app.core.analytics.posthog`),
# swap the import path here and the event will fire automatically.
try:
from app.core.analytics import posthog # type: ignore[attr-defined]
except ImportError:
logger.debug(
"PostHog server-side capture skipped — client not configured"
)
return
distinct_id = lead.posthog_distinct_id or f"sales_lead:{lead.id}"
posthog.capture(
distinct_id=distinct_id,
event="talk_to_sales_form_submitted",
properties={
"source": lead.source,
"company": lead.company,
"team_size": lead.team_size,
},
)
except Exception:
logger.warning(
"PostHog capture failed for sales lead %s",
lead.id,
exc_info=True,
)
@router.post("", response_model=SalesLeadCreateResponse, status_code=201)
@limiter.limit("5/hour")
async def create_sales_lead(
request: Request,
data: SalesLeadCreate,
db: Annotated[AsyncSession, Depends(get_admin_db)],
) -> SalesLeadCreateResponse:
"""Public Talk-to-Sales submission.
Creates a sales_leads row, fires (best-effort) a notification email and a
server-side PostHog event. Rate-limited per IP at 5/hour.
"""
lead = SalesLead(
email=str(data.email).lower(),
name=data.name,
company=data.company,
team_size=data.team_size,
message=data.message,
source=data.source,
posthog_distinct_id=data.posthog_distinct_id,
)
db.add(lead)
await db.commit()
await db.refresh(lead)
# Fire-and-forget: email + analytics. Failures must not fail the request.
asyncio.create_task(_send_notification_email(lead))
_capture_posthog_event(lead)
return SalesLeadCreateResponse(id=lead.id, status="received")

View File

@@ -26,8 +26,10 @@ from app.api.endpoints import (
billing,
beta_feedback,
beta_signup,
sales_leads,
branding,
categories,
config as config_endpoints,
copilot,
device_types,
draft_templates,
@@ -43,6 +45,7 @@ from app.api.endpoints import (
notifications,
oauth as oauth_endpoints,
onboarding,
plans_public,
public_templates,
ratings,
scripts,
@@ -68,6 +71,7 @@ from app.api.endpoints import (
uploads,
webhooks,
accounts,
account_invite_lookup,
)
api_router = APIRouter()
@@ -88,9 +92,13 @@ api_router.include_router(billing.router) # Reachable when subscription lock
api_router.include_router(shared.router) # Public share links (no auth)
api_router.include_router(shares.public_router) # Public session share links (optional auth)
api_router.include_router(beta_signup.router)
api_router.include_router(sales_leads.router) # Talk-to-Sales (no auth, rate-limited)
api_router.include_router(webhooks.router) # Stripe webhook receiver
api_router.include_router(public_templates.router) # Public gallery (no auth, rate-limited)
api_router.include_router(survey.router) # Public survey flow (no auth, rate-limited)
api_router.include_router(config_endpoints.router) # Public runtime feature flags
api_router.include_router(account_invite_lookup.router) # Public invite-code lookup for /accept-invite
api_router.include_router(plans_public.router) # Public plan catalog for /pricing page
# ---------------------------------------------------------------------------
# Admin endpoints — super_admin only

View File

@@ -84,6 +84,7 @@ class Settings(BaseSettings):
RESEND_API_KEY: Optional[str] = None
FROM_EMAIL: str = "ResolutionFlow <invites@resolutionflow.com>"
FEEDBACK_EMAIL: Optional[str] = None
SALES_LEAD_RECIPIENT_EMAIL: str = "sales@resolutionflow.com"
@property
def email_enabled(self) -> bool:

View File

@@ -1,6 +1,11 @@
import logging
from typing import TYPE_CHECKING
from app.core.config import settings
if TYPE_CHECKING:
from app.models.sales_lead import SalesLead
logger = logging.getLogger(__name__)
@@ -484,6 +489,99 @@ class EmailService:
logger.exception("Failed to send beta signup notification for %s", signup_email)
return False
@staticmethod
async def send_sales_lead_notification(
to_email: str,
lead: "SalesLead",
) -> bool:
"""Notify the sales recipient about a new Talk-to-Sales submission.
Fire-and-forget. Returns False (and logs) on any failure; never raises.
"""
if not settings.email_enabled:
logger.warning(
"Sales lead email not sent — RESEND_API_KEY not configured (lead %s)",
lead.id,
)
return False
try:
import resend
import html as html_mod
from datetime import datetime, timezone
resend.api_key = settings.RESEND_API_KEY
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
safe_email = html_mod.escape(lead.email)
safe_name = html_mod.escape(lead.name)
safe_company = html_mod.escape(lead.company)
safe_team_size = html_mod.escape(lead.team_size or "")
safe_source = html_mod.escape(lead.source)
safe_message = html_mod.escape(lead.message or "(no message)")
subject = f"[ResolutionFlow Sales] New lead — {safe_company} ({safe_email})"
email_html = f"""<!DOCTYPE html>
<html><head><meta charset="utf-8"><meta name="viewport" content="width=device-width"></head>
<body style="margin:0;padding:0;background:#101114;font-family:'Inter',Helvetica,Arial,sans-serif;">
<table width="100%" cellpadding="0" cellspacing="0" style="background:#101114;padding:40px 0;">
<tr><td align="center">
<table width="560" cellpadding="0" cellspacing="0" style="background:#14161a;border:1px solid rgba(255,255,255,0.06);border-radius:16px;">
<tr><td style="padding:40px 40px 24px;text-align:center;">
<h1 style="margin:0;color:#f8fafc;font-size:24px;font-weight:600;">Resolution<span style="color:#06b6d4;">Flow</span></h1>
<p style="margin:8px 0 0;color:#5a6170;font-size:14px;">New Sales Lead</p>
</td></tr>
<tr><td style="padding:0 40px 16px;">
<p style="margin:0;color:#8891a0;font-size:16px;line-height:1.6;">
Source: <strong style="color:#f8fafc;">{safe_source}</strong>
</p>
</td></tr>
<tr><td style="padding:0 40px 16px;">
<table width="100%" cellpadding="0" cellspacing="0" style="background:rgba(0,0,0,0.3);border:1px solid rgba(255,255,255,0.06);border-radius:12px;">
<tr><td style="padding:16px;">
<p style="margin:0 0 4px;color:#5a6170;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Name</p>
<p style="margin:0 0 12px;color:#f8fafc;font-size:16px;font-weight:600;">{safe_name}</p>
<p style="margin:0 0 4px;color:#5a6170;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Email</p>
<p style="margin:0 0 12px;color:#22d3ee;font-size:16px;font-weight:600;">{safe_email}</p>
<p style="margin:0 0 4px;color:#5a6170;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Company</p>
<p style="margin:0 0 12px;color:#f8fafc;font-size:16px;font-weight:600;">{safe_company}</p>
<p style="margin:0 0 4px;color:#5a6170;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Team Size</p>
<p style="margin:0;color:#f8fafc;font-size:16px;font-weight:600;">{safe_team_size}</p>
</td></tr>
</table>
</td></tr>
<tr><td style="padding:0 40px 16px;">
<p style="margin:0 0 4px;color:#5a6170;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Message</p>
<p style="margin:0;color:#8891a0;font-size:14px;line-height:1.6;white-space:pre-wrap;">{safe_message}</p>
</td></tr>
<tr><td style="padding:0 40px 32px;">
<p style="margin:0;color:#5a6170;font-size:12px;text-align:center;">
Submitted at {date_str} · Lead ID: {lead.id}
</p>
</td></tr>
</table>
</td></tr>
</table>
</body></html>"""
resend.Emails.send({
"from": settings.FROM_EMAIL,
"to": [to_email],
"reply_to": lead.email,
"subject": subject,
"html": email_html,
})
logger.info("Sales lead notification sent for %s (lead %s)", lead.email, lead.id)
return True
except Exception:
logger.exception(
"Failed to send sales lead notification for %s (lead %s)",
lead.email,
lead.id,
)
return False
@staticmethod
async def send_notification_email(
to_email: str,

View File

@@ -172,6 +172,21 @@ class PlanLimitResponse(BaseModel):
from_attributes = True
class PlanLimitWithBillingResponse(PlanLimitResponse):
"""PlanLimits + plan_billing fields merged. Billing fields are None when no
plan_billing row exists for the plan yet."""
display_name: Optional[str] = None
description: Optional[str] = None
monthly_price_cents: Optional[int] = None
annual_price_cents: Optional[int] = None
stripe_product_id: Optional[str] = None
stripe_monthly_price_id: Optional[str] = None
stripe_annual_price_id: Optional[str] = None
is_public: Optional[bool] = None
is_archived: Optional[bool] = None
sort_order: Optional[int] = None
class PlanLimitUpdate(BaseModel):
plan: str
max_trees: Optional[int] = None
@@ -180,6 +195,19 @@ class PlanLimitUpdate(BaseModel):
custom_branding: bool = False
priority_support: bool = False
export_formats: list = Field(default_factory=lambda: ["markdown", "text"])
# plan_billing fields — all optional, partial-update semantics. If any are
# set in the body, the admin endpoint upserts the plan_billing row in the
# same transaction.
display_name: Optional[str] = None
description: Optional[str] = None
monthly_price_cents: Optional[int] = None
annual_price_cents: Optional[int] = None
stripe_product_id: Optional[str] = None
stripe_monthly_price_id: Optional[str] = None
stripe_annual_price_id: Optional[str] = None
is_public: Optional[bool] = None
is_archived: Optional[bool] = None
sort_order: Optional[int] = None
class AccountOverrideCreate(BaseModel):

View File

@@ -13,6 +13,10 @@ class CheckoutSessionResponse(BaseModel):
url: str
class BillingPortalSessionResponse(BaseModel):
url: str
class SubscriptionState(BaseModel):
status: str
plan: str
@@ -38,3 +42,23 @@ class BillingStateResponse(BaseModel):
plan_billing: Optional[PlanBillingState]
plan_limits: Dict[str, Any]
enabled_features: Dict[str, bool]
class PublicPlanResponse(BaseModel):
"""Public-safe view of a billable plan, used by the marketing /pricing page.
Sourced from `plan_billing` joined with `plan_limits.max_users` (exposed
here as `max_seats`). Always filtered server-side to is_public=True and
is_archived=False, so `is_public` is a constant True for any row returned
here — included for clarity and forward compatibility.
"""
plan: str
display_name: str
description: Optional[str] = None
monthly_price_cents: Optional[int] = None
annual_price_cents: Optional[int] = None
max_seats: Optional[int] = None
sort_order: int
is_public: bool = True
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,18 @@
"""Pydantic schemas for public runtime configuration."""
from __future__ import annotations
from typing import List
from pydantic import BaseModel
class PublicConfigResponse(BaseModel):
"""Runtime feature flags + OAuth provider list exposed to anonymous clients.
Read once by the frontend at app load to decide whether to render the
self-serve signup flow and which OAuth buttons to show.
"""
self_serve_enabled: bool
oauth_providers: List[str]

View File

@@ -4,6 +4,11 @@ from pydantic import BaseModel
class OAuthCallbackPayload(BaseModel):
code: str
state: str | None = None
# When the OAuth flow originated from /accept-invite, the frontend round-trips
# the invite code + invited email so the backend can link the new user to the
# invited account instead of creating a personal one.
account_invite_code: str | None = None
invited_email: str | None = None
class OAuthCallbackResponse(BaseModel):
@@ -11,3 +16,17 @@ class OAuthCallbackResponse(BaseModel):
refresh_token: str
token_type: str = "bearer"
is_new_user: bool
class InviteLookupResponse(BaseModel):
"""Public response surface for GET /accounts/invites/{code}/lookup.
Returns the minimum context needed for the AcceptInvitePage:
account name (so we can title the card), inviter name (for the resend
fallback message), invited email (locked into the form), and role.
"""
account_name: str
inviter_name: str
invited_email: str
role: str

View File

@@ -1,12 +1,55 @@
from pydantic import BaseModel
from typing import Literal, Optional
from pydantic import BaseModel, Field
class OnboardingStatus(BaseModel):
created_flow: bool
ran_session: bool
exported_session: bool
# Kept for backward-compat during deploy; new code paths should not branch on this.
tried_ai_assistant: bool
invited_teammate: bool
connected_psa: bool
is_team_user: bool
dismissed: bool
# New (Phase 2 — Task 41) — drive the unified next-step card + checklist.
email_verified: bool
shop_setup_done: bool
# --- Welcome wizard (Phase 2) ----------------------------------------------
TeamSizeBucket = Literal["1-2", "3-5", "6-10", "11-25", "26+"]
RoleAtSignup = Literal["owner", "lead_tech", "tech", "other"]
PrimaryPsa = Literal["connectwise", "autotask", "halopsa", "none"]
WizardStep = Literal[1, 2, 3]
WizardAction = Literal["complete", "skip"]
class OnboardingStepData(BaseModel):
"""Optional payload carried with `action="complete"` for steps 1 and 2.
Step 1 fields: company_name, team_size_bucket, role_at_signup
Step 2 fields: primary_psa
Step 3 has no data (invitations posted separately).
"""
# Step 1
company_name: Optional[str] = Field(default=None, max_length=255)
team_size_bucket: Optional[TeamSizeBucket] = None
role_at_signup: Optional[RoleAtSignup] = None
# Step 2
primary_psa: Optional[PrimaryPsa] = None
class OnboardingStepRequest(BaseModel):
step: WizardStep
action: WizardAction
data: Optional[OnboardingStepData] = None
class OnboardingStepResponse(BaseModel):
onboarding_step_completed: Optional[int]
onboarding_dismissed: bool

View File

@@ -0,0 +1,27 @@
"""Pydantic schemas for Talk-to-Sales submissions."""
from typing import Literal, Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, EmailStr, Field
SalesLeadSource = Literal["pricing_page", "register_footer", "landing_page"]
class SalesLeadCreate(BaseModel):
"""Public Talk-to-Sales form submission."""
model_config = ConfigDict(str_strip_whitespace=True)
email: EmailStr
name: str = Field(..., min_length=1, max_length=255)
company: str = Field(..., min_length=1, max_length=255)
team_size: Optional[str] = Field(default=None, max_length=20)
message: Optional[str] = Field(default=None, max_length=5000)
source: SalesLeadSource
posthog_distinct_id: Optional[str] = Field(default=None, max_length=255)
class SalesLeadCreateResponse(BaseModel):
id: UUID
status: Literal["received"] = "received"

View File

@@ -58,6 +58,8 @@ class UserResponse(UserBase):
timezone: str = "UTC"
avatar_url: Optional[str] = None
email_verified_at: Optional[datetime] = None
onboarding_step_completed: Optional[int] = None
onboarding_dismissed: bool = False
class Config:
from_attributes = True

View File

@@ -1,6 +1,7 @@
"""Single billing service module. Stripe is the only impl — no provider
abstraction. Account row is canonical local state; Stripe is canonical
remote state; the webhook handler bridges the two."""
import logging
from datetime import datetime, timezone, timedelta
import stripe
@@ -17,8 +18,32 @@ from app.models.subscription import Subscription
TRIAL_DAYS = 14
logger = logging.getLogger(__name__)
class BillingService:
@staticmethod
async def invalidate_billing_cache(account_ids) -> None:
"""No-op stub for future in-process billing cache invalidation.
Today there is no `app.state.billing_cache` — `BillingService.get_billing_state`
always reads fresh from the DB. Call sites that mutate plan/feature data
invoke this hook so that wiring is in place when an in-process cache is
added later. Until then, this just logs.
TODO: when an in-process billing cache (e.g. `app.state.billing_cache`)
is introduced, evict entries for the given account_ids here.
"""
try:
count = len(list(account_ids))
except TypeError:
count = -1
logger.debug(
"BillingService.invalidate_billing_cache called for %d account(s) "
"(no-op stub — wire to app.state.billing_cache when added)",
count,
)
@staticmethod
async def start_trial(db: AsyncSession, account_id) -> Subscription:
"""Idempotent. Creates a trialing Subscription on Pro for the account if
@@ -105,6 +130,25 @@ class BillingService:
)
return session.url
@staticmethod
async def open_customer_portal(account: Account) -> str:
"""Create a Stripe-hosted Customer Portal session and return the URL.
Raises RuntimeError if Stripe isn't configured (endpoint maps to 503).
Raises ValueError if the account has no stripe_customer_id yet — the
user must complete a checkout first (endpoint maps to 400).
"""
if not settings.stripe_enabled:
raise RuntimeError("Stripe not configured")
if account.stripe_customer_id is None:
raise ValueError("no_stripe_customer")
stripe.api_key = settings.STRIPE_SECRET_KEY
session = stripe.billing_portal.Session.create(
customer=account.stripe_customer_id,
return_url=f"{settings.FRONTEND_URL}/account/billing",
)
return session.url
@staticmethod
async def get_billing_state(db: AsyncSession, account):
"""Aggregate Subscription + PlanLimits + PlanBilling + resolved feature
@@ -166,28 +210,44 @@ class BillingService:
) -> bool:
"""Idempotent. Returns True if the event was applied; False if it had
already been processed (idempotent ack). The webhook handler returns 200
either way."""
either way.
Atomic: the StripeEvent idempotency mark and the handler's state
mutations are committed in a single transaction. If the handler raises
the entire transaction (idempotency mark + partial mutations) is rolled
back, so a Stripe retry will re-run the handler. Without this, a
handler that fails mid-flight would leave the StripeEvent row persisted
and silently desync subscription state from Stripe.
"""
db.add(StripeEvent(
id=event_id,
event_type=event_type,
payload_excerpt=_excerpt(payload),
))
try:
db.add(StripeEvent(
id=event_id,
event_type=event_type,
payload_excerpt=_excerpt(payload),
))
await db.commit()
await db.flush()
except IntegrityError:
# Duplicate event_id — already processed (or in flight). Ack with False.
await db.rollback()
return False
if event_type == "checkout.session.completed":
await _handle_checkout_completed(db, payload)
elif event_type == "customer.subscription.updated":
await _handle_subscription_updated(db, payload)
elif event_type == "customer.subscription.deleted":
await _handle_subscription_deleted(db, payload)
elif event_type == "invoice.payment_failed":
await _handle_payment_failed(db, payload)
elif event_type == "invoice.payment_succeeded":
await _handle_payment_succeeded(db, payload)
try:
if event_type == "checkout.session.completed":
await _handle_checkout_completed(db, payload)
elif event_type == "customer.subscription.updated":
await _handle_subscription_updated(db, payload)
elif event_type == "customer.subscription.deleted":
await _handle_subscription_deleted(db, payload)
elif event_type == "invoice.payment_failed":
await _handle_payment_failed(db, payload)
elif event_type == "invoice.payment_succeeded":
await _handle_payment_succeeded(db, payload)
await db.commit()
except Exception:
# Roll back the StripeEvent insert + any partial handler mutations
# so Stripe's retry can re-run cleanly.
await db.rollback()
raise
return True
@@ -238,7 +298,7 @@ async def _handle_checkout_completed(db: AsyncSession, payload: dict):
)).scalar_one_or_none()
if pb is not None:
sub.plan = pb.plan
await db.commit()
# No commit — apply_subscription_event commits once for the full event.
async def _handle_subscription_updated(db: AsyncSession, payload: dict):
@@ -253,7 +313,7 @@ async def _handle_subscription_updated(db: AsyncSession, payload: dict):
sub.current_period_end = datetime.fromtimestamp(obj["current_period_end"], tz=timezone.utc)
sub.cancel_at_period_end = obj.get("cancel_at_period_end", False)
sub.seat_limit = obj["items"]["data"][0]["quantity"]
await db.commit()
# No commit — apply_subscription_event commits once for the full event.
async def _handle_subscription_deleted(db: AsyncSession, payload: dict):
@@ -264,7 +324,7 @@ async def _handle_subscription_deleted(db: AsyncSession, payload: dict):
if sub is None:
return
sub.status = "canceled"
await db.commit()
# No commit — apply_subscription_event commits once for the full event.
async def _handle_payment_failed(db: AsyncSession, payload: dict):
@@ -278,7 +338,7 @@ async def _handle_payment_failed(db: AsyncSession, payload: dict):
if sub is None:
return
sub.status = "past_due"
await db.commit()
# No commit — apply_subscription_event commits once for the full event.
async def _handle_payment_succeeded(db: AsyncSession, payload: dict):
@@ -293,4 +353,4 @@ async def _handle_payment_succeeded(db: AsyncSession, payload: dict):
return
if sub.status == "past_due":
sub.status = "active"
await db.commit()
# No commit — apply_subscription_event commits once for the full event.