Files
resolutionflow/backend/tests/test_admin_plan_limits.py
Michael Chihlas ba36c47075 feat(billing): reconcile plan taxonomy and add Stripe sync script
The marketing surface (PricingPage, Stripe products) was wired for
"Starter / Pro / Enterprise" while the backend was on "free / pro / team",
leaving plan_billing unseeded and BillingPlan accepting a literal that
violated the FK to plan_limits.

This change:

- Migration 4ce3e594cb87: defensive UPDATE of any subscriptions on
  plan='team' to 'enterprise' (dev has zero), renames the plan_limits
  row team -> enterprise, inserts a starter row with caps interpolated
  between free and pro (max_trees=10, sessions=75, ai=15/mo).
- Renames the plan tier across schemas (invite_code, billing, admin,
  subscription comment), is_paid/has_pro_entitlement checks in the
  Subscription model, admin/admin_dashboard plan validators, and the
  frontend useSubscription isPaidPlan check. Resource visibility uses
  the same string 'team' in a separate domain (Tree/StepLibrary
  visibility) and is intentionally untouched.
- New backend/scripts/sync_stripe_plan_ids.py: idempotent upsert of
  plan_billing rows from Stripe products by exact name match. Picks
  the active monthly recurring price for tiers that have one; leaves
  annual fields NULL by design. Works against test or live keys.
- Test fixture updates: conftest seeds the new taxonomy, the public
  plans helper is a true upsert so tests can override max_users, and
  team -> enterprise across test_admin_plan_limits and test_invite_plan.

Verified: 86/86 passing across the subscription/billing/plan/invite/
admin sweep; sync script run against test mode populates plan_billing
correctly for all three tiers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-07 15:59:42 -04:00

265 lines
10 KiB
Python

