The marketing surface (PricingPage, Stripe products) was wired for "Starter / Pro / Enterprise" while the backend was on "free / pro / team", leaving plan_billing unseeded and BillingPlan accepting a literal that violated the FK to plan_limits. This change: - Migration 4ce3e594cb87: defensive UPDATE of any subscriptions on plan='team' to 'enterprise' (dev has zero), renames the plan_limits row team -> enterprise, inserts a starter row with caps interpolated between free and pro (max_trees=10, sessions=75, ai=15/mo). - Renames the plan tier across schemas (invite_code, billing, admin, subscription comment), is_paid/has_pro_entitlement checks in the Subscription model, admin/admin_dashboard plan validators, and the frontend useSubscription isPaidPlan check. Resource visibility uses the same string 'team' in a separate domain (Tree/StepLibrary visibility) and is intentionally untouched. - New backend/scripts/sync_stripe_plan_ids.py: idempotent upsert of plan_billing rows from Stripe products by exact name match. Picks the active monthly recurring price for tiers that have one; leaves annual fields NULL by design. Works against test or live keys. - Test fixture updates: conftest seeds the new taxonomy, the public plans helper is a true upsert so tests can override max_users, and team -> enterprise across test_admin_plan_limits and test_invite_plan. Verified: 86/86 passing across the subscription/billing/plan/invite/ admin sweep; sync script run against test mode populates plan_billing correctly for all three tiers. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
265 lines
10 KiB
Python
265 lines
10 KiB
Python
"""Integration tests for admin plan limits and account override endpoints."""
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from httpx import AsyncClient
|
|
from sqlalchemy import select
|
|
|
|
from app.models.plan_billing import PlanBilling
|
|
|
|
|
|
class TestAdminPlanLimits:
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_plan_limits(
|
|
self, client: AsyncClient, admin_auth_headers: dict
|
|
):
|
|
"""List all plan limits."""
|
|
response = await client.get("/api/v1/admin/plan-limits", headers=admin_auth_headers)
|
|
assert response.status_code == 200
|
|
plans = response.json()
|
|
assert len(plans) >= 3 # free, pro, team seeded in conftest
|
|
plan_names = [p["plan"] for p in plans]
|
|
assert "free" in plan_names
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_plan_limits(
|
|
self, client: AsyncClient, admin_auth_headers: dict
|
|
):
|
|
"""Update a plan's limits."""
|
|
response = await client.put(
|
|
"/api/v1/admin/plan-limits",
|
|
json={
|
|
"plan": "free",
|
|
"max_trees": 5,
|
|
"max_sessions_per_month": 30,
|
|
"max_users": 2,
|
|
"custom_branding": False,
|
|
"priority_support": False,
|
|
"export_formats": ["markdown", "text"],
|
|
},
|
|
headers=admin_auth_headers,
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["max_trees"] == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_account_overrides(
|
|
self, client: AsyncClient, admin_auth_headers: dict
|
|
):
|
|
"""List account overrides."""
|
|
response = await client.get("/api/v1/admin/account-overrides", headers=admin_auth_headers)
|
|
assert response.status_code == 200
|
|
assert isinstance(response.json(), list)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_non_admin_cannot_access(
|
|
self, client: AsyncClient, auth_headers: dict
|
|
):
|
|
"""Non-admin gets 403."""
|
|
response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers)
|
|
assert response.status_code == 403
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_admin_plan_limits_get_includes_plan_billing_fields_when_present(
|
|
self, client: AsyncClient, admin_auth_headers: dict, test_db
|
|
):
|
|
"""GET /admin/plan-limits returns plan_billing fields when a row exists,
|
|
and None for plans that don't have one yet."""
|
|
# Seed a plan_billing row for "pro".
|
|
existing = (await test_db.execute(
|
|
select(PlanBilling).where(PlanBilling.plan == "pro")
|
|
)).scalar_one_or_none()
|
|
if existing is None:
|
|
test_db.add(PlanBilling(
|
|
plan="pro",
|
|
display_name="Pro",
|
|
description="For working teams",
|
|
monthly_price_cents=4900,
|
|
annual_price_cents=49000,
|
|
stripe_product_id="prod_seed",
|
|
stripe_monthly_price_id="price_seed_m",
|
|
stripe_annual_price_id="price_seed_a",
|
|
is_public=True,
|
|
is_archived=False,
|
|
sort_order=10,
|
|
))
|
|
await test_db.commit()
|
|
|
|
response = await client.get(
|
|
"/api/v1/admin/plan-limits", headers=admin_auth_headers
|
|
)
|
|
assert response.status_code == 200
|
|
plans_by_name = {p["plan"]: p for p in response.json()}
|
|
|
|
assert "pro" in plans_by_name
|
|
pro = plans_by_name["pro"]
|
|
assert pro["display_name"] == "Pro"
|
|
assert pro["monthly_price_cents"] == 4900
|
|
assert pro["stripe_monthly_price_id"] == "price_seed_m"
|
|
assert pro["is_public"] is True
|
|
assert pro["is_archived"] is False
|
|
assert pro["sort_order"] == 10
|
|
|
|
# A plan without a plan_billing row should still return, with None
|
|
# billing fields.
|
|
if "free" in plans_by_name:
|
|
free = plans_by_name["free"]
|
|
# free has no plan_billing row in the seed → fields are None.
|
|
no_billing_row = (await test_db.execute(
|
|
select(PlanBilling).where(PlanBilling.plan == "free")
|
|
)).scalar_one_or_none() is None
|
|
if no_billing_row:
|
|
assert free["display_name"] is None
|
|
assert free["monthly_price_cents"] is None
|
|
assert free["stripe_product_id"] is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_admin_plan_limits_put_creates_plan_billing_row(
|
|
self, client: AsyncClient, admin_auth_headers: dict, test_db
|
|
):
|
|
"""PUT /admin/plan-limits upserts a plan_billing row when billing
|
|
fields are included in the body."""
|
|
# Ensure no plan_billing row exists for "enterprise" yet.
|
|
existing = (await test_db.execute(
|
|
select(PlanBilling).where(PlanBilling.plan == "enterprise")
|
|
)).scalar_one_or_none()
|
|
if existing is not None:
|
|
await test_db.delete(existing)
|
|
await test_db.commit()
|
|
|
|
response = await client.put(
|
|
"/api/v1/admin/plan-limits",
|
|
json={
|
|
"plan": "enterprise",
|
|
"max_trees": None,
|
|
"max_sessions_per_month": None,
|
|
"max_users": None,
|
|
"custom_branding": True,
|
|
"priority_support": True,
|
|
"export_formats": ["markdown", "text", "pdf"],
|
|
"display_name": "Team",
|
|
"description": "For growing shops",
|
|
"monthly_price_cents": 9900,
|
|
"annual_price_cents": 99000,
|
|
"stripe_product_id": "prod_team_test",
|
|
"stripe_monthly_price_id": "price_team_m",
|
|
"stripe_annual_price_id": "price_team_a",
|
|
"is_public": True,
|
|
"is_archived": False,
|
|
"sort_order": 20,
|
|
},
|
|
headers=admin_auth_headers,
|
|
)
|
|
assert response.status_code == 200, response.text
|
|
body = response.json()
|
|
assert body["display_name"] == "Team"
|
|
assert body["monthly_price_cents"] == 9900
|
|
assert body["stripe_product_id"] == "prod_team_test"
|
|
assert body["sort_order"] == 20
|
|
|
|
# Confirm the row was actually persisted.
|
|
await test_db.commit() # ensure session sees other-session writes
|
|
pb = (await test_db.execute(
|
|
select(PlanBilling).where(PlanBilling.plan == "enterprise")
|
|
)).scalar_one_or_none()
|
|
assert pb is not None
|
|
assert pb.display_name == "Team"
|
|
assert pb.monthly_price_cents == 9900
|
|
assert pb.stripe_monthly_price_id == "price_team_m"
|
|
assert pb.is_public is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_admin_plan_limits_put_does_not_null_out_required_fields(
|
|
self, client: AsyncClient, admin_auth_headers: dict, test_db
|
|
):
|
|
"""PUT /admin/plan-limits must not NULL out NOT NULL columns on the
|
|
plan_billing row when the caller passes explicit nulls. The set of
|
|
guarded fields is {display_name, is_public, is_archived, sort_order}.
|
|
"""
|
|
# Seed a plan_billing row for "enterprise" with non-default values for every
|
|
# NOT NULL field so we can detect any clobbering.
|
|
existing = (await test_db.execute(
|
|
select(PlanBilling).where(PlanBilling.plan == "enterprise")
|
|
)).scalar_one_or_none()
|
|
if existing is not None:
|
|
await test_db.delete(existing)
|
|
await test_db.commit()
|
|
|
|
seeded = PlanBilling(
|
|
plan="enterprise",
|
|
display_name="Team Seeded",
|
|
is_public=False,
|
|
is_archived=True,
|
|
sort_order=5,
|
|
)
|
|
test_db.add(seeded)
|
|
await test_db.commit()
|
|
|
|
response = await client.put(
|
|
"/api/v1/admin/plan-limits",
|
|
json={
|
|
"plan": "enterprise",
|
|
"max_trees": None,
|
|
"max_sessions_per_month": None,
|
|
"max_users": None,
|
|
"custom_branding": True,
|
|
"priority_support": True,
|
|
"export_formats": ["markdown", "text"],
|
|
# Explicit nulls for every NOT NULL plan_billing field.
|
|
"display_name": None,
|
|
"is_public": None,
|
|
"is_archived": None,
|
|
"sort_order": None,
|
|
},
|
|
headers=admin_auth_headers,
|
|
)
|
|
assert response.status_code == 200, response.text
|
|
|
|
# Confirm the seeded NOT NULL values were preserved.
|
|
await test_db.commit() # ensure session sees writes from the request
|
|
pb = (await test_db.execute(
|
|
select(PlanBilling).where(PlanBilling.plan == "enterprise")
|
|
)).scalar_one_or_none()
|
|
assert pb is not None
|
|
assert pb.display_name == "Team Seeded"
|
|
assert pb.is_public is False
|
|
assert pb.is_archived is True
|
|
assert pb.sort_order == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_admin_plan_limits_put_invalidates_billing_cache(
|
|
self, client: AsyncClient, admin_auth_headers: dict
|
|
):
|
|
"""PUT /admin/plan-limits calls BillingService.invalidate_billing_cache
|
|
with the account_ids on the affected plan."""
|
|
# Patch the staticmethod on the class. The endpoint imports
|
|
# BillingService at module load, so patch the symbol on the class
|
|
# itself — both the import and the dotted reference resolve to it.
|
|
with patch(
|
|
"app.api.endpoints.admin_plan_limits.BillingService.invalidate_billing_cache",
|
|
new_callable=AsyncMock,
|
|
) as spy:
|
|
response = await client.put(
|
|
"/api/v1/admin/plan-limits",
|
|
json={
|
|
"plan": "pro",
|
|
"max_trees": 25,
|
|
"max_sessions_per_month": 500,
|
|
"max_users": 10,
|
|
"custom_branding": True,
|
|
"priority_support": True,
|
|
"export_formats": ["markdown", "text"],
|
|
},
|
|
headers=admin_auth_headers,
|
|
)
|
|
assert response.status_code == 200, response.text
|
|
spy.assert_awaited_once()
|
|
(account_ids_arg,) = spy.await_args.args
|
|
# admin fixture seeds an active Pro Subscription, so we expect at
|
|
# least one account_id in the invalidation list.
|
|
assert isinstance(account_ids_arg, list)
|
|
assert len(account_ids_arg) >= 1
|