From d05b475a411439a0183f6c481f0b3891566f8775 Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Wed, 6 May 2026 20:32:09 -0400 Subject: [PATCH] feat(admin): extend /admin/plan-limits to manage plan_billing fields MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Task 30 of self-serve signup Phase 2. Super-admins can now manage Stripe IDs, display names, prices, and public/archived flags via the existing admin plan-limits endpoints. - GET /admin/plan-limits now outer-joins plan_billing and returns merged PlanLimitWithBillingResponse rows. Plans without a plan_billing row return None for the billing fields. - PUT /admin/plan-limits accepts the new optional billing fields and upserts plan_billing in the same transaction. If no plan_billing row exists for the plan and the body includes any billing field, a row is created (display_name defaults to plan.capitalize() when omitted; display_name is never NULLed out on an existing row). - After commit, the handler queries account_ids on the affected plan and calls BillingService.invalidate_billing_cache(account_ids). This is a no-op stub today (logs only) — there's no in-process billing cache yet. TODO comment marks the wire-up point. - 3 new integration tests cover GET-with-billing-present, PUT creating a plan_billing row, and the invalidation hook being awaited with a list of account_ids. Co-Authored-By: Claude Opus 4.7 --- .../app/api/endpoints/admin_plan_limits.py | 125 ++++++++++- backend/app/schemas/admin.py | 28 +++ backend/app/services/billing.py | 25 +++ backend/tests/test_admin_plan_limits.py | 206 ++++++++++++++++++ 4 files changed, 375 insertions(+), 9 deletions(-) diff --git a/backend/app/api/endpoints/admin_plan_limits.py b/backend/app/api/endpoints/admin_plan_limits.py index 387081f5..52ea09b4 100644 --- a/backend/app/api/endpoints/admin_plan_limits.py +++ b/backend/app/api/endpoints/admin_plan_limits.py @@ -8,34 +8,101 @@ from app.core.database import get_db from app.core.audit import log_audit from app.models.user import User from app.models.plan_limits import PlanLimits +from app.models.plan_billing import PlanBilling from app.models.account import Account from app.models.account_limit_override import AccountLimitOverride +from app.models.subscription import Subscription from app.schemas.admin import ( - PlanLimitResponse, PlanLimitUpdate, + PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse, AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse, ) from app.api.deps import require_admin +from app.services.billing import BillingService router = APIRouter(prefix="/admin", tags=["admin-plan-limits"]) -@router.get("/plan-limits", response_model=list[PlanLimitResponse]) +# Fields on PlanLimitUpdate that map to plan_billing (not plan_limits). +_PLAN_BILLING_FIELDS = ( + "display_name", + "description", + "monthly_price_cents", + "annual_price_cents", + "stripe_product_id", + "stripe_monthly_price_id", + "stripe_annual_price_id", + "is_public", + "is_archived", + "sort_order", +) + +# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model. +# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit +# null for any of them would otherwise trigger a NOT NULL violation at commit. +_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({ + "display_name", + "is_public", + "is_archived", + "sort_order", +}) + + +def _merge_plan_with_billing( + plan: PlanLimits, billing: PlanBilling | None +) -> PlanLimitWithBillingResponse: + """Build a merged response. Billing fields are None when no plan_billing row + exists for the plan.""" + payload = { + "plan": plan.plan, + "max_trees": plan.max_trees, + "max_sessions_per_month": plan.max_sessions_per_month, + "max_users": plan.max_users, + "custom_branding": plan.custom_branding, + "priority_support": plan.priority_support, + "export_formats": plan.export_formats or [], + } + if billing is not None: + payload.update({ + "display_name": billing.display_name, + "description": billing.description, + "monthly_price_cents": billing.monthly_price_cents, + "annual_price_cents": billing.annual_price_cents, + "stripe_product_id": billing.stripe_product_id, + "stripe_monthly_price_id": billing.stripe_monthly_price_id, + "stripe_annual_price_id": billing.stripe_annual_price_id, + "is_public": billing.is_public, + "is_archived": billing.is_archived, + "sort_order": billing.sort_order, + }) + return PlanLimitWithBillingResponse(**payload) + + +@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse]) async def list_plan_limits( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(require_admin)], ): - """List all plan limit configurations.""" - result = await db.execute(select(PlanLimits)) - return result.scalars().all() + """List all plan limit configurations, merged with plan_billing fields + where present. Plans without a plan_billing row return None for the + billing fields.""" + rows = (await db.execute( + select(PlanLimits, PlanBilling) + .outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan) + )).all() + return [_merge_plan_with_billing(pl, pb) for pl, pb in rows] -@router.put("/plan-limits", response_model=PlanLimitResponse) +@router.put("/plan-limits", response_model=PlanLimitWithBillingResponse) async def update_plan_limits( data: PlanLimitUpdate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(require_admin)], ): - """Update a plan's limits.""" + """Update a plan's limits and (if any plan_billing field is included) + upsert the matching plan_billing row in the same transaction. After + commit, invalidates the in-process billing cache for accounts on this + plan (currently a no-op — see BillingService.invalidate_billing_cache). + """ result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan)) plan = result.scalar_one_or_none() if not plan: @@ -48,10 +115,50 @@ async def update_plan_limits( plan.priority_support = data.priority_support plan.export_formats = data.export_formats - await log_audit(db, current_user.id, "plan_limits.update", "plan_limits", details={"plan": data.plan}) + # Did the request include any plan_billing field? (Pydantic gives us + # `model_fields_set` to distinguish "user passed null" from "field omitted".) + billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS) + billing: PlanBilling | None = None + if billing_fields_set: + billing = (await db.execute( + select(PlanBilling).where(PlanBilling.plan == data.plan) + )).scalar_one_or_none() + + if billing is None: + # Create. display_name is required on the model — derive from the + # plan name when the caller didn't supply one (e.g. "pro" → "Pro"). + display_name = data.display_name or data.plan.capitalize() + billing = PlanBilling(plan=data.plan, display_name=display_name) + db.add(billing) + + # Apply only the fields the caller actually included. Allows partial + # updates without clobbering existing values. + for field in billing_fields_set: + value = getattr(data, field) + if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS: + # Don't NULL out a NOT NULL column on update. + continue + setattr(billing, field, value) + + await log_audit( + db, current_user.id, "plan_limits.update", "plan_limits", + details={"plan": data.plan, "updated_billing": bool(billing_fields_set)}, + ) await db.commit() await db.refresh(plan) - return plan + if billing is not None: + await db.refresh(billing) + + # Invalidate any in-process billing cache for accounts on this plan. + # TODO: invalidate app.state.billing_cache when added. + account_ids = [ + row[0] for row in (await db.execute( + select(Subscription.account_id).where(Subscription.plan == data.plan) + )).all() + ] + await BillingService.invalidate_billing_cache(account_ids) + + return _merge_plan_with_billing(plan, billing) @router.get("/account-overrides", response_model=list[AccountOverrideResponse]) diff --git a/backend/app/schemas/admin.py b/backend/app/schemas/admin.py index 72c63d43..a223d994 100644 --- a/backend/app/schemas/admin.py +++ b/backend/app/schemas/admin.py @@ -172,6 +172,21 @@ class PlanLimitResponse(BaseModel): from_attributes = True +class PlanLimitWithBillingResponse(PlanLimitResponse): + """PlanLimits + plan_billing fields merged. Billing fields are None when no + plan_billing row exists for the plan yet.""" + display_name: Optional[str] = None + description: Optional[str] = None + monthly_price_cents: Optional[int] = None + annual_price_cents: Optional[int] = None + stripe_product_id: Optional[str] = None + stripe_monthly_price_id: Optional[str] = None + stripe_annual_price_id: Optional[str] = None + is_public: Optional[bool] = None + is_archived: Optional[bool] = None + sort_order: Optional[int] = None + + class PlanLimitUpdate(BaseModel): plan: str max_trees: Optional[int] = None @@ -180,6 +195,19 @@ class PlanLimitUpdate(BaseModel): custom_branding: bool = False priority_support: bool = False export_formats: list = Field(default_factory=lambda: ["markdown", "text"]) + # plan_billing fields — all optional, partial-update semantics. If any are + # set in the body, the admin endpoint upserts the plan_billing row in the + # same transaction. + display_name: Optional[str] = None + description: Optional[str] = None + monthly_price_cents: Optional[int] = None + annual_price_cents: Optional[int] = None + stripe_product_id: Optional[str] = None + stripe_monthly_price_id: Optional[str] = None + stripe_annual_price_id: Optional[str] = None + is_public: Optional[bool] = None + is_archived: Optional[bool] = None + sort_order: Optional[int] = None class AccountOverrideCreate(BaseModel): diff --git a/backend/app/services/billing.py b/backend/app/services/billing.py index b662ed47..1ae0a999 100644 --- a/backend/app/services/billing.py +++ b/backend/app/services/billing.py @@ -1,6 +1,7 @@ """Single billing service module. Stripe is the only impl — no provider abstraction. Account row is canonical local state; Stripe is canonical remote state; the webhook handler bridges the two.""" +import logging from datetime import datetime, timezone, timedelta import stripe @@ -17,8 +18,32 @@ from app.models.subscription import Subscription TRIAL_DAYS = 14 +logger = logging.getLogger(__name__) + class BillingService: + @staticmethod + async def invalidate_billing_cache(account_ids) -> None: + """No-op stub for future in-process billing cache invalidation. + + Today there is no `app.state.billing_cache` — `BillingService.get_billing_state` + always reads fresh from the DB. Call sites that mutate plan/feature data + invoke this hook so that wiring is in place when an in-process cache is + added later. Until then, this just logs. + + TODO: when an in-process billing cache (e.g. `app.state.billing_cache`) + is introduced, evict entries for the given account_ids here. + """ + try: + count = len(list(account_ids)) + except TypeError: + count = -1 + logger.debug( + "BillingService.invalidate_billing_cache called for %d account(s) " + "(no-op stub — wire to app.state.billing_cache when added)", + count, + ) + @staticmethod async def start_trial(db: AsyncSession, account_id) -> Subscription: """Idempotent. Creates a trialing Subscription on Pro for the account if diff --git a/backend/tests/test_admin_plan_limits.py b/backend/tests/test_admin_plan_limits.py index 7e701b16..8eb22d45 100644 --- a/backend/tests/test_admin_plan_limits.py +++ b/backend/tests/test_admin_plan_limits.py @@ -1,7 +1,12 @@ """Integration tests for admin plan limits and account override endpoints.""" +from unittest.mock import AsyncMock, patch + import pytest from httpx import AsyncClient +from sqlalchemy import select + +from app.models.plan_billing import PlanBilling class TestAdminPlanLimits: @@ -56,3 +61,204 @@ class TestAdminPlanLimits: """Non-admin gets 403.""" response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers) assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_admin_plan_limits_get_includes_plan_billing_fields_when_present( + self, client: AsyncClient, admin_auth_headers: dict, test_db + ): + """GET /admin/plan-limits returns plan_billing fields when a row exists, + and None for plans that don't have one yet.""" + # Seed a plan_billing row for "pro". + existing = (await test_db.execute( + select(PlanBilling).where(PlanBilling.plan == "pro") + )).scalar_one_or_none() + if existing is None: + test_db.add(PlanBilling( + plan="pro", + display_name="Pro", + description="For working teams", + monthly_price_cents=4900, + annual_price_cents=49000, + stripe_product_id="prod_seed", + stripe_monthly_price_id="price_seed_m", + stripe_annual_price_id="price_seed_a", + is_public=True, + is_archived=False, + sort_order=10, + )) + await test_db.commit() + + response = await client.get( + "/api/v1/admin/plan-limits", headers=admin_auth_headers + ) + assert response.status_code == 200 + plans_by_name = {p["plan"]: p for p in response.json()} + + assert "pro" in plans_by_name + pro = plans_by_name["pro"] + assert pro["display_name"] == "Pro" + assert pro["monthly_price_cents"] == 4900 + assert pro["stripe_monthly_price_id"] == "price_seed_m" + assert pro["is_public"] is True + assert pro["is_archived"] is False + assert pro["sort_order"] == 10 + + # A plan without a plan_billing row should still return, with None + # billing fields. + if "free" in plans_by_name: + free = plans_by_name["free"] + # free has no plan_billing row in the seed → fields are None. + no_billing_row = (await test_db.execute( + select(PlanBilling).where(PlanBilling.plan == "free") + )).scalar_one_or_none() is None + if no_billing_row: + assert free["display_name"] is None + assert free["monthly_price_cents"] is None + assert free["stripe_product_id"] is None + + @pytest.mark.asyncio + async def test_admin_plan_limits_put_creates_plan_billing_row( + self, client: AsyncClient, admin_auth_headers: dict, test_db + ): + """PUT /admin/plan-limits upserts a plan_billing row when billing + fields are included in the body.""" + # Ensure no plan_billing row exists for "team" yet. + existing = (await test_db.execute( + select(PlanBilling).where(PlanBilling.plan == "team") + )).scalar_one_or_none() + if existing is not None: + await test_db.delete(existing) + await test_db.commit() + + response = await client.put( + "/api/v1/admin/plan-limits", + json={ + "plan": "team", + "max_trees": None, + "max_sessions_per_month": None, + "max_users": None, + "custom_branding": True, + "priority_support": True, + "export_formats": ["markdown", "text", "pdf"], + "display_name": "Team", + "description": "For growing shops", + "monthly_price_cents": 9900, + "annual_price_cents": 99000, + "stripe_product_id": "prod_team_test", + "stripe_monthly_price_id": "price_team_m", + "stripe_annual_price_id": "price_team_a", + "is_public": True, + "is_archived": False, + "sort_order": 20, + }, + headers=admin_auth_headers, + ) + assert response.status_code == 200, response.text + body = response.json() + assert body["display_name"] == "Team" + assert body["monthly_price_cents"] == 9900 + assert body["stripe_product_id"] == "prod_team_test" + assert body["sort_order"] == 20 + + # Confirm the row was actually persisted. + await test_db.commit() # ensure session sees other-session writes + pb = (await test_db.execute( + select(PlanBilling).where(PlanBilling.plan == "team") + )).scalar_one_or_none() + assert pb is not None + assert pb.display_name == "Team" + assert pb.monthly_price_cents == 9900 + assert pb.stripe_monthly_price_id == "price_team_m" + assert pb.is_public is True + + @pytest.mark.asyncio + async def test_admin_plan_limits_put_does_not_null_out_required_fields( + self, client: AsyncClient, admin_auth_headers: dict, test_db + ): + """PUT /admin/plan-limits must not NULL out NOT NULL columns on the + plan_billing row when the caller passes explicit nulls. The set of + guarded fields is {display_name, is_public, is_archived, sort_order}. + """ + # Seed a plan_billing row for "team" with non-default values for every + # NOT NULL field so we can detect any clobbering. + existing = (await test_db.execute( + select(PlanBilling).where(PlanBilling.plan == "team") + )).scalar_one_or_none() + if existing is not None: + await test_db.delete(existing) + await test_db.commit() + + seeded = PlanBilling( + plan="team", + display_name="Team Seeded", + is_public=False, + is_archived=True, + sort_order=5, + ) + test_db.add(seeded) + await test_db.commit() + + response = await client.put( + "/api/v1/admin/plan-limits", + json={ + "plan": "team", + "max_trees": None, + "max_sessions_per_month": None, + "max_users": None, + "custom_branding": True, + "priority_support": True, + "export_formats": ["markdown", "text"], + # Explicit nulls for every NOT NULL plan_billing field. + "display_name": None, + "is_public": None, + "is_archived": None, + "sort_order": None, + }, + headers=admin_auth_headers, + ) + assert response.status_code == 200, response.text + + # Confirm the seeded NOT NULL values were preserved. + await test_db.commit() # ensure session sees writes from the request + pb = (await test_db.execute( + select(PlanBilling).where(PlanBilling.plan == "team") + )).scalar_one_or_none() + assert pb is not None + assert pb.display_name == "Team Seeded" + assert pb.is_public is False + assert pb.is_archived is True + assert pb.sort_order == 5 + + @pytest.mark.asyncio + async def test_admin_plan_limits_put_invalidates_billing_cache( + self, client: AsyncClient, admin_auth_headers: dict + ): + """PUT /admin/plan-limits calls BillingService.invalidate_billing_cache + with the account_ids on the affected plan.""" + # Patch the staticmethod on the class. The endpoint imports + # BillingService at module load, so patch the symbol on the class + # itself — both the import and the dotted reference resolve to it. + with patch( + "app.api.endpoints.admin_plan_limits.BillingService.invalidate_billing_cache", + new_callable=AsyncMock, + ) as spy: + response = await client.put( + "/api/v1/admin/plan-limits", + json={ + "plan": "pro", + "max_trees": 25, + "max_sessions_per_month": 500, + "max_users": 10, + "custom_branding": True, + "priority_support": True, + "export_formats": ["markdown", "text"], + }, + headers=admin_auth_headers, + ) + assert response.status_code == 200, response.text + spy.assert_awaited_once() + (account_ids_arg,) = spy.await_args.args + # admin fixture seeds an active Pro Subscription, so we expect at + # least one account_id in the invalidation list. + assert isinstance(account_ids_arg, list) + assert len(account_ids_arg) >= 1