feat(admin): extend /admin/plan-limits to manage plan_billing fields

Task 30 of self-serve signup Phase 2. Super-admins can now manage Stripe
IDs, display names, prices, and public/archived flags via the existing
admin plan-limits endpoints.

- GET /admin/plan-limits now outer-joins plan_billing and returns
  merged PlanLimitWithBillingResponse rows. Plans without a
  plan_billing row return None for the billing fields.
- PUT /admin/plan-limits accepts the new optional billing fields and
  upserts plan_billing in the same transaction. If no plan_billing
  row exists for the plan and the body includes any billing field, a
  row is created (display_name defaults to plan.capitalize() when
  omitted; display_name is never NULLed out on an existing row).
- After commit, the handler queries account_ids on the affected plan
  and calls BillingService.invalidate_billing_cache(account_ids).
  This is a no-op stub today (logs only) — there's no in-process
  billing cache yet. TODO comment marks the wire-up point.
- 3 new integration tests cover GET-with-billing-present, PUT creating
  a plan_billing row, and the invalidation hook being awaited with a
  list of account_ids.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-06 20:32:09 -04:00
parent 694279f89e
commit d05b475a41
4 changed files with 375 additions and 9 deletions

View File

@@ -8,34 +8,101 @@ from app.core.database import get_db
from app.core.audit import log_audit from app.core.audit import log_audit
from app.models.user import User from app.models.user import User
from app.models.plan_limits import PlanLimits from app.models.plan_limits import PlanLimits
from app.models.plan_billing import PlanBilling
from app.models.account import Account from app.models.account import Account
from app.models.account_limit_override import AccountLimitOverride from app.models.account_limit_override import AccountLimitOverride
from app.models.subscription import Subscription
from app.schemas.admin import ( from app.schemas.admin import (
PlanLimitResponse, PlanLimitUpdate, PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse,
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse, AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
) )
from app.api.deps import require_admin from app.api.deps import require_admin
from app.services.billing import BillingService
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"]) router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
@router.get("/plan-limits", response_model=list[PlanLimitResponse]) # Fields on PlanLimitUpdate that map to plan_billing (not plan_limits).
_PLAN_BILLING_FIELDS = (
"display_name",
"description",
"monthly_price_cents",
"annual_price_cents",
"stripe_product_id",
"stripe_monthly_price_id",
"stripe_annual_price_id",
"is_public",
"is_archived",
"sort_order",
)
# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model.
# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit
# null for any of them would otherwise trigger a NOT NULL violation at commit.
_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({
"display_name",
"is_public",
"is_archived",
"sort_order",
})
def _merge_plan_with_billing(
plan: PlanLimits, billing: PlanBilling | None
) -> PlanLimitWithBillingResponse:
"""Build a merged response. Billing fields are None when no plan_billing row
exists for the plan."""
payload = {
"plan": plan.plan,
"max_trees": plan.max_trees,
"max_sessions_per_month": plan.max_sessions_per_month,
"max_users": plan.max_users,
"custom_branding": plan.custom_branding,
"priority_support": plan.priority_support,
"export_formats": plan.export_formats or [],
}
if billing is not None:
payload.update({
"display_name": billing.display_name,
"description": billing.description,
"monthly_price_cents": billing.monthly_price_cents,
"annual_price_cents": billing.annual_price_cents,
"stripe_product_id": billing.stripe_product_id,
"stripe_monthly_price_id": billing.stripe_monthly_price_id,
"stripe_annual_price_id": billing.stripe_annual_price_id,
"is_public": billing.is_public,
"is_archived": billing.is_archived,
"sort_order": billing.sort_order,
})
return PlanLimitWithBillingResponse(**payload)
@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse])
async def list_plan_limits( async def list_plan_limits(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)], current_user: Annotated[User, Depends(require_admin)],
): ):
"""List all plan limit configurations.""" """List all plan limit configurations, merged with plan_billing fields
result = await db.execute(select(PlanLimits)) where present. Plans without a plan_billing row return None for the
return result.scalars().all() billing fields."""
rows = (await db.execute(
select(PlanLimits, PlanBilling)
.outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan)
)).all()
return [_merge_plan_with_billing(pl, pb) for pl, pb in rows]
@router.put("/plan-limits", response_model=PlanLimitResponse) @router.put("/plan-limits", response_model=PlanLimitWithBillingResponse)
async def update_plan_limits( async def update_plan_limits(
data: PlanLimitUpdate, data: PlanLimitUpdate,
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)], current_user: Annotated[User, Depends(require_admin)],
): ):
"""Update a plan's limits.""" """Update a plan's limits and (if any plan_billing field is included)
upsert the matching plan_billing row in the same transaction. After
commit, invalidates the in-process billing cache for accounts on this
plan (currently a no-op — see BillingService.invalidate_billing_cache).
"""
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan)) result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
plan = result.scalar_one_or_none() plan = result.scalar_one_or_none()
if not plan: if not plan:
@@ -48,10 +115,50 @@ async def update_plan_limits(
plan.priority_support = data.priority_support plan.priority_support = data.priority_support
plan.export_formats = data.export_formats plan.export_formats = data.export_formats
await log_audit(db, current_user.id, "plan_limits.update", "plan_limits", details={"plan": data.plan}) # Did the request include any plan_billing field? (Pydantic gives us
# `model_fields_set` to distinguish "user passed null" from "field omitted".)
billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS)
billing: PlanBilling | None = None
if billing_fields_set:
billing = (await db.execute(
select(PlanBilling).where(PlanBilling.plan == data.plan)
)).scalar_one_or_none()
if billing is None:
# Create. display_name is required on the model — derive from the
# plan name when the caller didn't supply one (e.g. "pro" → "Pro").
display_name = data.display_name or data.plan.capitalize()
billing = PlanBilling(plan=data.plan, display_name=display_name)
db.add(billing)
# Apply only the fields the caller actually included. Allows partial
# updates without clobbering existing values.
for field in billing_fields_set:
value = getattr(data, field)
if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS:
# Don't NULL out a NOT NULL column on update.
continue
setattr(billing, field, value)
await log_audit(
db, current_user.id, "plan_limits.update", "plan_limits",
details={"plan": data.plan, "updated_billing": bool(billing_fields_set)},
)
await db.commit() await db.commit()
await db.refresh(plan) await db.refresh(plan)
return plan if billing is not None:
await db.refresh(billing)
# Invalidate any in-process billing cache for accounts on this plan.
# TODO: invalidate app.state.billing_cache when added.
account_ids = [
row[0] for row in (await db.execute(
select(Subscription.account_id).where(Subscription.plan == data.plan)
)).all()
]
await BillingService.invalidate_billing_cache(account_ids)
return _merge_plan_with_billing(plan, billing)
@router.get("/account-overrides", response_model=list[AccountOverrideResponse]) @router.get("/account-overrides", response_model=list[AccountOverrideResponse])

