feat(admin): extend /admin/plan-limits to manage plan_billing fields
Task 30 of self-serve signup Phase 2. Super-admins can now manage Stripe IDs, display names, prices, and public/archived flags via the existing admin plan-limits endpoints. - GET /admin/plan-limits now outer-joins plan_billing and returns merged PlanLimitWithBillingResponse rows. Plans without a plan_billing row return None for the billing fields. - PUT /admin/plan-limits accepts the new optional billing fields and upserts plan_billing in the same transaction. If no plan_billing row exists for the plan and the body includes any billing field, a row is created (display_name defaults to plan.capitalize() when omitted; display_name is never NULLed out on an existing row). - After commit, the handler queries account_ids on the affected plan and calls BillingService.invalidate_billing_cache(account_ids). This is a no-op stub today (logs only) — there's no in-process billing cache yet. TODO comment marks the wire-up point. - 3 new integration tests cover GET-with-billing-present, PUT creating a plan_billing row, and the invalidation hook being awaited with a list of account_ids. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -8,34 +8,101 @@ from app.core.database import get_db
|
||||
from app.core.audit import log_audit
|
||||
from app.models.user import User
|
||||
from app.models.plan_limits import PlanLimits
|
||||
from app.models.plan_billing import PlanBilling
|
||||
from app.models.account import Account
|
||||
from app.models.account_limit_override import AccountLimitOverride
|
||||
from app.models.subscription import Subscription
|
||||
from app.schemas.admin import (
|
||||
PlanLimitResponse, PlanLimitUpdate,
|
||||
PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse,
|
||||
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
|
||||
)
|
||||
from app.api.deps import require_admin
|
||||
from app.services.billing import BillingService
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
|
||||
|
||||
|
||||
@router.get("/plan-limits", response_model=list[PlanLimitResponse])
|
||||
# Fields on PlanLimitUpdate that map to plan_billing (not plan_limits).
|
||||
_PLAN_BILLING_FIELDS = (
|
||||
"display_name",
|
||||
"description",
|
||||
"monthly_price_cents",
|
||||
"annual_price_cents",
|
||||
"stripe_product_id",
|
||||
"stripe_monthly_price_id",
|
||||
"stripe_annual_price_id",
|
||||
"is_public",
|
||||
"is_archived",
|
||||
"sort_order",
|
||||
)
|
||||
|
||||
# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model.
|
||||
# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit
|
||||
# null for any of them would otherwise trigger a NOT NULL violation at commit.
|
||||
_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({
|
||||
"display_name",
|
||||
"is_public",
|
||||
"is_archived",
|
||||
"sort_order",
|
||||
})
|
||||
|
||||
|
||||
def _merge_plan_with_billing(
|
||||
plan: PlanLimits, billing: PlanBilling | None
|
||||
) -> PlanLimitWithBillingResponse:
|
||||
"""Build a merged response. Billing fields are None when no plan_billing row
|
||||
exists for the plan."""
|
||||
payload = {
|
||||
"plan": plan.plan,
|
||||
"max_trees": plan.max_trees,
|
||||
"max_sessions_per_month": plan.max_sessions_per_month,
|
||||
"max_users": plan.max_users,
|
||||
"custom_branding": plan.custom_branding,
|
||||
"priority_support": plan.priority_support,
|
||||
"export_formats": plan.export_formats or [],
|
||||
}
|
||||
if billing is not None:
|
||||
payload.update({
|
||||
"display_name": billing.display_name,
|
||||
"description": billing.description,
|
||||
"monthly_price_cents": billing.monthly_price_cents,
|
||||
"annual_price_cents": billing.annual_price_cents,
|
||||
"stripe_product_id": billing.stripe_product_id,
|
||||
"stripe_monthly_price_id": billing.stripe_monthly_price_id,
|
||||
"stripe_annual_price_id": billing.stripe_annual_price_id,
|
||||
"is_public": billing.is_public,
|
||||
"is_archived": billing.is_archived,
|
||||
"sort_order": billing.sort_order,
|
||||
})
|
||||
return PlanLimitWithBillingResponse(**payload)
|
||||
|
||||
|
||||
@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse])
|
||||
async def list_plan_limits(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""List all plan limit configurations."""
|
||||
result = await db.execute(select(PlanLimits))
|
||||
return result.scalars().all()
|
||||
"""List all plan limit configurations, merged with plan_billing fields
|
||||
where present. Plans without a plan_billing row return None for the
|
||||
billing fields."""
|
||||
rows = (await db.execute(
|
||||
select(PlanLimits, PlanBilling)
|
||||
.outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan)
|
||||
)).all()
|
||||
return [_merge_plan_with_billing(pl, pb) for pl, pb in rows]
|
||||
|
||||
|
||||
@router.put("/plan-limits", response_model=PlanLimitResponse)
|
||||
@router.put("/plan-limits", response_model=PlanLimitWithBillingResponse)
|
||||
async def update_plan_limits(
|
||||
data: PlanLimitUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Update a plan's limits."""
|
||||
"""Update a plan's limits and (if any plan_billing field is included)
|
||||
upsert the matching plan_billing row in the same transaction. After
|
||||
commit, invalidates the in-process billing cache for accounts on this
|
||||
plan (currently a no-op — see BillingService.invalidate_billing_cache).
|
||||
"""
|
||||
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
|
||||
plan = result.scalar_one_or_none()
|
||||
if not plan:
|
||||
@@ -48,10 +115,50 @@ async def update_plan_limits(
|
||||
plan.priority_support = data.priority_support
|
||||
plan.export_formats = data.export_formats
|
||||
|
||||
await log_audit(db, current_user.id, "plan_limits.update", "plan_limits", details={"plan": data.plan})
|
||||
# Did the request include any plan_billing field? (Pydantic gives us
|
||||
# `model_fields_set` to distinguish "user passed null" from "field omitted".)
|
||||
billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS)
|
||||
billing: PlanBilling | None = None
|
||||
if billing_fields_set:
|
||||
billing = (await db.execute(
|
||||
select(PlanBilling).where(PlanBilling.plan == data.plan)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if billing is None:
|
||||
# Create. display_name is required on the model — derive from the
|
||||
# plan name when the caller didn't supply one (e.g. "pro" → "Pro").
|
||||
display_name = data.display_name or data.plan.capitalize()
|
||||
billing = PlanBilling(plan=data.plan, display_name=display_name)
|
||||
db.add(billing)
|
||||
|
||||
# Apply only the fields the caller actually included. Allows partial
|
||||
# updates without clobbering existing values.
|
||||
for field in billing_fields_set:
|
||||
value = getattr(data, field)
|
||||
if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS:
|
||||
# Don't NULL out a NOT NULL column on update.
|
||||
continue
|
||||
setattr(billing, field, value)
|
||||
|
||||
await log_audit(
|
||||
db, current_user.id, "plan_limits.update", "plan_limits",
|
||||
details={"plan": data.plan, "updated_billing": bool(billing_fields_set)},
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(plan)
|
||||
return plan
|
||||
if billing is not None:
|
||||
await db.refresh(billing)
|
||||
|
||||
# Invalidate any in-process billing cache for accounts on this plan.
|
||||
# TODO: invalidate app.state.billing_cache when added.
|
||||
account_ids = [
|
||||
row[0] for row in (await db.execute(
|
||||
select(Subscription.account_id).where(Subscription.plan == data.plan)
|
||||
)).all()
|
||||
]
|
||||
await BillingService.invalidate_billing_cache(account_ids)
|
||||
|
||||
return _merge_plan_with_billing(plan, billing)
|
||||
|
||||
|
||||
@router.get("/account-overrides", response_model=list[AccountOverrideResponse])
|
||||
|
||||
@@ -172,6 +172,21 @@ class PlanLimitResponse(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PlanLimitWithBillingResponse(PlanLimitResponse):
|
||||
"""PlanLimits + plan_billing fields merged. Billing fields are None when no
|
||||
plan_billing row exists for the plan yet."""
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
monthly_price_cents: Optional[int] = None
|
||||
annual_price_cents: Optional[int] = None
|
||||
stripe_product_id: Optional[str] = None
|
||||
stripe_monthly_price_id: Optional[str] = None
|
||||
stripe_annual_price_id: Optional[str] = None
|
||||
is_public: Optional[bool] = None
|
||||
is_archived: Optional[bool] = None
|
||||
sort_order: Optional[int] = None
|
||||
|
||||
|
||||
class PlanLimitUpdate(BaseModel):
|
||||
plan: str
|
||||
max_trees: Optional[int] = None
|
||||
@@ -180,6 +195,19 @@ class PlanLimitUpdate(BaseModel):
|
||||
custom_branding: bool = False
|
||||
priority_support: bool = False
|
||||
export_formats: list = Field(default_factory=lambda: ["markdown", "text"])
|
||||
# plan_billing fields — all optional, partial-update semantics. If any are
|
||||
# set in the body, the admin endpoint upserts the plan_billing row in the
|
||||
# same transaction.
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
monthly_price_cents: Optional[int] = None
|
||||
annual_price_cents: Optional[int] = None
|
||||
stripe_product_id: Optional[str] = None
|
||||
stripe_monthly_price_id: Optional[str] = None
|
||||
stripe_annual_price_id: Optional[str] = None
|
||||
is_public: Optional[bool] = None
|
||||
is_archived: Optional[bool] = None
|
||||
sort_order: Optional[int] = None
|
||||
|
||||
|
||||
class AccountOverrideCreate(BaseModel):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Single billing service module. Stripe is the only impl — no provider
|
||||
abstraction. Account row is canonical local state; Stripe is canonical
|
||||
remote state; the webhook handler bridges the two."""
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import stripe
|
||||
@@ -17,8 +18,32 @@ from app.models.subscription import Subscription
|
||||
|
||||
TRIAL_DAYS = 14
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BillingService:
|
||||
@staticmethod
|
||||
async def invalidate_billing_cache(account_ids) -> None:
|
||||
"""No-op stub for future in-process billing cache invalidation.
|
||||
|
||||
Today there is no `app.state.billing_cache` — `BillingService.get_billing_state`
|
||||
always reads fresh from the DB. Call sites that mutate plan/feature data
|
||||
invoke this hook so that wiring is in place when an in-process cache is
|
||||
added later. Until then, this just logs.
|
||||
|
||||
TODO: when an in-process billing cache (e.g. `app.state.billing_cache`)
|
||||
is introduced, evict entries for the given account_ids here.
|
||||
"""
|
||||
try:
|
||||
count = len(list(account_ids))
|
||||
except TypeError:
|
||||
count = -1
|
||||
logger.debug(
|
||||
"BillingService.invalidate_billing_cache called for %d account(s) "
|
||||
"(no-op stub — wire to app.state.billing_cache when added)",
|
||||
count,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def start_trial(db: AsyncSession, account_id) -> Subscription:
|
||||
"""Idempotent. Creates a trialing Subscription on Pro for the account if
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
"""Integration tests for admin plan limits and account override endpoints."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.plan_billing import PlanBilling
|
||||
|
||||
|
||||
class TestAdminPlanLimits:
|
||||
@@ -56,3 +61,204 @@ class TestAdminPlanLimits:
|
||||
"""Non-admin gets 403."""
|
||||
response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers)
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_plan_limits_get_includes_plan_billing_fields_when_present(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_db
|
||||
):
|
||||
"""GET /admin/plan-limits returns plan_billing fields when a row exists,
|
||||
and None for plans that don't have one yet."""
|
||||
# Seed a plan_billing row for "pro".
|
||||
existing = (await test_db.execute(
|
||||
select(PlanBilling).where(PlanBilling.plan == "pro")
|
||||
)).scalar_one_or_none()
|
||||
if existing is None:
|
||||
test_db.add(PlanBilling(
|
||||
plan="pro",
|
||||
display_name="Pro",
|
||||
description="For working teams",
|
||||
monthly_price_cents=4900,
|
||||
annual_price_cents=49000,
|
||||
stripe_product_id="prod_seed",
|
||||
stripe_monthly_price_id="price_seed_m",
|
||||
stripe_annual_price_id="price_seed_a",
|
||||
is_public=True,
|
||||
is_archived=False,
|
||||
sort_order=10,
|
||||
))
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/admin/plan-limits", headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
plans_by_name = {p["plan"]: p for p in response.json()}
|
||||
|
||||
assert "pro" in plans_by_name
|
||||
pro = plans_by_name["pro"]
|
||||
assert pro["display_name"] == "Pro"
|
||||
assert pro["monthly_price_cents"] == 4900
|
||||
assert pro["stripe_monthly_price_id"] == "price_seed_m"
|
||||
assert pro["is_public"] is True
|
||||
assert pro["is_archived"] is False
|
||||
assert pro["sort_order"] == 10
|
||||
|
||||
# A plan without a plan_billing row should still return, with None
|
||||
# billing fields.
|
||||
if "free" in plans_by_name:
|
||||
free = plans_by_name["free"]
|
||||
# free has no plan_billing row in the seed → fields are None.
|
||||
no_billing_row = (await test_db.execute(
|
||||
select(PlanBilling).where(PlanBilling.plan == "free")
|
||||
)).scalar_one_or_none() is None
|
||||
if no_billing_row:
|
||||
assert free["display_name"] is None
|
||||
assert free["monthly_price_cents"] is None
|
||||
assert free["stripe_product_id"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_plan_limits_put_creates_plan_billing_row(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_db
|
||||
):
|
||||
"""PUT /admin/plan-limits upserts a plan_billing row when billing
|
||||
fields are included in the body."""
|
||||
# Ensure no plan_billing row exists for "team" yet.
|
||||
existing = (await test_db.execute(
|
||||
select(PlanBilling).where(PlanBilling.plan == "team")
|
||||
)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
await test_db.delete(existing)
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.put(
|
||||
"/api/v1/admin/plan-limits",
|
||||
json={
|
||||
"plan": "team",
|
||||
"max_trees": None,
|
||||
"max_sessions_per_month": None,
|
||||
"max_users": None,
|
||||
"custom_branding": True,
|
||||
"priority_support": True,
|
||||
"export_formats": ["markdown", "text", "pdf"],
|
||||
"display_name": "Team",
|
||||
"description": "For growing shops",
|
||||
"monthly_price_cents": 9900,
|
||||
"annual_price_cents": 99000,
|
||||
"stripe_product_id": "prod_team_test",
|
||||
"stripe_monthly_price_id": "price_team_m",
|
||||
"stripe_annual_price_id": "price_team_a",
|
||||
"is_public": True,
|
||||
"is_archived": False,
|
||||
"sort_order": 20,
|
||||
},
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
assert body["display_name"] == "Team"
|
||||
assert body["monthly_price_cents"] == 9900
|
||||
assert body["stripe_product_id"] == "prod_team_test"
|
||||
assert body["sort_order"] == 20
|
||||
|
||||
# Confirm the row was actually persisted.
|
||||
await test_db.commit() # ensure session sees other-session writes
|
||||
pb = (await test_db.execute(
|
||||
select(PlanBilling).where(PlanBilling.plan == "team")
|
||||
)).scalar_one_or_none()
|
||||
assert pb is not None
|
||||
assert pb.display_name == "Team"
|
||||
assert pb.monthly_price_cents == 9900
|
||||
assert pb.stripe_monthly_price_id == "price_team_m"
|
||||
assert pb.is_public is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_plan_limits_put_does_not_null_out_required_fields(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_db
|
||||
):
|
||||
"""PUT /admin/plan-limits must not NULL out NOT NULL columns on the
|
||||
plan_billing row when the caller passes explicit nulls. The set of
|
||||
guarded fields is {display_name, is_public, is_archived, sort_order}.
|
||||
"""
|
||||
# Seed a plan_billing row for "team" with non-default values for every
|
||||
# NOT NULL field so we can detect any clobbering.
|
||||
existing = (await test_db.execute(
|
||||
select(PlanBilling).where(PlanBilling.plan == "team")
|
||||
)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
await test_db.delete(existing)
|
||||
await test_db.commit()
|
||||
|
||||
seeded = PlanBilling(
|
||||
plan="team",
|
||||
display_name="Team Seeded",
|
||||
is_public=False,
|
||||
is_archived=True,
|
||||
sort_order=5,
|
||||
)
|
||||
test_db.add(seeded)
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.put(
|
||||
"/api/v1/admin/plan-limits",
|
||||
json={
|
||||
"plan": "team",
|
||||
"max_trees": None,
|
||||
"max_sessions_per_month": None,
|
||||
"max_users": None,
|
||||
"custom_branding": True,
|
||||
"priority_support": True,
|
||||
"export_formats": ["markdown", "text"],
|
||||
# Explicit nulls for every NOT NULL plan_billing field.
|
||||
"display_name": None,
|
||||
"is_public": None,
|
||||
"is_archived": None,
|
||||
"sort_order": None,
|
||||
},
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
|
||||
# Confirm the seeded NOT NULL values were preserved.
|
||||
await test_db.commit() # ensure session sees writes from the request
|
||||
pb = (await test_db.execute(
|
||||
select(PlanBilling).where(PlanBilling.plan == "team")
|
||||
)).scalar_one_or_none()
|
||||
assert pb is not None
|
||||
assert pb.display_name == "Team Seeded"
|
||||
assert pb.is_public is False
|
||||
assert pb.is_archived is True
|
||||
assert pb.sort_order == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_plan_limits_put_invalidates_billing_cache(
|
||||
self, client: AsyncClient, admin_auth_headers: dict
|
||||
):
|
||||
"""PUT /admin/plan-limits calls BillingService.invalidate_billing_cache
|
||||
with the account_ids on the affected plan."""
|
||||
# Patch the staticmethod on the class. The endpoint imports
|
||||
# BillingService at module load, so patch the symbol on the class
|
||||
# itself — both the import and the dotted reference resolve to it.
|
||||
with patch(
|
||||
"app.api.endpoints.admin_plan_limits.BillingService.invalidate_billing_cache",
|
||||
new_callable=AsyncMock,
|
||||
) as spy:
|
||||
response = await client.put(
|
||||
"/api/v1/admin/plan-limits",
|
||||
json={
|
||||
"plan": "pro",
|
||||
"max_trees": 25,
|
||||
"max_sessions_per_month": 500,
|
||||
"max_users": 10,
|
||||
"custom_branding": True,
|
||||
"priority_support": True,
|
||||
"export_formats": ["markdown", "text"],
|
||||
},
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
spy.assert_awaited_once()
|
||||
(account_ids_arg,) = spy.await_args.args
|
||||
# admin fixture seeds an active Pro Subscription, so we expect at
|
||||
# least one account_id in the invalidation list.
|
||||
assert isinstance(account_ids_arg, list)
|
||||
assert len(account_ids_arg) >= 1
|
||||
|
||||
Reference in New Issue
Block a user