Files
resolutionflow/backend/app/api/endpoints/admin_plan_limits.py
Michael Chihlas f1be3abcc5
Some checks failed
CI / e2e (push) Has been cancelled
CI / frontend (push) Has been cancelled
CI / backend (push) Has been cancelled
Mirror to GitHub / mirror (push) Has been cancelled
feat: self-serve signup Phase 2 (frontend cutover) (#162)
Co-authored-by: Michael Chihlas <michael@resolutionflow.com>
Co-committed-by: Michael Chihlas <michael@resolutionflow.com>
2026-05-07 18:42:20 +00:00

306 lines
12 KiB
Python

from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.database import get_db
from app.core.audit import log_audit
from app.models.user import User
from app.models.plan_limits import PlanLimits
from app.models.plan_billing import PlanBilling
from app.models.account import Account
from app.models.account_limit_override import AccountLimitOverride
from app.models.subscription import Subscription
from app.schemas.admin import (
PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse,
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
)
from app.api.deps import require_admin
from app.services.billing import BillingService
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
# Fields on PlanLimitUpdate that map to plan_billing (not plan_limits).
_PLAN_BILLING_FIELDS = (
"display_name",
"description",
"monthly_price_cents",
"annual_price_cents",
"stripe_product_id",
"stripe_monthly_price_id",
"stripe_annual_price_id",
"is_public",
"is_archived",
"sort_order",
)
# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model.
# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit
# null for any of them would otherwise trigger a NOT NULL violation at commit.
_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({
"display_name",
"is_public",
"is_archived",
"sort_order",
})
def _merge_plan_with_billing(
plan: PlanLimits, billing: PlanBilling | None
) -> PlanLimitWithBillingResponse:
"""Build a merged response. Billing fields are None when no plan_billing row
exists for the plan."""
payload = {
"plan": plan.plan,
"max_trees": plan.max_trees,
"max_sessions_per_month": plan.max_sessions_per_month,
"max_users": plan.max_users,
"custom_branding": plan.custom_branding,
"priority_support": plan.priority_support,
"export_formats": plan.export_formats or [],
}
if billing is not None:
payload.update({
"display_name": billing.display_name,
"description": billing.description,
"monthly_price_cents": billing.monthly_price_cents,
"annual_price_cents": billing.annual_price_cents,
"stripe_product_id": billing.stripe_product_id,
"stripe_monthly_price_id": billing.stripe_monthly_price_id,
"stripe_annual_price_id": billing.stripe_annual_price_id,
"is_public": billing.is_public,
"is_archived": billing.is_archived,
"sort_order": billing.sort_order,
})
return PlanLimitWithBillingResponse(**payload)
@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse])
async def list_plan_limits(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""List all plan limit configurations, merged with plan_billing fields
where present. Plans without a plan_billing row return None for the
billing fields."""
rows = (await db.execute(
select(PlanLimits, PlanBilling)
.outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan)
)).all()
return [_merge_plan_with_billing(pl, pb) for pl, pb in rows]
@router.put("/plan-limits", response_model=PlanLimitWithBillingResponse)
async def update_plan_limits(
data: PlanLimitUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Update a plan's limits and (if any plan_billing field is included)
upsert the matching plan_billing row in the same transaction. After
commit, invalidates the in-process billing cache for accounts on this
plan (currently a no-op — see BillingService.invalidate_billing_cache).
"""
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
plan = result.scalar_one_or_none()
if not plan:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plan not found")
plan.max_trees = data.max_trees
plan.max_sessions_per_month = data.max_sessions_per_month
plan.max_users = data.max_users
plan.custom_branding = data.custom_branding
plan.priority_support = data.priority_support
plan.export_formats = data.export_formats
# Did the request include any plan_billing field? (Pydantic gives us
# `model_fields_set` to distinguish "user passed null" from "field omitted".)
billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS)
billing: PlanBilling | None = None
if billing_fields_set:
billing = (await db.execute(
select(PlanBilling).where(PlanBilling.plan == data.plan)
)).scalar_one_or_none()
if billing is None:
# Create. display_name is required on the model — derive from the
# plan name when the caller didn't supply one (e.g. "pro" → "Pro").
display_name = data.display_name or data.plan.capitalize()
billing = PlanBilling(plan=data.plan, display_name=display_name)
db.add(billing)
# Apply only the fields the caller actually included. Allows partial
# updates without clobbering existing values.
for field in billing_fields_set:
value = getattr(data, field)
if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS:
# Don't NULL out a NOT NULL column on update.
continue
setattr(billing, field, value)
await log_audit(
db, current_user.id, "plan_limits.update", "plan_limits",
details={"plan": data.plan, "updated_billing": bool(billing_fields_set)},
)
await db.commit()
await db.refresh(plan)
if billing is not None:
await db.refresh(billing)
# Invalidate any in-process billing cache for accounts on this plan.
# TODO: invalidate app.state.billing_cache when added.
account_ids = [
row[0] for row in (await db.execute(
select(Subscription.account_id).where(Subscription.plan == data.plan)
)).all()
]
await BillingService.invalidate_billing_cache(account_ids)
return _merge_plan_with_billing(plan, billing)
@router.get("/account-overrides", response_model=list[AccountOverrideResponse])
async def list_account_overrides(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""List all account limit overrides."""
query = (
select(
AccountLimitOverride,
Account.name.label("account_name"),
Account.display_code.label("account_display_code"),
)
.outerjoin(Account, AccountLimitOverride.account_id == Account.id)
.order_by(AccountLimitOverride.created_at.desc())
)
result = await db.execute(query)
rows = result.all()
return [
AccountOverrideResponse(
id=row.AccountLimitOverride.id,
account_id=row.AccountLimitOverride.account_id,
account_name=row.account_name,
account_display_code=row.account_display_code,
override_max_trees=row.AccountLimitOverride.override_max_trees,
override_max_sessions_per_month=row.AccountLimitOverride.override_max_sessions_per_month,
override_max_users=row.AccountLimitOverride.override_max_users,
note=row.AccountLimitOverride.note,
created_at=row.AccountLimitOverride.created_at,
updated_at=row.AccountLimitOverride.updated_at,
)
for row in rows
]
@router.post("/account-overrides", response_model=AccountOverrideResponse, status_code=status.HTTP_201_CREATED)
async def create_account_override(
data: AccountOverrideCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Create an account limit override."""
# Look up account by display_code
result = await db.execute(select(Account).where(Account.display_code == data.account_display_code))
account = result.scalar_one_or_none()
if not account:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
# Check for existing override
existing = await db.execute(
select(AccountLimitOverride).where(AccountLimitOverride.account_id == account.id)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Override already exists for this account")
override = AccountLimitOverride(
account_id=account.id,
override_max_trees=data.override_max_trees,
override_max_sessions_per_month=data.override_max_sessions_per_month,
override_max_users=data.override_max_users,
note=data.note,
created_by_id=current_user.id,
)
db.add(override)
await log_audit(db, current_user.id, "account_override.create", "account", account.id,
{"display_code": data.account_display_code})
await db.commit()
await db.refresh(override)
return AccountOverrideResponse(
id=override.id,
account_id=override.account_id,
account_name=account.name,
account_display_code=account.display_code,
override_max_trees=override.override_max_trees,
override_max_sessions_per_month=override.override_max_sessions_per_month,
override_max_users=override.override_max_users,
note=override.note,
created_at=override.created_at,
updated_at=override.updated_at,
)
@router.put("/account-overrides/{override_id}", response_model=AccountOverrideResponse)
async def update_account_override(
override_id: UUID,
data: AccountOverrideUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Update an account limit override."""
result = await db.execute(select(AccountLimitOverride).where(AccountLimitOverride.id == override_id))
override = result.scalar_one_or_none()
if not override:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Override not found")
if data.override_max_trees is not None:
override.override_max_trees = data.override_max_trees
if data.override_max_sessions_per_month is not None:
override.override_max_sessions_per_month = data.override_max_sessions_per_month
if data.override_max_users is not None:
override.override_max_users = data.override_max_users
if data.note is not None:
override.note = data.note
await log_audit(db, current_user.id, "account_override.update", "account", override.account_id)
await db.commit()
await db.refresh(override)
# Fetch account info
acct = await db.execute(select(Account).where(Account.id == override.account_id))
account = acct.scalar_one_or_none()
return AccountOverrideResponse(
id=override.id,
account_id=override.account_id,
account_name=account.name if account else None,
account_display_code=account.display_code if account else None,
override_max_trees=override.override_max_trees,
override_max_sessions_per_month=override.override_max_sessions_per_month,
override_max_users=override.override_max_users,
note=override.note,
created_at=override.created_at,
updated_at=override.updated_at,
)
@router.delete("/account-overrides/{override_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_account_override(
override_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Delete an account limit override."""
result = await db.execute(select(AccountLimitOverride).where(AccountLimitOverride.id == override_id))
override = result.scalar_one_or_none()
if not override:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Override not found")
await log_audit(db, current_user.id, "account_override.delete", "account", override.account_id)
await db.delete(override)
await db.commit()