feat(billing): plan taxonomy reconciliation + Stripe sync + internal-tester allowlist (#164)
All checks were successful
CI / frontend (push) Successful in 6m40s
Mirror to GitHub / mirror (push) Successful in 7s
CI / e2e (push) Successful in 10m7s
CI / backend (push) Successful in 10m34s

Co-authored-by: Michael Chihlas <michael@resolutionflow.com>
Co-committed-by: Michael Chihlas <michael@resolutionflow.com>
This commit was merged in pull request #164.
This commit is contained in:
2026-05-11 05:07:07 +00:00
committed by chihlasm
parent dad5e1f546
commit 3f04911070
38 changed files with 745 additions and 110 deletions

View File

@@ -29,4 +29,14 @@ CW_CLIENT_ID=<CONNECTWISE CLIENT ID>
# When unset, app/core/config.py:stripe_enabled returns False and Stripe code paths short-circuit.
STRIPE_SECRET_KEY=sk_test_
STRIPE_PUBLISHABLE_KEY=pk_test_
STRIPE_WEBHOOK_SECRET=whsec_
STRIPE_WEBHOOK_SECRET=whsec_
# Self-serve cutover
# SELF_SERVE_ENABLED is the master switch for the public self-serve signup
# flow (pricing page, invite-code-optional registration). Default is false
# until Phase O cutover.
# INTERNAL_TESTER_EMAILS is a comma-separated allowlist that bypasses the
# global flag for specific users — used for prod test-mode validation
# before the public flip. Empty by default.
SELF_SERVE_ENABLED=false
INTERNAL_TESTER_EMAILS=

View File

@@ -0,0 +1,84 @@
"""add_starter_rename_team_to_enterprise
Revision ID: 4ce3e594cb87
Revises: c6cbfc534fad
Create Date: 2026-05-07 19:36:27.172082
Plan tier taxonomy reconciliation. Marketing surface and Stripe products
named "Starter / Pro / Enterprise"; backend was on "free / pro / team".
This migration:
1. Defensively migrates any existing subscriptions on plan='team' to
plan='enterprise' (dev has zero such rows; prod is expected to have
none, but the UPDATE is safe and idempotent).
2. Renames the plan_limits row 'team' -> 'enterprise'. plan_billing
and plan_feature_defaults are FK-referenced but currently empty;
the rename works because PostgreSQL allows updating PK values when
no FK rows reference them.
3. Inserts a new plan_limits row for 'starter' between free and pro.
Resource visibility (Tree.visibility, StepLibrary.visibility) also uses
the string 'team' for "shared with my account" — that is a separate
domain and is intentionally not touched.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '4ce3e594cb87'
down_revision: Union[str, None] = 'c6cbfc534fad'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute("UPDATE subscriptions SET plan = 'enterprise' WHERE plan = 'team'")
op.execute("UPDATE plan_limits SET plan = 'enterprise' WHERE plan = 'team'")
op.execute("""
INSERT INTO plan_limits (
plan,
max_trees,
max_sessions_per_month,
max_users,
custom_branding,
priority_support,
export_formats,
max_ai_builds_per_month,
max_ai_builds_per_24h,
kb_accelerator_enabled,
kb_max_lifetime_conversions,
kb_batch_max_size,
kb_allowed_formats,
kb_detailed_analysis,
kb_conversational_refinement,
kb_step_library_matching,
kb_history_limit
) VALUES (
'starter',
10,
75,
1,
FALSE,
FALSE,
'["markdown", "text", "html"]'::jsonb,
15,
5,
FALSE,
NULL,
NULL,
'["txt", "paste", "md"]'::jsonb,
FALSE,
FALSE,
FALSE,
NULL
)
ON CONFLICT (plan) DO NOTHING
""")
def downgrade() -> None:
op.execute("DELETE FROM plan_limits WHERE plan = 'starter'")
op.execute("UPDATE plan_limits SET plan = 'team' WHERE plan = 'enterprise'")
op.execute("UPDATE subscriptions SET plan = 'team' WHERE plan = 'enterprise'")

View File

@@ -64,6 +64,40 @@ async def get_current_user(
return user
async def get_current_user_optional(
request: Request,
db: Annotated[AsyncSession, Depends(get_admin_db)],
) -> Optional[User]:
"""Best-effort current user for endpoints that work both anonymous and authed.
Returns None on missing/invalid/expired token instead of raising. Used by
surfaces like /config/public that anonymous clients can hit but where an
authenticated user gets a tailored response (e.g. INTERNAL_TESTER_EMAILS
allowlist override).
"""
auth_header = request.headers.get("Authorization") or request.headers.get("authorization")
if not auth_header or not auth_header.lower().startswith("bearer "):
return None
token = auth_header.split(None, 1)[1].strip()
if not token:
return None
payload = decode_token(token)
if payload is None or payload.get("type") != "access":
return None
user_id = payload.get("sub")
if user_id is None:
return None
try:
user_uuid = UUID(user_id)
except ValueError:
return None
result = await db.execute(select(User).where(User.id == user_uuid))
return result.scalar_one_or_none()
async def get_refresh_token_payload(
token: Annotated[str, Depends(oauth2_scheme)]
) -> dict:

View File

@@ -972,7 +972,7 @@ async def update_user_plan(
current_user: Annotated[User, Depends(require_admin)],
):
"""Change a user's subscription plan (super admin only)."""
if data.plan not in ("free", "pro", "team"):
if data.plan not in ("free", "pro", "starter", "enterprise"):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
user, subscription = await _get_user_subscription(user_id, db)
old_plan = subscription.plan
@@ -991,7 +991,7 @@ async def update_account_plan(
current_user: Annotated[User, Depends(require_admin)],
):
"""Change an account subscription plan (super admin only)."""
if data.plan not in ("free", "pro", "team"):
if data.plan not in ("free", "pro", "starter", "enterprise"):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
account, subscription = await _get_account_subscription(account_id, db)
old_plan = subscription.plan

View File

@@ -28,7 +28,7 @@ async def get_dashboard_metrics(
) or 0
paid_accounts = await db.scalar(
select(func.count()).select_from(Subscription).where(
Subscription.plan.in_(["pro", "team"])
Subscription.plan.in_(["pro", "starter", "enterprise"])
)
) or 0
total_trees = await db.scalar(

View File

@@ -150,7 +150,7 @@ async def register(
# and so paid/trial-bearing codes still apply when supplied.
if (
settings.REQUIRE_INVITE_CODE
and not settings.SELF_SERVE_ENABLED
and not settings.is_self_serve_active_for(user_data.email)
and not user_data.invite_code
):
raise HTTPException(

View File

@@ -11,22 +11,31 @@ frontend codegen and other call sites if needed.
from __future__ import annotations
from fastapi import APIRouter
from typing import Annotated, Optional
from fastapi import APIRouter, Depends
from app.api.deps import get_current_user_optional
from app.core.config import settings
from app.models.user import User
from app.schemas.config import PublicConfigResponse
router = APIRouter(prefix="/config", tags=["config"])
@router.get("/public", response_model=PublicConfigResponse)
async def get_public_config() -> PublicConfigResponse:
async def get_public_config(
current_user: Annotated[Optional[User], Depends(get_current_user_optional)],
) -> PublicConfigResponse:
"""Return public-safe runtime config.
`oauth_providers` reflects which OAuth client IDs are configured server
side; the frontend uses it to render only buttons that will actually
succeed. `self_serve_enabled` is the master switch for the new public
self-serve signup flow.
self-serve signup flow; an authenticated caller whose email is on the
INTERNAL_TESTER_EMAILS allowlist sees `True` even when the global flag
is off, so internal validation in prod test mode can exercise the full
surface before the public flip.
"""
providers: list[str] = []
if settings.GOOGLE_CLIENT_ID:
@@ -34,7 +43,8 @@ async def get_public_config() -> PublicConfigResponse:
if settings.MS_CLIENT_ID:
providers.append("microsoft")
user_email = current_user.email if current_user else None
return PublicConfigResponse(
self_serve_enabled=settings.SELF_SERVE_ENABLED,
self_serve_enabled=settings.is_self_serve_active_for(user_email),
oauth_providers=providers,
)

View File

@@ -97,6 +97,40 @@ class Settings(BaseSettings):
STRIPE_WEBHOOK_SECRET: Optional[str] = None
SELF_SERVE_ENABLED: bool = False
# Internal tester allowlist for soft cutover. Comma-separated emails;
# when SELF_SERVE_ENABLED is False, listed users still see the self-serve
# surfaces (pricing page, invite-code-optional registration, etc.) so the
# full flow can be exercised in prod test mode before public flip.
INTERNAL_TESTER_EMAILS: list[str] = []
@field_validator("INTERNAL_TESTER_EMAILS", mode="before")
@classmethod
def split_internal_tester_emails(cls, v) -> list[str]:
"""Parse a comma-separated string into a normalized lowercase list."""
if v is None or v == "":
return []
if isinstance(v, list):
return [e.strip().lower() for e in v if e and e.strip()]
if isinstance(v, str):
return [e.strip().lower() for e in v.split(",") if e.strip()]
return []
def is_internal_tester(self, email: Optional[str]) -> bool:
"""Case-insensitive allowlist check. None/empty email is never a tester."""
if not email:
return False
return email.lower() in self.INTERNAL_TESTER_EMAILS
def is_self_serve_active_for(self, email: Optional[str]) -> bool:
"""True if self-serve surfaces should render for this user.
Either the global flag is on, or the user is on the internal-tester
allowlist. Anonymous calls (email is None) only see the global flag.
"""
if self.SELF_SERVE_ENABLED:
return True
return self.is_internal_tester(email)
@property
def stripe_enabled(self) -> bool:
"""Check if Stripe is configured."""

View File

@@ -37,12 +37,12 @@ class Subscription(Base):
@property
def is_paid(self) -> bool:
# Excludes complimentary and trialing so MRR/paid-customer metrics aren't inflated.
return self.plan in ("pro", "team") and self.status not in ("complimentary", "trialing")
return self.plan in ("pro", "starter", "enterprise") and self.status not in ("complimentary", "trialing")
@property
def has_pro_entitlement(self) -> bool:
"""True if the account can access Pro features right now."""
if self.plan in ("pro", "team"):
if self.plan in ("pro", "starter", "enterprise"):
if self.status in ("active", "complimentary"):
return True
if self.status == "trialing" and self.current_period_end is not None:

View File

@@ -125,7 +125,7 @@ class AdminAccountDetailResponse(AdminAccountListItem):
class AdminAccountCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
plan: Literal["free", "pro", "team"] = "free"
plan: Literal["free", "pro", "starter", "enterprise"] = "free"
owner_email: Optional[EmailStr] = Field(None, description="Email of an existing user to set as owner")

View File

@@ -4,7 +4,7 @@ from pydantic import BaseModel
class CheckoutSessionCreate(BaseModel):
plan: Literal["pro", "starter", "team", "enterprise"]
plan: Literal["pro", "starter", "enterprise"]
seats: int
billing_interval: Literal["monthly", "annual"] = "monthly"

View File

@@ -9,7 +9,7 @@ class InviteCodeCreate(BaseModel):
expires_at: Optional[datetime] = Field(None, description="Optional expiration time")
note: Optional[str] = Field(None, max_length=255, description="Note about who this code is for")
email: Optional[EmailStr] = Field(None, description="Recipient email for invite delivery")
assigned_plan: Literal["free", "pro", "team"] = Field("free", description="Plan to assign on registration")
assigned_plan: Literal["free", "pro", "starter", "enterprise"] = Field("free", description="Plan to assign on registration")
trial_duration_days: Optional[int] = Field(None, ge=1, le=90, description="Trial duration in days (1-90)")
@model_validator(mode="after")

View File

@@ -41,7 +41,7 @@ class SubscriptionDetails(BaseModel):
class SubscriptionPlanUpdate(BaseModel):
plan: str # free, pro, team
plan: str # free, pro, starter, enterprise
model_config = {"json_schema_extra": {"examples": [{"plan": "pro"}]}}

View File

@@ -0,0 +1,199 @@
#!/usr/bin/env python3
"""Sync plan_billing rows from Stripe products and prices.
Reads the active Stripe environment (test or live, determined by
STRIPE_SECRET_KEY in env), looks up the canonical ResolutionFlow products
by exact name match, picks the active monthly recurring price for tiers
that have one, and upserts plan_billing rows.
Idempotent. Safe to re-run after price changes, after live cutover, or
after rotating Stripe keys.
Tier mapping (name in Stripe -> plan slug in plan_limits):
ResolutionFlow Starter -> starter (monthly price required)
ResolutionFlow Pro -> pro (monthly price required)
ResolutionFlow Enterprise -> enterprise (no price, sales-led)
Annual prices are intentionally not supported in this iteration. The
plan_billing schema allows annual fields (stripe_annual_price_id,
annual_price_cents); this script leaves them NULL.
Usage:
docker exec -w /app resolutionflow_backend python -m scripts.sync_stripe_plan_ids
docker exec -w /app resolutionflow_backend python -m scripts.sync_stripe_plan_ids --dry-run
"""
import argparse
import asyncio
import logging
import sys
from typing import Optional
import stripe
from app.core.config import settings
from app.core.database import async_session_maker
from sqlalchemy import text
logger = logging.getLogger("sync_stripe_plan_ids")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
PLAN_NAME_TO_SLUG = {
"ResolutionFlow Starter": "starter",
"ResolutionFlow Pro": "pro",
"ResolutionFlow Enterprise": "enterprise",
}
PLANS_REQUIRING_PRICE = {"starter", "pro"}
PLAN_DEFAULTS = {
"starter": {"sort_order": 10, "is_public": True},
"pro": {"sort_order": 20, "is_public": True},
"enterprise": {"sort_order": 30, "is_public": True},
}
def find_product_by_name(target: str) -> Optional[stripe.Product]:
"""Page through active products and return the first exact name match."""
for product in stripe.Product.list(active=True, limit=100).auto_paging_iter():
if product.name == target:
return product
return None
def find_active_monthly_price(product_id: str) -> Optional[stripe.Price]:
"""Return the active recurring monthly price for a product, or None."""
candidates = [
p
for p in stripe.Price.list(product=product_id, active=True, limit=100).auto_paging_iter()
if p.type == "recurring"
and p.recurring is not None
and p.recurring.get("interval") == "month"
and p.recurring.get("interval_count", 1) == 1
]
if not candidates:
return None
if len(candidates) > 1:
logger.warning(
"Product %s has %d active monthly recurring prices; picking %s. "
"Archive the others to silence this warning.",
product_id, len(candidates), candidates[0].id,
)
return candidates[0]
async def upsert_plan_billing(
plan: str,
display_name: str,
description: Optional[str],
monthly_price_cents: Optional[int],
stripe_product_id: Optional[str],
stripe_monthly_price_id: Optional[str],
sort_order: int,
is_public: bool,
dry_run: bool,
) -> None:
"""Upsert one plan_billing row. Annual fields stay NULL."""
if dry_run:
logger.info(
"[dry-run] would upsert plan=%s display=%s monthly_cents=%s "
"product=%s monthly_price=%s",
plan, display_name, monthly_price_cents,
stripe_product_id, stripe_monthly_price_id,
)
return
sql = text("""
INSERT INTO plan_billing (
plan, display_name, description,
monthly_price_cents, annual_price_cents,
stripe_product_id, stripe_monthly_price_id, stripe_annual_price_id,
is_public, is_archived, sort_order
) VALUES (
:plan, :display_name, :description,
:monthly_price_cents, NULL,
:stripe_product_id, :stripe_monthly_price_id, NULL,
:is_public, FALSE, :sort_order
)
ON CONFLICT (plan) DO UPDATE SET
display_name = EXCLUDED.display_name,
description = EXCLUDED.description,
monthly_price_cents = EXCLUDED.monthly_price_cents,
stripe_product_id = EXCLUDED.stripe_product_id,
stripe_monthly_price_id = EXCLUDED.stripe_monthly_price_id,
is_public = EXCLUDED.is_public,
sort_order = EXCLUDED.sort_order,
updated_at = NOW()
""")
async with async_session_maker() as session:
await session.execute(sql, {
"plan": plan,
"display_name": display_name,
"description": description,
"monthly_price_cents": monthly_price_cents,
"stripe_product_id": stripe_product_id,
"stripe_monthly_price_id": stripe_monthly_price_id,
"is_public": is_public,
"sort_order": sort_order,
})
await session.commit()
logger.info("upserted plan_billing for plan=%s", plan)
async def main(dry_run: bool) -> int:
if not settings.STRIPE_SECRET_KEY:
logger.error("STRIPE_SECRET_KEY is not set. Refusing to run.")
return 2
stripe.api_key = settings.STRIPE_SECRET_KEY
mode = "live" if settings.STRIPE_SECRET_KEY.startswith("sk_live_") else "test"
logger.info("connected to Stripe in %s mode", mode)
errors: list[str] = []
for product_name, plan in PLAN_NAME_TO_SLUG.items():
defaults = PLAN_DEFAULTS[plan]
product = find_product_by_name(product_name)
if product is None:
errors.append(f"Stripe product not found: {product_name!r}")
continue
price = None
if plan in PLANS_REQUIRING_PRICE:
price = find_active_monthly_price(product.id)
if price is None:
errors.append(
f"No active monthly recurring price for {product_name!r} "
f"(product {product.id})"
)
continue
await upsert_plan_billing(
plan=plan,
display_name=product.name,
description=product.description,
monthly_price_cents=price.unit_amount if price else None,
stripe_product_id=product.id,
stripe_monthly_price_id=price.id if price else None,
sort_order=defaults["sort_order"],
is_public=defaults["is_public"],
dry_run=dry_run,
)
if errors:
for e in errors:
logger.error(e)
return 1
logger.info("done")
return 0
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--dry-run", action="store_true", help="Log actions without writing.")
args = parser.parse_args()
sys.exit(asyncio.run(main(dry_run=args.dry_run)))

View File

@@ -172,8 +172,9 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]:
INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats)
VALUES
('free', 3, 20, 1, false, false, '["markdown", "text"]'),
('starter', 10, 75, 1, false, false, '["markdown", "text", "html"]'),
('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'),
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
('enterprise', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
"""))
# Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by

View File

@@ -122,9 +122,9 @@ class TestAdminPlanLimits:
):
"""PUT /admin/plan-limits upserts a plan_billing row when billing
fields are included in the body."""
# Ensure no plan_billing row exists for "team" yet.
# Ensure no plan_billing row exists for "enterprise" yet.
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
select(PlanBilling).where(PlanBilling.plan == "enterprise")
)).scalar_one_or_none()
if existing is not None:
await test_db.delete(existing)
@@ -133,7 +133,7 @@ class TestAdminPlanLimits:
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "team",
"plan": "enterprise",
"max_trees": None,
"max_sessions_per_month": None,
"max_users": None,
@@ -163,7 +163,7 @@ class TestAdminPlanLimits:
# Confirm the row was actually persisted.
await test_db.commit() # ensure session sees other-session writes
pb = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
select(PlanBilling).where(PlanBilling.plan == "enterprise")
)).scalar_one_or_none()
assert pb is not None
assert pb.display_name == "Team"
@@ -179,17 +179,17 @@ class TestAdminPlanLimits:
plan_billing row when the caller passes explicit nulls. The set of
guarded fields is {display_name, is_public, is_archived, sort_order}.
"""
# Seed a plan_billing row for "team" with non-default values for every
# Seed a plan_billing row for "enterprise" with non-default values for every
# NOT NULL field so we can detect any clobbering.
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
select(PlanBilling).where(PlanBilling.plan == "enterprise")
)).scalar_one_or_none()
if existing is not None:
await test_db.delete(existing)
await test_db.commit()
seeded = PlanBilling(
plan="team",
plan="enterprise",
display_name="Team Seeded",
is_public=False,
is_archived=True,
@@ -201,7 +201,7 @@ class TestAdminPlanLimits:
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "team",
"plan": "enterprise",
"max_trees": None,
"max_sessions_per_month": None,
"max_users": None,
@@ -221,7 +221,7 @@ class TestAdminPlanLimits:
# Confirm the seeded NOT NULL values were preserved.
await test_db.commit() # ensure session sees writes from the request
pb = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
select(PlanBilling).where(PlanBilling.plan == "enterprise")
)).scalar_one_or_none()
assert pb is not None
assert pb.display_name == "Team Seeded"

