From 2f8ec3775e1e58b694f60776870b336c6b27fab3 Mon Sep 17 00:00:00 2001
From: Michael Chihlas
Date: Wed, 6 May 2026 19:55:42 -0400
Subject: [PATCH 01/29] feat(billing): add BillingService.open_customer_portal
+ GET endpoint
Authed users can now request a Stripe-hosted Customer Portal URL for card
updates and cancellation via GET /api/v1/billing/portal-session. The path is
already in both _SUBSCRIPTION_GUARD_ALLOWLIST and _EMAIL_VERIFICATION_ALLOWLIST
so canceled or unverified-past-grace users can still update billing.
- Returns 503 with {"error": "stripe_not_configured"} when STRIPE_SECRET_KEY unset.
- Returns 400 with {"error": "no_stripe_customer"} when account has no
stripe_customer_id (must complete checkout first).
Co-Authored-By: Claude Opus 4.7
---
backend/app/api/endpoints/billing.py | 26 ++++++++-
backend/app/schemas/billing.py | 4 ++
backend/app/services/billing.py | 19 +++++++
backend/tests/test_billing_portal.py | 83 ++++++++++++++++++++++++++++
4 files changed, 131 insertions(+), 1 deletion(-)
create mode 100644 backend/tests/test_billing_portal.py
diff --git a/backend/app/api/endpoints/billing.py b/backend/app/api/endpoints/billing.py
index 23d067d4..7fa8694e 100644
--- a/backend/app/api/endpoints/billing.py
+++ b/backend/app/api/endpoints/billing.py
@@ -1,6 +1,6 @@
from typing import Annotated
-from fastapi import APIRouter, Depends
+from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -10,6 +10,7 @@ from app.core.config import settings
from app.models.account import Account
from app.models.user import User
from app.schemas.billing import (
+ BillingPortalSessionResponse,
BillingStateResponse,
CheckoutSessionCreate,
CheckoutSessionResponse,
@@ -50,3 +51,26 @@ async def get_billing_state(
)).scalar_one()
state = await BillingService.get_billing_state(db, account)
return BillingStateResponse(**state)
+
+
+@router.get("/portal-session", response_model=BillingPortalSessionResponse)
+async def get_billing_portal_session(
+ current_user: Annotated[User, Depends(get_current_active_user)],
+ db: Annotated[AsyncSession, Depends(get_admin_db)],
+) -> BillingPortalSessionResponse:
+ """Return a Stripe-hosted Customer Portal URL for the account so the user
+ can update card / cancel. Allowlisted from the subscription + email-verify
+ guards (a canceled or unverified-past-grace user must still be able to
+ update billing)."""
+ if not settings.stripe_enabled:
+ raise HTTPException(status_code=503, detail={"error": "stripe_not_configured"})
+
+ account = (await db.execute(
+ select(Account).where(Account.id == current_user.account_id)
+ )).scalar_one()
+
+ try:
+ url = await BillingService.open_customer_portal(account)
+ except ValueError:
+ raise HTTPException(status_code=400, detail={"error": "no_stripe_customer"})
+ return BillingPortalSessionResponse(url=url)
diff --git a/backend/app/schemas/billing.py b/backend/app/schemas/billing.py
index ebe9ab9d..0a1bcf98 100644
--- a/backend/app/schemas/billing.py
+++ b/backend/app/schemas/billing.py
@@ -13,6 +13,10 @@ class CheckoutSessionResponse(BaseModel):
url: str
+class BillingPortalSessionResponse(BaseModel):
+ url: str
+
+
class SubscriptionState(BaseModel):
status: str
plan: str
diff --git a/backend/app/services/billing.py b/backend/app/services/billing.py
index a104a5b1..b662ed47 100644
--- a/backend/app/services/billing.py
+++ b/backend/app/services/billing.py
@@ -105,6 +105,25 @@ class BillingService:
)
return session.url
+ @staticmethod
+ async def open_customer_portal(account: Account) -> str:
+ """Create a Stripe-hosted Customer Portal session and return the URL.
+
+ Raises RuntimeError if Stripe isn't configured (endpoint maps to 503).
+ Raises ValueError if the account has no stripe_customer_id yet — the
+ user must complete a checkout first (endpoint maps to 400).
+ """
+ if not settings.stripe_enabled:
+ raise RuntimeError("Stripe not configured")
+ if account.stripe_customer_id is None:
+ raise ValueError("no_stripe_customer")
+ stripe.api_key = settings.STRIPE_SECRET_KEY
+ session = stripe.billing_portal.Session.create(
+ customer=account.stripe_customer_id,
+ return_url=f"{settings.FRONTEND_URL}/account/billing",
+ )
+ return session.url
+
@staticmethod
async def get_billing_state(db: AsyncSession, account):
"""Aggregate Subscription + PlanLimits + PlanBilling + resolved feature
diff --git a/backend/tests/test_billing_portal.py b/backend/tests/test_billing_portal.py
new file mode 100644
index 00000000..76841a7a
--- /dev/null
+++ b/backend/tests/test_billing_portal.py
@@ -0,0 +1,83 @@
+import uuid
+import pytest
+from unittest.mock import patch, MagicMock
+from sqlalchemy import select
+
+from app.models.account import Account
+
+
+@pytest.mark.asyncio
+async def test_billing_portal_returns_url_for_account_with_stripe_customer(
+ client, test_db, test_user, auth_headers, monkeypatch
+):
+ """Happy path: account has a stripe_customer_id and Stripe is configured →
+ GET /billing/portal-session returns the portal URL."""
+ from app.core.config import settings
+ monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy")
+ monkeypatch.setattr(settings, "FRONTEND_URL", "https://app.example.com")
+
+ account_id = uuid.UUID(test_user["user_data"]["account_id"])
+ account = (await test_db.execute(
+ select(Account).where(Account.id == account_id)
+ )).scalar_one()
+ account.stripe_customer_id = "cus_test_456"
+ await test_db.commit()
+
+ fake_session = MagicMock()
+ fake_session.url = "https://billing.stripe.com/p/session/test_abc"
+
+ with patch(
+ "stripe.billing_portal.Session.create",
+ return_value=fake_session,
+ ) as portal_mock:
+ response = await client.get(
+ "/api/v1/billing/portal-session",
+ headers=auth_headers,
+ )
+
+ assert response.status_code == 200, response.json()
+ assert response.json() == {"url": "https://billing.stripe.com/p/session/test_abc"}
+ portal_mock.assert_called_once()
+ call_kwargs = portal_mock.call_args.kwargs
+ assert call_kwargs["customer"] == "cus_test_456"
+ assert call_kwargs["return_url"] == "https://app.example.com/account/billing"
+
+
+@pytest.mark.asyncio
+async def test_billing_portal_returns_503_when_stripe_not_configured(
+ client, test_db, test_user, auth_headers, monkeypatch
+):
+ """STRIPE_SECRET_KEY unset → settings.stripe_enabled is False → 503."""
+ from app.core.config import settings
+ monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", None)
+
+ response = await client.get(
+ "/api/v1/billing/portal-session",
+ headers=auth_headers,
+ )
+ assert response.status_code == 503
+ assert response.json()["detail"]["error"] == "stripe_not_configured"
+
+
+@pytest.mark.asyncio
+async def test_billing_portal_returns_400_when_account_has_no_stripe_customer(
+ client, test_db, test_user, auth_headers, monkeypatch
+):
+ """Account with no stripe_customer_id (never completed checkout) → 400
+ with `no_stripe_customer` error."""
+ from app.core.config import settings
+ monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy")
+
+ # test_user fixture seeds an account with no stripe_customer_id by default.
+ account_id = uuid.UUID(test_user["user_data"]["account_id"])
+ account = (await test_db.execute(
+ select(Account).where(Account.id == account_id)
+ )).scalar_one()
+ assert account.stripe_customer_id is None
+
+ response = await client.get(
+ "/api/v1/billing/portal-session",
+ headers=auth_headers,
+ )
+ assert response.status_code == 400
+ assert response.json()["detail"]["error"] == "no_stripe_customer"
--
2.49.1
From 16f5e4ce051aa04a40fd66557e52f90d196001fb Mon Sep 17 00:00:00 2001
From: Michael Chihlas
Date: Wed, 6 May 2026 20:04:43 -0400
Subject: [PATCH 02/29] feat(onboarding): add PATCH /users/me/onboarding-step +
dismiss-rest
Persists welcome-wizard Step 1/2/3 progress for self-serve signup Phase 2.
PATCH validates step cannot decrease, ignores `data` on action="skip", and
is idempotent on re-PATCH of the same step. POST /users/me/onboarding-dismiss-rest
backs the wizard's "Skip the rest" button.
Both routes added to _EMAIL_VERIFICATION_ALLOWLIST and _SUBSCRIPTION_GUARD_ALLOWLIST
so the wizard runs before email verification and during the trial. 4 integration
tests cover field writes, skip semantics, decrease guard, and dismiss-rest.
Co-Authored-By: Claude Opus 4.7
---
backend/app/api/deps.py | 3 +
backend/app/api/endpoints/onboarding.py | 104 ++++++++++++++++-
backend/app/schemas/onboarding.py | 41 ++++++-
backend/tests/test_onboarding_step.py | 149 ++++++++++++++++++++++++
4 files changed, 294 insertions(+), 3 deletions(-)
create mode 100644 backend/tests/test_onboarding_step.py
diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py
index 9fbd5815..32ada630 100644
--- a/backend/app/api/deps.py
+++ b/backend/app/api/deps.py
@@ -235,6 +235,7 @@ _SUBSCRIPTION_GUARD_ALLOWLIST = {
"/api/v1/billing/portal-session",
"/api/v1/users/me",
"/api/v1/users/me/onboarding-step",
+ "/api/v1/users/me/onboarding-dismiss-rest",
}
@@ -298,6 +299,8 @@ _EMAIL_VERIFICATION_ALLOWLIST = {
"/api/v1/auth/email/verify",
"/api/v1/auth/password/change",
"/api/v1/users/me",
+ "/api/v1/users/me/onboarding-step",
+ "/api/v1/users/me/onboarding-dismiss-rest",
"/api/v1/billing/state",
"/api/v1/billing/checkout-session",
"/api/v1/billing/portal-session",
diff --git a/backend/app/api/endpoints/onboarding.py b/backend/app/api/endpoints/onboarding.py
index 534f58a6..4cecf091 100644
--- a/backend/app/api/endpoints/onboarding.py
+++ b/backend/app/api/endpoints/onboarding.py
@@ -2,19 +2,24 @@
from typing import Annotated
-from fastapi import APIRouter, Depends
+from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_active_user
from app.core.database import get_db
from app.core.admin_database import get_admin_db
+from app.models.account import Account
from app.models.assistant_chat import AssistantChat
from app.models.psa_connection import PsaConnection
from app.models.session import Session
from app.models.tree import Tree
from app.models.user import User
-from app.schemas.onboarding import OnboardingStatus
+from app.schemas.onboarding import (
+ OnboardingStatus,
+ OnboardingStepRequest,
+ OnboardingStepResponse,
+)
router = APIRouter(prefix="/users", tags=["onboarding"])
@@ -109,3 +114,98 @@ async def dismiss_onboarding(
# Return updated status (reuse the GET logic)
return await get_onboarding_status(db=db, current_user=current_user)
+
+
+# ---------------------------------------------------------------------------
+# Welcome wizard endpoints (Phase 2)
+#
+# These persist Step 1/2/3 progress for the post-signup welcome wizard.
+# Mounted on /users/me/* (the parent router prefix is /users) so the wizard
+# can run before email verification and during trial.
+# ---------------------------------------------------------------------------
+
+
+@router.patch("/me/onboarding-step", response_model=OnboardingStepResponse)
+async def patch_onboarding_step(
+ body: OnboardingStepRequest,
+ db: Annotated[AsyncSession, Depends(get_admin_db)],
+ current_user: Annotated[User, Depends(get_current_active_user)],
+) -> OnboardingStepResponse:
+ """Persist welcome-wizard progress for the current user.
+
+ Contract:
+ - step=1 + complete writes accounts.name, accounts.team_size_bucket,
+ users.role_at_signup, then sets users.onboarding_step_completed=1.
+ - step=2 + complete writes accounts.primary_psa, then sets
+ users.onboarding_step_completed=2.
+ - step=3 + complete just sets users.onboarding_step_completed=3
+ (invites are POSTed separately).
+ - action="skip" ignores `data` entirely and only advances the step.
+ - The new step must be >= current onboarding_step_completed (None=>0);
+ otherwise 400. Idempotent re-PATCH of the same step succeeds.
+ """
+ current_step = current_user.onboarding_step_completed or 0
+ if body.step < current_step:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail={
+ "error": "step_cannot_decrease",
+ "current_step": current_step,
+ "requested_step": body.step,
+ },
+ )
+
+ if body.action == "complete" and body.data is not None and body.step in (1, 2):
+ # Load the user's account for field writes. Step 3 has no data writes.
+ account_result = await db.execute(
+ select(Account).where(Account.id == current_user.account_id)
+ )
+ account = account_result.scalar_one_or_none()
+ if account is None:
+ # Should never happen — user is required to have an account_id.
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="account_not_found",
+ )
+
+ if body.step == 1:
+ data = body.data
+ if data.company_name is not None:
+ account.name = data.company_name
+ if data.team_size_bucket is not None:
+ account.team_size_bucket = data.team_size_bucket
+ if data.role_at_signup is not None:
+ current_user.role_at_signup = data.role_at_signup
+ elif body.step == 2:
+ data = body.data
+ if data.primary_psa is not None:
+ account.primary_psa = data.primary_psa
+
+ current_user.onboarding_step_completed = body.step
+ await db.commit()
+ await db.refresh(current_user)
+
+ return OnboardingStepResponse(
+ onboarding_step_completed=current_user.onboarding_step_completed,
+ onboarding_dismissed=current_user.onboarding_dismissed,
+ )
+
+
+@router.post("/me/onboarding-dismiss-rest", response_model=OnboardingStepResponse)
+async def dismiss_onboarding_rest(
+ db: Annotated[AsyncSession, Depends(get_admin_db)],
+ current_user: Annotated[User, Depends(get_current_active_user)],
+) -> OnboardingStepResponse:
+ """Set users.onboarding_dismissed=TRUE — backs the wizard's "Skip the rest" button.
+
+ Returns the same shape as the step PATCH so the frontend can update its
+ local store from a single response.
+ """
+ current_user.onboarding_dismissed = True
+ await db.commit()
+ await db.refresh(current_user)
+
+ return OnboardingStepResponse(
+ onboarding_step_completed=current_user.onboarding_step_completed,
+ onboarding_dismissed=current_user.onboarding_dismissed,
+ )
diff --git a/backend/app/schemas/onboarding.py b/backend/app/schemas/onboarding.py
index d21647b5..303e1ceb 100644
--- a/backend/app/schemas/onboarding.py
+++ b/backend/app/schemas/onboarding.py
@@ -1,4 +1,6 @@
-from pydantic import BaseModel
+from typing import Literal, Optional
+
+from pydantic import BaseModel, Field
class OnboardingStatus(BaseModel):
@@ -10,3 +12,40 @@ class OnboardingStatus(BaseModel):
connected_psa: bool
is_team_user: bool
dismissed: bool
+
+
+# --- Welcome wizard (Phase 2) ----------------------------------------------
+
+
+TeamSizeBucket = Literal["1-2", "3-5", "6-10", "11-25", "26+"]
+RoleAtSignup = Literal["owner", "lead_tech", "tech", "other"]
+PrimaryPsa = Literal["connectwise", "autotask", "halopsa", "none"]
+WizardStep = Literal[1, 2, 3]
+WizardAction = Literal["complete", "skip"]
+
+
+class OnboardingStepData(BaseModel):
+ """Optional payload carried with `action="complete"` for steps 1 and 2.
+
+ Step 1 fields: company_name, team_size_bucket, role_at_signup
+ Step 2 fields: primary_psa
+ Step 3 has no data (invitations posted separately).
+ """
+
+ # Step 1
+ company_name: Optional[str] = Field(default=None, max_length=255)
+ team_size_bucket: Optional[TeamSizeBucket] = None
+ role_at_signup: Optional[RoleAtSignup] = None
+ # Step 2
+ primary_psa: Optional[PrimaryPsa] = None
+
+
+class OnboardingStepRequest(BaseModel):
+ step: WizardStep
+ action: WizardAction
+ data: Optional[OnboardingStepData] = None
+
+
+class OnboardingStepResponse(BaseModel):
+ onboarding_step_completed: Optional[int]
+ onboarding_dismissed: bool
diff --git a/backend/tests/test_onboarding_step.py b/backend/tests/test_onboarding_step.py
new file mode 100644
index 00000000..eaf9a2a9
--- /dev/null
+++ b/backend/tests/test_onboarding_step.py
@@ -0,0 +1,149 @@
+"""Tests for welcome-wizard onboarding-step endpoints (Phase 2)."""
+
+import pytest
+from sqlalchemy import select
+
+from app.models.account import Account
+from app.models.user import User
+
+
+@pytest.mark.asyncio
+async def test_onboarding_step1_complete_writes_account_name_and_team_size_and_role(
+ client, auth_headers, test_db, test_user
+):
+ """Step 1 + complete writes account.name + team_size_bucket + user.role_at_signup
+ and advances onboarding_step_completed to 1."""
+ response = await client.patch(
+ "/api/v1/users/me/onboarding-step",
+ headers=auth_headers,
+ json={
+ "step": 1,
+ "action": "complete",
+ "data": {
+ "company_name": "Acme MSP",
+ "team_size_bucket": "3-5",
+ "role_at_signup": "owner",
+ },
+ },
+ )
+ assert response.status_code == 200, response.text
+ data = response.json()
+ assert data["onboarding_step_completed"] == 1
+ assert data["onboarding_dismissed"] is False
+
+ # Verify persisted writes
+ account_id = test_user["user_data"]["account_id"]
+ user_email = test_user["email"]
+
+ acct = (
+ await test_db.execute(select(Account).where(Account.id == account_id))
+ ).scalar_one()
+ assert acct.name == "Acme MSP"
+ assert acct.team_size_bucket == "3-5"
+
+ user = (
+ await test_db.execute(select(User).where(User.email == user_email))
+ ).scalar_one()
+ assert user.role_at_signup == "owner"
+ assert user.onboarding_step_completed == 1
+
+
+@pytest.mark.asyncio
+async def test_onboarding_step2_skip_advances_without_psa(
+ client, auth_headers, test_db, test_user
+):
+ """Step 2 + skip ignores data entirely and only advances the step counter
+ (no primary_psa write)."""
+ # Capture original account.primary_psa so we can assert it's untouched.
+ account_id = test_user["user_data"]["account_id"]
+ acct_before = (
+ await test_db.execute(select(Account).where(Account.id == account_id))
+ ).scalar_one()
+ psa_before = acct_before.primary_psa # likely None
+
+ # Advance step 1 first so step 2 is allowed.
+ r1 = await client.patch(
+ "/api/v1/users/me/onboarding-step",
+ headers=auth_headers,
+ json={"step": 1, "action": "skip"},
+ )
+ assert r1.status_code == 200, r1.text
+
+ # Skip step 2 — even if data is present it must be ignored.
+ r2 = await client.patch(
+ "/api/v1/users/me/onboarding-step",
+ headers=auth_headers,
+ json={
+ "step": 2,
+ "action": "skip",
+ "data": {"primary_psa": "connectwise"},
+ },
+ )
+ assert r2.status_code == 200, r2.text
+ assert r2.json()["onboarding_step_completed"] == 2
+
+ # Re-fetch account: primary_psa must NOT have been written.
+ test_db.expire_all()
+ acct_after = (
+ await test_db.execute(select(Account).where(Account.id == account_id))
+ ).scalar_one()
+ assert acct_after.primary_psa == psa_before
+
+
+@pytest.mark.asyncio
+async def test_onboarding_step_cannot_decrease(client, auth_headers):
+ """A step=2 PATCH followed by step=1 must return 400."""
+ # Advance to step 2.
+ r1 = await client.patch(
+ "/api/v1/users/me/onboarding-step",
+ headers=auth_headers,
+ json={"step": 1, "action": "skip"},
+ )
+ assert r1.status_code == 200, r1.text
+ r2 = await client.patch(
+ "/api/v1/users/me/onboarding-step",
+ headers=auth_headers,
+ json={"step": 2, "action": "skip"},
+ )
+ assert r2.status_code == 200, r2.text
+ assert r2.json()["onboarding_step_completed"] == 2
+
+ # Try to go back to step 1 — must fail.
+ r3 = await client.patch(
+ "/api/v1/users/me/onboarding-step",
+ headers=auth_headers,
+ json={"step": 1, "action": "skip"},
+ )
+ assert r3.status_code == 400, r3.text
+
+ # Idempotent re-PATCH of same step succeeds.
+ r4 = await client.patch(
+ "/api/v1/users/me/onboarding-step",
+ headers=auth_headers,
+ json={"step": 2, "action": "skip"},
+ )
+ assert r4.status_code == 200, r4.text
+ assert r4.json()["onboarding_step_completed"] == 2
+
+
+@pytest.mark.asyncio
+async def test_onboarding_dismiss_rest_sets_flag(
+ client, auth_headers, test_db, test_user
+):
+ """POST /users/me/onboarding-dismiss-rest sets users.onboarding_dismissed=TRUE."""
+ response = await client.post(
+ "/api/v1/users/me/onboarding-dismiss-rest",
+ headers=auth_headers,
+ )
+ assert response.status_code == 200, response.text
+ data = response.json()
+ assert data["onboarding_dismissed"] is True
+ # step counter is whatever it was (None for a fresh user).
+ assert "onboarding_step_completed" in data
+
+ # Verify persisted.
+ user_email = test_user["email"]
+ user = (
+ await test_db.execute(select(User).where(User.email == user_email))
+ ).scalar_one()
+ assert user.onboarding_dismissed is True
--
2.49.1
From 694279f89e22d4444a5fffc3d663d47e1b3335a7 Mon Sep 17 00:00:00 2001
From: Michael Chihlas
Date: Wed, 6 May 2026 20:12:03 -0400
Subject: [PATCH 03/29] feat(sales): add POST /sales-leads public endpoint
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Phase 2 Task 29 — public Talk-to-Sales submission endpoint.
- New POST /api/v1/sales-leads (public, no auth, rate-limited 5/hour per IP).
- Inserts a sales_leads row, fires best-effort notification email and
PostHog server-side capture; failures are logged but never fail the
request.
- New EmailService.send_sales_lead_notification static method.
- New SALES_LEAD_RECIPIENT_EMAIL setting (defaults to sales@resolutionflow.com).
- Schemas: SalesLeadCreate / SalesLeadCreateResponse with literal source enum.
- Tests: happy path (row + email), email-failure resilience, and rate-limit
enforcement (re-enables the slowapi limiter for the rate-limit assertion
since DEBUG=true disables it by default in tests).
PostHog server-side instrumentation point is wired in but no-ops gracefully
until app.core.analytics.posthog exists — turning it on is a one-line
change when the backend SDK is configured.
Co-Authored-By: Claude Opus 4.7
---
backend/app/api/endpoints/sales_leads.py | 114 +++++++++++++++++++
backend/app/api/router.py | 2 +
backend/app/core/config.py | 1 +
backend/app/core/email.py | 98 +++++++++++++++++
backend/app/schemas/sales_lead.py | 27 +++++
backend/tests/test_sales_leads.py | 134 +++++++++++++++++++++++
6 files changed, 376 insertions(+)
create mode 100644 backend/app/api/endpoints/sales_leads.py
create mode 100644 backend/app/schemas/sales_lead.py
create mode 100644 backend/tests/test_sales_leads.py
diff --git a/backend/app/api/endpoints/sales_leads.py b/backend/app/api/endpoints/sales_leads.py
new file mode 100644
index 00000000..5f786319
--- /dev/null
+++ b/backend/app/api/endpoints/sales_leads.py
@@ -0,0 +1,114 @@
+"""Public Talk-to-Sales endpoint — no auth required.
+
+POST /api/v1/sales-leads
+ - Inserts a sales_leads row.
+ - Fires (best-effort) a notification email to settings.SALES_LEAD_RECIPIENT_EMAIL.
+ - Emits a server-side PostHog event (best-effort).
+ - Rate-limited per IP (5/hour).
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from typing import Annotated
+
+from fastapi import APIRouter, Depends, Request
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.core.admin_database import get_admin_db
+from app.core.config import settings
+from app.core.email import EmailService
+from app.core.rate_limit import limiter
+from app.models.sales_lead import SalesLead
+from app.schemas.sales_lead import SalesLeadCreate, SalesLeadCreateResponse
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/sales-leads", tags=["sales"])
+
+
+async def _send_notification_email(lead: SalesLead) -> None:
+ """Fire-and-forget wrapper. EmailService methods never raise, but we
+ still wrap in a try/except to defend against future regressions."""
+ try:
+ await EmailService.send_sales_lead_notification(
+ to_email=settings.SALES_LEAD_RECIPIENT_EMAIL,
+ lead=lead,
+ )
+ except Exception:
+ logger.warning(
+ "Sales lead notification email failed for lead %s",
+ lead.id,
+ exc_info=True,
+ )
+
+
+def _capture_posthog_event(lead: SalesLead) -> None:
+ """Emit `talk_to_sales_form_submitted` server-side. Best-effort.
+
+ Backend PostHog SDK isn't initialized in the project today; this function
+ is the single instrumentation point so wiring it up later is a one-line
+ change. The call is wrapped so any future failure can never fail the
+ request.
+ """
+ try:
+ # Lazy import — keeps the dependency optional. When the backend
+ # PostHog client is wired in (likely as `app.core.analytics.posthog`),
+ # swap the import path here and the event will fire automatically.
+ try:
+ from app.core.analytics import posthog # type: ignore[attr-defined]
+ except ImportError:
+ logger.debug(
+ "PostHog server-side capture skipped — client not configured"
+ )
+ return
+
+ distinct_id = lead.posthog_distinct_id or f"sales_lead:{lead.id}"
+ posthog.capture(
+ distinct_id=distinct_id,
+ event="talk_to_sales_form_submitted",
+ properties={
+ "source": lead.source,
+ "company": lead.company,
+ "team_size": lead.team_size,
+ },
+ )
+ except Exception:
+ logger.warning(
+ "PostHog capture failed for sales lead %s",
+ lead.id,
+ exc_info=True,
+ )
+
+
+@router.post("", response_model=SalesLeadCreateResponse, status_code=201)
+@limiter.limit("5/hour")
+async def create_sales_lead(
+ request: Request,
+ data: SalesLeadCreate,
+ db: Annotated[AsyncSession, Depends(get_admin_db)],
+) -> SalesLeadCreateResponse:
+ """Public Talk-to-Sales submission.
+
+ Creates a sales_leads row, fires (best-effort) a notification email and a
+ server-side PostHog event. Rate-limited per IP at 5/hour.
+ """
+ lead = SalesLead(
+ email=str(data.email).lower(),
+ name=data.name,
+ company=data.company,
+ team_size=data.team_size,
+ message=data.message,
+ source=data.source,
+ posthog_distinct_id=data.posthog_distinct_id,
+ )
+ db.add(lead)
+ await db.commit()
+ await db.refresh(lead)
+
+ # Fire-and-forget: email + analytics. Failures must not fail the request.
+ asyncio.create_task(_send_notification_email(lead))
+ _capture_posthog_event(lead)
+
+ return SalesLeadCreateResponse(id=lead.id, status="received")
diff --git a/backend/app/api/router.py b/backend/app/api/router.py
index 01ce9a8a..5ef8122c 100644
--- a/backend/app/api/router.py
+++ b/backend/app/api/router.py
@@ -26,6 +26,7 @@ from app.api.endpoints import (
billing,
beta_feedback,
beta_signup,
+ sales_leads,
branding,
categories,
copilot,
@@ -88,6 +89,7 @@ api_router.include_router(billing.router) # Reachable when subscription lock
api_router.include_router(shared.router) # Public share links (no auth)
api_router.include_router(shares.public_router) # Public session share links (optional auth)
api_router.include_router(beta_signup.router)
+api_router.include_router(sales_leads.router) # Talk-to-Sales (no auth, rate-limited)
api_router.include_router(webhooks.router) # Stripe webhook receiver
api_router.include_router(public_templates.router) # Public gallery (no auth, rate-limited)
api_router.include_router(survey.router) # Public survey flow (no auth, rate-limited)
diff --git a/backend/app/core/config.py b/backend/app/core/config.py
index f2c28593..815c95db 100644
--- a/backend/app/core/config.py
+++ b/backend/app/core/config.py
@@ -84,6 +84,7 @@ class Settings(BaseSettings):
RESEND_API_KEY: Optional[str] = None
FROM_EMAIL: str = "ResolutionFlow "
FEEDBACK_EMAIL: Optional[str] = None
+ SALES_LEAD_RECIPIENT_EMAIL: str = "sales@resolutionflow.com"
@property
def email_enabled(self) -> bool:
diff --git a/backend/app/core/email.py b/backend/app/core/email.py
index 313d5db0..0bb62b94 100644
--- a/backend/app/core/email.py
+++ b/backend/app/core/email.py
@@ -1,6 +1,11 @@
import logging
+from typing import TYPE_CHECKING
+
from app.core.config import settings
+if TYPE_CHECKING:
+ from app.models.sales_lead import SalesLead
+
logger = logging.getLogger(__name__)
@@ -484,6 +489,99 @@ class EmailService:
logger.exception("Failed to send beta signup notification for %s", signup_email)
return False
+ @staticmethod
+ async def send_sales_lead_notification(
+ to_email: str,
+ lead: "SalesLead",
+ ) -> bool:
+ """Notify the sales recipient about a new Talk-to-Sales submission.
+
+ Fire-and-forget. Returns False (and logs) on any failure; never raises.
+ """
+ if not settings.email_enabled:
+ logger.warning(
+ "Sales lead email not sent — RESEND_API_KEY not configured (lead %s)",
+ lead.id,
+ )
+ return False
+
+ try:
+ import resend
+ import html as html_mod
+ from datetime import datetime, timezone
+
+ resend.api_key = settings.RESEND_API_KEY
+
+ date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
+ safe_email = html_mod.escape(lead.email)
+ safe_name = html_mod.escape(lead.name)
+ safe_company = html_mod.escape(lead.company)
+ safe_team_size = html_mod.escape(lead.team_size or "—")
+ safe_source = html_mod.escape(lead.source)
+ safe_message = html_mod.escape(lead.message or "(no message)")
+ subject = f"[ResolutionFlow Sales] New lead — {safe_company} ({safe_email})"
+
+ email_html = f"""
+
+
+
+
+
+
+
ResolutionFlow
+
New Sales Lead
+
+
+
+ Source: {safe_source}
+
+
+
+
+
+
Name
+
{safe_name}
+
Email
+
{safe_email}
+
Company
+
{safe_company}
+
Team Size
+
{safe_team_size}
+
+
+
+
+
Message
+
{safe_message}
+
+
+
+ Submitted at {date_str} · Lead ID: {lead.id}
+
+
+
+
+
+"""
+
+ resend.Emails.send({
+ "from": settings.FROM_EMAIL,
+ "to": [to_email],
+ "reply_to": lead.email,
+ "subject": subject,
+ "html": email_html,
+ })
+ logger.info("Sales lead notification sent for %s (lead %s)", lead.email, lead.id)
+ return True
+
+ except Exception:
+ logger.exception(
+ "Failed to send sales lead notification for %s (lead %s)",
+ lead.email,
+ lead.id,
+ )
+ return False
+
@staticmethod
async def send_notification_email(
to_email: str,
diff --git a/backend/app/schemas/sales_lead.py b/backend/app/schemas/sales_lead.py
new file mode 100644
index 00000000..9247e91e
--- /dev/null
+++ b/backend/app/schemas/sales_lead.py
@@ -0,0 +1,27 @@
+"""Pydantic schemas for Talk-to-Sales submissions."""
+
+from typing import Literal, Optional
+from uuid import UUID
+
+from pydantic import BaseModel, ConfigDict, EmailStr, Field
+
+SalesLeadSource = Literal["pricing_page", "register_footer", "landing_page"]
+
+
+class SalesLeadCreate(BaseModel):
+ """Public Talk-to-Sales form submission."""
+
+ model_config = ConfigDict(str_strip_whitespace=True)
+
+ email: EmailStr
+ name: str = Field(..., min_length=1, max_length=255)
+ company: str = Field(..., min_length=1, max_length=255)
+ team_size: Optional[str] = Field(default=None, max_length=20)
+ message: Optional[str] = Field(default=None, max_length=5000)
+ source: SalesLeadSource
+ posthog_distinct_id: Optional[str] = Field(default=None, max_length=255)
+
+
+class SalesLeadCreateResponse(BaseModel):
+ id: UUID
+ status: Literal["received"] = "received"
diff --git a/backend/tests/test_sales_leads.py b/backend/tests/test_sales_leads.py
new file mode 100644
index 00000000..c3620ab8
--- /dev/null
+++ b/backend/tests/test_sales_leads.py
@@ -0,0 +1,134 @@
+"""Integration tests for the public Talk-to-Sales endpoint.
+
+POST /api/v1/sales-leads — no auth, rate-limited 5/hour per IP.
+"""
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+import sqlalchemy as sa
+
+
+@pytest.mark.asyncio
+async def test_sales_lead_creates_row_and_sends_notification_email(client, test_db):
+ """Happy path: row inserted, notification email fired, 201 returned."""
+
+ payload = {
+ "email": "buyer@acme.example",
+ "name": "Pat Buyer",
+ "company": "Acme MSP",
+ "team_size": "11-50",
+ "message": "We're evaluating ResolutionFlow for our NOC team.",
+ "source": "pricing_page",
+ "posthog_distinct_id": "ph_distinct_123",
+ }
+
+ with patch(
+ "app.api.endpoints.sales_leads.EmailService.send_sales_lead_notification",
+ new=AsyncMock(return_value=True),
+ ) as mock_email:
+ response = await client.post("/api/v1/sales-leads", json=payload)
+
+ assert response.status_code == 201, response.text
+ body = response.json()
+ assert body["status"] == "received"
+ assert "id" in body
+
+ # Notification email was attempted (asyncio.create_task — give it a tick).
+ import asyncio
+ await asyncio.sleep(0)
+ await asyncio.sleep(0)
+ assert mock_email.await_count == 1
+ kwargs = mock_email.await_args.kwargs
+ assert kwargs["to_email"] # default placeholder until cutover
+ assert kwargs["lead"].email == "buyer@acme.example"
+ assert kwargs["lead"].source == "pricing_page"
+
+ # Row was inserted with normalized email + all fields preserved.
+ result = await test_db.execute(
+ sa.text("SELECT email, name, company, team_size, message, source, posthog_distinct_id, status FROM sales_leads")
+ )
+ rows = result.all()
+ assert len(rows) == 1
+ row = rows[0]
+ assert row.email == "buyer@acme.example"
+ assert row.name == "Pat Buyer"
+ assert row.company == "Acme MSP"
+ assert row.team_size == "11-50"
+ assert row.message == "We're evaluating ResolutionFlow for our NOC team."
+ assert row.source == "pricing_page"
+ assert row.posthog_distinct_id == "ph_distinct_123"
+ assert row.status == "new"
+
+
+@pytest.mark.asyncio
+async def test_sales_lead_email_failure_does_not_fail_request(client, test_db):
+ """If the email send raises, the API still returns 201 and the row persists."""
+
+ payload = {
+ "email": "buyer2@acme.example",
+ "name": "Sam Lead",
+ "company": "Acme MSP",
+ "source": "register_footer",
+ }
+
+ with patch(
+ "app.api.endpoints.sales_leads.EmailService.send_sales_lead_notification",
+ new=AsyncMock(side_effect=RuntimeError("resend exploded")),
+ ):
+ response = await client.post("/api/v1/sales-leads", json=payload)
+
+ assert response.status_code == 201, response.text
+
+ # Row must still be persisted even though email failed.
+ import asyncio
+ await asyncio.sleep(0)
+ result = await test_db.execute(
+ sa.text("SELECT count(*) FROM sales_leads WHERE email = 'buyer2@acme.example'")
+ )
+ assert result.scalar() == 1
+
+
+@pytest.mark.asyncio
+async def test_sales_lead_rate_limited_after_5_per_hour(client):
+ """The 6th submission within an hour from the same IP returns 429.
+
+ The default `limiter` is disabled in tests (DEBUG=true). We re-enable it
+ for this test, then reset its state on teardown so other tests aren't
+ affected.
+ """
+ from app.core.rate_limit import limiter
+
+ was_enabled = limiter.enabled
+ limiter.enabled = True
+ try:
+ limiter.reset()
+
+ with patch(
+ "app.api.endpoints.sales_leads.EmailService.send_sales_lead_notification",
+ new=AsyncMock(return_value=True),
+ ):
+ for i in range(5):
+ payload = {
+ "email": f"lead{i}@acme.example",
+ "name": f"Lead {i}",
+ "company": "Acme MSP",
+ "source": "landing_page",
+ }
+ resp = await client.post("/api/v1/sales-leads", json=payload)
+ assert resp.status_code == 201, f"submission {i}: {resp.text}"
+
+ # 6th should be rate-limited.
+ resp = await client.post(
+ "/api/v1/sales-leads",
+ json={
+ "email": "lead6@acme.example",
+ "name": "Lead 6",
+ "company": "Acme MSP",
+ "source": "landing_page",
+ },
+ )
+ assert resp.status_code == 429, resp.text
+ finally:
+ limiter.reset()
+ limiter.enabled = was_enabled
--
2.49.1
From d05b475a411439a0183f6c481f0b3891566f8775 Mon Sep 17 00:00:00 2001
From: Michael Chihlas
Date: Wed, 6 May 2026 20:32:09 -0400
Subject: [PATCH 04/29] feat(admin): extend /admin/plan-limits to manage
plan_billing fields
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Task 30 of self-serve signup Phase 2. Super-admins can now manage Stripe
IDs, display names, prices, and public/archived flags via the existing
admin plan-limits endpoints.
- GET /admin/plan-limits now outer-joins plan_billing and returns
merged PlanLimitWithBillingResponse rows. Plans without a
plan_billing row return None for the billing fields.
- PUT /admin/plan-limits accepts the new optional billing fields and
upserts plan_billing in the same transaction. If no plan_billing
row exists for the plan and the body includes any billing field, a
row is created (display_name defaults to plan.capitalize() when
omitted; display_name is never NULLed out on an existing row).
- After commit, the handler queries account_ids on the affected plan
and calls BillingService.invalidate_billing_cache(account_ids).
This is a no-op stub today (logs only) — there's no in-process
billing cache yet. TODO comment marks the wire-up point.
- 3 new integration tests cover GET-with-billing-present, PUT creating
a plan_billing row, and the invalidation hook being awaited with a
list of account_ids.
Co-Authored-By: Claude Opus 4.7
---
.../app/api/endpoints/admin_plan_limits.py | 125 ++++++++++-
backend/app/schemas/admin.py | 28 +++
backend/app/services/billing.py | 25 +++
backend/tests/test_admin_plan_limits.py | 206 ++++++++++++++++++
4 files changed, 375 insertions(+), 9 deletions(-)
diff --git a/backend/app/api/endpoints/admin_plan_limits.py b/backend/app/api/endpoints/admin_plan_limits.py
index 387081f5..52ea09b4 100644
--- a/backend/app/api/endpoints/admin_plan_limits.py
+++ b/backend/app/api/endpoints/admin_plan_limits.py
@@ -8,34 +8,101 @@ from app.core.database import get_db
from app.core.audit import log_audit
from app.models.user import User
from app.models.plan_limits import PlanLimits
+from app.models.plan_billing import PlanBilling
from app.models.account import Account
from app.models.account_limit_override import AccountLimitOverride
+from app.models.subscription import Subscription
from app.schemas.admin import (
- PlanLimitResponse, PlanLimitUpdate,
+ PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse,
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
)
from app.api.deps import require_admin
+from app.services.billing import BillingService
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
-@router.get("/plan-limits", response_model=list[PlanLimitResponse])
+# Fields on PlanLimitUpdate that map to plan_billing (not plan_limits).
+_PLAN_BILLING_FIELDS = (
+ "display_name",
+ "description",
+ "monthly_price_cents",
+ "annual_price_cents",
+ "stripe_product_id",
+ "stripe_monthly_price_id",
+ "stripe_annual_price_id",
+ "is_public",
+ "is_archived",
+ "sort_order",
+)
+
+# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model.
+# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit
+# null for any of them would otherwise trigger a NOT NULL violation at commit.
+_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({
+ "display_name",
+ "is_public",
+ "is_archived",
+ "sort_order",
+})
+
+
+def _merge_plan_with_billing(
+ plan: PlanLimits, billing: PlanBilling | None
+) -> PlanLimitWithBillingResponse:
+ """Build a merged response. Billing fields are None when no plan_billing row
+ exists for the plan."""
+ payload = {
+ "plan": plan.plan,
+ "max_trees": plan.max_trees,
+ "max_sessions_per_month": plan.max_sessions_per_month,
+ "max_users": plan.max_users,
+ "custom_branding": plan.custom_branding,
+ "priority_support": plan.priority_support,
+ "export_formats": plan.export_formats or [],
+ }
+ if billing is not None:
+ payload.update({
+ "display_name": billing.display_name,
+ "description": billing.description,
+ "monthly_price_cents": billing.monthly_price_cents,
+ "annual_price_cents": billing.annual_price_cents,
+ "stripe_product_id": billing.stripe_product_id,
+ "stripe_monthly_price_id": billing.stripe_monthly_price_id,
+ "stripe_annual_price_id": billing.stripe_annual_price_id,
+ "is_public": billing.is_public,
+ "is_archived": billing.is_archived,
+ "sort_order": billing.sort_order,
+ })
+ return PlanLimitWithBillingResponse(**payload)
+
+
+@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse])
async def list_plan_limits(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
- """List all plan limit configurations."""
- result = await db.execute(select(PlanLimits))
- return result.scalars().all()
+ """List all plan limit configurations, merged with plan_billing fields
+ where present. Plans without a plan_billing row return None for the
+ billing fields."""
+ rows = (await db.execute(
+ select(PlanLimits, PlanBilling)
+ .outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan)
+ )).all()
+ return [_merge_plan_with_billing(pl, pb) for pl, pb in rows]
-@router.put("/plan-limits", response_model=PlanLimitResponse)
+@router.put("/plan-limits", response_model=PlanLimitWithBillingResponse)
async def update_plan_limits(
data: PlanLimitUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
- """Update a plan's limits."""
+ """Update a plan's limits and (if any plan_billing field is included)
+ upsert the matching plan_billing row in the same transaction. After
+ commit, invalidates the in-process billing cache for accounts on this
+ plan (currently a no-op — see BillingService.invalidate_billing_cache).
+ """
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
plan = result.scalar_one_or_none()
if not plan:
@@ -48,10 +115,50 @@ async def update_plan_limits(
plan.priority_support = data.priority_support
plan.export_formats = data.export_formats
- await log_audit(db, current_user.id, "plan_limits.update", "plan_limits", details={"plan": data.plan})
+ # Did the request include any plan_billing field? (Pydantic gives us
+ # `model_fields_set` to distinguish "user passed null" from "field omitted".)
+ billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS)
+ billing: PlanBilling | None = None
+ if billing_fields_set:
+ billing = (await db.execute(
+ select(PlanBilling).where(PlanBilling.plan == data.plan)
+ )).scalar_one_or_none()
+
+ if billing is None:
+ # Create. display_name is required on the model — derive from the
+ # plan name when the caller didn't supply one (e.g. "pro" → "Pro").
+ display_name = data.display_name or data.plan.capitalize()
+ billing = PlanBilling(plan=data.plan, display_name=display_name)
+ db.add(billing)
+
+ # Apply only the fields the caller actually included. Allows partial
+ # updates without clobbering existing values.
+ for field in billing_fields_set:
+ value = getattr(data, field)
+ if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS:
+ # Don't NULL out a NOT NULL column on update.
+ continue
+ setattr(billing, field, value)
+
+ await log_audit(
+ db, current_user.id, "plan_limits.update", "plan_limits",
+ details={"plan": data.plan, "updated_billing": bool(billing_fields_set)},
+ )
await db.commit()
await db.refresh(plan)
- return plan
+ if billing is not None:
+ await db.refresh(billing)
+
+ # Invalidate any in-process billing cache for accounts on this plan.
+ # TODO: invalidate app.state.billing_cache when added.
+ account_ids = [
+ row[0] for row in (await db.execute(
+ select(Subscription.account_id).where(Subscription.plan == data.plan)
+ )).all()
+ ]
+ await BillingService.invalidate_billing_cache(account_ids)
+
+ return _merge_plan_with_billing(plan, billing)
@router.get("/account-overrides", response_model=list[AccountOverrideResponse])
diff --git a/backend/app/schemas/admin.py b/backend/app/schemas/admin.py
index 72c63d43..a223d994 100644
--- a/backend/app/schemas/admin.py
+++ b/backend/app/schemas/admin.py
@@ -172,6 +172,21 @@ class PlanLimitResponse(BaseModel):
from_attributes = True
+class PlanLimitWithBillingResponse(PlanLimitResponse):
+ """PlanLimits + plan_billing fields merged. Billing fields are None when no
+ plan_billing row exists for the plan yet."""
+ display_name: Optional[str] = None
+ description: Optional[str] = None
+ monthly_price_cents: Optional[int] = None
+ annual_price_cents: Optional[int] = None
+ stripe_product_id: Optional[str] = None
+ stripe_monthly_price_id: Optional[str] = None
+ stripe_annual_price_id: Optional[str] = None
+ is_public: Optional[bool] = None
+ is_archived: Optional[bool] = None
+ sort_order: Optional[int] = None
+
+
class PlanLimitUpdate(BaseModel):
plan: str
max_trees: Optional[int] = None
@@ -180,6 +195,19 @@ class PlanLimitUpdate(BaseModel):
custom_branding: bool = False
priority_support: bool = False
export_formats: list = Field(default_factory=lambda: ["markdown", "text"])
+ # plan_billing fields — all optional, partial-update semantics. If any are
+ # set in the body, the admin endpoint upserts the plan_billing row in the
+ # same transaction.
+ display_name: Optional[str] = None
+ description: Optional[str] = None
+ monthly_price_cents: Optional[int] = None
+ annual_price_cents: Optional[int] = None
+ stripe_product_id: Optional[str] = None
+ stripe_monthly_price_id: Optional[str] = None
+ stripe_annual_price_id: Optional[str] = None
+ is_public: Optional[bool] = None
+ is_archived: Optional[bool] = None
+ sort_order: Optional[int] = None
class AccountOverrideCreate(BaseModel):
diff --git a/backend/app/services/billing.py b/backend/app/services/billing.py
index b662ed47..1ae0a999 100644
--- a/backend/app/services/billing.py
+++ b/backend/app/services/billing.py
@@ -1,6 +1,7 @@
"""Single billing service module. Stripe is the only impl — no provider
abstraction. Account row is canonical local state; Stripe is canonical
remote state; the webhook handler bridges the two."""
+import logging
from datetime import datetime, timezone, timedelta
import stripe
@@ -17,8 +18,32 @@ from app.models.subscription import Subscription
TRIAL_DAYS = 14
+logger = logging.getLogger(__name__)
+
class BillingService:
+ @staticmethod
+ async def invalidate_billing_cache(account_ids) -> None:
+ """No-op stub for future in-process billing cache invalidation.
+
+ Today there is no `app.state.billing_cache` — `BillingService.get_billing_state`
+ always reads fresh from the DB. Call sites that mutate plan/feature data
+ invoke this hook so that wiring is in place when an in-process cache is
+ added later. Until then, this just logs.
+
+ TODO: when an in-process billing cache (e.g. `app.state.billing_cache`)
+ is introduced, evict entries for the given account_ids here.
+ """
+ try:
+ count = len(list(account_ids))
+ except TypeError:
+ count = -1
+ logger.debug(
+ "BillingService.invalidate_billing_cache called for %d account(s) "
+ "(no-op stub — wire to app.state.billing_cache when added)",
+ count,
+ )
+
@staticmethod
async def start_trial(db: AsyncSession, account_id) -> Subscription:
"""Idempotent. Creates a trialing Subscription on Pro for the account if
diff --git a/backend/tests/test_admin_plan_limits.py b/backend/tests/test_admin_plan_limits.py
index 7e701b16..8eb22d45 100644
--- a/backend/tests/test_admin_plan_limits.py
+++ b/backend/tests/test_admin_plan_limits.py
@@ -1,7 +1,12 @@
"""Integration tests for admin plan limits and account override endpoints."""
+from unittest.mock import AsyncMock, patch
+
import pytest
from httpx import AsyncClient
+from sqlalchemy import select
+
+from app.models.plan_billing import PlanBilling
class TestAdminPlanLimits:
@@ -56,3 +61,204 @@ class TestAdminPlanLimits:
"""Non-admin gets 403."""
response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers)
assert response.status_code == 403
+
+ @pytest.mark.asyncio
+ async def test_admin_plan_limits_get_includes_plan_billing_fields_when_present(
+ self, client: AsyncClient, admin_auth_headers: dict, test_db
+ ):
+ """GET /admin/plan-limits returns plan_billing fields when a row exists,
+ and None for plans that don't have one yet."""
+ # Seed a plan_billing row for "pro".
+ existing = (await test_db.execute(
+ select(PlanBilling).where(PlanBilling.plan == "pro")
+ )).scalar_one_or_none()
+ if existing is None:
+ test_db.add(PlanBilling(
+ plan="pro",
+ display_name="Pro",
+ description="For working teams",
+ monthly_price_cents=4900,
+ annual_price_cents=49000,
+ stripe_product_id="prod_seed",
+ stripe_monthly_price_id="price_seed_m",
+ stripe_annual_price_id="price_seed_a",
+ is_public=True,
+ is_archived=False,
+ sort_order=10,
+ ))
+ await test_db.commit()
+
+ response = await client.get(
+ "/api/v1/admin/plan-limits", headers=admin_auth_headers
+ )
+ assert response.status_code == 200
+ plans_by_name = {p["plan"]: p for p in response.json()}
+
+ assert "pro" in plans_by_name
+ pro = plans_by_name["pro"]
+ assert pro["display_name"] == "Pro"
+ assert pro["monthly_price_cents"] == 4900
+ assert pro["stripe_monthly_price_id"] == "price_seed_m"
+ assert pro["is_public"] is True
+ assert pro["is_archived"] is False
+ assert pro["sort_order"] == 10
+
+ # A plan without a plan_billing row should still return, with None
+ # billing fields.
+ if "free" in plans_by_name:
+ free = plans_by_name["free"]
+ # free has no plan_billing row in the seed → fields are None.
+ no_billing_row = (await test_db.execute(
+ select(PlanBilling).where(PlanBilling.plan == "free")
+ )).scalar_one_or_none() is None
+ if no_billing_row:
+ assert free["display_name"] is None
+ assert free["monthly_price_cents"] is None
+ assert free["stripe_product_id"] is None
+
+ @pytest.mark.asyncio
+ async def test_admin_plan_limits_put_creates_plan_billing_row(
+ self, client: AsyncClient, admin_auth_headers: dict, test_db
+ ):
+ """PUT /admin/plan-limits upserts a plan_billing row when billing
+ fields are included in the body."""
+ # Ensure no plan_billing row exists for "team" yet.
+ existing = (await test_db.execute(
+ select(PlanBilling).where(PlanBilling.plan == "team")
+ )).scalar_one_or_none()
+ if existing is not None:
+ await test_db.delete(existing)
+ await test_db.commit()
+
+ response = await client.put(
+ "/api/v1/admin/plan-limits",
+ json={
+ "plan": "team",
+ "max_trees": None,
+ "max_sessions_per_month": None,
+ "max_users": None,
+ "custom_branding": True,
+ "priority_support": True,
+ "export_formats": ["markdown", "text", "pdf"],
+ "display_name": "Team",
+ "description": "For growing shops",
+ "monthly_price_cents": 9900,
+ "annual_price_cents": 99000,
+ "stripe_product_id": "prod_team_test",
+ "stripe_monthly_price_id": "price_team_m",
+ "stripe_annual_price_id": "price_team_a",
+ "is_public": True,
+ "is_archived": False,
+ "sort_order": 20,
+ },
+ headers=admin_auth_headers,
+ )
+ assert response.status_code == 200, response.text
+ body = response.json()
+ assert body["display_name"] == "Team"
+ assert body["monthly_price_cents"] == 9900
+ assert body["stripe_product_id"] == "prod_team_test"
+ assert body["sort_order"] == 20
+
+ # Confirm the row was actually persisted.
+ await test_db.commit() # ensure session sees other-session writes
+ pb = (await test_db.execute(
+ select(PlanBilling).where(PlanBilling.plan == "team")
+ )).scalar_one_or_none()
+ assert pb is not None
+ assert pb.display_name == "Team"
+ assert pb.monthly_price_cents == 9900
+ assert pb.stripe_monthly_price_id == "price_team_m"
+ assert pb.is_public is True
+
+ @pytest.mark.asyncio
+ async def test_admin_plan_limits_put_does_not_null_out_required_fields(
+ self, client: AsyncClient, admin_auth_headers: dict, test_db
+ ):
+ """PUT /admin/plan-limits must not NULL out NOT NULL columns on the
+ plan_billing row when the caller passes explicit nulls. The set of
+ guarded fields is {display_name, is_public, is_archived, sort_order}.
+ """
+ # Seed a plan_billing row for "team" with non-default values for every
+ # NOT NULL field so we can detect any clobbering.
+ existing = (await test_db.execute(
+ select(PlanBilling).where(PlanBilling.plan == "team")
+ )).scalar_one_or_none()
+ if existing is not None:
+ await test_db.delete(existing)
+ await test_db.commit()
+
+ seeded = PlanBilling(
+ plan="team",
+ display_name="Team Seeded",
+ is_public=False,
+ is_archived=True,
+ sort_order=5,
+ )
+ test_db.add(seeded)
+ await test_db.commit()
+
+ response = await client.put(
+ "/api/v1/admin/plan-limits",
+ json={
+ "plan": "team",
+ "max_trees": None,
+ "max_sessions_per_month": None,
+ "max_users": None,
+ "custom_branding": True,
+ "priority_support": True,
+ "export_formats": ["markdown", "text"],
+ # Explicit nulls for every NOT NULL plan_billing field.
+ "display_name": None,
+ "is_public": None,
+ "is_archived": None,
+ "sort_order": None,
+ },
+ headers=admin_auth_headers,
+ )
+ assert response.status_code == 200, response.text
+
+ # Confirm the seeded NOT NULL values were preserved.
+ await test_db.commit() # ensure session sees writes from the request
+ pb = (await test_db.execute(
+ select(PlanBilling).where(PlanBilling.plan == "team")
+ )).scalar_one_or_none()
+ assert pb is not None
+ assert pb.display_name == "Team Seeded"
+ assert pb.is_public is False
+ assert pb.is_archived is True
+ assert pb.sort_order == 5
+
+ @pytest.mark.asyncio
+ async def test_admin_plan_limits_put_invalidates_billing_cache(
+ self, client: AsyncClient, admin_auth_headers: dict
+ ):
+ """PUT /admin/plan-limits calls BillingService.invalidate_billing_cache
+ with the account_ids on the affected plan."""
+ # Patch the staticmethod on the class. The endpoint imports
+ # BillingService at module load, so patch the symbol on the class
+ # itself — both the import and the dotted reference resolve to it.
+ with patch(
+ "app.api.endpoints.admin_plan_limits.BillingService.invalidate_billing_cache",
+ new_callable=AsyncMock,
+ ) as spy:
+ response = await client.put(
+ "/api/v1/admin/plan-limits",
+ json={
+ "plan": "pro",
+ "max_trees": 25,
+ "max_sessions_per_month": 500,
+ "max_users": 10,
+ "custom_branding": True,
+ "priority_support": True,
+ "export_formats": ["markdown", "text"],
+ },
+ headers=admin_auth_headers,
+ )
+ assert response.status_code == 200, response.text
+ spy.assert_awaited_once()
+ (account_ids_arg,) = spy.await_args.args
+ # admin fixture seeds an active Pro Subscription, so we expect at
+ # least one account_id in the invalidation list.
+ assert isinstance(account_ids_arg, list)
+ assert len(account_ids_arg) >= 1
--
2.49.1
From 80baf89b003eef4a34f8fc99a611fda9a787431e Mon Sep 17 00:00:00 2001
From: Michael Chihlas
Date: Wed, 6 May 2026 20:38:50 -0400
Subject: [PATCH 05/29] feat(config): add SELF_SERVE_ENABLED flag + GET
/config/public
Phase 2 Task 31. Single flag now controls whether the public-facing
self-serve flow is exposed.
- New public endpoint GET /api/v1/config/public returns
{self_serve_enabled, oauth_providers}. oauth_providers includes
"google" if GOOGLE_CLIENT_ID is set and "microsoft" if MS_CLIENT_ID
is set. No auth required; consumed once by the frontend at load.
- POST /auth/register: when SELF_SERVE_ENABLED=true the platform
invite-code requirement is bypassed even with REQUIRE_INVITE_CODE=true.
invite_code stays in the schema for backward compat and still applies
when supplied. With the flag off, the gate behaves exactly as before.
- Adds backend/app/schemas/config.py with PublicConfigResponse and
registers the new router in the public/unauthenticated section.
- Adds 3 integration tests in tests/test_config_public.py covering the
flag round-trip, the regression case (flag off keeps the 400), and
the new behavior (flag on bypasses the gate, creates user + Pro trial).
Co-Authored-By: Claude Opus 4.7
---
backend/app/api/endpoints/auth.py | 10 ++-
backend/app/api/endpoints/config.py | 40 +++++++++++
backend/app/api/router.py | 2 +
backend/app/schemas/config.py | 18 +++++
backend/tests/test_config_public.py | 100 ++++++++++++++++++++++++++++
5 files changed, 169 insertions(+), 1 deletion(-)
create mode 100644 backend/app/api/endpoints/config.py
create mode 100644 backend/app/schemas/config.py
create mode 100644 backend/tests/test_config_public.py
diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py
index 44507328..7f5e017f 100644
--- a/backend/app/api/endpoints/auth.py
+++ b/backend/app/api/endpoints/auth.py
@@ -136,7 +136,15 @@ async def register(
# Validate platform invite code (skip if account invite was provided)
invite_code_record = None
if not account_invite_record:
- if settings.REQUIRE_INVITE_CODE and not user_data.invite_code:
+ # When SELF_SERVE_ENABLED is on, the platform invite gate is bypassed
+ # entirely — public self-serve signup is the whole point. The
+ # invite_code field stays in the schema for backward compatibility
+ # and so paid/trial-bearing codes still apply when supplied.
+ if (
+ settings.REQUIRE_INVITE_CODE
+ and not settings.SELF_SERVE_ENABLED
+ and not user_data.invite_code
+ ):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code is required"
diff --git a/backend/app/api/endpoints/config.py b/backend/app/api/endpoints/config.py
new file mode 100644
index 00000000..a621e738
--- /dev/null
+++ b/backend/app/api/endpoints/config.py
@@ -0,0 +1,40 @@
+"""Public runtime configuration endpoint.
+
+GET /api/v1/config/public
+ Returns the small set of runtime flags the frontend needs at app load
+ to decide whether to render the self-serve signup flow and which OAuth
+ buttons to show. No authentication required.
+
+The response model lives in `app.schemas.config` so it can be reused by
+frontend codegen and other call sites if needed.
+"""
+
+from __future__ import annotations
+
+from fastapi import APIRouter
+
+from app.core.config import settings
+from app.schemas.config import PublicConfigResponse
+
+router = APIRouter(prefix="/config", tags=["config"])
+
+
+@router.get("/public", response_model=PublicConfigResponse)
+async def get_public_config() -> PublicConfigResponse:
+ """Return public-safe runtime config.
+
+ `oauth_providers` reflects which OAuth client IDs are configured server
+ side; the frontend uses it to render only buttons that will actually
+ succeed. `self_serve_enabled` is the master switch for the new public
+ self-serve signup flow.
+ """
+ providers: list[str] = []
+ if settings.GOOGLE_CLIENT_ID:
+ providers.append("google")
+ if settings.MS_CLIENT_ID:
+ providers.append("microsoft")
+
+ return PublicConfigResponse(
+ self_serve_enabled=settings.SELF_SERVE_ENABLED,
+ oauth_providers=providers,
+ )
diff --git a/backend/app/api/router.py b/backend/app/api/router.py
index 5ef8122c..155fa304 100644
--- a/backend/app/api/router.py
+++ b/backend/app/api/router.py
@@ -29,6 +29,7 @@ from app.api.endpoints import (
sales_leads,
branding,
categories,
+ config as config_endpoints,
copilot,
device_types,
draft_templates,
@@ -93,6 +94,7 @@ api_router.include_router(sales_leads.router) # Talk-to-Sales (no auth, rate-li
api_router.include_router(webhooks.router) # Stripe webhook receiver
api_router.include_router(public_templates.router) # Public gallery (no auth, rate-limited)
api_router.include_router(survey.router) # Public survey flow (no auth, rate-limited)
+api_router.include_router(config_endpoints.router) # Public runtime feature flags
# ---------------------------------------------------------------------------
# Admin endpoints — super_admin only
diff --git a/backend/app/schemas/config.py b/backend/app/schemas/config.py
new file mode 100644
index 00000000..c9937d8a
--- /dev/null
+++ b/backend/app/schemas/config.py
@@ -0,0 +1,18 @@
+"""Pydantic schemas for public runtime configuration."""
+
+from __future__ import annotations
+
+from typing import List
+
+from pydantic import BaseModel
+
+
+class PublicConfigResponse(BaseModel):
+ """Runtime feature flags + OAuth provider list exposed to anonymous clients.
+
+ Read once by the frontend at app load to decide whether to render the
+ self-serve signup flow and which OAuth buttons to show.
+ """
+
+ self_serve_enabled: bool
+ oauth_providers: List[str]
diff --git a/backend/tests/test_config_public.py b/backend/tests/test_config_public.py
new file mode 100644
index 00000000..c68738a3
--- /dev/null
+++ b/backend/tests/test_config_public.py
@@ -0,0 +1,100 @@
+"""Integration tests for the public runtime config endpoint.
+
+Covers GET /api/v1/config/public and the SELF_SERVE_ENABLED interaction
+with the existing /auth/register invite-code gate.
+"""
+
+from __future__ import annotations
+
+import pytest
+from httpx import AsyncClient
+
+from app.core.config import settings
+
+
+class TestConfigPublic:
+ """GET /api/v1/config/public — anonymous, no auth."""
+
+ @pytest.mark.asyncio
+ async def test_get_config_public_returns_self_serve_flag(
+ self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
+ ):
+ """Endpoint reflects the current SELF_SERVE_ENABLED setting and the
+ configured OAuth providers, with no auth required."""
+ # Default-off: SELF_SERVE_ENABLED is False unless explicitly set.
+ monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
+ monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
+ monkeypatch.setattr(settings, "MS_CLIENT_ID", None)
+
+ response = await client.get("/api/v1/config/public")
+ assert response.status_code == 200
+ body = response.json()
+ assert body == {"self_serve_enabled": False, "oauth_providers": []}
+
+ # Flip it on, with both OAuth providers configured.
+ monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", True)
+ monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "google-test-id")
+ monkeypatch.setattr(settings, "MS_CLIENT_ID", "ms-test-id")
+
+ response = await client.get("/api/v1/config/public")
+ assert response.status_code == 200
+ body = response.json()
+ assert body["self_serve_enabled"] is True
+ assert body["oauth_providers"] == ["google", "microsoft"]
+
+ # Only Microsoft configured.
+ monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
+ monkeypatch.setattr(settings, "MS_CLIENT_ID", "ms-test-id")
+ response = await client.get("/api/v1/config/public")
+ assert response.status_code == 200
+ assert response.json()["oauth_providers"] == ["microsoft"]
+
+
+class TestRegisterInviteCodeGate:
+ """Regression + new-behavior tests for /auth/register vs SELF_SERVE_ENABLED."""
+
+ @pytest.mark.asyncio
+ async def test_register_invite_code_required_when_self_serve_disabled(
+ self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
+ ):
+ """Pre-self-serve behavior: REQUIRE_INVITE_CODE=True without an
+ invite code (and no account-invite) must still 400."""
+ monkeypatch.setattr(settings, "REQUIRE_INVITE_CODE", True)
+ monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
+
+ response = await client.post(
+ "/api/v1/auth/register",
+ json={
+ "email": "no-invite@example.com",
+ "password": "SecurePass123!",
+ "name": "No Invite",
+ },
+ )
+
+ assert response.status_code == 400
+ assert "invite code is required" in response.json()["detail"].lower()
+
+ @pytest.mark.asyncio
+ async def test_register_invite_code_optional_when_self_serve_enabled(
+ self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
+ ):
+ """Self-serve on: registration succeeds with no invite code even
+ when REQUIRE_INVITE_CODE is True. The user, personal account, and
+ a Pro trial subscription are all created."""
+ monkeypatch.setattr(settings, "REQUIRE_INVITE_CODE", True)
+ monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", True)
+
+ response = await client.post(
+ "/api/v1/auth/register",
+ json={
+ "email": "self-serve@example.com",
+ "password": "SecurePass123!",
+ "name": "Self Serve",
+ },
+ )
+
+ assert response.status_code == 201, response.text
+ body = response.json()
+ assert body["email"] == "self-serve@example.com"
+ assert body["account_role"] == "owner"
+ assert "account_id" in body
--
2.49.1
From 7a9cb4b03b4dd11c0cdd55d28a2460e9853a8935 Mon Sep 17 00:00:00 2001
From: Michael Chihlas
Date: Wed, 6 May 2026 20:44:20 -0400
Subject: [PATCH 06/29] feat(billing): add useBillingStore and /billing/state
integration
T32: Single frontend source of truth for subscription / plan / feature
state. New Zustand `useBillingStore` fetches `/billing/state` (auto-fetch
on login via authStore, reset on logout), exposes `refetch` for
post-Checkout refresh, and is supported by a `useBillingPoll` hook
that re-fetches every 60s while authenticated. The new `billingApi`
client transforms the snake_case backend payload to camelCase at a
single boundary so the rest of the frontend never sees `plan_billing`
or `enabled_features`.
Co-Authored-By: Claude Opus 4.7
---
frontend/src/api/billing.ts | 27 +++++
frontend/src/api/index.ts | 1 +
frontend/src/components/layout/AppLayout.tsx | 4 +
frontend/src/hooks/useBillingPoll.ts | 32 +++++
frontend/src/store/authStore.ts | 7 ++
frontend/src/store/billingStore.test.ts | 118 +++++++++++++++++++
frontend/src/store/billingStore.ts | 82 +++++++++++++
frontend/src/types/billing.ts | 51 ++++++++
frontend/src/types/index.ts | 8 ++
9 files changed, 330 insertions(+)
create mode 100644 frontend/src/api/billing.ts
create mode 100644 frontend/src/hooks/useBillingPoll.ts
create mode 100644 frontend/src/store/billingStore.test.ts
create mode 100644 frontend/src/store/billingStore.ts
create mode 100644 frontend/src/types/billing.ts
diff --git a/frontend/src/api/billing.ts b/frontend/src/api/billing.ts
new file mode 100644
index 00000000..2ba56173
--- /dev/null
+++ b/frontend/src/api/billing.ts
@@ -0,0 +1,27 @@
+import apiClient from './client'
+import type { BillingStateApiResponse, BillingStatePayload } from '@/types'
+
+/**
+ * Single boundary where the snake_case backend payload is transformed
+ * into the camelCase shape used by the rest of the frontend.
+ *
+ * Keeping the transform here means the store, hooks, and components
+ * never see snake_case keys.
+ */
+function transformBillingState(raw: BillingStateApiResponse): BillingStatePayload {
+ return {
+ subscription: raw.subscription ?? null,
+ planBilling: raw.plan_billing ?? null,
+ planLimits: raw.plan_limits ?? {},
+ enabledFeatures: raw.enabled_features ?? {},
+ }
+}
+
+export const billingApi = {
+ async getState(): Promise {
+ const response = await apiClient.get('/billing/state')
+ return transformBillingState(response.data)
+ },
+}
+
+export default billingApi
diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts
index 50440df8..5084a2aa 100644
--- a/frontend/src/api/index.ts
+++ b/frontend/src/api/index.ts
@@ -9,6 +9,7 @@ export { default as foldersApi } from './folders'
export { default as stepsApi } from './steps'
export { default as stepCategoriesApi } from './stepCategories'
export { default as accountsApi } from './accounts'
+export { default as billingApi } from './billing'
export { default as adminApi } from './admin'
export { treeMarkdownApi } from './treeMarkdown'
export { default as analyticsApi } from './analytics'
diff --git a/frontend/src/components/layout/AppLayout.tsx b/frontend/src/components/layout/AppLayout.tsx
index 4a4b08a6..ab4722d2 100644
--- a/frontend/src/components/layout/AppLayout.tsx
+++ b/frontend/src/components/layout/AppLayout.tsx
@@ -4,6 +4,7 @@ import { Menu, X, LayoutGrid, Clock, AlertTriangle, GitBranch, Wand2, BarChart3,
import { useAuthStore } from '@/store/authStore'
import { usePermissions } from '@/hooks/usePermissions'
import { useUserPreferencesStore } from '@/store/userPreferencesStore'
+import { useBillingPoll } from '@/hooks/useBillingPoll'
import { BrandLogo } from '@/components/common/BrandLogo'
import { TopBar } from './TopBar'
import { Sidebar } from './Sidebar'
@@ -13,6 +14,9 @@ import { FeedbackWidget } from '@/components/common/FeedbackWidget'
import { cn } from '@/lib/utils'
export function AppLayout() {
+ // Poll /billing/state every 60s while authenticated. Hook no-ops when logged out.
+ useBillingPoll()
+
const location = useLocation()
const navigate = useNavigate()
const { user, logout } = useAuthStore()
diff --git a/frontend/src/hooks/useBillingPoll.ts b/frontend/src/hooks/useBillingPoll.ts
new file mode 100644
index 00000000..cfc397fd
--- /dev/null
+++ b/frontend/src/hooks/useBillingPoll.ts
@@ -0,0 +1,32 @@
+import { useEffect } from 'react'
+import { useAuthStore } from '@/store/authStore'
+import { useBillingStore } from '@/store/billingStore'
+
+const POLL_INTERVAL_MS = 60_000
+
+/**
+ * Re-fetches billing state every 60s while a user is logged in.
+ *
+ * Mount once at the top of the authenticated dashboard tree. Polling
+ * automatically pauses when the auth store reports no logged-in user.
+ *
+ * Note: this is a v1 simple-interval implementation; a later task may
+ * swap to SSE / visibility-aware polling.
+ */
+export function useBillingPoll(): void {
+ const isAuthenticated = useAuthStore((s) => s.isAuthenticated)
+
+ useEffect(() => {
+ if (!isAuthenticated) return
+
+ const id = window.setInterval(() => {
+ void useBillingStore.getState().refetch()
+ }, POLL_INTERVAL_MS)
+
+ return () => {
+ window.clearInterval(id)
+ }
+ }, [isAuthenticated])
+}
+
+export default useBillingPoll
diff --git a/frontend/src/store/authStore.ts b/frontend/src/store/authStore.ts
index 3465b626..68e40459 100644
--- a/frontend/src/store/authStore.ts
+++ b/frontend/src/store/authStore.ts
@@ -6,6 +6,7 @@ import { authApi } from '@/api/auth'
import { identifyUser, resetAnalytics, analytics } from '@/lib/analytics'
import { apiClient } from '@/api/client'
import { clearCachedQuota } from '@/hooks/useCachedQuota'
+import { useBillingStore } from '@/store/billingStore'
interface AuthState {
user: User | null
@@ -85,6 +86,7 @@ export const useAuthStore = create()(
localStorage.removeItem('access_token')
localStorage.removeItem('refresh_token')
clearCachedQuota()
+ useBillingStore.getState().reset()
Sentry.setUser(null)
resetAnalytics()
set({ user: null, token: null, account: null, subscription: null, isAuthenticated: false, error: null })
@@ -117,6 +119,11 @@ export const useAuthStore = create()(
identifyUser({ id: user.id, email: user.email, role: user.role, is_super_admin: user.is_super_admin, account_id: account?.id })
set({ user, account, subscription, isLoading: false })
+
+ // Kick off billing-state fetch alongside auth — fire-and-forget so
+ // a billing error never breaks login. The billing store records
+ // its own error state.
+ void useBillingStore.getState().fetch()
} catch (error: unknown) {
const message = error instanceof Error ? error.message : 'Failed to fetch user'
set({ error: message, isLoading: false })
diff --git a/frontend/src/store/billingStore.test.ts b/frontend/src/store/billingStore.test.ts
new file mode 100644
index 00000000..e3977b70
--- /dev/null
+++ b/frontend/src/store/billingStore.test.ts
@@ -0,0 +1,118 @@
+import { describe, it, expect, vi, beforeEach } from 'vitest'
+import { useBillingStore } from './billingStore'
+import { billingApi } from '@/api/billing'
+import type { BillingStatePayload } from '@/types'
+
+vi.mock('@/api/billing', () => ({
+ billingApi: {
+ getState: vi.fn(),
+ },
+ default: {
+ getState: vi.fn(),
+ },
+}))
+
+const mockGetState = billingApi.getState as ReturnType
+
+const INITIAL_PAYLOAD: BillingStatePayload = {
+ subscription: {
+ status: 'trialing',
+ plan: 'pro',
+ current_period_start: '2026-05-01T00:00:00Z',
+ current_period_end: '2026-05-15T00:00:00Z',
+ cancel_at_period_end: false,
+ seat_limit: 5,
+ has_pro_entitlement: true,
+ is_paid: false,
+ },
+ planBilling: {
+ display_name: 'Pro',
+ description: 'Pro plan',
+ monthly_price_cents: 4900,
+ annual_price_cents: 49000,
+ },
+ planLimits: { seats: 5 },
+ enabledFeatures: { ai_assistant: true },
+}
+
+describe('useBillingStore', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ // Reset store to empty initial state.
+ useBillingStore.setState({
+ subscription: null,
+ planBilling: null,
+ planLimits: {},
+ enabledFeatures: {},
+ isLoading: false,
+ error: null,
+ })
+ })
+
+ it('useBillingStore fetches on login and populates subscription', async () => {
+ mockGetState.mockResolvedValueOnce(INITIAL_PAYLOAD)
+
+ // Sanity: starts empty.
+ expect(useBillingStore.getState().subscription).toBeNull()
+
+ await useBillingStore.getState().fetch()
+
+ const state = useBillingStore.getState()
+ expect(mockGetState).toHaveBeenCalledOnce()
+ expect(state.subscription).toEqual(INITIAL_PAYLOAD.subscription)
+ expect(state.planBilling).toEqual(INITIAL_PAYLOAD.planBilling)
+ expect(state.planLimits).toEqual(INITIAL_PAYLOAD.planLimits)
+ expect(state.enabledFeatures).toEqual(INITIAL_PAYLOAD.enabledFeatures)
+ expect(state.isLoading).toBe(false)
+ expect(state.error).toBeNull()
+ })
+
+ it('useBillingStore resets on logout', async () => {
+ mockGetState.mockResolvedValueOnce(INITIAL_PAYLOAD)
+ await useBillingStore.getState().fetch()
+ expect(useBillingStore.getState().subscription).not.toBeNull()
+
+ useBillingStore.getState().reset()
+
+ const state = useBillingStore.getState()
+ expect(state.subscription).toBeNull()
+ expect(state.planBilling).toBeNull()
+ expect(state.planLimits).toEqual({})
+ expect(state.enabledFeatures).toEqual({})
+ expect(state.isLoading).toBe(false)
+ expect(state.error).toBeNull()
+ })
+
+ it('useBillingStore refetch overwrites stale data', async () => {
+ mockGetState.mockResolvedValueOnce(INITIAL_PAYLOAD)
+ await useBillingStore.getState().fetch()
+ expect(useBillingStore.getState().subscription?.status).toBe('trialing')
+
+ const updatedPayload: BillingStatePayload = {
+ ...INITIAL_PAYLOAD,
+ subscription: {
+ ...INITIAL_PAYLOAD.subscription!,
+ status: 'active',
+ is_paid: true,
+ },
+ enabledFeatures: { ai_assistant: true, advanced_reports: true },
+ }
+ // Hold the refetch promise open so we can observe mid-flight isLoading=true.
+ let resolveSecond: (value: BillingStatePayload) => void = () => {}
+ mockGetState.mockImplementationOnce(
+ () => new Promise((resolve) => { resolveSecond = resolve })
+ )
+
+ const refetchPromise = useBillingStore.getState().refetch()
+ expect(useBillingStore.getState().isLoading).toBe(true)
+ resolveSecond(updatedPayload)
+ await refetchPromise
+
+ const state = useBillingStore.getState()
+ expect(mockGetState).toHaveBeenCalledTimes(2)
+ expect(state.subscription?.status).toBe('active')
+ expect(state.subscription?.is_paid).toBe(true)
+ expect(state.enabledFeatures).toEqual({ ai_assistant: true, advanced_reports: true })
+ expect(state.isLoading).toBe(false)
+ })
+})
diff --git a/frontend/src/store/billingStore.ts b/frontend/src/store/billingStore.ts
new file mode 100644
index 00000000..594a7a9e
--- /dev/null
+++ b/frontend/src/store/billingStore.ts
@@ -0,0 +1,82 @@
+import { create } from 'zustand'
+import { billingApi } from '@/api/billing'
+import type {
+ BillingSubscriptionState,
+ PlanBillingState,
+} from '@/types'
+
+interface BillingState {
+ subscription: BillingSubscriptionState | null
+ planBilling: PlanBillingState | null
+ planLimits: Record
+ enabledFeatures: Record
+ isLoading: boolean
+ error: string | null
+}
+
+interface BillingActions {
+ /** Fetch billing state. Sets `isLoading` while in flight. */
+ fetch: () => Promise
+ /** Same as `fetch` but intended for explicit refresh after Stripe Checkout. */
+ refetch: () => Promise
+ /** Reset to empty initial state — call on logout. */
+ reset: () => void
+}
+
+export type BillingStore = BillingState & BillingActions
+
+const INITIAL_STATE: BillingState = {
+ subscription: null,
+ planBilling: null,
+ planLimits: {},
+ enabledFeatures: {},
+ isLoading: false,
+ error: null,
+}
+
+export const useBillingStore = create((set) => ({
+ ...INITIAL_STATE,
+
+ fetch: async () => {
+ set({ isLoading: true, error: null })
+ try {
+ const data = await billingApi.getState()
+ set({
+ subscription: data.subscription,
+ planBilling: data.planBilling,
+ planLimits: data.planLimits,
+ enabledFeatures: data.enabledFeatures,
+ isLoading: false,
+ error: null,
+ })
+ } catch (error: unknown) {
+ // 401s are handled globally by the apiClient response interceptor
+ // (token-refresh + logout), so we just record any other error here.
+ const message = error instanceof Error ? error.message : 'Failed to load billing state'
+ set({ isLoading: false, error: message })
+ }
+ },
+
+ refetch: async () => {
+ // Same semantics as fetch — separate name documents intent at the call site.
+ set({ isLoading: true, error: null })
+ try {
+ const data = await billingApi.getState()
+ set({
+ subscription: data.subscription,
+ planBilling: data.planBilling,
+ planLimits: data.planLimits,
+ enabledFeatures: data.enabledFeatures,
+ isLoading: false,
+ error: null,
+ })
+ } catch (error: unknown) {
+ const message = error instanceof Error ? error.message : 'Failed to load billing state'
+ set({ isLoading: false, error: message })
+ }
+ },
+
+ reset: () => set({ ...INITIAL_STATE }),
+}))
+
+export default useBillingStore
diff --git a/frontend/src/types/billing.ts b/frontend/src/types/billing.ts
new file mode 100644
index 00000000..f0654038
--- /dev/null
+++ b/frontend/src/types/billing.ts
@@ -0,0 +1,51 @@
+/**
+ * Billing state types for the unified `/billing/state` endpoint.
+ *
+ * The backend returns snake_case keys (`plan_billing`, `enabled_features`);
+ * the API client (`frontend/src/api/billing.ts`) transforms the payload to
+ * camelCase before it reaches the rest of the frontend.
+ */
+
+export type SubscriptionStatus =
+ | 'trialing'
+ | 'active'
+ | 'past_due'
+ | 'canceled'
+ | 'incomplete'
+ | 'complimentary'
+
+export interface SubscriptionState {
+ status: SubscriptionStatus
+ plan: string
+ /** ISO 8601 string or null */
+ current_period_start: string | null
+ /** ISO 8601 string or null */
+ current_period_end: string | null
+ cancel_at_period_end: boolean
+ seat_limit: number | null
+ has_pro_entitlement: boolean
+ is_paid: boolean
+}
+
+export interface PlanBillingState {
+ display_name: string
+ description: string | null
+ monthly_price_cents: number | null
+ annual_price_cents: number | null
+}
+
+/** Camel-cased billing-state payload, post-transform. */
+export interface BillingStatePayload {
+ subscription: SubscriptionState | null
+ planBilling: PlanBillingState | null
+ planLimits: Record
+ enabledFeatures: Record
+}
+
+/** Raw snake_case payload returned by the backend. */
+export interface BillingStateApiResponse {
+ subscription: SubscriptionState | null
+ plan_billing: PlanBillingState | null
+ plan_limits: Record
+ enabled_features: Record
+}
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index bfd9e759..5dc98ebb 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -93,6 +93,14 @@ export type {
KBQuotaResponse,
} from './kbAccelerator'
+export type {
+ SubscriptionStatus,
+ SubscriptionState as BillingSubscriptionState,
+ PlanBillingState,
+ BillingStatePayload,
+ BillingStateApiResponse,
+} from './billing'
+
export * from './scripts'
export * from './script-builder'
export * from './integrations'
--
2.49.1
From 0b5ed9aa104eda09148fc161a99544cd0b7aaf2d Mon Sep 17 00:00:00 2001
From: Michael Chihlas
Date: Wed, 6 May 2026 20:52:18 -0400
Subject: [PATCH 07/29] feat(billing): add useFeature, useFeatureLimit,
useTrialBanner hooks
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Phase 2 Task 33. Components can now ask "is this feature on?", "how many
sessions left?", and "what stage is the trial in?" without re-implementing
the read against useBillingStore.
- useFeature(flagKey): boolean — reads enabledFeatures from store
- useFeatureLimit(field): { used, limit, percentage, isAtLimit, isLoading }
with non-blocking 60s module-level cache and graceful 404 degradation
- useTrialBanner(): derives stage from subscription status + trial countdown,
returns null on initial load to prevent flicker
- usageApi.getCount(field) — calls /api/v1/usage/{field}; backend endpoint
is not yet implemented (planned), so the hook degrades to used=0
Co-Authored-By: Claude Opus 4.7
---
frontend/src/api/index.ts | 1 +
frontend/src/api/usage.ts | 23 ++++
frontend/src/hooks/useFeature.test.ts | 44 +++++++
frontend/src/hooks/useFeature.ts | 16 +++
frontend/src/hooks/useFeatureLimit.test.ts | 112 ++++++++++++++++++
frontend/src/hooks/useFeatureLimit.ts | 125 ++++++++++++++++++++
frontend/src/hooks/useTrialBanner.test.ts | 131 +++++++++++++++++++++
frontend/src/hooks/useTrialBanner.ts | 86 ++++++++++++++
8 files changed, 538 insertions(+)
create mode 100644 frontend/src/api/usage.ts
create mode 100644 frontend/src/hooks/useFeature.test.ts
create mode 100644 frontend/src/hooks/useFeature.ts
create mode 100644 frontend/src/hooks/useFeatureLimit.test.ts
create mode 100644 frontend/src/hooks/useFeatureLimit.ts
create mode 100644 frontend/src/hooks/useTrialBanner.test.ts
create mode 100644 frontend/src/hooks/useTrialBanner.ts
diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts
index 5084a2aa..1d0ca2f4 100644
--- a/frontend/src/api/index.ts
+++ b/frontend/src/api/index.ts
@@ -10,6 +10,7 @@ export { default as stepsApi } from './steps'
export { default as stepCategoriesApi } from './stepCategories'
export { default as accountsApi } from './accounts'
export { default as billingApi } from './billing'
+export { default as usageApi } from './usage'
export { default as adminApi } from './admin'
export { treeMarkdownApi } from './treeMarkdown'
export { default as analyticsApi } from './analytics'
diff --git a/frontend/src/api/usage.ts b/frontend/src/api/usage.ts
new file mode 100644
index 00000000..f08f7f44
--- /dev/null
+++ b/frontend/src/api/usage.ts
@@ -0,0 +1,23 @@
+import apiClient from './client'
+
+/**
+ * Usage counters API.
+ *
+ * TODO: backend `/usage/{field}` endpoint not yet implemented (planned).
+ * Tracked under self-serve signup Phase 2 — Task 33 calls this lazily; today
+ * it 404s and the consuming hook (`useFeatureLimit`) cleanly degrades to
+ * `used = 0`.
+ */
+export const usageApi = {
+ /**
+ * Fetch the current count for a usage field (e.g. `active_users`,
+ * `flowpilot_sessions_this_month`). The field name is the same key used in
+ * `BillingState.planLimits`.
+ */
+ async getCount(field: string): Promise<{ used: number }> {
+ const response = await apiClient.get<{ used: number }>(`/usage/${field}`)
+ return response.data
+ },
+}
+
+export default usageApi
diff --git a/frontend/src/hooks/useFeature.test.ts b/frontend/src/hooks/useFeature.test.ts
new file mode 100644
index 00000000..38f35d29
--- /dev/null
+++ b/frontend/src/hooks/useFeature.test.ts
@@ -0,0 +1,44 @@
+import { describe, it, expect, beforeEach } from 'vitest'
+import { renderHook, act } from '@testing-library/react'
+import { useFeature } from './useFeature'
+import { useBillingStore } from '@/store/billingStore'
+
+describe('useFeature', () => {
+ beforeEach(() => {
+ useBillingStore.setState({
+ subscription: null,
+ planBilling: null,
+ planLimits: {},
+ enabledFeatures: {},
+ isLoading: false,
+ error: null,
+ })
+ })
+
+ it('returns false when flag absent', () => {
+ const { result } = renderHook(() => useFeature('does_not_exist'))
+ expect(result.current).toBe(false)
+ })
+
+ it('returns true when flag is enabled', () => {
+ useBillingStore.setState({ enabledFeatures: { ai_builder: true } })
+ const { result } = renderHook(() => useFeature('ai_builder'))
+ expect(result.current).toBe(true)
+ })
+
+ it('returns false when flag is explicitly disabled', () => {
+ useBillingStore.setState({ enabledFeatures: { ai_builder: false } })
+ const { result } = renderHook(() => useFeature('ai_builder'))
+ expect(result.current).toBe(false)
+ })
+
+ it('updates when store changes (subscribes to store)', () => {
+ const { result } = renderHook(() => useFeature('foo'))
+ expect(result.current).toBe(false)
+
+ act(() => {
+ useBillingStore.setState({ enabledFeatures: { foo: true } })
+ })
+ expect(result.current).toBe(true)
+ })
+})
diff --git a/frontend/src/hooks/useFeature.ts b/frontend/src/hooks/useFeature.ts
new file mode 100644
index 00000000..12f971d7
--- /dev/null
+++ b/frontend/src/hooks/useFeature.ts
@@ -0,0 +1,16 @@
+import { useBillingStore } from '@/store/billingStore'
+
+/**
+ * Returns whether a feature flag is enabled for the current account.
+ *
+ * Reads from `useBillingStore.enabledFeatures`, which is populated by
+ * `GET /billing/state`. Returns `false` when the flag is absent (closed-by-default).
+ *
+ * The hook subscribes to the store so updates from `refetch()` propagate
+ * without manual refetch in the component.
+ */
+export function useFeature(flagKey: string): boolean {
+ return useBillingStore((state) => Boolean(state.enabledFeatures[flagKey]))
+}
+
+export default useFeature
diff --git a/frontend/src/hooks/useFeatureLimit.test.ts b/frontend/src/hooks/useFeatureLimit.test.ts
new file mode 100644
index 00000000..8561463d
--- /dev/null
+++ b/frontend/src/hooks/useFeatureLimit.test.ts
@@ -0,0 +1,112 @@
+import { describe, it, expect, beforeEach, vi } from 'vitest'
+import { renderHook, waitFor } from '@testing-library/react'
+import { useFeatureLimit, clearUsageCache } from './useFeatureLimit'
+import { useBillingStore } from '@/store/billingStore'
+
+vi.mock('@/api/usage', () => ({
+ usageApi: {
+ getCount: vi.fn(),
+ },
+}))
+
+import { usageApi } from '@/api/usage'
+
+const mockedGetCount = vi.mocked(usageApi.getCount)
+
+describe('useFeatureLimit', () => {
+ beforeEach(() => {
+ clearUsageCache()
+ mockedGetCount.mockReset()
+ useBillingStore.setState({
+ subscription: null,
+ planBilling: null,
+ planLimits: {},
+ enabledFeatures: {},
+ isLoading: false,
+ error: null,
+ })
+ })
+
+ it('transitions isLoading -> loaded', async () => {
+ useBillingStore.setState({ planLimits: { active_users: 10 } })
+ mockedGetCount.mockResolvedValueOnce({ used: 4 })
+
+ const { result } = renderHook(() => useFeatureLimit('active_users'))
+
+ // Non-blocking initial state.
+ expect(result.current.isLoading).toBe(true)
+ expect(result.current.used).toBe(0)
+ expect(result.current.limit).toBe(10)
+
+ await waitFor(() => {
+ expect(result.current.isLoading).toBe(false)
+ })
+
+ expect(result.current.used).toBe(4)
+ expect(result.current.limit).toBe(10)
+ expect(result.current.percentage).toBe(40)
+ expect(result.current.isAtLimit).toBe(false)
+ })
+
+ it('flags isAtLimit when used >= limit', async () => {
+ useBillingStore.setState({ planLimits: { seats: 3 } })
+ mockedGetCount.mockResolvedValueOnce({ used: 3 })
+
+ const { result } = renderHook(() => useFeatureLimit('seats'))
+ await waitFor(() => expect(result.current.isLoading).toBe(false))
+
+ expect(result.current.isAtLimit).toBe(true)
+ expect(result.current.percentage).toBe(100)
+ })
+
+ it('returns null percentage when limit is null (unlimited)', async () => {
+ useBillingStore.setState({ planLimits: { sessions: null } })
+ mockedGetCount.mockResolvedValueOnce({ used: 7 })
+
+ const { result } = renderHook(() => useFeatureLimit('sessions'))
+ await waitFor(() => expect(result.current.isLoading).toBe(false))
+
+ expect(result.current.limit).toBe(null)
+ expect(result.current.percentage).toBe(null)
+ expect(result.current.isAtLimit).toBe(false)
+ })
+
+ it('resets isLoading=true synchronously when `field` prop changes', async () => {
+ useBillingStore.setState({ planLimits: { max_trees: 5, max_users: 10 } })
+ mockedGetCount.mockResolvedValueOnce({ used: 2 }) // for max_trees
+ mockedGetCount.mockResolvedValueOnce({ used: 3 }) // for max_users (slow)
+
+ const { result, rerender } = renderHook(
+ ({ field }: { field: string }) => useFeatureLimit(field),
+ { initialProps: { field: 'max_trees' } },
+ )
+
+ // First field resolves.
+ await waitFor(() => expect(result.current.isLoading).toBe(false))
+ expect(result.current.used).toBe(2)
+ expect(result.current.limit).toBe(5)
+
+ // Switch field. Next render must report isLoading=true (no stale data
+ // bleed-through) before the new fetch resolves.
+ rerender({ field: 'max_users' })
+ expect(result.current.isLoading).toBe(true)
+ expect(result.current.used).toBe(0)
+ expect(result.current.limit).toBe(10)
+
+ await waitFor(() => expect(result.current.isLoading).toBe(false))
+ expect(result.current.used).toBe(3)
+ expect(result.current.limit).toBe(10)
+ })
+
+ it('degrades to used=0 on fetch error (404 from missing endpoint)', async () => {
+ useBillingStore.setState({ planLimits: { active_users: 5 } })
+ mockedGetCount.mockRejectedValueOnce(new Error('Request failed with status 404'))
+
+ const { result } = renderHook(() => useFeatureLimit('active_users'))
+ await waitFor(() => expect(result.current.isLoading).toBe(false))
+
+ expect(result.current.used).toBe(0)
+ expect(result.current.limit).toBe(5)
+ expect(result.current.percentage).toBe(0)
+ })
+})
diff --git a/frontend/src/hooks/useFeatureLimit.ts b/frontend/src/hooks/useFeatureLimit.ts
new file mode 100644
index 00000000..4d6b05a9
--- /dev/null
+++ b/frontend/src/hooks/useFeatureLimit.ts
@@ -0,0 +1,125 @@
+import { useEffect, useRef, useState } from 'react'
+import { useBillingStore } from '@/store/billingStore'
+import { usageApi } from '@/api/usage'
+
+const CACHE_TTL_MS = 60 * 1000
+
+interface CacheEntry {
+ used: number
+ timestamp: number
+}
+
+const cache = new Map()
+
+/** Clear the usage cache (call on logout to prevent stale data across users). */
+export function clearUsageCache() {
+ cache.clear()
+}
+
+export interface FeatureLimitResult {
+ used: number
+ limit: number | null
+ /** null when limit is null (unlimited) or unknown */
+ percentage: number | null
+ isAtLimit: boolean
+ isLoading: boolean
+}
+
+function coerceLimit(raw: unknown): number | null {
+ if (typeof raw === 'number' && Number.isFinite(raw)) return raw
+ if (raw === null || raw === undefined) return null
+ // The store types planLimits as Record; the backend
+ // currently returns numbers, but defensively handle string ints too.
+ if (typeof raw === 'string') {
+ const n = Number(raw)
+ return Number.isFinite(n) ? n : null
+ }
+ return null
+}
+
+/**
+ * Returns progress against a quantitative plan limit.
+ *
+ * `limit` comes from `useBillingStore.planLimits[field]`, which is read
+ * synchronously from the store. `used` is fetched lazily from
+ * `GET /api/v1/usage/{field}` on mount and cached for 60s in a module-level
+ * map keyed by field.
+ *
+ * Render is non-blocking: the hook returns `isLoading=true` (with `used=0`)
+ * until the usage fetch resolves. On 404 or any error the hook degrades to
+ * `used=0` with `isLoading=false` rather than surfacing the error — the
+ * `/usage/{field}` endpoint is not yet implemented on the backend (planned).
+ */
+export function useFeatureLimit(field: string): FeatureLimitResult {
+ const limit = useBillingStore((state) => coerceLimit(state.planLimits[field]))
+
+ // Initialize from cache on first mount only; subsequent `field` changes
+ // are handled inside the effect below so the render-phase result reflects
+ // the new field synchronously (no stale `used`/`isLoading` for one tick).
+ const initialCached = useRef(undefined)
+ if (initialCached.current === undefined) {
+ initialCached.current = cache.get(field)
+ }
+ const initialFresh =
+ initialCached.current && Date.now() - initialCached.current.timestamp < CACHE_TTL_MS
+ const [used, setUsed] = useState(initialFresh ? initialCached.current!.used : 0)
+ const [isLoading, setIsLoading] = useState(!initialFresh)
+
+ // Track the field that the current `used`/`isLoading` state describes.
+ // When `field` changes, we synchronously reset state in render so callers
+ // never see stale data for the previous field.
+ const stateField = useRef(field)
+ if (stateField.current !== field) {
+ stateField.current = field
+ const existing = cache.get(field)
+ const freshNow = existing && Date.now() - existing.timestamp < CACHE_TTL_MS
+ if (freshNow) {
+ setUsed(existing!.used)
+ setIsLoading(false)
+ } else {
+ setUsed(0)
+ setIsLoading(true)
+ }
+ }
+
+ useEffect(() => {
+ const existing = cache.get(field)
+ if (existing && Date.now() - existing.timestamp < CACHE_TTL_MS) {
+ setUsed(existing.used)
+ setIsLoading(false)
+ return
+ }
+
+ let cancelled = false
+ setIsLoading(true)
+ usageApi
+ .getCount(field)
+ .then((result) => {
+ if (cancelled) return
+ cache.set(field, { used: result.used, timestamp: Date.now() })
+ setUsed(result.used)
+ })
+ .catch(() => {
+ // TODO: backend /usage/{field} endpoint not yet implemented (planned).
+ // 404s and other errors degrade to used=0 silently — no toast.
+ if (cancelled) return
+ setUsed(0)
+ })
+ .finally(() => {
+ if (cancelled) return
+ setIsLoading(false)
+ })
+
+ return () => {
+ cancelled = true
+ }
+ }, [field])
+
+ const percentage =
+ limit === null || limit <= 0 ? null : Math.min(100, Math.round((used / limit) * 100))
+ const isAtLimit = limit !== null && used >= limit
+
+ return { used, limit, percentage, isAtLimit, isLoading }
+}
+
+export default useFeatureLimit
diff --git a/frontend/src/hooks/useTrialBanner.test.ts b/frontend/src/hooks/useTrialBanner.test.ts
new file mode 100644
index 00000000..93cc2843
--- /dev/null
+++ b/frontend/src/hooks/useTrialBanner.test.ts
@@ -0,0 +1,131 @@
+import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'
+import { renderHook } from '@testing-library/react'
+import { useTrialBanner } from './useTrialBanner'
+import { useBillingStore } from '@/store/billingStore'
+import type { SubscriptionState } from '@/types/billing'
+
+const FROZEN_NOW = new Date('2026-05-06T00:00:00Z')
+
+function makeSub(overrides: Partial): SubscriptionState {
+ return {
+ status: 'trialing',
+ plan: 'starter',
+ current_period_start: '2026-05-01T00:00:00Z',
+ current_period_end: null,
+ cancel_at_period_end: false,
+ seat_limit: null,
+ has_pro_entitlement: false,
+ is_paid: false,
+ ...overrides,
+ }
+}
+
+function setSubscription(overrides: Partial) {
+ useBillingStore.setState({ subscription: makeSub(overrides) })
+}
+
+describe('useTrialBanner', () => {
+ beforeEach(() => {
+ vi.useFakeTimers()
+ vi.setSystemTime(FROZEN_NOW)
+ useBillingStore.setState({
+ subscription: null,
+ planBilling: null,
+ planLimits: {},
+ enabledFeatures: {},
+ isLoading: false,
+ error: null,
+ })
+ })
+
+ afterEach(() => {
+ vi.useRealTimers()
+ })
+
+ describe('stage matches subscription state matrix', () => {
+ it('returns null when subscription is null (no flicker on initial load)', () => {
+ const { result } = renderHook(() => useTrialBanner())
+ expect(result.current.stage).toBe(null)
+ expect(result.current.daysRemaining).toBe(null)
+ })
+
+ it('complimentary status -> complimentary stage', () => {
+ setSubscription({ status: 'complimentary' })
+ const { result } = renderHook(() => useTrialBanner())
+ expect(result.current.stage).toBe('complimentary')
+ })
+
+ it('active status -> paid stage', () => {
+ setSubscription({ status: 'active' })
+ const { result } = renderHook(() => useTrialBanner())
+ expect(result.current.stage).toBe('paid')
+ })
+
+ it('past_due status -> past_due stage', () => {
+ setSubscription({ status: 'past_due' })
+ const { result } = renderHook(() => useTrialBanner())
+ expect(result.current.stage).toBe('past_due')
+ })
+
+ it('canceled status -> canceled stage', () => {
+ setSubscription({ status: 'canceled' })
+ const { result } = renderHook(() => useTrialBanner())
+ expect(result.current.stage).toBe('canceled')
+ })
+
+ it('trialing >3 days remaining -> pristine', () => {
+ // 7 days from frozen now.
+ setSubscription({
+ status: 'trialing',
+ current_period_end: '2026-05-13T00:00:00Z',
+ })
+ const { result } = renderHook(() => useTrialBanner())
+ expect(result.current.stage).toBe('pristine')
+ expect(result.current.daysRemaining).toBe(7)
+ })
+
+ it('trialing 1-3 days remaining -> warning', () => {
+ // 2 days from frozen now.
+ setSubscription({
+ status: 'trialing',
+ current_period_end: '2026-05-08T00:00:00Z',
+ })
+ const { result } = renderHook(() => useTrialBanner())
+ expect(result.current.stage).toBe('warning')
+ expect(result.current.daysRemaining).toBe(2)
+ })
+
+ it('trialing exactly 24 hours remaining -> warning (boundary, not urgent)', () => {
+ // Exactly 1.0 fractional day from frozen now — must sit on the warning
+ // side per spec (1–3 days inclusive of 1).
+ setSubscription({
+ status: 'trialing',
+ current_period_end: '2026-05-07T00:00:00Z',
+ })
+ const { result } = renderHook(() => useTrialBanner())
+ expect(result.current.stage).toBe('warning')
+ expect(result.current.daysRemaining).toBe(1)
+ })
+
+ it('trialing <1 day remaining -> urgent', () => {
+ // 12 hours from frozen now -> Math.ceil(0.5) = 1 day.
+ setSubscription({
+ status: 'trialing',
+ current_period_end: '2026-05-06T12:00:00Z',
+ })
+ const { result } = renderHook(() => useTrialBanner())
+ expect(result.current.stage).toBe('urgent')
+ expect(result.current.daysRemaining).toBe(1)
+ })
+
+ it('trialing past period_end -> expired', () => {
+ setSubscription({
+ status: 'trialing',
+ current_period_end: '2026-05-01T00:00:00Z',
+ })
+ const { result } = renderHook(() => useTrialBanner())
+ expect(result.current.stage).toBe('expired')
+ expect(result.current.daysRemaining).toBe(0)
+ })
+ })
+})
diff --git a/frontend/src/hooks/useTrialBanner.ts b/frontend/src/hooks/useTrialBanner.ts
new file mode 100644
index 00000000..cc3ebe47
--- /dev/null
+++ b/frontend/src/hooks/useTrialBanner.ts
@@ -0,0 +1,86 @@
+import { useBillingStore } from '@/store/billingStore'
+
+export type TrialBannerStage =
+ | 'pristine'
+ | 'warning'
+ | 'urgent'
+ | 'expired'
+ | 'complimentary'
+ | 'paid'
+ | 'past_due'
+ | 'canceled'
+
+export interface TrialBannerResult {
+ stage: TrialBannerStage | null
+ daysRemaining: number | null
+}
+
+const MS_PER_DAY = 24 * 60 * 60 * 1000
+
+/**
+ * Derives the trial-banner display stage from the current subscription.
+ *
+ * Returns `{ stage: null, daysRemaining: null }` when subscription data is
+ * not yet loaded — this prevents the banner flickering on initial render.
+ *
+ * Subscribes to `useBillingStore` so updates from `refetch()` after a Stripe
+ * checkout propagate automatically.
+ */
+export function useTrialBanner(): TrialBannerResult {
+ const subscription = useBillingStore((state) => state.subscription)
+
+ if (!subscription) {
+ return { stage: null, daysRemaining: null }
+ }
+
+ switch (subscription.status) {
+ case 'complimentary':
+ return { stage: 'complimentary', daysRemaining: null }
+ case 'active':
+ return { stage: 'paid', daysRemaining: null }
+ case 'past_due':
+ return { stage: 'past_due', daysRemaining: null }
+ case 'canceled':
+ return { stage: 'canceled', daysRemaining: null }
+ case 'trialing': {
+ const end = subscription.current_period_end
+ ? new Date(subscription.current_period_end).getTime()
+ : null
+ if (end === null || Number.isNaN(end)) {
+ // Trialing without a period end is malformed; treat as expired so the
+ // upgrade prompt still surfaces rather than silently swallowing it.
+ return { stage: 'expired', daysRemaining: null }
+ }
+ const now = Date.now()
+ if (end <= now) {
+ return { stage: 'expired', daysRemaining: 0 }
+ }
+ const msRemaining = end - now
+ // Use fractional days for stage thresholds so exactly 24h remaining
+ // sits on the warning side (1.0), not urgent. The displayed integer
+ // countdown still uses Math.ceil so "0.5 days" renders as "1 day".
+ const fractionalDays = msRemaining / MS_PER_DAY
+ const daysRemaining = Math.ceil(fractionalDays)
+ // Spec thresholds:
+ // >3 days remaining → pristine
+ // 1–3 days → warning (inclusive of exactly 1)
+ // <1 day → urgent
+ let stage: TrialBannerStage = 'pristine'
+ if (fractionalDays < 1) stage = 'urgent'
+ else if (fractionalDays <= 3) stage = 'warning'
+ return { stage, daysRemaining }
+ }
+ case 'incomplete':
+ // Not in the spec's matrix; surface as null so the banner stays hidden
+ // until checkout actually resolves.
+ return { stage: null, daysRemaining: null }
+ default: {
+ // Defensive fallthrough for unknown statuses — keep the banner hidden.
+ const _exhaustive: never = subscription.status as never
+ void _exhaustive
+ return { stage: null, daysRemaining: null }
+ }
+ }
+}
+
+export default useTrialBanner
--
2.49.1
From ece82225f202f9a903cb525481145e97bb8a3b51 Mon Sep 17 00:00:00 2001
From: Michael Chihlas
Date: Wed, 6 May 2026 21:01:53 -0400
Subject: [PATCH 08/29] feat(billing): add FeatureGate, UpgradePrompt,
EmailVerificationGate components
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Three drop-in gating components for the self-serve signup flow.
- FeatureGate reads useFeature(flag) and renders children when enabled,
else a fallback (default UpgradePrompt). UX-only — security boundary
remains require_feature on the backend.
- UpgradePrompt resolves a feature key to display name + required plan
via an inline catalog and links to /account/billing/select-plan.
- EmailVerificationGate gates protected content behind a 6-day grace
period; renders a minimal EmailVerificationWall (resend + sign out)
on Day 7+ unverified. Wall design will be refined in Task 37.
Co-Authored-By: Claude Opus 4.7
---
.../common/EmailVerificationGate.tsx | 56 ++++++++
.../common/EmailVerificationWall.tsx | 88 +++++++++++++
.../src/components/common/FeatureGate.tsx | 42 ++++++
.../src/components/common/UpgradePrompt.tsx | 111 ++++++++++++++++
.../__tests__/EmailVerificationGate.test.tsx | 123 ++++++++++++++++++
.../common/__tests__/FeatureGate.test.tsx | 67 ++++++++++
.../common/__tests__/UpgradePrompt.test.tsx | 30 +++++
7 files changed, 517 insertions(+)
create mode 100644 frontend/src/components/common/EmailVerificationGate.tsx
create mode 100644 frontend/src/components/common/EmailVerificationWall.tsx
create mode 100644 frontend/src/components/common/FeatureGate.tsx
create mode 100644 frontend/src/components/common/UpgradePrompt.tsx
create mode 100644 frontend/src/components/common/__tests__/EmailVerificationGate.test.tsx
create mode 100644 frontend/src/components/common/__tests__/FeatureGate.test.tsx
create mode 100644 frontend/src/components/common/__tests__/UpgradePrompt.test.tsx
diff --git a/frontend/src/components/common/EmailVerificationGate.tsx b/frontend/src/components/common/EmailVerificationGate.tsx
new file mode 100644
index 00000000..cc17cdac
--- /dev/null
+++ b/frontend/src/components/common/EmailVerificationGate.tsx
@@ -0,0 +1,56 @@
+import type { ReactNode } from 'react'
+import { useAuthStore } from '@/store/authStore'
+import { EmailVerificationWall } from './EmailVerificationWall'
+
+interface EmailVerificationGateProps {
+ children: ReactNode
+ /**
+ * Override the grace period (in days). Day `gracePeriodDays + 1` and beyond
+ * trigger the wall. Defaults to 6 — the spec says Day 1–6 unverified renders
+ * children and Day 7+ renders the wall.
+ */
+ gracePeriodDays?: number
+}
+
+const MS_PER_DAY = 24 * 60 * 60 * 1000
+
+/** Whole days elapsed between two ISO timestamps (floored). */
+function daysSince(iso: string, now: number = Date.now()): number {
+ const created = Date.parse(iso)
+ if (Number.isNaN(created)) {
+ // Defensive: bad timestamp — treat as just-signed-up so we don't
+ // accidentally lock anyone out.
+ return 0
+ }
+ return Math.floor((now - created) / MS_PER_DAY)
+}
+
+/**
+ * Wraps protected content. While the current user is past the grace period
+ * without having verified their email, renders ``
+ * instead of children.
+ *
+ * Behavior:
+ * - No user (signed out): renders children (let route guards handle auth).
+ * - User has `email_verified_at`: renders children.
+ * - Day 1–6 unverified: renders children (banner is shown elsewhere).
+ * - Day 7+ unverified: renders the wall.
+ */
+export function EmailVerificationGate({
+ children,
+ gracePeriodDays = 6,
+}: EmailVerificationGateProps) {
+ const user = useAuthStore((s) => s.user)
+
+ if (!user) return <>{children}>
+ if (user.email_verified_at) return <>{children}>
+
+ const elapsed = daysSince(user.created_at)
+ if (elapsed > gracePeriodDays) {
+ return
+ }
+
+ return <>{children}>
+}
+
+export default EmailVerificationGate
diff --git a/frontend/src/components/common/EmailVerificationWall.tsx b/frontend/src/components/common/EmailVerificationWall.tsx
new file mode 100644
index 00000000..abb95e62
--- /dev/null
+++ b/frontend/src/components/common/EmailVerificationWall.tsx
@@ -0,0 +1,88 @@
+import { useState } from 'react'
+import { Loader2, MailCheck } from 'lucide-react'
+import { authApi } from '@/api/auth'
+import { useAuthStore } from '@/store/authStore'
+import { toast } from '@/lib/toast'
+import { cn } from '@/lib/utils'
+
+interface EmailVerificationWallProps {
+ className?: string
+}
+
+/**
+ * Hard wall shown after the email-verification grace period expires.
+ *
+ * Minimal v1 — Task 37 will refine copy, layout, and add the
+ * `/verify-email?token=...` route handling. Until then this gives
+ * Day 7+ unverified users a way to re-send the verification email
+ * or sign out.
+ */
+export function EmailVerificationWall({ className }: EmailVerificationWallProps) {
+ const user = useAuthStore((s) => s.user)
+ const logout = useAuthStore((s) => s.logout)
+ const [isSending, setIsSending] = useState(false)
+
+ const handleResend = async () => {
+ setIsSending(true)
+ try {
+ await authApi.sendVerificationEmail()
+ toast.success('Verification email sent')
+ } catch {
+ toast.error('Failed to send verification email')
+ } finally {
+ setIsSending(false)
+ }
+ }
+
+ const handleLogout = async () => {
+ try {
+ await logout()
+ } catch {
+ // logout swallows API errors internally
+ }
+ }
+
+ return (
+
+
+
+
+
+
+ Verify your email to continue
+
+
+ {user?.email
+ ? `We sent a verification link to ${user.email}. Click it to unlock your account.`
+ : 'Check your inbox for the verification link we sent when you signed up.'}
+
+
+
+
+
+
+
+ )
+}
+
+export default EmailVerificationWall
diff --git a/frontend/src/components/common/FeatureGate.tsx b/frontend/src/components/common/FeatureGate.tsx
new file mode 100644
index 00000000..e27237d4
--- /dev/null
+++ b/frontend/src/components/common/FeatureGate.tsx
@@ -0,0 +1,42 @@
+import type { ReactNode } from 'react'
+import { useFeature } from '@/hooks/useFeature'
+import { UpgradePrompt } from './UpgradePrompt'
+
+interface FeatureGateProps {
+ /** Feature flag key (e.g. `psa_integration`). Must match a backend `feature_flags.flag_key`. */
+ feature: string
+ /**
+ * Rendered when the feature is enabled for the current account.
+ */
+ children: ReactNode
+ /**
+ * Rendered when the feature is disabled. Defaults to ``.
+ * Pass `null` to render nothing.
+ */
+ fallback?: ReactNode
+}
+
+/**
+ * Conditionally renders `children` based on whether `feature` is enabled
+ * for the current account.
+ *
+ * This is a UX affordance — the security boundary is the backend
+ * `require_feature` dependency. Never trust this gate for authorization.
+ */
+export function FeatureGate({ feature, children, fallback }: FeatureGateProps) {
+ const enabled = useFeature(feature)
+
+ if (enabled) {
+ return <>{children}>
+ }
+
+ // Use explicit fallback when provided, otherwise render the standard prompt.
+ // `null` is a valid fallback (renders nothing).
+ if (fallback !== undefined) {
+ return <>{fallback}>
+ }
+
+ return
+}
+
+export default FeatureGate
diff --git a/frontend/src/components/common/UpgradePrompt.tsx b/frontend/src/components/common/UpgradePrompt.tsx
new file mode 100644
index 00000000..7780d717
--- /dev/null
+++ b/frontend/src/components/common/UpgradePrompt.tsx
@@ -0,0 +1,111 @@
+import { Lock, Sparkles } from 'lucide-react'
+import { Link } from 'react-router-dom'
+import { cn } from '@/lib/utils'
+
+interface UpgradePromptProps {
+ feature: string
+ className?: string
+}
+
+interface FeatureMeta {
+ /** Display name shown in the prompt heading. */
+ displayName: string
+ /** Plan that unlocks this feature. */
+ requiredPlan: string
+ /** Optional one-line value pitch. */
+ description?: string
+}
+
+/**
+ * Mapping from feature flag key to display metadata.
+ *
+ * v1: small inline table maintained here. If this grows, lift to
+ * `frontend/src/lib/featureCatalog.ts` and source from a backend endpoint.
+ *
+ * Keys must match `feature_flags.flag_key` on the backend.
+ */
+const FEATURE_CATALOG: Record = {
+ psa_integration: {
+ displayName: 'PSA Integration',
+ requiredPlan: 'Pro',
+ description: 'Sync tickets and assets with your PSA in real time.',
+ },
+ kb_accelerator: {
+ displayName: 'Knowledge Base Accelerator',
+ requiredPlan: 'Pro',
+ description: 'Auto-generate troubleshooting flows from your existing KB.',
+ },
+ ai_builder: {
+ displayName: 'AI Builder',
+ requiredPlan: 'Pro',
+ description: 'Generate decision trees from natural-language prompts.',
+ },
+ branching_logic: {
+ displayName: 'Branching Logic',
+ requiredPlan: 'Pro',
+ },
+ custom_branding: {
+ displayName: 'Custom Branding',
+ requiredPlan: 'Pro',
+ },
+ api_access: {
+ displayName: 'API Access',
+ requiredPlan: 'Pro',
+ },
+ sso: {
+ displayName: 'Single Sign-On',
+ requiredPlan: 'Enterprise',
+ },
+}
+
+/** Humanize an unknown feature key for the fallback display name. */
+function humanizeFeatureKey(key: string): string {
+ return key
+ .split('_')
+ .map((part) => part.charAt(0).toUpperCase() + part.slice(1))
+ .join(' ')
+}
+
+/**
+ * Standardized "this feature is on Pro" affordance.
+ *
+ * Renders a locked panel with a CTA that routes to the plan-selection page.
+ * The actual gating is enforced server-side via `require_feature` — this is UX.
+ */
+export function UpgradePrompt({ feature, className }: UpgradePromptProps) {
+ const meta = FEATURE_CATALOG[feature]
+ const displayName = meta?.displayName ?? humanizeFeatureKey(feature)
+ const requiredPlan = meta?.requiredPlan ?? 'Pro'
+ const description = meta?.description
+
+ return (
+
+ {lookup.status === 'missing-code'
+ ? 'The invite link is missing its code.'
+ : 'This invite has expired, been used, or been revoked.'}{' '}
+ Ask the person who invited you to resend it.
+
+ >
+ )
+}
+
+export default AcceptInvitePage
diff --git a/frontend/src/pages/OAuthCallbackPage.tsx b/frontend/src/pages/OAuthCallbackPage.tsx
index 5e0b8a1d..19ec82fc 100644
--- a/frontend/src/pages/OAuthCallbackPage.tsx
+++ b/frontend/src/pages/OAuthCallbackPage.tsx
@@ -4,6 +4,7 @@ import { authApi } from '@/api/auth'
import { useAuthStore } from '@/store/authStore'
import { BrandLogo } from '@/components/common/BrandLogo'
import { PageMeta } from '@/components/common/PageMeta'
+import { decodeOAuthState } from '@/lib/oauthState'
type Provider = 'google' | 'microsoft'
@@ -13,8 +14,16 @@ type Provider = 'google' | 'microsoft'
* public routes (NOT inside ProtectedRoute).
*
* Reads `?code=...` from the URL, POSTs it to the backend, stores the
- * returned tokens, hydrates the auth store via fetchUser(), and redirects
- * to /welcome (new user) or / (returning user).
+ * returned tokens, hydrates the auth store via fetchUser(), and redirects.
+ *
+ * Two state forms are supported:
+ * - Legacy: `state` is a raw random hex string. CSRF check against
+ * sessionStorage('rf-oauth-state').
+ * - /accept-invite: `state` is base64url(JSON({csrf, accountInviteCode,
+ * invitedEmail})). The CSRF value is compared against
+ * sessionStorage('rf-oauth-state'); the invite fields are forwarded to
+ * the backend so the new user joins the invited account instead of
+ * getting a personal one.
*/
export function OAuthCallbackPage() {
const navigate = useNavigate()
@@ -35,9 +44,10 @@ export function OAuthCallbackPage() {
const oauthError = search.get('error')
const returnedState = search.get('state')
- // CSRF: validate state round-trip against the value RegisterPage stashed
- // in sessionStorage before redirecting to the provider. Always clear the
- // stored value so a stale entry can't be re-used by a later attempt.
+ // CSRF: validate state round-trip against the value RegisterPage /
+ // AcceptInvitePage stashed in sessionStorage before redirecting to the
+ // provider. Always clear the stored value so a stale entry can't be
+ // re-used by a later attempt.
let storedState: string | null = null
try {
storedState = sessionStorage.getItem('rf-oauth-state')
@@ -51,7 +61,17 @@ export function OAuthCallbackPage() {
setError(`OAuth error: ${oauthError}`)
return
}
- if (!storedState || returnedState !== storedState) {
+ if (!storedState || !returnedState) {
+ setError('Invalid OAuth state — possible CSRF. Please try again.')
+ return
+ }
+
+ // The decoded form encodes the original CSRF value; compare that.
+ const decoded = decodeOAuthState(returnedState)
+ const matchesCsrf = decoded
+ ? decoded.csrf === storedState
+ : returnedState === storedState
+ if (!matchesCsrf) {
setError('Invalid OAuth state — possible CSRF. Please try again.')
return
}
@@ -63,10 +83,16 @@ export function OAuthCallbackPage() {
let cancelled = false
void (async () => {
try {
+ const inviteOptions = decoded
+ ? {
+ accountInviteCode: decoded.accountInviteCode,
+ invitedEmail: decoded.invitedEmail,
+ }
+ : undefined
const result =
provider === 'microsoft'
- ? await authApi.microsoftCallback(code)
- : await authApi.googleCallback(code)
+ ? await authApi.microsoftCallback(code, inviteOptions)
+ : await authApi.googleCallback(code, inviteOptions)
if (cancelled) return
// Persist tokens for apiClient interceptor + zustand store.
@@ -81,7 +107,15 @@ export function OAuthCallbackPage() {
await fetchUser()
if (cancelled) return
- const dest = result.is_new_user ? '/welcome' : '/'
+ // Invitee path lands on the dashboard with the teammate-welcome
+ // marker; new self-serve owners go to the welcome wizard; returning
+ // users to /.
+ let dest = '/'
+ if (decoded?.accountInviteCode) {
+ dest = '/?welcome=teammate'
+ } else if (result.is_new_user) {
+ dest = '/welcome'
+ }
navigate(dest, { replace: true })
} catch (err: unknown) {
if (cancelled) return
@@ -89,8 +123,28 @@ export function OAuthCallbackPage() {
response?: { data?: { detail?: unknown } }
}
const detail = axiosErr.response?.data?.detail
- const msg =
- (typeof detail === 'string' ? detail : null) ||
+ // Backend returns { error: "invite_email_mismatch" } etc.
+ let msg: string | null = null
+ if (typeof detail === 'string') {
+ msg = detail
+ } else if (
+ detail &&
+ typeof detail === 'object' &&
+ 'error' in (detail as Record)
+ ) {
+ const code = (detail as { error: string }).error
+ if (code === 'invite_email_mismatch') {
+ msg =
+ 'The email on your provider account does not match the invited email. ' +
+ 'Sign in with the matching account, or ask your inviter to resend.'
+ } else if (code === 'invite_invalid_or_expired_or_revoked') {
+ msg = 'This invite is no longer valid. Ask your inviter to resend.'
+ } else {
+ msg = code
+ }
+ }
+ msg =
+ msg ||
(err instanceof Error ? err.message : 'Sign-in failed')
setError(msg)
}
diff --git a/frontend/src/pages/__tests__/AcceptInvitePage.test.tsx b/frontend/src/pages/__tests__/AcceptInvitePage.test.tsx
new file mode 100644
index 00000000..3f8ebf76
--- /dev/null
+++ b/frontend/src/pages/__tests__/AcceptInvitePage.test.tsx
@@ -0,0 +1,123 @@
+import { describe, it, expect, beforeEach, vi } from 'vitest'
+import { render, screen, waitFor } from '@testing-library/react'
+import { MemoryRouter } from 'react-router-dom'
+import { HelmetProvider } from 'react-helmet-async'
+
+import { AcceptInvitePage } from '../AcceptInvitePage'
+import { inviteApi } from '@/api/invite'
+import {
+ __resetAppConfigCache,
+ __setAppConfigCache,
+} from '@/hooks/useAppConfig'
+
+vi.mock('@/api/invite', () => ({
+ inviteApi: {
+ lookupAccountInvite: vi.fn(),
+ validateCode: vi.fn(),
+ },
+}))
+
+vi.mock('@/store/authStore', () => ({
+ useAuthStore: () => ({
+ register: vi.fn().mockResolvedValue(undefined),
+ isLoading: false,
+ error: null,
+ clearError: vi.fn(),
+ }),
+}))
+
+function renderPage(initialPath: string) {
+ return render(
+
+
+
+
+ ,
+ )
+}
+
+describe('AcceptInvitePage', () => {
+ beforeEach(() => {
+ __resetAppConfigCache()
+ __setAppConfigCache({
+ self_serve_enabled: true,
+ oauth_providers: ['google', 'microsoft'],
+ })
+ vi.clearAllMocks()
+ })
+
+ it('shows account name + locked email + accept buttons for a valid code', async () => {
+ vi.mocked(inviteApi.lookupAccountInvite).mockResolvedValue({
+ account_name: 'Acme MSP',
+ inviter_name: 'Alice Owner',
+ invited_email: 'bob@acme.example',
+ role: 'engineer',
+ })
+
+ renderPage('/accept-invite?code=VALIDINVITECODE0011223344556677')
+
+ // Inviter context (also confirms the lookup completed and rendered)
+ await waitFor(() => {
+ expect(
+ screen.getByText(/Alice Owner invited you as engineer/),
+ ).toBeInTheDocument()
+ })
+ // Account name surfaces in the heading line.
+ expect(
+ screen.getByText((_content, node) => {
+ return (
+ node?.tagName.toLowerCase() === 'span' &&
+ /Acme MSP/.test(node.textContent || '')
+ )
+ }),
+ ).toBeInTheDocument()
+
+ // Locked email — not an editable input
+ const emailDisplay = screen.getByTestId('invited-email')
+ expect(emailDisplay.tagName.toLowerCase()).not.toBe('input')
+ expect(emailDisplay).toHaveTextContent('bob@acme.example')
+ expect(screen.queryByLabelText(/email address/i)).not.toBeInTheDocument()
+
+ // OAuth buttons + password submit all rendered
+ expect(screen.getByTestId('oauth-google')).toBeInTheDocument()
+ expect(screen.getByTestId('oauth-microsoft')).toBeInTheDocument()
+ expect(screen.getByTestId('accept-submit')).toBeInTheDocument()
+ expect(screen.getByTestId('accept-submit')).toHaveTextContent(/Join Acme MSP/)
+
+ expect(inviteApi.lookupAccountInvite).toHaveBeenCalledWith(
+ 'VALIDINVITECODE0011223344556677',
+ )
+ })
+
+ it('shows resend message + mailto link for an invalid invite code', async () => {
+ vi.mocked(inviteApi.lookupAccountInvite).mockRejectedValue(
+ Object.assign(new Error('not found'), {
+ response: {
+ status: 404,
+ data: { detail: { error: 'invite_invalid_or_expired_or_revoked' } },
+ },
+ }),
+ )
+
+ renderPage('/accept-invite?code=BADCODE')
+
+ await waitFor(() => {
+ expect(
+ screen.getByText(/This invite is no longer valid/i),
+ ).toBeInTheDocument()
+ })
+ expect(
+ screen.getByText(/Ask the person who invited you to resend it/i),
+ ).toBeInTheDocument()
+
+ const resendLink = screen.getByRole('link', { name: /Email your inviter/i })
+ expect(resendLink).toHaveAttribute(
+ 'href',
+ expect.stringMatching(/^mailto:/),
+ )
+
+ // No accept form rendered when invite is invalid.
+ expect(screen.queryByTestId('accept-submit')).not.toBeInTheDocument()
+ expect(screen.queryByTestId('oauth-google')).not.toBeInTheDocument()
+ })
+})
diff --git a/frontend/src/router.tsx b/frontend/src/router.tsx
index f0b3e50b..401b4969 100644
--- a/frontend/src/router.tsx
+++ b/frontend/src/router.tsx
@@ -26,6 +26,7 @@ const TermsPage = lazyWithRetry(() => import('@/pages/TermsPage'))
// Standalone auth pages
const VerifyEmailPage = lazyWithRetry(() => import('@/pages/VerifyEmailPage'))
const OAuthCallbackPage = lazyWithRetry(() => import('@/pages/OAuthCallbackPage'))
+const AcceptInvitePage = lazyWithRetry(() => import('@/pages/AcceptInvitePage'))
const ChangePasswordPage = lazyWithRetry(() => import('@/pages/ChangePasswordPage'))
const ForgotPasswordPage = lazyWithRetry(() => import('@/pages/ForgotPasswordPage'))
const ResetPasswordPage = lazyWithRetry(() => import('@/pages/ResetPasswordPage'))
@@ -150,6 +151,11 @@ export const router = sentryCreateBrowserRouter([
element: page(VerifyEmailPage),
errorElement: ,
},
+ {
+ path: '/accept-invite',
+ element: page(AcceptInvitePage),
+ errorElement: ,
+ },
{
path: '/auth/google/callback',
element: page(OAuthCallbackPage),
diff --git a/frontend/src/types/user.ts b/frontend/src/types/user.ts
index 471957e9..8f65b34f 100644
--- a/frontend/src/types/user.ts
+++ b/frontend/src/types/user.ts
@@ -26,6 +26,8 @@ export interface UserCreate {
name: string
role?: UserRole
invite_code?: string
+ /** Account invite code to join an existing account (issued via /accounts/me/invites). */
+ account_invite_code?: string
}
export interface UserLogin {
--
2.49.1
From 7d939a4acfd64310824ccbd2c71582ed8d8f1b7c Mon Sep 17 00:00:00 2001
From: Michael Chihlas
Date: Wed, 6 May 2026 21:41:30 -0400
Subject: [PATCH 11/29] feat(auth): add email verification banner, wall,
/verify-email page
Wires up the soft 7-day email-verification grace period UX.
- EmailVerificationBanner now uses the design-system warning tokens
(bg-warning-dim / text-warning) and hides itself once the grace
period expires, so the wall takes over without double-messaging.
- EmailVerificationWall picks up data-testids on the resend and
sign-out CTAs.
- VerifyEmailPage gains a single-fire useRef guard (so React 19
strict-mode double-invoke doesn't burn the token), an
already-verified short-circuit that skips the API call, success
state with auth-store refresh + redirect to /?verified=1, and
an error state with a resend CTA.
Tests: banner hides past day-7, banner resend triggers API call,
verify success refreshes + redirects, verify short-circuits when
already verified, single-fire guard holds across remount.
Co-Authored-By: Claude Opus 4.7 (1M context)
---
.../common/EmailVerificationWall.tsx | 2 +
frontend/src/components/layout/AppLayout.tsx | 5 +-
.../layout/EmailVerificationBanner.tsx | 54 +++-
.../layout/__tests__/AppLayout.test.tsx | 123 +++++++++
.../EmailVerificationBanner.test.tsx | 119 ++++++++
frontend/src/pages/VerifyEmailPage.tsx | 256 ++++++++++++++----
.../pages/__tests__/VerifyEmailPage.test.tsx | 174 ++++++++++++
7 files changed, 673 insertions(+), 60 deletions(-)
create mode 100644 frontend/src/components/layout/__tests__/AppLayout.test.tsx
create mode 100644 frontend/src/components/layout/__tests__/EmailVerificationBanner.test.tsx
create mode 100644 frontend/src/pages/__tests__/VerifyEmailPage.test.tsx
diff --git a/frontend/src/components/common/EmailVerificationWall.tsx b/frontend/src/components/common/EmailVerificationWall.tsx
index abb95e62..8eb5cab2 100644
--- a/frontend/src/components/common/EmailVerificationWall.tsx
+++ b/frontend/src/components/common/EmailVerificationWall.tsx
@@ -67,6 +67,7 @@ export function EmailVerificationWall({ className }: EmailVerificationWallProps)
type="button"
onClick={handleResend}
disabled={isSending}
+ data-testid="resend-button"
className="inline-flex items-center justify-center gap-2 rounded-md bg-primary px-4 py-2 text-sm font-medium text-primary-foreground transition-colors hover:bg-primary/90 disabled:opacity-50"
>
{isSending && }
@@ -75,6 +76,7 @@ export function EmailVerificationWall({ className }: EmailVerificationWallProps)
diff --git a/frontend/src/components/layout/EmailVerificationBanner.tsx b/frontend/src/components/layout/EmailVerificationBanner.tsx
index e65bdb2f..91ecd0c9 100644
--- a/frontend/src/components/layout/EmailVerificationBanner.tsx
+++ b/frontend/src/components/layout/EmailVerificationBanner.tsx
@@ -5,7 +5,39 @@ import { useAuthStore } from '@/store/authStore'
import { cn } from '@/lib/utils'
import { toast } from '@/lib/toast'
-export function EmailVerificationBanner() {
+const MS_PER_DAY = 24 * 60 * 60 * 1000
+
+/**
+ * Whole days elapsed between an ISO timestamp and now (floored).
+ *
+ * Mirrors the helper in `EmailVerificationGate` — keep the two in sync so the
+ * banner hides on the same day the wall appears (Day 7+ unverified). Defensive
+ * on bad timestamps: treats unparseable input as "just signed up" so we never
+ * accidentally hide the banner on a real unverified user.
+ */
+function daysSince(iso: string, now: number = Date.now()): number {
+ const created = Date.parse(iso)
+ if (Number.isNaN(created)) return 0
+ return Math.floor((now - created) / MS_PER_DAY)
+}
+
+interface EmailVerificationBannerProps {
+ /**
+ * Override the grace period (in days). Day `gracePeriodDays + 1` and beyond
+ * suppress the banner — `EmailVerificationGate` shows the wall instead.
+ * Defaults to 6 (matches the gate).
+ */
+ gracePeriodDays?: number
+}
+
+/**
+ * Top-of-dashboard bar shown to users who signed up but haven't verified their
+ * email yet. Hides itself once the grace period expires (the wall takes over)
+ * and once the user dismisses it for the session.
+ */
+export function EmailVerificationBanner({
+ gracePeriodDays = 6,
+}: EmailVerificationBannerProps = {}) {
const user = useAuthStore((s) => s.user)
const [dismissed, setDismissed] = useState(false)
const [isSending, setIsSending] = useState(false)
@@ -19,6 +51,11 @@ export function EmailVerificationBanner() {
if (!user || user.email_verified_at || dismissed || !verificationEnabled) return null
+ // Past grace period: the wall takes over inside .
+ // Keep the banner out of the way so we don't double-show messaging.
+ const elapsed = daysSince(user.created_at)
+ if (elapsed > gracePeriodDays) return null
+
const handleResend = async () => {
setIsSending(true)
try {
@@ -32,22 +69,29 @@ export function EmailVerificationBanner() {
}
return (
-
-
-
+
+
+
Your email is not verified.
}
+ />
+
+
+ ,
+ )
+}
+
+describe('AppLayout — EmailVerificationGate wiring', () => {
+ beforeEach(() => {
+ vi.useFakeTimers()
+ vi.setSystemTime(FROZEN_NOW)
+ useAuthStore.setState({ user: null, token: null, isAuthenticated: false })
+ })
+
+ afterEach(() => {
+ vi.useRealTimers()
+ useAuthStore.setState({ user: null, token: null, isAuthenticated: false })
+ })
+
+ it('renders the wall and hides the child route on day 8 unverified', () => {
+ // created 8 days before frozen now -> elapsed=8, > grace=6 -> wall.
+ useAuthStore.setState({
+ user: makeUser({ created_at: '2026-04-28T00:00:00Z' }),
+ })
+
+ renderAppLayout()
+
+ expect(screen.getByTestId('email-verification-wall')).toBeInTheDocument()
+ expect(screen.queryByTestId('child-route-content')).not.toBeInTheDocument()
+ })
+
+ it('renders the child route within the grace period (day 1 unverified)', () => {
+ useAuthStore.setState({
+ user: makeUser({ created_at: '2026-05-05T00:00:00Z' }),
+ })
+
+ renderAppLayout()
+
+ expect(screen.getByTestId('child-route-content')).toBeInTheDocument()
+ expect(
+ screen.queryByTestId('email-verification-wall'),
+ ).not.toBeInTheDocument()
+ })
+})
diff --git a/frontend/src/components/layout/__tests__/EmailVerificationBanner.test.tsx b/frontend/src/components/layout/__tests__/EmailVerificationBanner.test.tsx
new file mode 100644
index 00000000..96222e60
--- /dev/null
+++ b/frontend/src/components/layout/__tests__/EmailVerificationBanner.test.tsx
@@ -0,0 +1,119 @@
+import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'
+import { render, screen, waitFor } from '@testing-library/react'
+import userEvent from '@testing-library/user-event'
+
+import { EmailVerificationBanner } from '../EmailVerificationBanner'
+import { useAuthStore } from '@/store/authStore'
+import { authApi } from '@/api/auth'
+import type { User } from '@/types'
+
+vi.mock('@/api/auth', () => ({
+ authApi: {
+ getVerificationStatus: vi.fn(),
+ sendVerificationEmail: vi.fn(),
+ },
+}))
+
+vi.mock('@/lib/toast', () => ({
+ toast: {
+ success: vi.fn(),
+ error: vi.fn(),
+ },
+}))
+
+function makeUser(overrides: Partial = {}): User {
+ return {
+ id: 'user-1',
+ email: 'test@example.com',
+ name: 'Test User',
+ role: 'engineer',
+ is_super_admin: false,
+ is_active: true,
+ must_change_password: false,
+ account_id: 'acct-1',
+ account_role: 'engineer',
+ team_id: null,
+ created_at: '2026-05-01T00:00:00Z',
+ last_login: null,
+ phone: null,
+ job_title: null,
+ timezone: 'UTC',
+ avatar_url: null,
+ email_verified_at: null,
+ ...overrides,
+ }
+}
+
+const FROZEN_NOW = new Date('2026-05-06T00:00:00Z')
+
+describe('EmailVerificationBanner', () => {
+ beforeEach(() => {
+ vi.useFakeTimers({ shouldAdvanceTime: true })
+ vi.setSystemTime(FROZEN_NOW)
+ useAuthStore.setState({ user: null, token: null, isAuthenticated: false })
+ vi.mocked(authApi.getVerificationStatus).mockResolvedValue({
+ enabled: true,
+ })
+ vi.mocked(authApi.sendVerificationEmail).mockResolvedValue(undefined)
+ })
+
+ afterEach(() => {
+ vi.useRealTimers()
+ vi.clearAllMocks()
+ })
+
+ it('hides past grace day-7+', async () => {
+ // Created 8 days before frozen now -> elapsed=8, > grace=6.
+ useAuthStore.setState({
+ user: makeUser({ created_at: '2026-04-28T00:00:00Z' }),
+ })
+
+ const { container } = render()
+
+ // Wait long enough for any pending verification-status fetch to resolve.
+ await waitFor(() => {
+ expect(authApi.getVerificationStatus).toHaveBeenCalled()
+ })
+
+ expect(
+ screen.queryByTestId('email-verification-banner'),
+ ).not.toBeInTheDocument()
+ expect(container.firstChild).toBeNull()
+ })
+
+ it('renders within the grace window', async () => {
+ // Created 1 day before frozen now -> elapsed=1, within grace.
+ useAuthStore.setState({
+ user: makeUser({ created_at: '2026-05-05T00:00:00Z' }),
+ })
+
+ render()
+
+ await waitFor(() => {
+ expect(
+ screen.getByTestId('email-verification-banner'),
+ ).toBeInTheDocument()
+ })
+ })
+
+ it('resend triggers API call', async () => {
+ useAuthStore.setState({
+ user: makeUser({ created_at: '2026-05-05T00:00:00Z' }),
+ })
+
+ render()
+
+ await waitFor(() => {
+ expect(
+ screen.getByTestId('email-verification-banner'),
+ ).toBeInTheDocument()
+ })
+
+ const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime })
+ await user.click(screen.getByTestId('banner-resend-button'))
+
+ await waitFor(() => {
+ expect(authApi.sendVerificationEmail).toHaveBeenCalledTimes(1)
+ })
+ })
+})
diff --git a/frontend/src/pages/VerifyEmailPage.tsx b/frontend/src/pages/VerifyEmailPage.tsx
index 14b21ce1..da83ea83 100644
--- a/frontend/src/pages/VerifyEmailPage.tsx
+++ b/frontend/src/pages/VerifyEmailPage.tsx
@@ -1,73 +1,221 @@
-import { useEffect, useState } from 'react'
-import { useSearchParams, Link } from 'react-router-dom'
-import { CheckCircle2, XCircle, Loader2 } from 'lucide-react'
+import { useEffect, useRef, useState } from 'react'
+import { useNavigate, useSearchParams, Link } from 'react-router-dom'
+import { CheckCircle2, XCircle, Loader2, MailCheck } from 'lucide-react'
import { authApi } from '@/api/auth'
+import { useAuthStore } from '@/store/authStore'
import { PageMeta } from '@/components/common/PageMeta'
+import { toast } from '@/lib/toast'
import { cn } from '@/lib/utils'
+type Status = 'loading' | 'success' | 'error' | 'already-verified' | 'no-token'
+
+const SUCCESS_REDIRECT_MS = 1200
+
+/**
+ * Standalone landing page for the email-verification link
+ * (`/verify-email?token=...`).
+ *
+ * Behavior:
+ * - If the user is already verified, short-circuit to a friendly
+ * "Already verified" state. No API call.
+ * - Else fire `POST /auth/email/verify` exactly once (a `useRef` guard keeps
+ * React 19 strict-mode double-invoke from double-firing the call). On
+ * success, refresh the auth store and bounce to `/?verified=1` so the
+ * dashboard surfaces a toast.
+ * - On error, show "Invalid or expired token" + a "Resend" CTA that calls
+ * `POST /auth/email/send-verification`.
+ */
export function VerifyEmailPage() {
const [searchParams] = useSearchParams()
+ const navigate = useNavigate()
const token = searchParams.get('token')
- const [status, setStatus] = useState<'loading' | 'success' | 'error'>(token ? 'loading' : 'error')
- const [errorMessage, setErrorMessage] = useState(token ? '' : 'No verification token provided')
+
+ const alreadyVerified = useAuthStore(
+ (s) => Boolean(s.user?.email_verified_at),
+ )
+
+ const initialStatus: Status = alreadyVerified
+ ? 'already-verified'
+ : token
+ ? 'loading'
+ : 'no-token'
+
+ const [status, setStatus] = useState(initialStatus)
+ const [errorMessage, setErrorMessage] = useState('')
+ const [isResending, setIsResending] = useState(false)
+
+ // Single-fire guard: React 19 strict mode runs effects twice on mount.
+ // Without this, the verify endpoint would burn the token on the first call
+ // and then 400 on the second, flashing an error past the success state.
+ const hasFiredRef = useRef(false)
useEffect(() => {
+ if (status !== 'loading') return
if (!token) return
+ if (hasFiredRef.current) return
+ hasFiredRef.current = true
- authApi.verifyEmail(token)
- .then(() => setStatus('success'))
- .catch((err) => {
- setStatus('error')
- const detail = (err as { response?: { data?: { detail?: string } } }).response?.data?.detail
- setErrorMessage(detail ?? 'Verification failed')
+ let cancelled = false
+
+ authApi
+ .verifyEmail(token)
+ .then(async () => {
+ // Refresh user so `email_verified_at` is populated everywhere.
+ try {
+ await useAuthStore.getState().fetchUser()
+ } catch {
+ // Non-fatal: server confirmed verification, the local user object
+ // will refresh on next page load.
+ }
+ if (cancelled) return
+ setStatus('success')
+ toast.success('Email verified')
+ // Brief success state, then redirect with a query flag so the
+ // dashboard can re-surface confirmation if it wants to.
+ window.setTimeout(() => {
+ navigate('/?verified=1', { replace: true })
+ }, SUCCESS_REDIRECT_MS)
})
- }, [token])
+ .catch((err) => {
+ if (cancelled) return
+ const detail = (err as { response?: { data?: { detail?: string } } })
+ .response?.data?.detail
+ setErrorMessage(detail ?? 'Invalid or expired verification link')
+ setStatus('error')
+ })
+
+ return () => {
+ cancelled = true
+ }
+ }, [status, token, navigate])
+
+ const handleResend = async () => {
+ setIsResending(true)
+ try {
+ await authApi.sendVerificationEmail()
+ toast.success('Verification email sent — check your inbox')
+ } catch {
+ toast.error('Failed to send verification email')
+ } finally {
+ setIsResending(false)
+ }
+ }
return (
<>
-
-
-
- {status === 'loading' && (
- <>
-
-
Verifying your email...
- >
- )}
- {status === 'success' && (
- <>
-
-
Email Verified
-
Your email has been successfully verified.
-
- Go to Dashboard
-
- >
- )}
- {status === 'error' && (
- <>
-
-