View File

@@ -172,6 +172,21 @@ class PlanLimitResponse(BaseModel):
from_attributes = True from_attributes = True
class PlanLimitWithBillingResponse(PlanLimitResponse):
"""PlanLimits + plan_billing fields merged. Billing fields are None when no
plan_billing row exists for the plan yet."""
display_name: Optional[str] = None
description: Optional[str] = None
monthly_price_cents: Optional[int] = None
annual_price_cents: Optional[int] = None
stripe_product_id: Optional[str] = None
stripe_monthly_price_id: Optional[str] = None
stripe_annual_price_id: Optional[str] = None
is_public: Optional[bool] = None
is_archived: Optional[bool] = None
sort_order: Optional[int] = None
class PlanLimitUpdate(BaseModel): class PlanLimitUpdate(BaseModel):
plan: str plan: str
max_trees: Optional[int] = None max_trees: Optional[int] = None
@@ -180,6 +195,19 @@ class PlanLimitUpdate(BaseModel):
custom_branding: bool = False custom_branding: bool = False
priority_support: bool = False priority_support: bool = False
export_formats: list = Field(default_factory=lambda: ["markdown", "text"]) export_formats: list = Field(default_factory=lambda: ["markdown", "text"])
# plan_billing fields — all optional, partial-update semantics. If any are
# set in the body, the admin endpoint upserts the plan_billing row in the
# same transaction.
display_name: Optional[str] = None
description: Optional[str] = None
monthly_price_cents: Optional[int] = None
annual_price_cents: Optional[int] = None
stripe_product_id: Optional[str] = None
stripe_monthly_price_id: Optional[str] = None
stripe_annual_price_id: Optional[str] = None
is_public: Optional[bool] = None
is_archived: Optional[bool] = None
sort_order: Optional[int] = None
class AccountOverrideCreate(BaseModel): class AccountOverrideCreate(BaseModel):

