feat: self-serve signup Phase 2 (frontend cutover) (#162)
Co-authored-by: Michael Chihlas <michael@resolutionflow.com> Co-committed-by: Michael Chihlas <michael@resolutionflow.com>
This commit was merged in pull request #162.
This commit is contained in:
@@ -8,34 +8,101 @@ from app.core.database import get_db
|
||||
from app.core.audit import log_audit
|
||||
from app.models.user import User
|
||||
from app.models.plan_limits import PlanLimits
|
||||
from app.models.plan_billing import PlanBilling
|
||||
from app.models.account import Account
|
||||
from app.models.account_limit_override import AccountLimitOverride
|
||||
from app.models.subscription import Subscription
|
||||
from app.schemas.admin import (
|
||||
PlanLimitResponse, PlanLimitUpdate,
|
||||
PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse,
|
||||
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
|
||||
)
|
||||
from app.api.deps import require_admin
|
||||
from app.services.billing import BillingService
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
|
||||
|
||||
|
||||
@router.get("/plan-limits", response_model=list[PlanLimitResponse])
|
||||
# Fields on PlanLimitUpdate that map to plan_billing (not plan_limits).
|
||||
_PLAN_BILLING_FIELDS = (
|
||||
"display_name",
|
||||
"description",
|
||||
"monthly_price_cents",
|
||||
"annual_price_cents",
|
||||
"stripe_product_id",
|
||||
"stripe_monthly_price_id",
|
||||
"stripe_annual_price_id",
|
||||
"is_public",
|
||||
"is_archived",
|
||||
"sort_order",
|
||||
)
|
||||
|
||||
# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model.
|
||||
# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit
|
||||
# null for any of them would otherwise trigger a NOT NULL violation at commit.
|
||||
_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({
|
||||
"display_name",
|
||||
"is_public",
|
||||
"is_archived",
|
||||
"sort_order",
|
||||
})
|
||||
|
||||
|
||||
def _merge_plan_with_billing(
|
||||
plan: PlanLimits, billing: PlanBilling | None
|
||||
) -> PlanLimitWithBillingResponse:
|
||||
"""Build a merged response. Billing fields are None when no plan_billing row
|
||||
exists for the plan."""
|
||||
payload = {
|
||||
"plan": plan.plan,
|
||||
"max_trees": plan.max_trees,
|
||||
"max_sessions_per_month": plan.max_sessions_per_month,
|
||||
"max_users": plan.max_users,
|
||||
"custom_branding": plan.custom_branding,
|
||||
"priority_support": plan.priority_support,
|
||||
"export_formats": plan.export_formats or [],
|
||||
}
|
||||
if billing is not None:
|
||||
payload.update({
|
||||
"display_name": billing.display_name,
|
||||
"description": billing.description,
|
||||
"monthly_price_cents": billing.monthly_price_cents,
|
||||
"annual_price_cents": billing.annual_price_cents,
|
||||
"stripe_product_id": billing.stripe_product_id,
|
||||
"stripe_monthly_price_id": billing.stripe_monthly_price_id,
|
||||
"stripe_annual_price_id": billing.stripe_annual_price_id,
|
||||
"is_public": billing.is_public,
|
||||
"is_archived": billing.is_archived,
|
||||
"sort_order": billing.sort_order,
|
||||
})
|
||||
return PlanLimitWithBillingResponse(**payload)
|
||||
|
||||
|
||||
@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse])
|
||||
async def list_plan_limits(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""List all plan limit configurations."""
|
||||
result = await db.execute(select(PlanLimits))
|
||||
return result.scalars().all()
|
||||
"""List all plan limit configurations, merged with plan_billing fields
|
||||
where present. Plans without a plan_billing row return None for the
|
||||
billing fields."""
|
||||
rows = (await db.execute(
|
||||
select(PlanLimits, PlanBilling)
|
||||
.outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan)
|
||||
)).all()
|
||||
return [_merge_plan_with_billing(pl, pb) for pl, pb in rows]
|
||||
|
||||
|
||||
@router.put("/plan-limits", response_model=PlanLimitResponse)
|
||||
@router.put("/plan-limits", response_model=PlanLimitWithBillingResponse)
|
||||
async def update_plan_limits(
|
||||
data: PlanLimitUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Update a plan's limits."""
|
||||
"""Update a plan's limits and (if any plan_billing field is included)
|
||||
upsert the matching plan_billing row in the same transaction. After
|
||||
commit, invalidates the in-process billing cache for accounts on this
|
||||
plan (currently a no-op — see BillingService.invalidate_billing_cache).
|
||||
"""
|
||||
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
|
||||
plan = result.scalar_one_or_none()
|
||||
if not plan:
|
||||
@@ -48,10 +115,50 @@ async def update_plan_limits(
|
||||
plan.priority_support = data.priority_support
|
||||
plan.export_formats = data.export_formats
|
||||
|
||||
await log_audit(db, current_user.id, "plan_limits.update", "plan_limits", details={"plan": data.plan})
|
||||
# Did the request include any plan_billing field? (Pydantic gives us
|
||||
# `model_fields_set` to distinguish "user passed null" from "field omitted".)
|
||||
billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS)
|
||||
billing: PlanBilling | None = None
|
||||
if billing_fields_set:
|
||||
billing = (await db.execute(
|
||||
select(PlanBilling).where(PlanBilling.plan == data.plan)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if billing is None:
|
||||
# Create. display_name is required on the model — derive from the
|
||||
# plan name when the caller didn't supply one (e.g. "pro" → "Pro").
|
||||
display_name = data.display_name or data.plan.capitalize()
|
||||
billing = PlanBilling(plan=data.plan, display_name=display_name)
|
||||
db.add(billing)
|
||||
|
||||
# Apply only the fields the caller actually included. Allows partial
|
||||
# updates without clobbering existing values.
|
||||
for field in billing_fields_set:
|
||||
value = getattr(data, field)
|
||||
if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS:
|
||||
# Don't NULL out a NOT NULL column on update.
|
||||
continue
|
||||
setattr(billing, field, value)
|
||||
|
||||
await log_audit(
|
||||
db, current_user.id, "plan_limits.update", "plan_limits",
|
||||
details={"plan": data.plan, "updated_billing": bool(billing_fields_set)},
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(plan)
|
||||
return plan
|
||||
if billing is not None:
|
||||
await db.refresh(billing)
|
||||
|
||||
# Invalidate any in-process billing cache for accounts on this plan.
|
||||
# TODO: invalidate app.state.billing_cache when added.
|
||||
account_ids = [
|
||||
row[0] for row in (await db.execute(
|
||||
select(Subscription.account_id).where(Subscription.plan == data.plan)
|
||||
)).all()
|
||||
]
|
||||
await BillingService.invalidate_billing_cache(account_ids)
|
||||
|
||||
return _merge_plan_with_billing(plan, billing)
|
||||
|
||||
|
||||
@router.get("/account-overrides", response_model=list[AccountOverrideResponse])
|
||||
|
||||
Reference in New Issue
Block a user