From ba36c4707551856ecaf4b80248213a8a821d5e20 Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Thu, 7 May 2026 15:59:21 -0400 Subject: [PATCH] feat(billing): reconcile plan taxonomy and add Stripe sync script The marketing surface (PricingPage, Stripe products) was wired for "Starter / Pro / Enterprise" while the backend was on "free / pro / team", leaving plan_billing unseeded and BillingPlan accepting a literal that violated the FK to plan_limits. This change: - Migration 4ce3e594cb87: defensive UPDATE of any subscriptions on plan='team' to 'enterprise' (dev has zero), renames the plan_limits row team -> enterprise, inserts a starter row with caps interpolated between free and pro (max_trees=10, sessions=75, ai=15/mo). - Renames the plan tier across schemas (invite_code, billing, admin, subscription comment), is_paid/has_pro_entitlement checks in the Subscription model, admin/admin_dashboard plan validators, and the frontend useSubscription isPaidPlan check. Resource visibility uses the same string 'team' in a separate domain (Tree/StepLibrary visibility) and is intentionally untouched. - New backend/scripts/sync_stripe_plan_ids.py: idempotent upsert of plan_billing rows from Stripe products by exact name match. Picks the active monthly recurring price for tiers that have one; leaves annual fields NULL by design. Works against test or live keys. - Test fixture updates: conftest seeds the new taxonomy, the public plans helper is a true upsert so tests can override max_users, and team -> enterprise across test_admin_plan_limits and test_invite_plan. Verified: 86/86 passing across the subscription/billing/plan/invite/ admin sweep; sync script run against test mode populates plan_billing correctly for all three tiers. Co-Authored-By: Claude Opus 4.7 (1M context) --- ...7_add_starter_rename_team_to_enterprise.py | 84 ++++++++ backend/app/api/endpoints/admin.py | 4 +- backend/app/api/endpoints/admin_dashboard.py | 2 +- backend/app/models/subscription.py | 4 +- backend/app/schemas/admin.py | 2 +- backend/app/schemas/billing.py | 2 +- backend/app/schemas/invite_code.py | 2 +- backend/app/schemas/subscription.py | 2 +- backend/scripts/sync_stripe_plan_ids.py | 199 ++++++++++++++++++ backend/tests/conftest.py | 3 +- backend/tests/test_admin_plan_limits.py | 18 +- backend/tests/test_invite_plan.py | 6 +- backend/tests/test_plans_public.py | 11 +- frontend/src/hooks/useSubscription.ts | 2 +- 14 files changed, 316 insertions(+), 25 deletions(-) create mode 100644 backend/alembic/versions/4ce3e594cb87_add_starter_rename_team_to_enterprise.py create mode 100644 backend/scripts/sync_stripe_plan_ids.py diff --git a/backend/alembic/versions/4ce3e594cb87_add_starter_rename_team_to_enterprise.py b/backend/alembic/versions/4ce3e594cb87_add_starter_rename_team_to_enterprise.py new file mode 100644 index 00000000..e468bd0f --- /dev/null +++ b/backend/alembic/versions/4ce3e594cb87_add_starter_rename_team_to_enterprise.py @@ -0,0 +1,84 @@ +"""add_starter_rename_team_to_enterprise + +Revision ID: 4ce3e594cb87 +Revises: c6cbfc534fad +Create Date: 2026-05-07 19:36:27.172082 + +Plan tier taxonomy reconciliation. Marketing surface and Stripe products +named "Starter / Pro / Enterprise"; backend was on "free / pro / team". +This migration: + + 1. Defensively migrates any existing subscriptions on plan='team' to + plan='enterprise' (dev has zero such rows; prod is expected to have + none, but the UPDATE is safe and idempotent). + 2. Renames the plan_limits row 'team' -> 'enterprise'. plan_billing + and plan_feature_defaults are FK-referenced but currently empty; + the rename works because PostgreSQL allows updating PK values when + no FK rows reference them. + 3. Inserts a new plan_limits row for 'starter' between free and pro. + +Resource visibility (Tree.visibility, StepLibrary.visibility) also uses +the string 'team' for "shared with my account" — that is a separate +domain and is intentionally not touched. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = '4ce3e594cb87' +down_revision: Union[str, None] = 'c6cbfc534fad' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute("UPDATE subscriptions SET plan = 'enterprise' WHERE plan = 'team'") + op.execute("UPDATE plan_limits SET plan = 'enterprise' WHERE plan = 'team'") + op.execute(""" + INSERT INTO plan_limits ( + plan, + max_trees, + max_sessions_per_month, + max_users, + custom_branding, + priority_support, + export_formats, + max_ai_builds_per_month, + max_ai_builds_per_24h, + kb_accelerator_enabled, + kb_max_lifetime_conversions, + kb_batch_max_size, + kb_allowed_formats, + kb_detailed_analysis, + kb_conversational_refinement, + kb_step_library_matching, + kb_history_limit + ) VALUES ( + 'starter', + 10, + 75, + 1, + FALSE, + FALSE, + '["markdown", "text", "html"]'::jsonb, + 15, + 5, + FALSE, + NULL, + NULL, + '["txt", "paste", "md"]'::jsonb, + FALSE, + FALSE, + FALSE, + NULL + ) + ON CONFLICT (plan) DO NOTHING + """) + + +def downgrade() -> None: + op.execute("DELETE FROM plan_limits WHERE plan = 'starter'") + op.execute("UPDATE plan_limits SET plan = 'team' WHERE plan = 'enterprise'") + op.execute("UPDATE subscriptions SET plan = 'team' WHERE plan = 'enterprise'") diff --git a/backend/app/api/endpoints/admin.py b/backend/app/api/endpoints/admin.py index ae606fb5..64033edf 100644 --- a/backend/app/api/endpoints/admin.py +++ b/backend/app/api/endpoints/admin.py @@ -972,7 +972,7 @@ async def update_user_plan( current_user: Annotated[User, Depends(require_admin)], ): """Change a user's subscription plan (super admin only).""" - if data.plan not in ("free", "pro", "team"): + if data.plan not in ("free", "pro", "starter", "enterprise"): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan") user, subscription = await _get_user_subscription(user_id, db) old_plan = subscription.plan @@ -991,7 +991,7 @@ async def update_account_plan( current_user: Annotated[User, Depends(require_admin)], ): """Change an account subscription plan (super admin only).""" - if data.plan not in ("free", "pro", "team"): + if data.plan not in ("free", "pro", "starter", "enterprise"): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan") account, subscription = await _get_account_subscription(account_id, db) old_plan = subscription.plan diff --git a/backend/app/api/endpoints/admin_dashboard.py b/backend/app/api/endpoints/admin_dashboard.py index 90859b18..0712558c 100644 --- a/backend/app/api/endpoints/admin_dashboard.py +++ b/backend/app/api/endpoints/admin_dashboard.py @@ -28,7 +28,7 @@ async def get_dashboard_metrics( ) or 0 paid_accounts = await db.scalar( select(func.count()).select_from(Subscription).where( - Subscription.plan.in_(["pro", "team"]) + Subscription.plan.in_(["pro", "starter", "enterprise"]) ) ) or 0 total_trees = await db.scalar( diff --git a/backend/app/models/subscription.py b/backend/app/models/subscription.py index 11024582..b4fa284e 100644 --- a/backend/app/models/subscription.py +++ b/backend/app/models/subscription.py @@ -37,12 +37,12 @@ class Subscription(Base): @property def is_paid(self) -> bool: # Excludes complimentary and trialing so MRR/paid-customer metrics aren't inflated. - return self.plan in ("pro", "team") and self.status not in ("complimentary", "trialing") + return self.plan in ("pro", "starter", "enterprise") and self.status not in ("complimentary", "trialing") @property def has_pro_entitlement(self) -> bool: """True if the account can access Pro features right now.""" - if self.plan in ("pro", "team"): + if self.plan in ("pro", "starter", "enterprise"): if self.status in ("active", "complimentary"): return True if self.status == "trialing" and self.current_period_end is not None: diff --git a/backend/app/schemas/admin.py b/backend/app/schemas/admin.py index a223d994..5d28670f 100644 --- a/backend/app/schemas/admin.py +++ b/backend/app/schemas/admin.py @@ -125,7 +125,7 @@ class AdminAccountDetailResponse(AdminAccountListItem): class AdminAccountCreate(BaseModel): name: str = Field(..., min_length=1, max_length=255) - plan: Literal["free", "pro", "team"] = "free" + plan: Literal["free", "pro", "starter", "enterprise"] = "free" owner_email: Optional[EmailStr] = Field(None, description="Email of an existing user to set as owner") diff --git a/backend/app/schemas/billing.py b/backend/app/schemas/billing.py index aaae78e6..48709c65 100644 --- a/backend/app/schemas/billing.py +++ b/backend/app/schemas/billing.py @@ -4,7 +4,7 @@ from pydantic import BaseModel class CheckoutSessionCreate(BaseModel): - plan: Literal["pro", "starter", "team", "enterprise"] + plan: Literal["pro", "starter", "enterprise"] seats: int billing_interval: Literal["monthly", "annual"] = "monthly" diff --git a/backend/app/schemas/invite_code.py b/backend/app/schemas/invite_code.py index 851403c8..6a917b6e 100644 --- a/backend/app/schemas/invite_code.py +++ b/backend/app/schemas/invite_code.py @@ -9,7 +9,7 @@ class InviteCodeCreate(BaseModel): expires_at: Optional[datetime] = Field(None, description="Optional expiration time") note: Optional[str] = Field(None, max_length=255, description="Note about who this code is for") email: Optional[EmailStr] = Field(None, description="Recipient email for invite delivery") - assigned_plan: Literal["free", "pro", "team"] = Field("free", description="Plan to assign on registration") + assigned_plan: Literal["free", "pro", "starter", "enterprise"] = Field("free", description="Plan to assign on registration") trial_duration_days: Optional[int] = Field(None, ge=1, le=90, description="Trial duration in days (1-90)") @model_validator(mode="after") diff --git a/backend/app/schemas/subscription.py b/backend/app/schemas/subscription.py index 80889fa0..e8f85385 100644 --- a/backend/app/schemas/subscription.py +++ b/backend/app/schemas/subscription.py @@ -41,7 +41,7 @@ class SubscriptionDetails(BaseModel): class SubscriptionPlanUpdate(BaseModel): - plan: str # free, pro, team + plan: str # free, pro, starter, enterprise model_config = {"json_schema_extra": {"examples": [{"plan": "pro"}]}} diff --git a/backend/scripts/sync_stripe_plan_ids.py b/backend/scripts/sync_stripe_plan_ids.py new file mode 100644 index 00000000..257f6a38 --- /dev/null +++ b/backend/scripts/sync_stripe_plan_ids.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +"""Sync plan_billing rows from Stripe products and prices. + +Reads the active Stripe environment (test or live, determined by +STRIPE_SECRET_KEY in env), looks up the canonical ResolutionFlow products +by exact name match, picks the active monthly recurring price for tiers +that have one, and upserts plan_billing rows. + +Idempotent. Safe to re-run after price changes, after live cutover, or +after rotating Stripe keys. + +Tier mapping (name in Stripe -> plan slug in plan_limits): + ResolutionFlow Starter -> starter (monthly price required) + ResolutionFlow Pro -> pro (monthly price required) + ResolutionFlow Enterprise -> enterprise (no price, sales-led) + +Annual prices are intentionally not supported in this iteration. The +plan_billing schema allows annual fields (stripe_annual_price_id, +annual_price_cents); this script leaves them NULL. + +Usage: + docker exec -w /app resolutionflow_backend python -m scripts.sync_stripe_plan_ids + docker exec -w /app resolutionflow_backend python -m scripts.sync_stripe_plan_ids --dry-run +""" +import argparse +import asyncio +import logging +import sys +from typing import Optional + +import stripe + +from app.core.config import settings +from app.core.database import async_session_maker +from sqlalchemy import text + + +logger = logging.getLogger("sync_stripe_plan_ids") +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", +) + + +PLAN_NAME_TO_SLUG = { + "ResolutionFlow Starter": "starter", + "ResolutionFlow Pro": "pro", + "ResolutionFlow Enterprise": "enterprise", +} + +PLANS_REQUIRING_PRICE = {"starter", "pro"} + +PLAN_DEFAULTS = { + "starter": {"sort_order": 10, "is_public": True}, + "pro": {"sort_order": 20, "is_public": True}, + "enterprise": {"sort_order": 30, "is_public": True}, +} + + +def find_product_by_name(target: str) -> Optional[stripe.Product]: + """Page through active products and return the first exact name match.""" + for product in stripe.Product.list(active=True, limit=100).auto_paging_iter(): + if product.name == target: + return product + return None + + +def find_active_monthly_price(product_id: str) -> Optional[stripe.Price]: + """Return the active recurring monthly price for a product, or None.""" + candidates = [ + p + for p in stripe.Price.list(product=product_id, active=True, limit=100).auto_paging_iter() + if p.type == "recurring" + and p.recurring is not None + and p.recurring.get("interval") == "month" + and p.recurring.get("interval_count", 1) == 1 + ] + if not candidates: + return None + if len(candidates) > 1: + logger.warning( + "Product %s has %d active monthly recurring prices; picking %s. " + "Archive the others to silence this warning.", + product_id, len(candidates), candidates[0].id, + ) + return candidates[0] + + +async def upsert_plan_billing( + plan: str, + display_name: str, + description: Optional[str], + monthly_price_cents: Optional[int], + stripe_product_id: Optional[str], + stripe_monthly_price_id: Optional[str], + sort_order: int, + is_public: bool, + dry_run: bool, +) -> None: + """Upsert one plan_billing row. Annual fields stay NULL.""" + if dry_run: + logger.info( + "[dry-run] would upsert plan=%s display=%s monthly_cents=%s " + "product=%s monthly_price=%s", + plan, display_name, monthly_price_cents, + stripe_product_id, stripe_monthly_price_id, + ) + return + + sql = text(""" + INSERT INTO plan_billing ( + plan, display_name, description, + monthly_price_cents, annual_price_cents, + stripe_product_id, stripe_monthly_price_id, stripe_annual_price_id, + is_public, is_archived, sort_order + ) VALUES ( + :plan, :display_name, :description, + :monthly_price_cents, NULL, + :stripe_product_id, :stripe_monthly_price_id, NULL, + :is_public, FALSE, :sort_order + ) + ON CONFLICT (plan) DO UPDATE SET + display_name = EXCLUDED.display_name, + description = EXCLUDED.description, + monthly_price_cents = EXCLUDED.monthly_price_cents, + stripe_product_id = EXCLUDED.stripe_product_id, + stripe_monthly_price_id = EXCLUDED.stripe_monthly_price_id, + is_public = EXCLUDED.is_public, + sort_order = EXCLUDED.sort_order, + updated_at = NOW() + """) + async with async_session_maker() as session: + await session.execute(sql, { + "plan": plan, + "display_name": display_name, + "description": description, + "monthly_price_cents": monthly_price_cents, + "stripe_product_id": stripe_product_id, + "stripe_monthly_price_id": stripe_monthly_price_id, + "is_public": is_public, + "sort_order": sort_order, + }) + await session.commit() + logger.info("upserted plan_billing for plan=%s", plan) + + +async def main(dry_run: bool) -> int: + if not settings.STRIPE_SECRET_KEY: + logger.error("STRIPE_SECRET_KEY is not set. Refusing to run.") + return 2 + + stripe.api_key = settings.STRIPE_SECRET_KEY + mode = "live" if settings.STRIPE_SECRET_KEY.startswith("sk_live_") else "test" + logger.info("connected to Stripe in %s mode", mode) + + errors: list[str] = [] + + for product_name, plan in PLAN_NAME_TO_SLUG.items(): + defaults = PLAN_DEFAULTS[plan] + product = find_product_by_name(product_name) + if product is None: + errors.append(f"Stripe product not found: {product_name!r}") + continue + + price = None + if plan in PLANS_REQUIRING_PRICE: + price = find_active_monthly_price(product.id) + if price is None: + errors.append( + f"No active monthly recurring price for {product_name!r} " + f"(product {product.id})" + ) + continue + + await upsert_plan_billing( + plan=plan, + display_name=product.name, + description=product.description, + monthly_price_cents=price.unit_amount if price else None, + stripe_product_id=product.id, + stripe_monthly_price_id=price.id if price else None, + sort_order=defaults["sort_order"], + is_public=defaults["is_public"], + dry_run=dry_run, + ) + + if errors: + for e in errors: + logger.error(e) + return 1 + logger.info("done") + return 0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dry-run", action="store_true", help="Log actions without writing.") + args = parser.parse_args() + sys.exit(asyncio.run(main(dry_run=args.dry_run))) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 5f1d21c9..cefd8beb 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -172,8 +172,9 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]: INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats) VALUES ('free', 3, 20, 1, false, false, '["markdown", "text"]'), + ('starter', 10, 75, 1, false, false, '["markdown", "text", "html"]'), ('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'), - ('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]') + ('enterprise', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]') """)) # Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by diff --git a/backend/tests/test_admin_plan_limits.py b/backend/tests/test_admin_plan_limits.py index 8eb22d45..643ee75c 100644 --- a/backend/tests/test_admin_plan_limits.py +++ b/backend/tests/test_admin_plan_limits.py @@ -122,9 +122,9 @@ class TestAdminPlanLimits: ): """PUT /admin/plan-limits upserts a plan_billing row when billing fields are included in the body.""" - # Ensure no plan_billing row exists for "team" yet. + # Ensure no plan_billing row exists for "enterprise" yet. existing = (await test_db.execute( - select(PlanBilling).where(PlanBilling.plan == "team") + select(PlanBilling).where(PlanBilling.plan == "enterprise") )).scalar_one_or_none() if existing is not None: await test_db.delete(existing) @@ -133,7 +133,7 @@ class TestAdminPlanLimits: response = await client.put( "/api/v1/admin/plan-limits", json={ - "plan": "team", + "plan": "enterprise", "max_trees": None, "max_sessions_per_month": None, "max_users": None, @@ -163,7 +163,7 @@ class TestAdminPlanLimits: # Confirm the row was actually persisted. await test_db.commit() # ensure session sees other-session writes pb = (await test_db.execute( - select(PlanBilling).where(PlanBilling.plan == "team") + select(PlanBilling).where(PlanBilling.plan == "enterprise") )).scalar_one_or_none() assert pb is not None assert pb.display_name == "Team" @@ -179,17 +179,17 @@ class TestAdminPlanLimits: plan_billing row when the caller passes explicit nulls. The set of guarded fields is {display_name, is_public, is_archived, sort_order}. """ - # Seed a plan_billing row for "team" with non-default values for every + # Seed a plan_billing row for "enterprise" with non-default values for every # NOT NULL field so we can detect any clobbering. existing = (await test_db.execute( - select(PlanBilling).where(PlanBilling.plan == "team") + select(PlanBilling).where(PlanBilling.plan == "enterprise") )).scalar_one_or_none() if existing is not None: await test_db.delete(existing) await test_db.commit() seeded = PlanBilling( - plan="team", + plan="enterprise", display_name="Team Seeded", is_public=False, is_archived=True, @@ -201,7 +201,7 @@ class TestAdminPlanLimits: response = await client.put( "/api/v1/admin/plan-limits", json={ - "plan": "team", + "plan": "enterprise", "max_trees": None, "max_sessions_per_month": None, "max_users": None, @@ -221,7 +221,7 @@ class TestAdminPlanLimits: # Confirm the seeded NOT NULL values were preserved. await test_db.commit() # ensure session sees writes from the request pb = (await test_db.execute( - select(PlanBilling).where(PlanBilling.plan == "team") + select(PlanBilling).where(PlanBilling.plan == "enterprise") )).scalar_one_or_none() assert pb is not None assert pb.display_name == "Team Seeded" diff --git a/backend/tests/test_invite_plan.py b/backend/tests/test_invite_plan.py index d33c8b99..e6b31051 100644 --- a/backend/tests/test_invite_plan.py +++ b/backend/tests/test_invite_plan.py @@ -49,7 +49,7 @@ class TestInviteCodeCreation: ): response = await client.post( "/api/v1/invites", - json={"assigned_plan": "team", "email": "beta@example.com"}, + json={"assigned_plan": "enterprise", "email": "beta@example.com"}, headers=admin_auth_headers, ) assert response.status_code == 201 @@ -149,7 +149,7 @@ class TestRegistrationWithInvitePlan: # Create team invite without trial resp = await client.post( "/api/v1/invites", - json={"assigned_plan": "team"}, + json={"assigned_plan": "enterprise"}, headers=admin_auth_headers, ) code = resp.json()["code"] @@ -172,7 +172,7 @@ class TestRegistrationWithInvitePlan: sub = (await test_db.execute( select(Subscription).where(Subscription.account_id == user.account_id) )).scalar_one() - assert sub.plan == "team" + assert sub.plan == "enterprise" assert sub.status == "active" diff --git a/backend/tests/test_plans_public.py b/backend/tests/test_plans_public.py index a676009a..f4f17260 100644 --- a/backend/tests/test_plans_public.py +++ b/backend/tests/test_plans_public.py @@ -14,7 +14,12 @@ from app.models.plan_limits import PlanLimits async def _seed_plan_limits(test_db, plan: str, max_users: int | None) -> None: - """Ensure a plan_limits row exists for the given plan name.""" + """Ensure a plan_limits row exists with the given max_users. + + Upserts: conftest seeds the canonical plans (free/starter/pro/enterprise) + so this helper has to overwrite max_users when a test wants different + values for fixture-driven assertions. + """ existing = await test_db.get(PlanLimits, plan) if existing is None: test_db.add( @@ -28,7 +33,9 @@ async def _seed_plan_limits(test_db, plan: str, max_users: int | None) -> None: export_formats=["markdown", "text"], ) ) - await test_db.commit() + else: + existing.max_users = max_users + await test_db.commit() class TestGetPlansPublic: diff --git a/frontend/src/hooks/useSubscription.ts b/frontend/src/hooks/useSubscription.ts index 715a86bb..d3b31381 100644 --- a/frontend/src/hooks/useSubscription.ts +++ b/frontend/src/hooks/useSubscription.ts @@ -8,7 +8,7 @@ export function useSubscription() { const usage = subscription?.usage ?? null const isActive = subscription?.subscription.status === 'active' || subscription?.subscription.status === 'trialing' - const isPaidPlan = plan === 'pro' || plan === 'team' + const isPaidPlan = plan === 'pro' || plan === 'starter' || plan === 'enterprise' const canUseFeature = (feature: 'custom_branding' | 'priority_support'): boolean => { if (!limits) return false