View File

@@ -1,6 +1,7 @@
"""Single billing service module. Stripe is the only impl — no provider """Single billing service module. Stripe is the only impl — no provider
abstraction. Account row is canonical local state; Stripe is canonical abstraction. Account row is canonical local state; Stripe is canonical
remote state; the webhook handler bridges the two.""" remote state; the webhook handler bridges the two."""
import logging
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
import stripe import stripe
@@ -17,8 +18,32 @@ from app.models.subscription import Subscription
TRIAL_DAYS = 14 TRIAL_DAYS = 14
logger = logging.getLogger(__name__)
class BillingService: class BillingService:
@staticmethod
async def invalidate_billing_cache(account_ids) -> None:
"""No-op stub for future in-process billing cache invalidation.
Today there is no `app.state.billing_cache` — `BillingService.get_billing_state`
always reads fresh from the DB. Call sites that mutate plan/feature data
invoke this hook so that wiring is in place when an in-process cache is
added later. Until then, this just logs.
TODO: when an in-process billing cache (e.g. `app.state.billing_cache`)
is introduced, evict entries for the given account_ids here.
"""
try:
count = len(list(account_ids))
except TypeError:
count = -1
logger.debug(
"BillingService.invalidate_billing_cache called for %d account(s) "
"(no-op stub — wire to app.state.billing_cache when added)",
count,
)
@staticmethod @staticmethod
async def start_trial(db: AsyncSession, account_id) -> Subscription: async def start_trial(db: AsyncSession, account_id) -> Subscription:
"""Idempotent. Creates a trialing Subscription on Pro for the account if """Idempotent. Creates a trialing Subscription on Pro for the account if

View File

