feat(admin): extend /admin/plan-limits to manage plan_billing fields
Task 30 of self-serve signup Phase 2. Super-admins can now manage Stripe IDs, display names, prices, and public/archived flags via the existing admin plan-limits endpoints. - GET /admin/plan-limits now outer-joins plan_billing and returns merged PlanLimitWithBillingResponse rows. Plans without a plan_billing row return None for the billing fields. - PUT /admin/plan-limits accepts the new optional billing fields and upserts plan_billing in the same transaction. If no plan_billing row exists for the plan and the body includes any billing field, a row is created (display_name defaults to plan.capitalize() when omitted; display_name is never NULLed out on an existing row). - After commit, the handler queries account_ids on the affected plan and calls BillingService.invalidate_billing_cache(account_ids). This is a no-op stub today (logs only) — there's no in-process billing cache yet. TODO comment marks the wire-up point. - 3 new integration tests cover GET-with-billing-present, PUT creating a plan_billing row, and the invalidation hook being awaited with a list of account_ids. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -8,34 +8,101 @@ from app.core.database import get_db
|
|||||||
from app.core.audit import log_audit
|
from app.core.audit import log_audit
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.plan_limits import PlanLimits
|
from app.models.plan_limits import PlanLimits
|
||||||
|
from app.models.plan_billing import PlanBilling
|
||||||
from app.models.account import Account
|
from app.models.account import Account
|
||||||
from app.models.account_limit_override import AccountLimitOverride
|
from app.models.account_limit_override import AccountLimitOverride
|
||||||
|
from app.models.subscription import Subscription
|
||||||
from app.schemas.admin import (
|
from app.schemas.admin import (
|
||||||
PlanLimitResponse, PlanLimitUpdate,
|
PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse,
|
||||||
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
|
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
|
||||||
)
|
)
|
||||||
from app.api.deps import require_admin
|
from app.api.deps import require_admin
|
||||||
|
from app.services.billing import BillingService
|
||||||
|
|
||||||
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
|
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/plan-limits", response_model=list[PlanLimitResponse])
|
# Fields on PlanLimitUpdate that map to plan_billing (not plan_limits).
|
||||||
|
_PLAN_BILLING_FIELDS = (
|
||||||
|
"display_name",
|
||||||
|
"description",
|
||||||
|
"monthly_price_cents",
|
||||||
|
"annual_price_cents",
|
||||||
|
"stripe_product_id",
|
||||||
|
"stripe_monthly_price_id",
|
||||||
|
"stripe_annual_price_id",
|
||||||
|
"is_public",
|
||||||
|
"is_archived",
|
||||||
|
"sort_order",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model.
|
||||||
|
# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit
|
||||||
|
# null for any of them would otherwise trigger a NOT NULL violation at commit.
|
||||||
|
_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({
|
||||||
|
"display_name",
|
||||||
|
"is_public",
|
||||||
|
"is_archived",
|
||||||
|
"sort_order",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_plan_with_billing(
|
||||||
|
plan: PlanLimits, billing: PlanBilling | None
|
||||||
|
) -> PlanLimitWithBillingResponse:
|
||||||
|
"""Build a merged response. Billing fields are None when no plan_billing row
|
||||||
|
exists for the plan."""
|
||||||
|
payload = {
|
||||||
|
"plan": plan.plan,
|
||||||
|
"max_trees": plan.max_trees,
|
||||||
|
"max_sessions_per_month": plan.max_sessions_per_month,
|
||||||
|
"max_users": plan.max_users,
|
||||||
|
"custom_branding": plan.custom_branding,
|
||||||
|
"priority_support": plan.priority_support,
|
||||||
|
"export_formats": plan.export_formats or [],
|
||||||
|
}
|
||||||
|
if billing is not None:
|
||||||
|
payload.update({
|
||||||
|
"display_name": billing.display_name,
|
||||||
|
"description": billing.description,
|
||||||
|
"monthly_price_cents": billing.monthly_price_cents,
|
||||||
|
"annual_price_cents": billing.annual_price_cents,
|
||||||
|
"stripe_product_id": billing.stripe_product_id,
|
||||||
|
"stripe_monthly_price_id": billing.stripe_monthly_price_id,
|
||||||
|
"stripe_annual_price_id": billing.stripe_annual_price_id,
|
||||||
|
"is_public": billing.is_public,
|
||||||
|
"is_archived": billing.is_archived,
|
||||||
|
"sort_order": billing.sort_order,
|
||||||
|
})
|
||||||
|
return PlanLimitWithBillingResponse(**payload)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse])
|
||||||
async def list_plan_limits(
|
async def list_plan_limits(
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(require_admin)],
|
current_user: Annotated[User, Depends(require_admin)],
|
||||||
):
|
):
|
||||||
"""List all plan limit configurations."""
|
"""List all plan limit configurations, merged with plan_billing fields
|
||||||
result = await db.execute(select(PlanLimits))
|
where present. Plans without a plan_billing row return None for the
|
||||||
return result.scalars().all()
|
billing fields."""
|
||||||
|
rows = (await db.execute(
|
||||||
|
select(PlanLimits, PlanBilling)
|
||||||
|
.outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan)
|
||||||
|
)).all()
|
||||||
|
return [_merge_plan_with_billing(pl, pb) for pl, pb in rows]
|
||||||
|
|
||||||
|
|
||||||
@router.put("/plan-limits", response_model=PlanLimitResponse)
|
@router.put("/plan-limits", response_model=PlanLimitWithBillingResponse)
|
||||||
async def update_plan_limits(
|
async def update_plan_limits(
|
||||||
data: PlanLimitUpdate,
|
data: PlanLimitUpdate,
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(require_admin)],
|
current_user: Annotated[User, Depends(require_admin)],
|
||||||
):
|
):
|
||||||
"""Update a plan's limits."""
|
"""Update a plan's limits and (if any plan_billing field is included)
|
||||||
|
upsert the matching plan_billing row in the same transaction. After
|
||||||
|
commit, invalidates the in-process billing cache for accounts on this
|
||||||
|
plan (currently a no-op — see BillingService.invalidate_billing_cache).
|
||||||
|
"""
|
||||||
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
|
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
|
||||||
plan = result.scalar_one_or_none()
|
plan = result.scalar_one_or_none()
|
||||||
if not plan:
|
if not plan:
|
||||||
@@ -48,10 +115,50 @@ async def update_plan_limits(
|
|||||||
plan.priority_support = data.priority_support
|
plan.priority_support = data.priority_support
|
||||||
plan.export_formats = data.export_formats
|
plan.export_formats = data.export_formats
|
||||||
|
|
||||||
await log_audit(db, current_user.id, "plan_limits.update", "plan_limits", details={"plan": data.plan})
|
# Did the request include any plan_billing field? (Pydantic gives us
|
||||||
|
# `model_fields_set` to distinguish "user passed null" from "field omitted".)
|
||||||
|
billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS)
|
||||||
|
billing: PlanBilling | None = None
|
||||||
|
if billing_fields_set:
|
||||||
|
billing = (await db.execute(
|
||||||
|
select(PlanBilling).where(PlanBilling.plan == data.plan)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
|
||||||
|
if billing is None:
|
||||||
|
# Create. display_name is required on the model — derive from the
|
||||||
|
# plan name when the caller didn't supply one (e.g. "pro" → "Pro").
|
||||||
|
display_name = data.display_name or data.plan.capitalize()
|
||||||
|
billing = PlanBilling(plan=data.plan, display_name=display_name)
|
||||||
|
db.add(billing)
|
||||||
|
|
||||||
|
# Apply only the fields the caller actually included. Allows partial
|
||||||
|
# updates without clobbering existing values.
|
||||||
|
for field in billing_fields_set:
|
||||||
|
value = getattr(data, field)
|
||||||
|
if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS:
|
||||||
|
# Don't NULL out a NOT NULL column on update.
|
||||||
|
continue
|
||||||
|
setattr(billing, field, value)
|
||||||
|
|
||||||
|
await log_audit(
|
||||||
|
db, current_user.id, "plan_limits.update", "plan_limits",
|
||||||
|
details={"plan": data.plan, "updated_billing": bool(billing_fields_set)},
|
||||||
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(plan)
|
await db.refresh(plan)
|
||||||
return plan
|
if billing is not None:
|
||||||
|
await db.refresh(billing)
|
||||||
|
|
||||||
|
# Invalidate any in-process billing cache for accounts on this plan.
|
||||||
|
# TODO: invalidate app.state.billing_cache when added.
|
||||||
|
account_ids = [
|
||||||
|
row[0] for row in (await db.execute(
|
||||||
|
select(Subscription.account_id).where(Subscription.plan == data.plan)
|
||||||
|
)).all()
|
||||||
|
]
|
||||||
|
await BillingService.invalidate_billing_cache(account_ids)
|
||||||
|
|
||||||
|
return _merge_plan_with_billing(plan, billing)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/account-overrides", response_model=list[AccountOverrideResponse])
|
@router.get("/account-overrides", response_model=list[AccountOverrideResponse])
|
||||||
|
|||||||
@@ -172,6 +172,21 @@ class PlanLimitResponse(BaseModel):
|
|||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class PlanLimitWithBillingResponse(PlanLimitResponse):
|
||||||
|
"""PlanLimits + plan_billing fields merged. Billing fields are None when no
|
||||||
|
plan_billing row exists for the plan yet."""
|
||||||
|
display_name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
monthly_price_cents: Optional[int] = None
|
||||||
|
annual_price_cents: Optional[int] = None
|
||||||
|
stripe_product_id: Optional[str] = None
|
||||||
|
stripe_monthly_price_id: Optional[str] = None
|
||||||
|
stripe_annual_price_id: Optional[str] = None
|
||||||
|
is_public: Optional[bool] = None
|
||||||
|
is_archived: Optional[bool] = None
|
||||||
|
sort_order: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class PlanLimitUpdate(BaseModel):
|
class PlanLimitUpdate(BaseModel):
|
||||||
plan: str
|
plan: str
|
||||||
max_trees: Optional[int] = None
|
max_trees: Optional[int] = None
|
||||||
@@ -180,6 +195,19 @@ class PlanLimitUpdate(BaseModel):
|
|||||||
custom_branding: bool = False
|
custom_branding: bool = False
|
||||||
priority_support: bool = False
|
priority_support: bool = False
|
||||||
export_formats: list = Field(default_factory=lambda: ["markdown", "text"])
|
export_formats: list = Field(default_factory=lambda: ["markdown", "text"])
|
||||||
|
# plan_billing fields — all optional, partial-update semantics. If any are
|
||||||
|
# set in the body, the admin endpoint upserts the plan_billing row in the
|
||||||
|
# same transaction.
|
||||||
|
display_name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
monthly_price_cents: Optional[int] = None
|
||||||
|
annual_price_cents: Optional[int] = None
|
||||||
|
stripe_product_id: Optional[str] = None
|
||||||
|
stripe_monthly_price_id: Optional[str] = None
|
||||||
|
stripe_annual_price_id: Optional[str] = None
|
||||||
|
is_public: Optional[bool] = None
|
||||||
|
is_archived: Optional[bool] = None
|
||||||
|
sort_order: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class AccountOverrideCreate(BaseModel):
|
class AccountOverrideCreate(BaseModel):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Single billing service module. Stripe is the only impl — no provider
|
"""Single billing service module. Stripe is the only impl — no provider
|
||||||
abstraction. Account row is canonical local state; Stripe is canonical
|
abstraction. Account row is canonical local state; Stripe is canonical
|
||||||
remote state; the webhook handler bridges the two."""
|
remote state; the webhook handler bridges the two."""
|
||||||
|
import logging
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
|
|
||||||
import stripe
|
import stripe
|
||||||
@@ -17,8 +18,32 @@ from app.models.subscription import Subscription
|
|||||||
|
|
||||||
TRIAL_DAYS = 14
|
TRIAL_DAYS = 14
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BillingService:
|
class BillingService:
|
||||||
|
@staticmethod
|
||||||
|
async def invalidate_billing_cache(account_ids) -> None:
|
||||||
|
"""No-op stub for future in-process billing cache invalidation.
|
||||||
|
|
||||||
|
Today there is no `app.state.billing_cache` — `BillingService.get_billing_state`
|
||||||
|
always reads fresh from the DB. Call sites that mutate plan/feature data
|
||||||
|
invoke this hook so that wiring is in place when an in-process cache is
|
||||||
|
added later. Until then, this just logs.
|
||||||
|
|
||||||
|
TODO: when an in-process billing cache (e.g. `app.state.billing_cache`)
|
||||||
|
is introduced, evict entries for the given account_ids here.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
count = len(list(account_ids))
|
||||||
|
except TypeError:
|
||||||
|
count = -1
|
||||||
|
logger.debug(
|
||||||
|
"BillingService.invalidate_billing_cache called for %d account(s) "
|
||||||
|
"(no-op stub — wire to app.state.billing_cache when added)",
|
||||||
|
count,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def start_trial(db: AsyncSession, account_id) -> Subscription:
|
async def start_trial(db: AsyncSession, account_id) -> Subscription:
|
||||||
"""Idempotent. Creates a trialing Subscription on Pro for the account if
|
"""Idempotent. Creates a trialing Subscription on Pro for the account if
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
"""Integration tests for admin plan limits and account override endpoints."""
|
"""Integration tests for admin plan limits and account override endpoints."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.models.plan_billing import PlanBilling
|
||||||
|
|
||||||
|
|
||||||
class TestAdminPlanLimits:
|
class TestAdminPlanLimits:
|
||||||
@@ -56,3 +61,204 @@ class TestAdminPlanLimits:
|
|||||||
"""Non-admin gets 403."""
|
"""Non-admin gets 403."""
|
||||||
response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers)
|
response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers)
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_admin_plan_limits_get_includes_plan_billing_fields_when_present(
|
||||||
|
self, client: AsyncClient, admin_auth_headers: dict, test_db
|
||||||
|
):
|
||||||
|
"""GET /admin/plan-limits returns plan_billing fields when a row exists,
|
||||||
|
and None for plans that don't have one yet."""
|
||||||
|
# Seed a plan_billing row for "pro".
|
||||||
|
existing = (await test_db.execute(
|
||||||
|
select(PlanBilling).where(PlanBilling.plan == "pro")
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if existing is None:
|
||||||
|
test_db.add(PlanBilling(
|
||||||
|
plan="pro",
|
||||||
|
display_name="Pro",
|
||||||
|
description="For working teams",
|
||||||
|
monthly_price_cents=4900,
|
||||||
|
annual_price_cents=49000,
|
||||||
|
stripe_product_id="prod_seed",
|
||||||
|
stripe_monthly_price_id="price_seed_m",
|
||||||
|
stripe_annual_price_id="price_seed_a",
|
||||||
|
is_public=True,
|
||||||
|
is_archived=False,
|
||||||
|
sort_order=10,
|
||||||
|
))
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/admin/plan-limits", headers=admin_auth_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
plans_by_name = {p["plan"]: p for p in response.json()}
|
||||||
|
|
||||||
|
assert "pro" in plans_by_name
|
||||||
|
pro = plans_by_name["pro"]
|
||||||
|
assert pro["display_name"] == "Pro"
|
||||||
|
assert pro["monthly_price_cents"] == 4900
|
||||||
|
assert pro["stripe_monthly_price_id"] == "price_seed_m"
|
||||||
|
assert pro["is_public"] is True
|
||||||
|
assert pro["is_archived"] is False
|
||||||
|
assert pro["sort_order"] == 10
|
||||||
|
|
||||||
|
# A plan without a plan_billing row should still return, with None
|
||||||
|
# billing fields.
|
||||||
|
if "free" in plans_by_name:
|
||||||
|
free = plans_by_name["free"]
|
||||||
|
# free has no plan_billing row in the seed → fields are None.
|
||||||
|
no_billing_row = (await test_db.execute(
|
||||||
|
select(PlanBilling).where(PlanBilling.plan == "free")
|
||||||
|
)).scalar_one_or_none() is None
|
||||||
|
if no_billing_row:
|
||||||
|
assert free["display_name"] is None
|
||||||
|
assert free["monthly_price_cents"] is None
|
||||||
|
assert free["stripe_product_id"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_admin_plan_limits_put_creates_plan_billing_row(
|
||||||
|
self, client: AsyncClient, admin_auth_headers: dict, test_db
|
||||||
|
):
|
||||||
|
"""PUT /admin/plan-limits upserts a plan_billing row when billing
|
||||||
|
fields are included in the body."""
|
||||||
|
# Ensure no plan_billing row exists for "team" yet.
|
||||||
|
existing = (await test_db.execute(
|
||||||
|
select(PlanBilling).where(PlanBilling.plan == "team")
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if existing is not None:
|
||||||
|
await test_db.delete(existing)
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
response = await client.put(
|
||||||
|
"/api/v1/admin/plan-limits",
|
||||||
|
json={
|
||||||
|
"plan": "team",
|
||||||
|
"max_trees": None,
|
||||||
|
"max_sessions_per_month": None,
|
||||||
|
"max_users": None,
|
||||||
|
"custom_branding": True,
|
||||||
|
"priority_support": True,
|
||||||
|
"export_formats": ["markdown", "text", "pdf"],
|
||||||
|
"display_name": "Team",
|
||||||
|
"description": "For growing shops",
|
||||||
|
"monthly_price_cents": 9900,
|
||||||
|
"annual_price_cents": 99000,
|
||||||
|
"stripe_product_id": "prod_team_test",
|
||||||
|
"stripe_monthly_price_id": "price_team_m",
|
||||||
|
"stripe_annual_price_id": "price_team_a",
|
||||||
|
"is_public": True,
|
||||||
|
"is_archived": False,
|
||||||
|
"sort_order": 20,
|
||||||
|
},
|
||||||
|
headers=admin_auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
body = response.json()
|
||||||
|
assert body["display_name"] == "Team"
|
||||||
|
assert body["monthly_price_cents"] == 9900
|
||||||
|
assert body["stripe_product_id"] == "prod_team_test"
|
||||||
|
assert body["sort_order"] == 20
|
||||||
|
|
||||||
|
# Confirm the row was actually persisted.
|
||||||
|
await test_db.commit() # ensure session sees other-session writes
|
||||||
|
pb = (await test_db.execute(
|
||||||
|
select(PlanBilling).where(PlanBilling.plan == "team")
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
assert pb is not None
|
||||||
|
assert pb.display_name == "Team"
|
||||||
|
assert pb.monthly_price_cents == 9900
|
||||||
|
assert pb.stripe_monthly_price_id == "price_team_m"
|
||||||
|
assert pb.is_public is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_admin_plan_limits_put_does_not_null_out_required_fields(
|
||||||
|
self, client: AsyncClient, admin_auth_headers: dict, test_db
|
||||||
|
):
|
||||||
|
"""PUT /admin/plan-limits must not NULL out NOT NULL columns on the
|
||||||
|
plan_billing row when the caller passes explicit nulls. The set of
|
||||||
|
guarded fields is {display_name, is_public, is_archived, sort_order}.
|
||||||
|
"""
|
||||||
|
# Seed a plan_billing row for "team" with non-default values for every
|
||||||
|
# NOT NULL field so we can detect any clobbering.
|
||||||
|
existing = (await test_db.execute(
|
||||||
|
select(PlanBilling).where(PlanBilling.plan == "team")
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if existing is not None:
|
||||||
|
await test_db.delete(existing)
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
seeded = PlanBilling(
|
||||||
|
plan="team",
|
||||||
|
display_name="Team Seeded",
|
||||||
|
is_public=False,
|
||||||
|
is_archived=True,
|
||||||
|
sort_order=5,
|
||||||
|
)
|
||||||
|
test_db.add(seeded)
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
response = await client.put(
|
||||||
|
"/api/v1/admin/plan-limits",
|
||||||
|
json={
|
||||||
|
"plan": "team",
|
||||||
|
"max_trees": None,
|
||||||
|
"max_sessions_per_month": None,
|
||||||
|
"max_users": None,
|
||||||
|
"custom_branding": True,
|
||||||
|
"priority_support": True,
|
||||||
|
"export_formats": ["markdown", "text"],
|
||||||
|
# Explicit nulls for every NOT NULL plan_billing field.
|
||||||
|
"display_name": None,
|
||||||
|
"is_public": None,
|
||||||
|
"is_archived": None,
|
||||||
|
"sort_order": None,
|
||||||
|
},
|
||||||
|
headers=admin_auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
|
||||||
|
# Confirm the seeded NOT NULL values were preserved.
|
||||||
|
await test_db.commit() # ensure session sees writes from the request
|
||||||
|
pb = (await test_db.execute(
|
||||||
|
select(PlanBilling).where(PlanBilling.plan == "team")
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
assert pb is not None
|
||||||
|
assert pb.display_name == "Team Seeded"
|
||||||
|
assert pb.is_public is False
|
||||||
|
assert pb.is_archived is True
|
||||||
|
assert pb.sort_order == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_admin_plan_limits_put_invalidates_billing_cache(
|
||||||
|
self, client: AsyncClient, admin_auth_headers: dict
|
||||||
|
):
|
||||||
|
"""PUT /admin/plan-limits calls BillingService.invalidate_billing_cache
|
||||||
|
with the account_ids on the affected plan."""
|
||||||
|
# Patch the staticmethod on the class. The endpoint imports
|
||||||
|
# BillingService at module load, so patch the symbol on the class
|
||||||
|
# itself — both the import and the dotted reference resolve to it.
|
||||||
|
with patch(
|
||||||
|
"app.api.endpoints.admin_plan_limits.BillingService.invalidate_billing_cache",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as spy:
|
||||||
|
response = await client.put(
|
||||||
|
"/api/v1/admin/plan-limits",
|
||||||
|
json={
|
||||||
|
"plan": "pro",
|
||||||
|
"max_trees": 25,
|
||||||
|
"max_sessions_per_month": 500,
|
||||||
|
"max_users": 10,
|
||||||
|
"custom_branding": True,
|
||||||
|
"priority_support": True,
|
||||||
|
"export_formats": ["markdown", "text"],
|
||||||
|
},
|
||||||
|
headers=admin_auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
spy.assert_awaited_once()
|
||||||
|
(account_ids_arg,) = spy.await_args.args
|
||||||
|
# admin fixture seeds an active Pro Subscription, so we expect at
|
||||||
|
# least one account_id in the invalidation list.
|
||||||
|
assert isinstance(account_ids_arg, list)
|
||||||
|
assert len(account_ids_arg) >= 1
|
||||||
|
|||||||
Reference in New Issue
Block a user