feat: self-serve signup Phase 2 (frontend cutover) (#162)
Co-authored-by: Michael Chihlas <michael@resolutionflow.com> Co-committed-by: Michael Chihlas <michael@resolutionflow.com>
This commit was merged in pull request #162.
This commit is contained in:
@@ -235,6 +235,7 @@ _SUBSCRIPTION_GUARD_ALLOWLIST = {
|
||||
"/api/v1/billing/portal-session",
|
||||
"/api/v1/users/me",
|
||||
"/api/v1/users/me/onboarding-step",
|
||||
"/api/v1/users/me/onboarding-dismiss-rest",
|
||||
}
|
||||
|
||||
|
||||
@@ -298,6 +299,8 @@ _EMAIL_VERIFICATION_ALLOWLIST = {
|
||||
"/api/v1/auth/email/verify",
|
||||
"/api/v1/auth/password/change",
|
||||
"/api/v1/users/me",
|
||||
"/api/v1/users/me/onboarding-step",
|
||||
"/api/v1/users/me/onboarding-dismiss-rest",
|
||||
"/api/v1/billing/state",
|
||||
"/api/v1/billing/checkout-session",
|
||||
"/api/v1/billing/portal-session",
|
||||
|
||||
54
backend/app/api/endpoints/account_invite_lookup.py
Normal file
54
backend/app/api/endpoints/account_invite_lookup.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Public endpoint for resolving an account invite code into display info.
|
||||
|
||||
Mounted as a public route (no tenant context, no auth) — used by the
|
||||
/accept-invite page on the frontend so an invitee can see what account they
|
||||
are about to join before they sign up. Uses the BYPASSRLS admin session
|
||||
factory because account_invites is account-scoped under Phase 4 RLS but the
|
||||
caller has no tenant identity yet.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.schemas.oauth import InviteLookupResponse
|
||||
|
||||
router = APIRouter(prefix="/accounts", tags=["account-invite-lookup"])
|
||||
|
||||
|
||||
@router.get("/invites/{code}/lookup", response_model=InviteLookupResponse)
|
||||
async def lookup_invite(
|
||||
code: str,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> InviteLookupResponse:
|
||||
"""Return minimal display data for a valid (unused, unexpired, not revoked)
|
||||
invite. Returns 404 with `invite_invalid_or_expired_or_revoked` for any
|
||||
invalid state — the AcceptInvitePage shows a single "ask the inviter to
|
||||
resend" message regardless of which condition failed (anti-enumeration)."""
|
||||
result = await db.execute(
|
||||
select(AccountInvite)
|
||||
.where(AccountInvite.code == code)
|
||||
.options(
|
||||
joinedload(AccountInvite.account),
|
||||
joinedload(AccountInvite.invited_by),
|
||||
)
|
||||
)
|
||||
invite = result.scalar_one_or_none()
|
||||
|
||||
if invite is None or not invite.is_valid:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": "invite_invalid_or_expired_or_revoked"},
|
||||
)
|
||||
|
||||
return InviteLookupResponse(
|
||||
account_name=invite.account.name,
|
||||
inviter_name=invite.invited_by.name,
|
||||
invited_email=invite.email,
|
||||
role=invite.role,
|
||||
)
|
||||
@@ -8,34 +8,101 @@ from app.core.database import get_db
|
||||
from app.core.audit import log_audit
|
||||
from app.models.user import User
|
||||
from app.models.plan_limits import PlanLimits
|
||||
from app.models.plan_billing import PlanBilling
|
||||
from app.models.account import Account
|
||||
from app.models.account_limit_override import AccountLimitOverride
|
||||
from app.models.subscription import Subscription
|
||||
from app.schemas.admin import (
|
||||
PlanLimitResponse, PlanLimitUpdate,
|
||||
PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse,
|
||||
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
|
||||
)
|
||||
from app.api.deps import require_admin
|
||||
from app.services.billing import BillingService
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
|
||||
|
||||
|
||||
@router.get("/plan-limits", response_model=list[PlanLimitResponse])
|
||||
# Fields on PlanLimitUpdate that map to plan_billing (not plan_limits).
|
||||
_PLAN_BILLING_FIELDS = (
|
||||
"display_name",
|
||||
"description",
|
||||
"monthly_price_cents",
|
||||
"annual_price_cents",
|
||||
"stripe_product_id",
|
||||
"stripe_monthly_price_id",
|
||||
"stripe_annual_price_id",
|
||||
"is_public",
|
||||
"is_archived",
|
||||
"sort_order",
|
||||
)
|
||||
|
||||
# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model.
|
||||
# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit
|
||||
# null for any of them would otherwise trigger a NOT NULL violation at commit.
|
||||
_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({
|
||||
"display_name",
|
||||
"is_public",
|
||||
"is_archived",
|
||||
"sort_order",
|
||||
})
|
||||
|
||||
|
||||
def _merge_plan_with_billing(
|
||||
plan: PlanLimits, billing: PlanBilling | None
|
||||
) -> PlanLimitWithBillingResponse:
|
||||
"""Build a merged response. Billing fields are None when no plan_billing row
|
||||
exists for the plan."""
|
||||
payload = {
|
||||
"plan": plan.plan,
|
||||
"max_trees": plan.max_trees,
|
||||
"max_sessions_per_month": plan.max_sessions_per_month,
|
||||
"max_users": plan.max_users,
|
||||
"custom_branding": plan.custom_branding,
|
||||
"priority_support": plan.priority_support,
|
||||
"export_formats": plan.export_formats or [],
|
||||
}
|
||||
if billing is not None:
|
||||
payload.update({
|
||||
"display_name": billing.display_name,
|
||||
"description": billing.description,
|
||||
"monthly_price_cents": billing.monthly_price_cents,
|
||||
"annual_price_cents": billing.annual_price_cents,
|
||||
"stripe_product_id": billing.stripe_product_id,
|
||||
"stripe_monthly_price_id": billing.stripe_monthly_price_id,
|
||||
"stripe_annual_price_id": billing.stripe_annual_price_id,
|
||||
"is_public": billing.is_public,
|
||||
"is_archived": billing.is_archived,
|
||||
"sort_order": billing.sort_order,
|
||||
})
|
||||
return PlanLimitWithBillingResponse(**payload)
|
||||
|
||||
|
||||
@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse])
|
||||
async def list_plan_limits(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""List all plan limit configurations."""
|
||||
result = await db.execute(select(PlanLimits))
|
||||
return result.scalars().all()
|
||||
"""List all plan limit configurations, merged with plan_billing fields
|
||||
where present. Plans without a plan_billing row return None for the
|
||||
billing fields."""
|
||||
rows = (await db.execute(
|
||||
select(PlanLimits, PlanBilling)
|
||||
.outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan)
|
||||
)).all()
|
||||
return [_merge_plan_with_billing(pl, pb) for pl, pb in rows]
|
||||
|
||||
|
||||
@router.put("/plan-limits", response_model=PlanLimitResponse)
|
||||
@router.put("/plan-limits", response_model=PlanLimitWithBillingResponse)
|
||||
async def update_plan_limits(
|
||||
data: PlanLimitUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Update a plan's limits."""
|
||||
"""Update a plan's limits and (if any plan_billing field is included)
|
||||
upsert the matching plan_billing row in the same transaction. After
|
||||
commit, invalidates the in-process billing cache for accounts on this
|
||||
plan (currently a no-op — see BillingService.invalidate_billing_cache).
|
||||
"""
|
||||
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
|
||||
plan = result.scalar_one_or_none()
|
||||
if not plan:
|
||||
@@ -48,10 +115,50 @@ async def update_plan_limits(
|
||||
plan.priority_support = data.priority_support
|
||||
plan.export_formats = data.export_formats
|
||||
|
||||
await log_audit(db, current_user.id, "plan_limits.update", "plan_limits", details={"plan": data.plan})
|
||||
# Did the request include any plan_billing field? (Pydantic gives us
|
||||
# `model_fields_set` to distinguish "user passed null" from "field omitted".)
|
||||
billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS)
|
||||
billing: PlanBilling | None = None
|
||||
if billing_fields_set:
|
||||
billing = (await db.execute(
|
||||
select(PlanBilling).where(PlanBilling.plan == data.plan)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if billing is None:
|
||||
# Create. display_name is required on the model — derive from the
|
||||
# plan name when the caller didn't supply one (e.g. "pro" → "Pro").
|
||||
display_name = data.display_name or data.plan.capitalize()
|
||||
billing = PlanBilling(plan=data.plan, display_name=display_name)
|
||||
db.add(billing)
|
||||
|
||||
# Apply only the fields the caller actually included. Allows partial
|
||||
# updates without clobbering existing values.
|
||||
for field in billing_fields_set:
|
||||
value = getattr(data, field)
|
||||
if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS:
|
||||
# Don't NULL out a NOT NULL column on update.
|
||||
continue
|
||||
setattr(billing, field, value)
|
||||
|
||||
await log_audit(
|
||||
db, current_user.id, "plan_limits.update", "plan_limits",
|
||||
details={"plan": data.plan, "updated_billing": bool(billing_fields_set)},
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(plan)
|
||||
return plan
|
||||
if billing is not None:
|
||||
await db.refresh(billing)
|
||||
|
||||
# Invalidate any in-process billing cache for accounts on this plan.
|
||||
# TODO: invalidate app.state.billing_cache when added.
|
||||
account_ids = [
|
||||
row[0] for row in (await db.execute(
|
||||
select(Subscription.account_id).where(Subscription.plan == data.plan)
|
||||
)).all()
|
||||
]
|
||||
await BillingService.invalidate_billing_cache(account_ids)
|
||||
|
||||
return _merge_plan_with_billing(plan, billing)
|
||||
|
||||
|
||||
@router.get("/account-overrides", response_model=list[AccountOverrideResponse])
|
||||
|
||||
@@ -47,8 +47,16 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
|
||||
async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
|
||||
"""Decode a refresh token JWT and store its hash in the database."""
|
||||
async def store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
|
||||
"""Decode a refresh token JWT and store its hash in the database.
|
||||
|
||||
Module-public so OAuth callback endpoints (and any future token-issuing
|
||||
surface) can register the JTI in the ``refresh_tokens`` table the same
|
||||
way ``/auth/login`` does. Without this the first ``/auth/refresh`` call
|
||||
will reject the token as "revoked" because no row exists.
|
||||
|
||||
Caller is responsible for committing the session.
|
||||
"""
|
||||
payload = decode_token(refresh_token_str)
|
||||
if payload and payload.get("jti"):
|
||||
token_record = RefreshToken(
|
||||
@@ -136,7 +144,15 @@ async def register(
|
||||
# Validate platform invite code (skip if account invite was provided)
|
||||
invite_code_record = None
|
||||
if not account_invite_record:
|
||||
if settings.REQUIRE_INVITE_CODE and not user_data.invite_code:
|
||||
# When SELF_SERVE_ENABLED is on, the platform invite gate is bypassed
|
||||
# entirely — public self-serve signup is the whole point. The
|
||||
# invite_code field stays in the schema for backward compatibility
|
||||
# and so paid/trial-bearing codes still apply when supplied.
|
||||
if (
|
||||
settings.REQUIRE_INVITE_CODE
|
||||
and not settings.SELF_SERVE_ENABLED
|
||||
and not user_data.invite_code
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invite code is required"
|
||||
@@ -312,7 +328,7 @@ async def login(
|
||||
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
||||
|
||||
# Store refresh token hash in DB
|
||||
await _store_refresh_token(db, refresh_token_str, user.id)
|
||||
await store_refresh_token(db, refresh_token_str, user.id)
|
||||
await db.commit()
|
||||
|
||||
return Token(
|
||||
@@ -347,7 +363,7 @@ async def login_json(
|
||||
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
||||
|
||||
# Store refresh token hash in DB
|
||||
await _store_refresh_token(db, refresh_token_str, user.id)
|
||||
await store_refresh_token(db, refresh_token_str, user.id)
|
||||
await db.commit()
|
||||
|
||||
return Token(
|
||||
@@ -405,7 +421,7 @@ async def refresh_token(
|
||||
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
||||
|
||||
# Store new refresh token
|
||||
await _store_refresh_token(db, new_refresh_token_str, user.id)
|
||||
await store_refresh_token(db, new_refresh_token_str, user.id)
|
||||
await db.commit()
|
||||
|
||||
return Token(
|
||||
|
||||
@@ -1,31 +1,44 @@
|
||||
"""Public beta signup endpoint — no auth required."""
|
||||
"""Legacy beta signup endpoint — redirects to /register?from=beta.
|
||||
|
||||
Phase 2 (self-serve signup) makes the public register flow the canonical
|
||||
front door. The old `/api/v1/beta-signup` POST endpoint is kept mounted to
|
||||
preserve any external links that still hit it, but now responds with a
|
||||
307 Temporary Redirect to `/register?from=beta` so the user lands in the
|
||||
real signup flow. The `?from=beta` marker lets the frontend tag the
|
||||
signup origin for analytics.
|
||||
|
||||
Note: there is no `beta_signup` database table — the original endpoint
|
||||
only fired a notification email. There is therefore no waitlist to email
|
||||
and no migration to run when retiring the endpoint.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from app.core.email import EmailService
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/beta-signup", tags=["beta"])
|
||||
|
||||
|
||||
class BetaSignupRequest(BaseModel):
|
||||
email: EmailStr
|
||||
# Local-dev fallback when FRONTEND_URL isn't configured. The redirect must
|
||||
# be absolute — a relative URL would resolve against the API origin
|
||||
# (api.resolutionflow.com), which has no /register page.
|
||||
_DEFAULT_FRONTEND_URL = "http://localhost:5173"
|
||||
|
||||
|
||||
class BetaSignupResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
@router.post("", include_in_schema=False)
|
||||
async def beta_signup_redirect() -> RedirectResponse:
|
||||
"""Redirect legacy beta-signup POST to the public register page.
|
||||
|
||||
|
||||
@router.post("", response_model=BetaSignupResponse)
|
||||
async def beta_signup(data: BetaSignupRequest):
|
||||
"""Collect beta interest — sends notification to beta@resolutionflow.com."""
|
||||
sent = await EmailService.send_beta_signup_notification(data.email)
|
||||
if not sent:
|
||||
logger.warning("Beta signup recorded (email delivery skipped): %s", data.email)
|
||||
return BetaSignupResponse(
|
||||
success=True,
|
||||
message="Thanks! We'll be in touch with beta access details.",
|
||||
Returns 307 so any client following the redirect preserves the HTTP
|
||||
method; the frontend treats `/register?from=beta` as the canonical
|
||||
entry point and reads the `from` query param for analytics.
|
||||
"""
|
||||
frontend_url = settings.FRONTEND_URL or _DEFAULT_FRONTEND_URL
|
||||
return RedirectResponse(
|
||||
url=f"{frontend_url}/register?from=beta",
|
||||
status_code=307,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.core.config import settings
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.schemas.billing import (
|
||||
BillingPortalSessionResponse,
|
||||
BillingStateResponse,
|
||||
CheckoutSessionCreate,
|
||||
CheckoutSessionResponse,
|
||||
@@ -50,3 +51,26 @@ async def get_billing_state(
|
||||
)).scalar_one()
|
||||
state = await BillingService.get_billing_state(db, account)
|
||||
return BillingStateResponse(**state)
|
||||
|
||||
|
||||
@router.get("/portal-session", response_model=BillingPortalSessionResponse)
|
||||
async def get_billing_portal_session(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> BillingPortalSessionResponse:
|
||||
"""Return a Stripe-hosted Customer Portal URL for the account so the user
|
||||
can update card / cancel. Allowlisted from the subscription + email-verify
|
||||
guards (a canceled or unverified-past-grace user must still be able to
|
||||
update billing)."""
|
||||
if not settings.stripe_enabled:
|
||||
raise HTTPException(status_code=503, detail={"error": "stripe_not_configured"})
|
||||
|
||||
account = (await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)).scalar_one()
|
||||
|
||||
try:
|
||||
url = await BillingService.open_customer_portal(account)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail={"error": "no_stripe_customer"})
|
||||
return BillingPortalSessionResponse(url=url)
|
||||
|
||||
40
backend/app/api/endpoints/config.py
Normal file
40
backend/app/api/endpoints/config.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Public runtime configuration endpoint.
|
||||
|
||||
GET /api/v1/config/public
|
||||
Returns the small set of runtime flags the frontend needs at app load
|
||||
to decide whether to render the self-serve signup flow and which OAuth
|
||||
buttons to show. No authentication required.
|
||||
|
||||
The response model lives in `app.schemas.config` so it can be reused by
|
||||
frontend codegen and other call sites if needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.config import PublicConfigResponse
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/public", response_model=PublicConfigResponse)
|
||||
async def get_public_config() -> PublicConfigResponse:
|
||||
"""Return public-safe runtime config.
|
||||
|
||||
`oauth_providers` reflects which OAuth client IDs are configured server
|
||||
side; the frontend uses it to render only buttons that will actually
|
||||
succeed. `self_serve_enabled` is the master switch for the new public
|
||||
self-serve signup flow.
|
||||
"""
|
||||
providers: list[str] = []
|
||||
if settings.GOOGLE_CLIENT_ID:
|
||||
providers.append("google")
|
||||
if settings.MS_CLIENT_ID:
|
||||
providers.append("microsoft")
|
||||
|
||||
return PublicConfigResponse(
|
||||
self_serve_enabled=settings.SELF_SERVE_ENABLED,
|
||||
oauth_providers=providers,
|
||||
)
|
||||
@@ -7,10 +7,12 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.endpoints.auth import store_refresh_token
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token, create_refresh_token
|
||||
from app.models.account import Account
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.models.oauth_identity import OAuthIdentity
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import OAuthCallbackPayload, OAuthCallbackResponse
|
||||
@@ -31,9 +33,21 @@ def _generate_display_code(length: int = 8) -> str:
|
||||
|
||||
|
||||
async def _sign_in_or_register(
|
||||
db: AsyncSession, provider: str, profile: OAuthProfile
|
||||
db: AsyncSession,
|
||||
provider: str,
|
||||
profile: OAuthProfile,
|
||||
*,
|
||||
account_invite_code: str | None = None,
|
||||
invited_email: str | None = None,
|
||||
) -> tuple[User, bool]:
|
||||
"""Returns (user, is_new_user). Idempotent on (provider, provider_subject)."""
|
||||
"""Returns (user, is_new_user). Idempotent on (provider, provider_subject).
|
||||
|
||||
When ``account_invite_code`` is supplied (from the /accept-invite flow),
|
||||
a brand-new user is created inside the invited account instead of getting
|
||||
a personal account + Pro trial. Mismatch between the OAuth profile email
|
||||
and ``invited_email`` raises ``invite_email_mismatch`` per the spec
|
||||
contract that mirrors the email+password register path.
|
||||
"""
|
||||
identity = (
|
||||
await db.execute(
|
||||
select(OAuthIdentity).where(
|
||||
@@ -53,28 +67,96 @@ async def _sign_in_or_register(
|
||||
await db.execute(select(User).where(User.email == profile.email))
|
||||
).scalar_one_or_none()
|
||||
is_new_user = user is None
|
||||
|
||||
# If the user arrived via an invite link but already has a ResolutionFlow
|
||||
# account (e.g., previously signed up with email+password), silently
|
||||
# linking the OAuth identity to that existing account would bypass the
|
||||
# invite — they'd stay in their personal account and the invite would
|
||||
# never be consumed. Fail loud instead so they can sign in and accept the
|
||||
# invite from the dashboard. The "invited user wants to transfer accounts"
|
||||
# case is a v2 concern.
|
||||
if account_invite_code and not is_new_user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "email_already_registered_use_login",
|
||||
"message": (
|
||||
"An account already exists for this email. Please sign in "
|
||||
"instead, then accept the invite from your dashboard."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
invite_record: AccountInvite | None = None
|
||||
if is_new_user and account_invite_code:
|
||||
# SELECT FOR UPDATE so two concurrent OAuth callbacks can't both
|
||||
# consume the same invite code.
|
||||
invite_record = (
|
||||
await db.execute(
|
||||
select(AccountInvite)
|
||||
.where(AccountInvite.code == account_invite_code)
|
||||
.with_for_update()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if invite_record is None or not invite_record.is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "invite_invalid_or_expired_or_revoked"},
|
||||
)
|
||||
# Verify the OAuth profile email matches what was invited. We compare
|
||||
# against the invite row directly (source of truth), but also accept
|
||||
# the client-supplied invited_email as a defensive equality check.
|
||||
if invite_record.email.lower() != profile.email.lower():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "invite_email_mismatch"},
|
||||
)
|
||||
if invited_email and invited_email.lower() != invite_record.email.lower():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "invite_email_mismatch"},
|
||||
)
|
||||
|
||||
if is_new_user:
|
||||
account = Account(
|
||||
name=f"{profile.name}'s Account",
|
||||
display_code=_generate_display_code(),
|
||||
)
|
||||
db.add(account)
|
||||
await db.flush()
|
||||
user = User(
|
||||
email=profile.email,
|
||||
name=profile.name,
|
||||
password_hash=None,
|
||||
account_id=account.id,
|
||||
account_role="owner",
|
||||
role="engineer",
|
||||
email_verified_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
account.owner_id = user.id
|
||||
await db.flush()
|
||||
# start_trial commits internally; flushed account/user above.
|
||||
await BillingService.start_trial(db, account.id)
|
||||
if invite_record is not None:
|
||||
# Join the invited account directly — no personal account, no
|
||||
# trial creation.
|
||||
user = User(
|
||||
email=profile.email,
|
||||
name=profile.name,
|
||||
password_hash=None,
|
||||
account_id=invite_record.account_id,
|
||||
account_role=invite_record.role,
|
||||
role="engineer",
|
||||
email_verified_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
invite_record.accepted_by_id = user.id
|
||||
invite_record.used_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
else:
|
||||
account = Account(
|
||||
name=f"{profile.name}'s Account",
|
||||
display_code=_generate_display_code(),
|
||||
)
|
||||
db.add(account)
|
||||
await db.flush()
|
||||
user = User(
|
||||
email=profile.email,
|
||||
name=profile.name,
|
||||
password_hash=None,
|
||||
account_id=account.id,
|
||||
account_role="owner",
|
||||
role="engineer",
|
||||
email_verified_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
account.owner_id = user.id
|
||||
await db.flush()
|
||||
# start_trial commits internally; flushed account/user above.
|
||||
await BillingService.start_trial(db, account.id)
|
||||
|
||||
db.add(
|
||||
OAuthIdentity(
|
||||
@@ -98,10 +180,23 @@ async def google_callback(
|
||||
raise HTTPException(status_code=503, detail="Google sign-in not configured")
|
||||
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/google/callback"
|
||||
profile = await google_exchange_code(payload.code, redirect_uri)
|
||||
user, is_new = await _sign_in_or_register(db, "google", profile)
|
||||
user, is_new = await _sign_in_or_register(
|
||||
db,
|
||||
"google",
|
||||
profile,
|
||||
account_invite_code=payload.account_invite_code,
|
||||
invited_email=payload.invited_email,
|
||||
)
|
||||
refresh_token_str = create_refresh_token({"sub": str(user.id)})
|
||||
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
|
||||
# reject this token as "revoked" (the rotation logic requires a row to
|
||||
# mark as used). _sign_in_or_register already committed; this needs a
|
||||
# second commit.
|
||||
await store_refresh_token(db, refresh_token_str, user.id)
|
||||
await db.commit()
|
||||
return OAuthCallbackResponse(
|
||||
access_token=create_access_token({"sub": str(user.id)}),
|
||||
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||
refresh_token=refresh_token_str,
|
||||
is_new_user=is_new,
|
||||
)
|
||||
|
||||
@@ -115,9 +210,22 @@ async def microsoft_callback(
|
||||
raise HTTPException(status_code=503, detail="Microsoft sign-in not configured")
|
||||
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/microsoft/callback"
|
||||
profile = await microsoft_exchange_code(payload.code, redirect_uri)
|
||||
user, is_new = await _sign_in_or_register(db, "microsoft", profile)
|
||||
user, is_new = await _sign_in_or_register(
|
||||
db,
|
||||
"microsoft",
|
||||
profile,
|
||||
account_invite_code=payload.account_invite_code,
|
||||
invited_email=payload.invited_email,
|
||||
)
|
||||
refresh_token_str = create_refresh_token({"sub": str(user.id)})
|
||||
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
|
||||
# reject this token as "revoked" (the rotation logic requires a row to
|
||||
# mark as used). _sign_in_or_register already committed; this needs a
|
||||
# second commit.
|
||||
await store_refresh_token(db, refresh_token_str, user.id)
|
||||
await db.commit()
|
||||
return OAuthCallbackResponse(
|
||||
access_token=create_access_token({"sub": str(user.id)}),
|
||||
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||
refresh_token=refresh_token_str,
|
||||
is_new_user=is_new,
|
||||
)
|
||||
|
||||
@@ -2,19 +2,24 @@
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.models.account import Account
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
from app.models.psa_connection import PsaConnection
|
||||
from app.models.session import Session
|
||||
from app.models.tree import Tree
|
||||
from app.models.user import User
|
||||
from app.schemas.onboarding import OnboardingStatus
|
||||
from app.schemas.onboarding import (
|
||||
OnboardingStatus,
|
||||
OnboardingStepRequest,
|
||||
OnboardingStepResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["onboarding"])
|
||||
|
||||
@@ -85,6 +90,10 @@ async def get_onboarding_status(
|
||||
)
|
||||
connected_psa = (psa_q.scalar() or 0) > 0
|
||||
|
||||
# New (Phase 2 — Task 41)
|
||||
email_verified = current_user.email_verified_at is not None
|
||||
shop_setup_done = (current_user.onboarding_step_completed or 0) >= 1
|
||||
|
||||
return OnboardingStatus(
|
||||
created_flow=created_flow,
|
||||
ran_session=ran_session,
|
||||
@@ -94,6 +103,8 @@ async def get_onboarding_status(
|
||||
connected_psa=connected_psa,
|
||||
is_team_user=is_team_user,
|
||||
dismissed=current_user.onboarding_dismissed,
|
||||
email_verified=email_verified,
|
||||
shop_setup_done=shop_setup_done,
|
||||
)
|
||||
|
||||
|
||||
@@ -109,3 +120,98 @@ async def dismiss_onboarding(
|
||||
|
||||
# Return updated status (reuse the GET logic)
|
||||
return await get_onboarding_status(db=db, current_user=current_user)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Welcome wizard endpoints (Phase 2)
|
||||
#
|
||||
# These persist Step 1/2/3 progress for the post-signup welcome wizard.
|
||||
# Mounted on /users/me/* (the parent router prefix is /users) so the wizard
|
||||
# can run before email verification and during trial.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.patch("/me/onboarding-step", response_model=OnboardingStepResponse)
|
||||
async def patch_onboarding_step(
|
||||
body: OnboardingStepRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> OnboardingStepResponse:
|
||||
"""Persist welcome-wizard progress for the current user.
|
||||
|
||||
Contract:
|
||||
- step=1 + complete writes accounts.name, accounts.team_size_bucket,
|
||||
users.role_at_signup, then sets users.onboarding_step_completed=1.
|
||||
- step=2 + complete writes accounts.primary_psa, then sets
|
||||
users.onboarding_step_completed=2.
|
||||
- step=3 + complete just sets users.onboarding_step_completed=3
|
||||
(invites are POSTed separately).
|
||||
- action="skip" ignores `data` entirely and only advances the step.
|
||||
- The new step must be >= current onboarding_step_completed (None=>0);
|
||||
otherwise 400. Idempotent re-PATCH of the same step succeeds.
|
||||
"""
|
||||
current_step = current_user.onboarding_step_completed or 0
|
||||
if body.step < current_step:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "step_cannot_decrease",
|
||||
"current_step": current_step,
|
||||
"requested_step": body.step,
|
||||
},
|
||||
)
|
||||
|
||||
if body.action == "complete" and body.data is not None and body.step in (1, 2):
|
||||
# Load the user's account for field writes. Step 3 has no data writes.
|
||||
account_result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = account_result.scalar_one_or_none()
|
||||
if account is None:
|
||||
# Should never happen — user is required to have an account_id.
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="account_not_found",
|
||||
)
|
||||
|
||||
if body.step == 1:
|
||||
data = body.data
|
||||
if data.company_name is not None:
|
||||
account.name = data.company_name
|
||||
if data.team_size_bucket is not None:
|
||||
account.team_size_bucket = data.team_size_bucket
|
||||
if data.role_at_signup is not None:
|
||||
current_user.role_at_signup = data.role_at_signup
|
||||
elif body.step == 2:
|
||||
data = body.data
|
||||
if data.primary_psa is not None:
|
||||
account.primary_psa = data.primary_psa
|
||||
|
||||
current_user.onboarding_step_completed = body.step
|
||||
await db.commit()
|
||||
await db.refresh(current_user)
|
||||
|
||||
return OnboardingStepResponse(
|
||||
onboarding_step_completed=current_user.onboarding_step_completed,
|
||||
onboarding_dismissed=current_user.onboarding_dismissed,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/me/onboarding-dismiss-rest", response_model=OnboardingStepResponse)
|
||||
async def dismiss_onboarding_rest(
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> OnboardingStepResponse:
|
||||
"""Set users.onboarding_dismissed=TRUE — backs the wizard's "Skip the rest" button.
|
||||
|
||||
Returns the same shape as the step PATCH so the frontend can update its
|
||||
local store from a single response.
|
||||
"""
|
||||
current_user.onboarding_dismissed = True
|
||||
await db.commit()
|
||||
await db.refresh(current_user)
|
||||
|
||||
return OnboardingStepResponse(
|
||||
onboarding_step_completed=current_user.onboarding_step_completed,
|
||||
onboarding_dismissed=current_user.onboarding_dismissed,
|
||||
)
|
||||
|
||||
58
backend/app/api/endpoints/plans_public.py
Normal file
58
backend/app/api/endpoints/plans_public.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Public plans endpoint — no auth required.
|
||||
|
||||
GET /api/v1/plans/public
|
||||
Returns the public-safe view of `plan_billing` joined with
|
||||
`plan_limits.max_users` (exposed as `max_seats`), filtered to
|
||||
`is_public=True AND is_archived=False`, ordered by sort_order ASC, plan ASC.
|
||||
|
||||
Distinct from `/admin/plan-limits` (admin-only, returns ALL plans including
|
||||
archived/internal). This endpoint exists to power the marketing /pricing page
|
||||
without exposing the rest of the admin-only billing surface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.models.plan_billing import PlanBilling
|
||||
from app.models.plan_limits import PlanLimits
|
||||
from app.schemas.billing import PublicPlanResponse
|
||||
|
||||
router = APIRouter(prefix="/plans", tags=["plans"])
|
||||
|
||||
|
||||
@router.get("/public", response_model=list[PublicPlanResponse])
|
||||
async def list_public_plans(
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> list[PublicPlanResponse]:
|
||||
"""List public, non-archived plans for the marketing /pricing page.
|
||||
|
||||
Public — no auth. Uses `get_admin_db` because this is a cross-tenant read
|
||||
of the global plan catalog (same pattern as `/config/public`).
|
||||
"""
|
||||
stmt = (
|
||||
select(PlanBilling, PlanLimits.max_users)
|
||||
.outerjoin(PlanLimits, PlanBilling.plan == PlanLimits.plan)
|
||||
.where(PlanBilling.is_public.is_(True))
|
||||
.where(PlanBilling.is_archived.is_(False))
|
||||
.order_by(PlanBilling.sort_order.asc(), PlanBilling.plan.asc())
|
||||
)
|
||||
rows = (await db.execute(stmt)).all()
|
||||
return [
|
||||
PublicPlanResponse(
|
||||
plan=billing.plan,
|
||||
display_name=billing.display_name,
|
||||
description=billing.description,
|
||||
monthly_price_cents=billing.monthly_price_cents,
|
||||
annual_price_cents=billing.annual_price_cents,
|
||||
max_seats=max_users,
|
||||
sort_order=billing.sort_order,
|
||||
is_public=billing.is_public,
|
||||
)
|
||||
for billing, max_users in rows
|
||||
]
|
||||
114
backend/app/api/endpoints/sales_leads.py
Normal file
114
backend/app/api/endpoints/sales_leads.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Public Talk-to-Sales endpoint — no auth required.
|
||||
|
||||
POST /api/v1/sales-leads
|
||||
- Inserts a sales_leads row.
|
||||
- Fires (best-effort) a notification email to settings.SALES_LEAD_RECIPIENT_EMAIL.
|
||||
- Emits a server-side PostHog event (best-effort).
|
||||
- Rate-limited per IP (5/hour).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.config import settings
|
||||
from app.core.email import EmailService
|
||||
from app.core.rate_limit import limiter
|
||||
from app.models.sales_lead import SalesLead
|
||||
from app.schemas.sales_lead import SalesLeadCreate, SalesLeadCreateResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/sales-leads", tags=["sales"])
|
||||
|
||||
|
||||
async def _send_notification_email(lead: SalesLead) -> None:
|
||||
"""Fire-and-forget wrapper. EmailService methods never raise, but we
|
||||
still wrap in a try/except to defend against future regressions."""
|
||||
try:
|
||||
await EmailService.send_sales_lead_notification(
|
||||
to_email=settings.SALES_LEAD_RECIPIENT_EMAIL,
|
||||
lead=lead,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Sales lead notification email failed for lead %s",
|
||||
lead.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def _capture_posthog_event(lead: SalesLead) -> None:
|
||||
"""Emit `talk_to_sales_form_submitted` server-side. Best-effort.
|
||||
|
||||
Backend PostHog SDK isn't initialized in the project today; this function
|
||||
is the single instrumentation point so wiring it up later is a one-line
|
||||
change. The call is wrapped so any future failure can never fail the
|
||||
request.
|
||||
"""
|
||||
try:
|
||||
# Lazy import — keeps the dependency optional. When the backend
|
||||
# PostHog client is wired in (likely as `app.core.analytics.posthog`),
|
||||
# swap the import path here and the event will fire automatically.
|
||||
try:
|
||||
from app.core.analytics import posthog # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
"PostHog server-side capture skipped — client not configured"
|
||||
)
|
||||
return
|
||||
|
||||
distinct_id = lead.posthog_distinct_id or f"sales_lead:{lead.id}"
|
||||
posthog.capture(
|
||||
distinct_id=distinct_id,
|
||||
event="talk_to_sales_form_submitted",
|
||||
properties={
|
||||
"source": lead.source,
|
||||
"company": lead.company,
|
||||
"team_size": lead.team_size,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"PostHog capture failed for sales lead %s",
|
||||
lead.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=SalesLeadCreateResponse, status_code=201)
|
||||
@limiter.limit("5/hour")
|
||||
async def create_sales_lead(
|
||||
request: Request,
|
||||
data: SalesLeadCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> SalesLeadCreateResponse:
|
||||
"""Public Talk-to-Sales submission.
|
||||
|
||||
Creates a sales_leads row, fires (best-effort) a notification email and a
|
||||
server-side PostHog event. Rate-limited per IP at 5/hour.
|
||||
"""
|
||||
lead = SalesLead(
|
||||
email=str(data.email).lower(),
|
||||
name=data.name,
|
||||
company=data.company,
|
||||
team_size=data.team_size,
|
||||
message=data.message,
|
||||
source=data.source,
|
||||
posthog_distinct_id=data.posthog_distinct_id,
|
||||
)
|
||||
db.add(lead)
|
||||
await db.commit()
|
||||
await db.refresh(lead)
|
||||
|
||||
# Fire-and-forget: email + analytics. Failures must not fail the request.
|
||||
asyncio.create_task(_send_notification_email(lead))
|
||||
_capture_posthog_event(lead)
|
||||
|
||||
return SalesLeadCreateResponse(id=lead.id, status="received")
|
||||
@@ -26,8 +26,10 @@ from app.api.endpoints import (
|
||||
billing,
|
||||
beta_feedback,
|
||||
beta_signup,
|
||||
sales_leads,
|
||||
branding,
|
||||
categories,
|
||||
config as config_endpoints,
|
||||
copilot,
|
||||
device_types,
|
||||
draft_templates,
|
||||
@@ -43,6 +45,7 @@ from app.api.endpoints import (
|
||||
notifications,
|
||||
oauth as oauth_endpoints,
|
||||
onboarding,
|
||||
plans_public,
|
||||
public_templates,
|
||||
ratings,
|
||||
scripts,
|
||||
@@ -68,6 +71,7 @@ from app.api.endpoints import (
|
||||
uploads,
|
||||
webhooks,
|
||||
accounts,
|
||||
account_invite_lookup,
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
@@ -88,9 +92,13 @@ api_router.include_router(billing.router) # Reachable when subscription lock
|
||||
api_router.include_router(shared.router) # Public share links (no auth)
|
||||
api_router.include_router(shares.public_router) # Public session share links (optional auth)
|
||||
api_router.include_router(beta_signup.router)
|
||||
api_router.include_router(sales_leads.router) # Talk-to-Sales (no auth, rate-limited)
|
||||
api_router.include_router(webhooks.router) # Stripe webhook receiver
|
||||
api_router.include_router(public_templates.router) # Public gallery (no auth, rate-limited)
|
||||
api_router.include_router(survey.router) # Public survey flow (no auth, rate-limited)
|
||||
api_router.include_router(config_endpoints.router) # Public runtime feature flags
|
||||
api_router.include_router(account_invite_lookup.router) # Public invite-code lookup for /accept-invite
|
||||
api_router.include_router(plans_public.router) # Public plan catalog for /pricing page
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin endpoints — super_admin only
|
||||
|
||||
Reference in New Issue
Block a user