View File

@@ -49,6 +49,58 @@ class TestConfigPublic:
assert response.status_code == 200
assert response.json()["oauth_providers"] == ["microsoft"]
@pytest.mark.asyncio
async def test_get_config_public_returns_true_for_internal_tester(
self,
client: AsyncClient,
auth_headers: dict,
test_user: dict,
monkeypatch: pytest.MonkeyPatch,
):
"""Authenticated user whose email is on INTERNAL_TESTER_EMAILS sees
self_serve_enabled=True even when the global flag is off."""
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
monkeypatch.setattr(settings, "MS_CLIENT_ID", None)
monkeypatch.setattr(settings, "INTERNAL_TESTER_EMAILS", [test_user["email"].lower()])
response = await client.get("/api/v1/config/public", headers=auth_headers)
assert response.status_code == 200
assert response.json()["self_serve_enabled"] is True
@pytest.mark.asyncio
async def test_get_config_public_returns_false_for_non_tester_when_global_off(
self,
client: AsyncClient,
auth_headers: dict,
monkeypatch: pytest.MonkeyPatch,
):
"""Authenticated user NOT on the allowlist sees the global flag —
prevents accidental opt-in via stale credentials or empty allowlist."""
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
monkeypatch.setattr(settings, "MS_CLIENT_ID", None)
monkeypatch.setattr(settings, "INTERNAL_TESTER_EMAILS", ["someone-else@example.com"])
response = await client.get("/api/v1/config/public", headers=auth_headers)
assert response.status_code == 200
assert response.json()["self_serve_enabled"] is False
@pytest.mark.asyncio
async def test_get_config_public_anonymous_ignores_allowlist(
self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
):
"""Anonymous callers always see the global flag — the allowlist is
keyed on authenticated identity, not request content."""
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
monkeypatch.setattr(settings, "MS_CLIENT_ID", None)
monkeypatch.setattr(settings, "INTERNAL_TESTER_EMAILS", ["anon-tester@example.com"])
response = await client.get("/api/v1/config/public")
assert response.status_code == 200
assert response.json()["self_serve_enabled"] is False
class TestRegisterInviteCodeGate:
"""Regression + new-behavior tests for /auth/register vs SELF_SERVE_ENABLED."""
@@ -98,3 +150,55 @@ class TestRegisterInviteCodeGate:
assert body["email"] == "self-serve@example.com"
assert body["account_role"] == "owner"
assert "account_id" in body
@pytest.mark.asyncio
async def test_register_invite_code_optional_for_internal_tester(
self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
):
"""SELF_SERVE_ENABLED is False but the registering email is on
INTERNAL_TESTER_EMAILS — registration should succeed without an
invite code, matching the per-email soft-cutover behavior."""
monkeypatch.setattr(settings, "REQUIRE_INVITE_CODE", True)
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
monkeypatch.setattr(
settings, "INTERNAL_TESTER_EMAILS", ["tester@example.com"]
)
response = await client.post(
"/api/v1/auth/register",
json={
"email": "tester@example.com",
"password": "SecurePass123!",
"name": "Internal Tester",
},
)
assert response.status_code == 201, response.text
body = response.json()
assert body["email"] == "tester@example.com"
assert body["account_role"] == "owner"
@pytest.mark.asyncio
async def test_register_blocked_for_non_tester_when_self_serve_disabled(
self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
):
"""Registering with an email NOT on the allowlist still 400s when
self-serve is off and no invite code is provided. Prevents the
allowlist from leaking to public users."""
monkeypatch.setattr(settings, "REQUIRE_INVITE_CODE", True)
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
monkeypatch.setattr(
settings, "INTERNAL_TESTER_EMAILS", ["other@example.com"]
)
response = await client.post(
"/api/v1/auth/register",
json={
"email": "outsider@example.com",
"password": "SecurePass123!",
"name": "Outsider",
},
)
assert response.status_code == 400
assert "invite code is required" in response.json()["detail"].lower()