"""Integration tests for admin plan limits and account override endpoints."""
from unittest.mock import AsyncMock, patch
import pytest
from httpx import AsyncClient
from sqlalchemy import select
from app.models.plan_billing import PlanBilling
class TestAdminPlanLimits:
@pytest.mark.asyncio
async def test_list_plan_limits(
self, client: AsyncClient, admin_auth_headers: dict
):
"""List all plan limits."""
response = await client.get("/api/v1/admin/plan-limits", headers=admin_auth_headers)
assert response.status_code == 200
plans = response.json()
assert len(plans) >= 3 # free, pro, team seeded in conftest
plan_names = [p["plan"] for p in plans]
assert "free" in plan_names
@pytest.mark.asyncio
async def test_update_plan_limits(
self, client: AsyncClient, admin_auth_headers: dict
):
"""Update a plan's limits."""
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "free",
"max_trees": 5,
"max_sessions_per_month": 30,
"max_users": 2,
"custom_branding": False,
"priority_support": False,
"export_formats": ["markdown", "text"],
},
headers=admin_auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["max_trees"] == 5
@pytest.mark.asyncio
async def test_list_account_overrides(
self, client: AsyncClient, admin_auth_headers: dict
):
"""List account overrides."""
response = await client.get("/api/v1/admin/account-overrides", headers=admin_auth_headers)
assert response.status_code == 200
assert isinstance(response.json(), list)
@pytest.mark.asyncio
async def test_non_admin_cannot_access(
self, client: AsyncClient, auth_headers: dict
):
"""Non-admin gets 403."""
response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_admin_plan_limits_get_includes_plan_billing_fields_when_present(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""GET /admin/plan-limits returns plan_billing fields when a row exists,
and None for plans that don't have one yet."""
# Seed a plan_billing row for "pro".
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "pro")
)).scalar_one_or_none()
if existing is None:
test_db.add(PlanBilling(
plan="pro",
display_name="Pro",
description="For working teams",
monthly_price_cents=4900,
annual_price_cents=49000,
stripe_product_id="prod_seed",
stripe_monthly_price_id="price_seed_m",
stripe_annual_price_id="price_seed_a",
is_public=True,
is_archived=False,
sort_order=10,
))
await test_db.commit()
response = await client.get(
"/api/v1/admin/plan-limits", headers=admin_auth_headers
)
assert response.status_code == 200
plans_by_name = {p["plan"]: p for p in response.json()}
assert "pro" in plans_by_name
pro = plans_by_name["pro"]
assert pro["display_name"] == "Pro"
assert pro["monthly_price_cents"] == 4900
assert pro["stripe_monthly_price_id"] == "price_seed_m"
assert pro["is_public"] is True
assert pro["is_archived"] is False
assert pro["sort_order"] == 10
# A plan without a plan_billing row should still return, with None
# billing fields.
if "free" in plans_by_name:
free = plans_by_name["free"]
# free has no plan_billing row in the seed → fields are None.
no_billing_row = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "free")
)).scalar_one_or_none() is None
if no_billing_row:
assert free["display_name"] is None
assert free["monthly_price_cents"] is None
assert free["stripe_product_id"] is None
@pytest.mark.asyncio
async def test_admin_plan_limits_put_creates_plan_billing_row(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""PUT /admin/plan-limits upserts a plan_billing row when billing
fields are included in the body."""
# Ensure no plan_billing row exists for "enterprise" yet.
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "enterprise")
)).scalar_one_or_none()
if existing is not None:
await test_db.delete(existing)
await test_db.commit()
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "enterprise",
"max_trees": None,
"max_sessions_per_month": None,
"max_users": None,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text", "pdf"],
"display_name": "Team",
"description": "For growing shops",
"monthly_price_cents": 9900,
"annual_price_cents": 99000,
"stripe_product_id": "prod_team_test",
"stripe_monthly_price_id": "price_team_m",
"stripe_annual_price_id": "price_team_a",
"is_public": True,
"is_archived": False,
"sort_order": 20,
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
body = response.json()
assert body["display_name"] == "Team"
assert body["monthly_price_cents"] == 9900
assert body["stripe_product_id"] == "prod_team_test"
assert body["sort_order"] == 20
# Confirm the row was actually persisted.
await test_db.commit() # ensure session sees other-session writes
pb = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "enterprise")
)).scalar_one_or_none()
assert pb is not None
assert pb.display_name == "Team"
assert pb.monthly_price_cents == 9900
assert pb.stripe_monthly_price_id == "price_team_m"
assert pb.is_public is True
@pytest.mark.asyncio
async def test_admin_plan_limits_put_does_not_null_out_required_fields(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""PUT /admin/plan-limits must not NULL out NOT NULL columns on the
plan_billing row when the caller passes explicit nulls. The set of
guarded fields is {display_name, is_public, is_archived, sort_order}.
"""
# Seed a plan_billing row for "enterprise" with non-default values for every
# NOT NULL field so we can detect any clobbering.
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "enterprise")
)).scalar_one_or_none()
if existing is not None:
await test_db.delete(existing)
await test_db.commit()
seeded = PlanBilling(
plan="enterprise",
display_name="Team Seeded",
is_public=False,
is_archived=True,
sort_order=5,
)
test_db.add(seeded)
await test_db.commit()
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "enterprise",
"max_trees": None,
"max_sessions_per_month": None,
"max_users": None,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text"],
# Explicit nulls for every NOT NULL plan_billing field.
"display_name": None,
"is_public": None,
"is_archived": None,
"sort_order": None,
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
# Confirm the seeded NOT NULL values were preserved.
await test_db.commit() # ensure session sees writes from the request
pb = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "enterprise")
)).scalar_one_or_none()
assert pb is not None
assert pb.display_name == "Team Seeded"
assert pb.is_public is False
assert pb.is_archived is True
assert pb.sort_order == 5
@pytest.mark.asyncio
async def test_admin_plan_limits_put_invalidates_billing_cache(
self, client: AsyncClient, admin_auth_headers: dict
):
"""PUT /admin/plan-limits calls BillingService.invalidate_billing_cache
with the account_ids on the affected plan."""
# Patch the staticmethod on the class. The endpoint imports
# BillingService at module load, so patch the symbol on the class
# itself — both the import and the dotted reference resolve to it.
with patch(
"app.api.endpoints.admin_plan_limits.BillingService.invalidate_billing_cache",
new_callable=AsyncMock,
) as spy:
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "pro",
"max_trees": 25,
"max_sessions_per_month": 500,
"max_users": 10,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text"],
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
spy.assert_awaited_once()
(account_ids_arg,) = spy.await_args.args
# admin fixture seeds an active Pro Subscription, so we expect at
# least one account_id in the invalidation list.
assert isinstance(account_ids_arg, list)
assert len(account_ids_arg) >= 1