@@ -1,7 +1,12 @@
"""Integration tests for admin plan limits and account override endpoints.""" """Integration tests for admin plan limits and account override endpoints."""
from unittest.mock import AsyncMock, patch
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy import select
from app.models.plan_billing import PlanBilling
class TestAdminPlanLimits: class TestAdminPlanLimits:
@@ -56,3 +61,204 @@ class TestAdminPlanLimits:
"""Non-admin gets 403.""" """Non-admin gets 403."""
response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers) response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers)
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio
async def test_admin_plan_limits_get_includes_plan_billing_fields_when_present(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""GET /admin/plan-limits returns plan_billing fields when a row exists,
and None for plans that don't have one yet."""
# Seed a plan_billing row for "pro".
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "pro")
)).scalar_one_or_none()
if existing is None:
test_db.add(PlanBilling(
plan="pro",
display_name="Pro",
description="For working teams",
monthly_price_cents=4900,
annual_price_cents=49000,
stripe_product_id="prod_seed",
stripe_monthly_price_id="price_seed_m",
stripe_annual_price_id="price_seed_a",
is_public=True,
is_archived=False,
sort_order=10,
))
await test_db.commit()
response = await client.get(
"/api/v1/admin/plan-limits", headers=admin_auth_headers
)
assert response.status_code == 200
plans_by_name = {p["plan"]: p for p in response.json()}
assert "pro" in plans_by_name
pro = plans_by_name["pro"]
assert pro["display_name"] == "Pro"
assert pro["monthly_price_cents"] == 4900
assert pro["stripe_monthly_price_id"] == "price_seed_m"
assert pro["is_public"] is True
assert pro["is_archived"] is False
assert pro["sort_order"] == 10
# A plan without a plan_billing row should still return, with None
# billing fields.
if "free" in plans_by_name:
free = plans_by_name["free"]
# free has no plan_billing row in the seed → fields are None.
no_billing_row = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "free")
)).scalar_one_or_none() is None
if no_billing_row:
assert free["display_name"] is None
assert free["monthly_price_cents"] is None
assert free["stripe_product_id"] is None
@pytest.mark.asyncio
async def test_admin_plan_limits_put_creates_plan_billing_row(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""PUT /admin/plan-limits upserts a plan_billing row when billing
fields are included in the body."""
# Ensure no plan_billing row exists for "team" yet.
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
if existing is not None:
await test_db.delete(existing)
await test_db.commit()
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "team",
"max_trees": None,
"max_sessions_per_month": None,
"max_users": None,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text", "pdf"],
"display_name": "Team",
"description": "For growing shops",
"monthly_price_cents": 9900,
"annual_price_cents": 99000,
"stripe_product_id": "prod_team_test",
"stripe_monthly_price_id": "price_team_m",
"stripe_annual_price_id": "price_team_a",
"is_public": True,
"is_archived": False,
"sort_order": 20,
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
body = response.json()
assert body["display_name"] == "Team"
assert body["monthly_price_cents"] == 9900
assert body["stripe_product_id"] == "prod_team_test"
assert body["sort_order"] == 20
# Confirm the row was actually persisted.
await test_db.commit() # ensure session sees other-session writes
pb = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
assert pb is not None
assert pb.display_name == "Team"
assert pb.monthly_price_cents == 9900
assert pb.stripe_monthly_price_id == "price_team_m"
assert pb.is_public is True
@pytest.mark.asyncio
async def test_admin_plan_limits_put_does_not_null_out_required_fields(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""PUT /admin/plan-limits must not NULL out NOT NULL columns on the
plan_billing row when the caller passes explicit nulls. The set of
guarded fields is {display_name, is_public, is_archived, sort_order}.
"""
# Seed a plan_billing row for "team" with non-default values for every
# NOT NULL field so we can detect any clobbering.
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
if existing is not None:
await test_db.delete(existing)
await test_db.commit()
seeded = PlanBilling(
plan="team",
display_name="Team Seeded",
is_public=False,
is_archived=True,
sort_order=5,
)
test_db.add(seeded)
await test_db.commit()
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "team",
"max_trees": None,
"max_sessions_per_month": None,
"max_users": None,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text"],
# Explicit nulls for every NOT NULL plan_billing field.
"display_name": None,
"is_public": None,
"is_archived": None,
"sort_order": None,
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
# Confirm the seeded NOT NULL values were preserved.
await test_db.commit() # ensure session sees writes from the request
pb = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
assert pb is not None
assert pb.display_name == "Team Seeded"
assert pb.is_public is False
assert pb.is_archived is True
assert pb.sort_order == 5
@pytest.mark.asyncio
async def test_admin_plan_limits_put_invalidates_billing_cache(
self, client: AsyncClient, admin_auth_headers: dict
):
"""PUT /admin/plan-limits calls BillingService.invalidate_billing_cache
with the account_ids on the affected plan."""
# Patch the staticmethod on the class. The endpoint imports
# BillingService at module load, so patch the symbol on the class
# itself — both the import and the dotted reference resolve to it.
with patch(
"app.api.endpoints.admin_plan_limits.BillingService.invalidate_billing_cache",
new_callable=AsyncMock,
) as spy:
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "pro",
"max_trees": 25,
"max_sessions_per_month": 500,
"max_users": 10,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text"],
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
spy.assert_awaited_once()
(account_ids_arg,) = spy.await_args.args
# admin fixture seeds an active Pro Subscription, so we expect at
# least one account_id in the invalidation list.
assert isinstance(account_ids_arg, list)
assert len(account_ids_arg) >= 1