View File

@@ -49,7 +49,7 @@ class TestInviteCodeCreation:
):
response = await client.post(
"/api/v1/invites",
json={"assigned_plan": "team", "email": "beta@example.com"},
json={"assigned_plan": "enterprise", "email": "beta@example.com"},
headers=admin_auth_headers,
)
assert response.status_code == 201
@@ -149,7 +149,7 @@ class TestRegistrationWithInvitePlan:
# Create team invite without trial
resp = await client.post(
"/api/v1/invites",
json={"assigned_plan": "team"},
json={"assigned_plan": "enterprise"},
headers=admin_auth_headers,
)
code = resp.json()["code"]
@@ -172,7 +172,7 @@ class TestRegistrationWithInvitePlan:
sub = (await test_db.execute(
select(Subscription).where(Subscription.account_id == user.account_id)
)).scalar_one()
assert sub.plan == "team"
assert sub.plan == "enterprise"
assert sub.status == "active"

View File

@@ -14,7 +14,12 @@ from app.models.plan_limits import PlanLimits
async def _seed_plan_limits(test_db, plan: str, max_users: int | None) -> None:
"""Ensure a plan_limits row exists for the given plan name."""
"""Ensure a plan_limits row exists with the given max_users.
Upserts: conftest seeds the canonical plans (free/starter/pro/enterprise)
so this helper has to overwrite max_users when a test wants different
values for fixture-driven assertions.
"""
existing = await test_db.get(PlanLimits, plan)
if existing is None:
test_db.add(
@@ -28,7 +33,9 @@ async def _seed_plan_limits(test_db, plan: str, max_users: int | None) -> None:
export_formats=["markdown", "text"],
)
)
await test_db.commit()
else:
existing.max_users = max_users
await test_db.commit()
class TestGetPlansPublic: