feat(billing): reconcile plan taxonomy and add Stripe sync script
The marketing surface (PricingPage, Stripe products) was wired for "Starter / Pro / Enterprise" while the backend was on "free / pro / team", leaving plan_billing unseeded and BillingPlan accepting a literal that violated the FK to plan_limits. This change: - Migration 4ce3e594cb87: defensive UPDATE of any subscriptions on plan='team' to 'enterprise' (dev has zero), renames the plan_limits row team -> enterprise, inserts a starter row with caps interpolated between free and pro (max_trees=10, sessions=75, ai=15/mo). - Renames the plan tier across schemas (invite_code, billing, admin, subscription comment), is_paid/has_pro_entitlement checks in the Subscription model, admin/admin_dashboard plan validators, and the frontend useSubscription isPaidPlan check. Resource visibility uses the same string 'team' in a separate domain (Tree/StepLibrary visibility) and is intentionally untouched. - New backend/scripts/sync_stripe_plan_ids.py: idempotent upsert of plan_billing rows from Stripe products by exact name match. Picks the active monthly recurring price for tiers that have one; leaves annual fields NULL by design. Works against test or live keys. - Test fixture updates: conftest seeds the new taxonomy, the public plans helper is a true upsert so tests can override max_users, and team -> enterprise across test_admin_plan_limits and test_invite_plan. Verified: 86/86 passing across the subscription/billing/plan/invite/ admin sweep; sync script run against test mode populates plan_billing correctly for all three tiers. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,84 @@
|
|||||||
|
"""add_starter_rename_team_to_enterprise
|
||||||
|
|
||||||
|
Revision ID: 4ce3e594cb87
|
||||||
|
Revises: c6cbfc534fad
|
||||||
|
Create Date: 2026-05-07 19:36:27.172082
|
||||||
|
|
||||||
|
Plan tier taxonomy reconciliation. Marketing surface and Stripe products
|
||||||
|
named "Starter / Pro / Enterprise"; backend was on "free / pro / team".
|
||||||
|
This migration:
|
||||||
|
|
||||||
|
1. Defensively migrates any existing subscriptions on plan='team' to
|
||||||
|
plan='enterprise' (dev has zero such rows; prod is expected to have
|
||||||
|
none, but the UPDATE is safe and idempotent).
|
||||||
|
2. Renames the plan_limits row 'team' -> 'enterprise'. plan_billing
|
||||||
|
and plan_feature_defaults are FK-referenced but currently empty;
|
||||||
|
the rename works because PostgreSQL allows updating PK values when
|
||||||
|
no FK rows reference them.
|
||||||
|
3. Inserts a new plan_limits row for 'starter' between free and pro.
|
||||||
|
|
||||||
|
Resource visibility (Tree.visibility, StepLibrary.visibility) also uses
|
||||||
|
the string 'team' for "shared with my account" — that is a separate
|
||||||
|
domain and is intentionally not touched.
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = '4ce3e594cb87'
|
||||||
|
down_revision: Union[str, None] = 'c6cbfc534fad'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute("UPDATE subscriptions SET plan = 'enterprise' WHERE plan = 'team'")
|
||||||
|
op.execute("UPDATE plan_limits SET plan = 'enterprise' WHERE plan = 'team'")
|
||||||
|
op.execute("""
|
||||||
|
INSERT INTO plan_limits (
|
||||||
|
plan,
|
||||||
|
max_trees,
|
||||||
|
max_sessions_per_month,
|
||||||
|
max_users,
|
||||||
|
custom_branding,
|
||||||
|
priority_support,
|
||||||
|
export_formats,
|
||||||
|
max_ai_builds_per_month,
|
||||||
|
max_ai_builds_per_24h,
|
||||||
|
kb_accelerator_enabled,
|
||||||
|
kb_max_lifetime_conversions,
|
||||||
|
kb_batch_max_size,
|
||||||
|
kb_allowed_formats,
|
||||||
|
kb_detailed_analysis,
|
||||||
|
kb_conversational_refinement,
|
||||||
|
kb_step_library_matching,
|
||||||
|
kb_history_limit
|
||||||
|
) VALUES (
|
||||||
|
'starter',
|
||||||
|
10,
|
||||||
|
75,
|
||||||
|
1,
|
||||||
|
FALSE,
|
||||||
|
FALSE,
|
||||||
|
'["markdown", "text", "html"]'::jsonb,
|
||||||
|
15,
|
||||||
|
5,
|
||||||
|
FALSE,
|
||||||
|
NULL,
|
||||||
|
NULL,
|
||||||
|
'["txt", "paste", "md"]'::jsonb,
|
||||||
|
FALSE,
|
||||||
|
FALSE,
|
||||||
|
FALSE,
|
||||||
|
NULL
|
||||||
|
)
|
||||||
|
ON CONFLICT (plan) DO NOTHING
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DELETE FROM plan_limits WHERE plan = 'starter'")
|
||||||
|
op.execute("UPDATE plan_limits SET plan = 'team' WHERE plan = 'enterprise'")
|
||||||
|
op.execute("UPDATE subscriptions SET plan = 'team' WHERE plan = 'enterprise'")
|
||||||
@@ -972,7 +972,7 @@ async def update_user_plan(
|
|||||||
current_user: Annotated[User, Depends(require_admin)],
|
current_user: Annotated[User, Depends(require_admin)],
|
||||||
):
|
):
|
||||||
"""Change a user's subscription plan (super admin only)."""
|
"""Change a user's subscription plan (super admin only)."""
|
||||||
if data.plan not in ("free", "pro", "team"):
|
if data.plan not in ("free", "pro", "starter", "enterprise"):
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
|
||||||
user, subscription = await _get_user_subscription(user_id, db)
|
user, subscription = await _get_user_subscription(user_id, db)
|
||||||
old_plan = subscription.plan
|
old_plan = subscription.plan
|
||||||
@@ -991,7 +991,7 @@ async def update_account_plan(
|
|||||||
current_user: Annotated[User, Depends(require_admin)],
|
current_user: Annotated[User, Depends(require_admin)],
|
||||||
):
|
):
|
||||||
"""Change an account subscription plan (super admin only)."""
|
"""Change an account subscription plan (super admin only)."""
|
||||||
if data.plan not in ("free", "pro", "team"):
|
if data.plan not in ("free", "pro", "starter", "enterprise"):
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
|
||||||
account, subscription = await _get_account_subscription(account_id, db)
|
account, subscription = await _get_account_subscription(account_id, db)
|
||||||
old_plan = subscription.plan
|
old_plan = subscription.plan
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ async def get_dashboard_metrics(
|
|||||||
) or 0
|
) or 0
|
||||||
paid_accounts = await db.scalar(
|
paid_accounts = await db.scalar(
|
||||||
select(func.count()).select_from(Subscription).where(
|
select(func.count()).select_from(Subscription).where(
|
||||||
Subscription.plan.in_(["pro", "team"])
|
Subscription.plan.in_(["pro", "starter", "enterprise"])
|
||||||
)
|
)
|
||||||
) or 0
|
) or 0
|
||||||
total_trees = await db.scalar(
|
total_trees = await db.scalar(
|
||||||
|
|||||||
@@ -37,12 +37,12 @@ class Subscription(Base):
|
|||||||
@property
|
@property
|
||||||
def is_paid(self) -> bool:
|
def is_paid(self) -> bool:
|
||||||
# Excludes complimentary and trialing so MRR/paid-customer metrics aren't inflated.
|
# Excludes complimentary and trialing so MRR/paid-customer metrics aren't inflated.
|
||||||
return self.plan in ("pro", "team") and self.status not in ("complimentary", "trialing")
|
return self.plan in ("pro", "starter", "enterprise") and self.status not in ("complimentary", "trialing")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_pro_entitlement(self) -> bool:
|
def has_pro_entitlement(self) -> bool:
|
||||||
"""True if the account can access Pro features right now."""
|
"""True if the account can access Pro features right now."""
|
||||||
if self.plan in ("pro", "team"):
|
if self.plan in ("pro", "starter", "enterprise"):
|
||||||
if self.status in ("active", "complimentary"):
|
if self.status in ("active", "complimentary"):
|
||||||
return True
|
return True
|
||||||
if self.status == "trialing" and self.current_period_end is not None:
|
if self.status == "trialing" and self.current_period_end is not None:
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ class AdminAccountDetailResponse(AdminAccountListItem):
|
|||||||
|
|
||||||
class AdminAccountCreate(BaseModel):
|
class AdminAccountCreate(BaseModel):
|
||||||
name: str = Field(..., min_length=1, max_length=255)
|
name: str = Field(..., min_length=1, max_length=255)
|
||||||
plan: Literal["free", "pro", "team"] = "free"
|
plan: Literal["free", "pro", "starter", "enterprise"] = "free"
|
||||||
owner_email: Optional[EmailStr] = Field(None, description="Email of an existing user to set as owner")
|
owner_email: Optional[EmailStr] = Field(None, description="Email of an existing user to set as owner")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
|
|
||||||
class CheckoutSessionCreate(BaseModel):
|
class CheckoutSessionCreate(BaseModel):
|
||||||
plan: Literal["pro", "starter", "team", "enterprise"]
|
plan: Literal["pro", "starter", "enterprise"]
|
||||||
seats: int
|
seats: int
|
||||||
billing_interval: Literal["monthly", "annual"] = "monthly"
|
billing_interval: Literal["monthly", "annual"] = "monthly"
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class InviteCodeCreate(BaseModel):
|
|||||||
expires_at: Optional[datetime] = Field(None, description="Optional expiration time")
|
expires_at: Optional[datetime] = Field(None, description="Optional expiration time")
|
||||||
note: Optional[str] = Field(None, max_length=255, description="Note about who this code is for")
|
note: Optional[str] = Field(None, max_length=255, description="Note about who this code is for")
|
||||||
email: Optional[EmailStr] = Field(None, description="Recipient email for invite delivery")
|
email: Optional[EmailStr] = Field(None, description="Recipient email for invite delivery")
|
||||||
assigned_plan: Literal["free", "pro", "team"] = Field("free", description="Plan to assign on registration")
|
assigned_plan: Literal["free", "pro", "starter", "enterprise"] = Field("free", description="Plan to assign on registration")
|
||||||
trial_duration_days: Optional[int] = Field(None, ge=1, le=90, description="Trial duration in days (1-90)")
|
trial_duration_days: Optional[int] = Field(None, ge=1, le=90, description="Trial duration in days (1-90)")
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class SubscriptionDetails(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class SubscriptionPlanUpdate(BaseModel):
|
class SubscriptionPlanUpdate(BaseModel):
|
||||||
plan: str # free, pro, team
|
plan: str # free, pro, starter, enterprise
|
||||||
|
|
||||||
model_config = {"json_schema_extra": {"examples": [{"plan": "pro"}]}}
|
model_config = {"json_schema_extra": {"examples": [{"plan": "pro"}]}}
|
||||||
|
|
||||||
|
|||||||
199
backend/scripts/sync_stripe_plan_ids.py
Normal file
199
backend/scripts/sync_stripe_plan_ids.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Sync plan_billing rows from Stripe products and prices.
|
||||||
|
|
||||||
|
Reads the active Stripe environment (test or live, determined by
|
||||||
|
STRIPE_SECRET_KEY in env), looks up the canonical ResolutionFlow products
|
||||||
|
by exact name match, picks the active monthly recurring price for tiers
|
||||||
|
that have one, and upserts plan_billing rows.
|
||||||
|
|
||||||
|
Idempotent. Safe to re-run after price changes, after live cutover, or
|
||||||
|
after rotating Stripe keys.
|
||||||
|
|
||||||
|
Tier mapping (name in Stripe -> plan slug in plan_limits):
|
||||||
|
ResolutionFlow Starter -> starter (monthly price required)
|
||||||
|
ResolutionFlow Pro -> pro (monthly price required)
|
||||||
|
ResolutionFlow Enterprise -> enterprise (no price, sales-led)
|
||||||
|
|
||||||
|
Annual prices are intentionally not supported in this iteration. The
|
||||||
|
plan_billing schema allows annual fields (stripe_annual_price_id,
|
||||||
|
annual_price_cents); this script leaves them NULL.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
docker exec -w /app resolutionflow_backend python -m scripts.sync_stripe_plan_ids
|
||||||
|
docker exec -w /app resolutionflow_backend python -m scripts.sync_stripe_plan_ids --dry-run
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import stripe
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.database import async_session_maker
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("sync_stripe_plan_ids")
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PLAN_NAME_TO_SLUG = {
|
||||||
|
"ResolutionFlow Starter": "starter",
|
||||||
|
"ResolutionFlow Pro": "pro",
|
||||||
|
"ResolutionFlow Enterprise": "enterprise",
|
||||||
|
}
|
||||||
|
|
||||||
|
PLANS_REQUIRING_PRICE = {"starter", "pro"}
|
||||||
|
|
||||||
|
PLAN_DEFAULTS = {
|
||||||
|
"starter": {"sort_order": 10, "is_public": True},
|
||||||
|
"pro": {"sort_order": 20, "is_public": True},
|
||||||
|
"enterprise": {"sort_order": 30, "is_public": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def find_product_by_name(target: str) -> Optional[stripe.Product]:
|
||||||
|
"""Page through active products and return the first exact name match."""
|
||||||
|
for product in stripe.Product.list(active=True, limit=100).auto_paging_iter():
|
||||||
|
if product.name == target:
|
||||||
|
return product
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def find_active_monthly_price(product_id: str) -> Optional[stripe.Price]:
|
||||||
|
"""Return the active recurring monthly price for a product, or None."""
|
||||||
|
candidates = [
|
||||||
|
p
|
||||||
|
for p in stripe.Price.list(product=product_id, active=True, limit=100).auto_paging_iter()
|
||||||
|
if p.type == "recurring"
|
||||||
|
and p.recurring is not None
|
||||||
|
and p.recurring.get("interval") == "month"
|
||||||
|
and p.recurring.get("interval_count", 1) == 1
|
||||||
|
]
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
if len(candidates) > 1:
|
||||||
|
logger.warning(
|
||||||
|
"Product %s has %d active monthly recurring prices; picking %s. "
|
||||||
|
"Archive the others to silence this warning.",
|
||||||
|
product_id, len(candidates), candidates[0].id,
|
||||||
|
)
|
||||||
|
return candidates[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def upsert_plan_billing(
|
||||||
|
plan: str,
|
||||||
|
display_name: str,
|
||||||
|
description: Optional[str],
|
||||||
|
monthly_price_cents: Optional[int],
|
||||||
|
stripe_product_id: Optional[str],
|
||||||
|
stripe_monthly_price_id: Optional[str],
|
||||||
|
sort_order: int,
|
||||||
|
is_public: bool,
|
||||||
|
dry_run: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Upsert one plan_billing row. Annual fields stay NULL."""
|
||||||
|
if dry_run:
|
||||||
|
logger.info(
|
||||||
|
"[dry-run] would upsert plan=%s display=%s monthly_cents=%s "
|
||||||
|
"product=%s monthly_price=%s",
|
||||||
|
plan, display_name, monthly_price_cents,
|
||||||
|
stripe_product_id, stripe_monthly_price_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
sql = text("""
|
||||||
|
INSERT INTO plan_billing (
|
||||||
|
plan, display_name, description,
|
||||||
|
monthly_price_cents, annual_price_cents,
|
||||||
|
stripe_product_id, stripe_monthly_price_id, stripe_annual_price_id,
|
||||||
|
is_public, is_archived, sort_order
|
||||||
|
) VALUES (
|
||||||
|
:plan, :display_name, :description,
|
||||||
|
:monthly_price_cents, NULL,
|
||||||
|
:stripe_product_id, :stripe_monthly_price_id, NULL,
|
||||||
|
:is_public, FALSE, :sort_order
|
||||||
|
)
|
||||||
|
ON CONFLICT (plan) DO UPDATE SET
|
||||||
|
display_name = EXCLUDED.display_name,
|
||||||
|
description = EXCLUDED.description,
|
||||||
|
monthly_price_cents = EXCLUDED.monthly_price_cents,
|
||||||
|
stripe_product_id = EXCLUDED.stripe_product_id,
|
||||||
|
stripe_monthly_price_id = EXCLUDED.stripe_monthly_price_id,
|
||||||
|
is_public = EXCLUDED.is_public,
|
||||||
|
sort_order = EXCLUDED.sort_order,
|
||||||
|
updated_at = NOW()
|
||||||
|
""")
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
await session.execute(sql, {
|
||||||
|
"plan": plan,
|
||||||
|
"display_name": display_name,
|
||||||
|
"description": description,
|
||||||
|
"monthly_price_cents": monthly_price_cents,
|
||||||
|
"stripe_product_id": stripe_product_id,
|
||||||
|
"stripe_monthly_price_id": stripe_monthly_price_id,
|
||||||
|
"is_public": is_public,
|
||||||
|
"sort_order": sort_order,
|
||||||
|
})
|
||||||
|
await session.commit()
|
||||||
|
logger.info("upserted plan_billing for plan=%s", plan)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(dry_run: bool) -> int:
|
||||||
|
if not settings.STRIPE_SECRET_KEY:
|
||||||
|
logger.error("STRIPE_SECRET_KEY is not set. Refusing to run.")
|
||||||
|
return 2
|
||||||
|
|
||||||
|
stripe.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
mode = "live" if settings.STRIPE_SECRET_KEY.startswith("sk_live_") else "test"
|
||||||
|
logger.info("connected to Stripe in %s mode", mode)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for product_name, plan in PLAN_NAME_TO_SLUG.items():
|
||||||
|
defaults = PLAN_DEFAULTS[plan]
|
||||||
|
product = find_product_by_name(product_name)
|
||||||
|
if product is None:
|
||||||
|
errors.append(f"Stripe product not found: {product_name!r}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
price = None
|
||||||
|
if plan in PLANS_REQUIRING_PRICE:
|
||||||
|
price = find_active_monthly_price(product.id)
|
||||||
|
if price is None:
|
||||||
|
errors.append(
|
||||||
|
f"No active monthly recurring price for {product_name!r} "
|
||||||
|
f"(product {product.id})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await upsert_plan_billing(
|
||||||
|
plan=plan,
|
||||||
|
display_name=product.name,
|
||||||
|
description=product.description,
|
||||||
|
monthly_price_cents=price.unit_amount if price else None,
|
||||||
|
stripe_product_id=product.id,
|
||||||
|
stripe_monthly_price_id=price.id if price else None,
|
||||||
|
sort_order=defaults["sort_order"],
|
||||||
|
is_public=defaults["is_public"],
|
||||||
|
dry_run=dry_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
for e in errors:
|
||||||
|
logger.error(e)
|
||||||
|
return 1
|
||||||
|
logger.info("done")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument("--dry-run", action="store_true", help="Log actions without writing.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
sys.exit(asyncio.run(main(dry_run=args.dry_run)))
|
||||||
@@ -172,8 +172,9 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats)
|
INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats)
|
||||||
VALUES
|
VALUES
|
||||||
('free', 3, 20, 1, false, false, '["markdown", "text"]'),
|
('free', 3, 20, 1, false, false, '["markdown", "text"]'),
|
||||||
|
('starter', 10, 75, 1, false, false, '["markdown", "text", "html"]'),
|
||||||
('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'),
|
('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'),
|
||||||
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
|
('enterprise', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
|
||||||
"""))
|
"""))
|
||||||
|
|
||||||
# Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by
|
# Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by
|
||||||
|
|||||||
@@ -122,9 +122,9 @@ class TestAdminPlanLimits:
|
|||||||
):
|
):
|
||||||
"""PUT /admin/plan-limits upserts a plan_billing row when billing
|
"""PUT /admin/plan-limits upserts a plan_billing row when billing
|
||||||
fields are included in the body."""
|
fields are included in the body."""
|
||||||
# Ensure no plan_billing row exists for "team" yet.
|
# Ensure no plan_billing row exists for "enterprise" yet.
|
||||||
existing = (await test_db.execute(
|
existing = (await test_db.execute(
|
||||||
select(PlanBilling).where(PlanBilling.plan == "team")
|
select(PlanBilling).where(PlanBilling.plan == "enterprise")
|
||||||
)).scalar_one_or_none()
|
)).scalar_one_or_none()
|
||||||
if existing is not None:
|
if existing is not None:
|
||||||
await test_db.delete(existing)
|
await test_db.delete(existing)
|
||||||
@@ -133,7 +133,7 @@ class TestAdminPlanLimits:
|
|||||||
response = await client.put(
|
response = await client.put(
|
||||||
"/api/v1/admin/plan-limits",
|
"/api/v1/admin/plan-limits",
|
||||||
json={
|
json={
|
||||||
"plan": "team",
|
"plan": "enterprise",
|
||||||
"max_trees": None,
|
"max_trees": None,
|
||||||
"max_sessions_per_month": None,
|
"max_sessions_per_month": None,
|
||||||
"max_users": None,
|
"max_users": None,
|
||||||
@@ -163,7 +163,7 @@ class TestAdminPlanLimits:
|
|||||||
# Confirm the row was actually persisted.
|
# Confirm the row was actually persisted.
|
||||||
await test_db.commit() # ensure session sees other-session writes
|
await test_db.commit() # ensure session sees other-session writes
|
||||||
pb = (await test_db.execute(
|
pb = (await test_db.execute(
|
||||||
select(PlanBilling).where(PlanBilling.plan == "team")
|
select(PlanBilling).where(PlanBilling.plan == "enterprise")
|
||||||
)).scalar_one_or_none()
|
)).scalar_one_or_none()
|
||||||
assert pb is not None
|
assert pb is not None
|
||||||
assert pb.display_name == "Team"
|
assert pb.display_name == "Team"
|
||||||
@@ -179,17 +179,17 @@ class TestAdminPlanLimits:
|
|||||||
plan_billing row when the caller passes explicit nulls. The set of
|
plan_billing row when the caller passes explicit nulls. The set of
|
||||||
guarded fields is {display_name, is_public, is_archived, sort_order}.
|
guarded fields is {display_name, is_public, is_archived, sort_order}.
|
||||||
"""
|
"""
|
||||||
# Seed a plan_billing row for "team" with non-default values for every
|
# Seed a plan_billing row for "enterprise" with non-default values for every
|
||||||
# NOT NULL field so we can detect any clobbering.
|
# NOT NULL field so we can detect any clobbering.
|
||||||
existing = (await test_db.execute(
|
existing = (await test_db.execute(
|
||||||
select(PlanBilling).where(PlanBilling.plan == "team")
|
select(PlanBilling).where(PlanBilling.plan == "enterprise")
|
||||||
)).scalar_one_or_none()
|
)).scalar_one_or_none()
|
||||||
if existing is not None:
|
if existing is not None:
|
||||||
await test_db.delete(existing)
|
await test_db.delete(existing)
|
||||||
await test_db.commit()
|
await test_db.commit()
|
||||||
|
|
||||||
seeded = PlanBilling(
|
seeded = PlanBilling(
|
||||||
plan="team",
|
plan="enterprise",
|
||||||
display_name="Team Seeded",
|
display_name="Team Seeded",
|
||||||
is_public=False,
|
is_public=False,
|
||||||
is_archived=True,
|
is_archived=True,
|
||||||
@@ -201,7 +201,7 @@ class TestAdminPlanLimits:
|
|||||||
response = await client.put(
|
response = await client.put(
|
||||||
"/api/v1/admin/plan-limits",
|
"/api/v1/admin/plan-limits",
|
||||||
json={
|
json={
|
||||||
"plan": "team",
|
"plan": "enterprise",
|
||||||
"max_trees": None,
|
"max_trees": None,
|
||||||
"max_sessions_per_month": None,
|
"max_sessions_per_month": None,
|
||||||
"max_users": None,
|
"max_users": None,
|
||||||
@@ -221,7 +221,7 @@ class TestAdminPlanLimits:
|
|||||||
# Confirm the seeded NOT NULL values were preserved.
|
# Confirm the seeded NOT NULL values were preserved.
|
||||||
await test_db.commit() # ensure session sees writes from the request
|
await test_db.commit() # ensure session sees writes from the request
|
||||||
pb = (await test_db.execute(
|
pb = (await test_db.execute(
|
||||||
select(PlanBilling).where(PlanBilling.plan == "team")
|
select(PlanBilling).where(PlanBilling.plan == "enterprise")
|
||||||
)).scalar_one_or_none()
|
)).scalar_one_or_none()
|
||||||
assert pb is not None
|
assert pb is not None
|
||||||
assert pb.display_name == "Team Seeded"
|
assert pb.display_name == "Team Seeded"
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class TestInviteCodeCreation:
|
|||||||
):
|
):
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/v1/invites",
|
"/api/v1/invites",
|
||||||
json={"assigned_plan": "team", "email": "beta@example.com"},
|
json={"assigned_plan": "enterprise", "email": "beta@example.com"},
|
||||||
headers=admin_auth_headers,
|
headers=admin_auth_headers,
|
||||||
)
|
)
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
@@ -149,7 +149,7 @@ class TestRegistrationWithInvitePlan:
|
|||||||
# Create team invite without trial
|
# Create team invite without trial
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
"/api/v1/invites",
|
"/api/v1/invites",
|
||||||
json={"assigned_plan": "team"},
|
json={"assigned_plan": "enterprise"},
|
||||||
headers=admin_auth_headers,
|
headers=admin_auth_headers,
|
||||||
)
|
)
|
||||||
code = resp.json()["code"]
|
code = resp.json()["code"]
|
||||||
@@ -172,7 +172,7 @@ class TestRegistrationWithInvitePlan:
|
|||||||
sub = (await test_db.execute(
|
sub = (await test_db.execute(
|
||||||
select(Subscription).where(Subscription.account_id == user.account_id)
|
select(Subscription).where(Subscription.account_id == user.account_id)
|
||||||
)).scalar_one()
|
)).scalar_one()
|
||||||
assert sub.plan == "team"
|
assert sub.plan == "enterprise"
|
||||||
assert sub.status == "active"
|
assert sub.status == "active"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,12 @@ from app.models.plan_limits import PlanLimits
|
|||||||
|
|
||||||
|
|
||||||
async def _seed_plan_limits(test_db, plan: str, max_users: int | None) -> None:
|
async def _seed_plan_limits(test_db, plan: str, max_users: int | None) -> None:
|
||||||
"""Ensure a plan_limits row exists for the given plan name."""
|
"""Ensure a plan_limits row exists with the given max_users.
|
||||||
|
|
||||||
|
Upserts: conftest seeds the canonical plans (free/starter/pro/enterprise)
|
||||||
|
so this helper has to overwrite max_users when a test wants different
|
||||||
|
values for fixture-driven assertions.
|
||||||
|
"""
|
||||||
existing = await test_db.get(PlanLimits, plan)
|
existing = await test_db.get(PlanLimits, plan)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
test_db.add(
|
test_db.add(
|
||||||
@@ -28,7 +33,9 @@ async def _seed_plan_limits(test_db, plan: str, max_users: int | None) -> None:
|
|||||||
export_formats=["markdown", "text"],
|
export_formats=["markdown", "text"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await test_db.commit()
|
else:
|
||||||
|
existing.max_users = max_users
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
|
||||||
class TestGetPlansPublic:
|
class TestGetPlansPublic:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ export function useSubscription() {
|
|||||||
const usage = subscription?.usage ?? null
|
const usage = subscription?.usage ?? null
|
||||||
const isActive = subscription?.subscription.status === 'active' || subscription?.subscription.status === 'trialing'
|
const isActive = subscription?.subscription.status === 'active' || subscription?.subscription.status === 'trialing'
|
||||||
|
|
||||||
const isPaidPlan = plan === 'pro' || plan === 'team'
|
const isPaidPlan = plan === 'pro' || plan === 'starter' || plan === 'enterprise'
|
||||||
|
|
||||||
const canUseFeature = (feature: 'custom_branding' | 'priority_support'): boolean => {
|
const canUseFeature = (feature: 'custom_branding' | 'priority_support'): boolean => {
|
||||||
if (!limits) return false
|
if (!limits) return false
|
||||||
|
|||||||
Reference in New Issue
Block a user