feat(billing): plan taxonomy reconciliation + Stripe sync + internal-tester allowlist (#164)
Co-authored-by: Michael Chihlas <michael@resolutionflow.com> Co-committed-by: Michael Chihlas <michael@resolutionflow.com>
This commit was merged in pull request #164.
This commit is contained in:
@@ -64,6 +64,40 @@ async def get_current_user(
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_optional(
|
||||
request: Request,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> Optional[User]:
|
||||
"""Best-effort current user for endpoints that work both anonymous and authed.
|
||||
|
||||
Returns None on missing/invalid/expired token instead of raising. Used by
|
||||
surfaces like /config/public that anonymous clients can hit but where an
|
||||
authenticated user gets a tailored response (e.g. INTERNAL_TESTER_EMAILS
|
||||
allowlist override).
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization") or request.headers.get("authorization")
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
token = auth_header.split(None, 1)[1].strip()
|
||||
if not token:
|
||||
return None
|
||||
|
||||
payload = decode_token(token)
|
||||
if payload is None or payload.get("type") != "access":
|
||||
return None
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
return None
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
result = await db.execute(select(User).where(User.id == user_uuid))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_refresh_token_payload(
|
||||
token: Annotated[str, Depends(oauth2_scheme)]
|
||||
) -> dict:
|
||||
|
||||
@@ -972,7 +972,7 @@ async def update_user_plan(
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Change a user's subscription plan (super admin only)."""
|
||||
if data.plan not in ("free", "pro", "team"):
|
||||
if data.plan not in ("free", "pro", "starter", "enterprise"):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
|
||||
user, subscription = await _get_user_subscription(user_id, db)
|
||||
old_plan = subscription.plan
|
||||
@@ -991,7 +991,7 @@ async def update_account_plan(
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Change an account subscription plan (super admin only)."""
|
||||
if data.plan not in ("free", "pro", "team"):
|
||||
if data.plan not in ("free", "pro", "starter", "enterprise"):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
|
||||
account, subscription = await _get_account_subscription(account_id, db)
|
||||
old_plan = subscription.plan
|
||||
|
||||
@@ -28,7 +28,7 @@ async def get_dashboard_metrics(
|
||||
) or 0
|
||||
paid_accounts = await db.scalar(
|
||||
select(func.count()).select_from(Subscription).where(
|
||||
Subscription.plan.in_(["pro", "team"])
|
||||
Subscription.plan.in_(["pro", "starter", "enterprise"])
|
||||
)
|
||||
) or 0
|
||||
total_trees = await db.scalar(
|
||||
|
||||
@@ -150,7 +150,7 @@ async def register(
|
||||
# and so paid/trial-bearing codes still apply when supplied.
|
||||
if (
|
||||
settings.REQUIRE_INVITE_CODE
|
||||
and not settings.SELF_SERVE_ENABLED
|
||||
and not settings.is_self_serve_active_for(user_data.email)
|
||||
and not user_data.invite_code
|
||||
):
|
||||
raise HTTPException(
|
||||
|
||||
@@ -11,22 +11,31 @@ frontend codegen and other call sites if needed.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.api.deps import get_current_user_optional
|
||||
from app.core.config import settings
|
||||
from app.models.user import User
|
||||
from app.schemas.config import PublicConfigResponse
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/public", response_model=PublicConfigResponse)
|
||||
async def get_public_config() -> PublicConfigResponse:
|
||||
async def get_public_config(
|
||||
current_user: Annotated[Optional[User], Depends(get_current_user_optional)],
|
||||
) -> PublicConfigResponse:
|
||||
"""Return public-safe runtime config.
|
||||
|
||||
`oauth_providers` reflects which OAuth client IDs are configured server
|
||||
side; the frontend uses it to render only buttons that will actually
|
||||
succeed. `self_serve_enabled` is the master switch for the new public
|
||||
self-serve signup flow.
|
||||
self-serve signup flow; an authenticated caller whose email is on the
|
||||
INTERNAL_TESTER_EMAILS allowlist sees `True` even when the global flag
|
||||
is off, so internal validation in prod test mode can exercise the full
|
||||
surface before the public flip.
|
||||
"""
|
||||
providers: list[str] = []
|
||||
if settings.GOOGLE_CLIENT_ID:
|
||||
@@ -34,7 +43,8 @@ async def get_public_config() -> PublicConfigResponse:
|
||||
if settings.MS_CLIENT_ID:
|
||||
providers.append("microsoft")
|
||||
|
||||
user_email = current_user.email if current_user else None
|
||||
return PublicConfigResponse(
|
||||
self_serve_enabled=settings.SELF_SERVE_ENABLED,
|
||||
self_serve_enabled=settings.is_self_serve_active_for(user_email),
|
||||
oauth_providers=providers,
|
||||
)
|
||||
|
||||
@@ -97,6 +97,40 @@ class Settings(BaseSettings):
|
||||
STRIPE_WEBHOOK_SECRET: Optional[str] = None
|
||||
SELF_SERVE_ENABLED: bool = False
|
||||
|
||||
# Internal tester allowlist for soft cutover. Comma-separated emails;
|
||||
# when SELF_SERVE_ENABLED is False, listed users still see the self-serve
|
||||
# surfaces (pricing page, invite-code-optional registration, etc.) so the
|
||||
# full flow can be exercised in prod test mode before public flip.
|
||||
INTERNAL_TESTER_EMAILS: list[str] = []
|
||||
|
||||
@field_validator("INTERNAL_TESTER_EMAILS", mode="before")
|
||||
@classmethod
|
||||
def split_internal_tester_emails(cls, v) -> list[str]:
|
||||
"""Parse a comma-separated string into a normalized lowercase list."""
|
||||
if v is None or v == "":
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
return [e.strip().lower() for e in v if e and e.strip()]
|
||||
if isinstance(v, str):
|
||||
return [e.strip().lower() for e in v.split(",") if e.strip()]
|
||||
return []
|
||||
|
||||
def is_internal_tester(self, email: Optional[str]) -> bool:
|
||||
"""Case-insensitive allowlist check. None/empty email is never a tester."""
|
||||
if not email:
|
||||
return False
|
||||
return email.lower() in self.INTERNAL_TESTER_EMAILS
|
||||
|
||||
def is_self_serve_active_for(self, email: Optional[str]) -> bool:
|
||||
"""True if self-serve surfaces should render for this user.
|
||||
|
||||
Either the global flag is on, or the user is on the internal-tester
|
||||
allowlist. Anonymous calls (email is None) only see the global flag.
|
||||
"""
|
||||
if self.SELF_SERVE_ENABLED:
|
||||
return True
|
||||
return self.is_internal_tester(email)
|
||||
|
||||
@property
|
||||
def stripe_enabled(self) -> bool:
|
||||
"""Check if Stripe is configured."""
|
||||
|
||||
@@ -37,12 +37,12 @@ class Subscription(Base):
|
||||
@property
|
||||
def is_paid(self) -> bool:
|
||||
# Excludes complimentary and trialing so MRR/paid-customer metrics aren't inflated.
|
||||
return self.plan in ("pro", "team") and self.status not in ("complimentary", "trialing")
|
||||
return self.plan in ("pro", "starter", "enterprise") and self.status not in ("complimentary", "trialing")
|
||||
|
||||
@property
|
||||
def has_pro_entitlement(self) -> bool:
|
||||
"""True if the account can access Pro features right now."""
|
||||
if self.plan in ("pro", "team"):
|
||||
if self.plan in ("pro", "starter", "enterprise"):
|
||||
if self.status in ("active", "complimentary"):
|
||||
return True
|
||||
if self.status == "trialing" and self.current_period_end is not None:
|
||||
|
||||
@@ -125,7 +125,7 @@ class AdminAccountDetailResponse(AdminAccountListItem):
|
||||
|
||||
class AdminAccountCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
plan: Literal["free", "pro", "team"] = "free"
|
||||
plan: Literal["free", "pro", "starter", "enterprise"] = "free"
|
||||
owner_email: Optional[EmailStr] = Field(None, description="Email of an existing user to set as owner")
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class CheckoutSessionCreate(BaseModel):
|
||||
plan: Literal["pro", "starter", "team", "enterprise"]
|
||||
plan: Literal["pro", "starter", "enterprise"]
|
||||
seats: int
|
||||
billing_interval: Literal["monthly", "annual"] = "monthly"
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ class InviteCodeCreate(BaseModel):
|
||||
expires_at: Optional[datetime] = Field(None, description="Optional expiration time")
|
||||
note: Optional[str] = Field(None, max_length=255, description="Note about who this code is for")
|
||||
email: Optional[EmailStr] = Field(None, description="Recipient email for invite delivery")
|
||||
assigned_plan: Literal["free", "pro", "team"] = Field("free", description="Plan to assign on registration")
|
||||
assigned_plan: Literal["free", "pro", "starter", "enterprise"] = Field("free", description="Plan to assign on registration")
|
||||
trial_duration_days: Optional[int] = Field(None, ge=1, le=90, description="Trial duration in days (1-90)")
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@@ -41,7 +41,7 @@ class SubscriptionDetails(BaseModel):
|
||||
|
||||
|
||||
class SubscriptionPlanUpdate(BaseModel):
|
||||
plan: str # free, pro, team
|
||||
plan: str # free, pro, starter, enterprise
|
||||
|
||||
model_config = {"json_schema_extra": {"examples": [{"plan": "pro"}]}}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user