Co-authored-by: Michael Chihlas <michael@resolutionflow.com> Co-committed-by: Michael Chihlas <michael@resolutionflow.com>
306 lines
12 KiB
Python
306 lines
12 KiB
Python
from typing import Annotated
|
|
from uuid import UUID
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
|
|
from app.core.database import get_db
|
|
from app.core.audit import log_audit
|
|
from app.models.user import User
|
|
from app.models.plan_limits import PlanLimits
|
|
from app.models.plan_billing import PlanBilling
|
|
from app.models.account import Account
|
|
from app.models.account_limit_override import AccountLimitOverride
|
|
from app.models.subscription import Subscription
|
|
from app.schemas.admin import (
|
|
PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse,
|
|
AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse,
|
|
)
|
|
from app.api.deps import require_admin
|
|
from app.services.billing import BillingService
|
|
|
|
router = APIRouter(prefix="/admin", tags=["admin-plan-limits"])
|
|
|
|
|
|
# Fields on PlanLimitUpdate that map to plan_billing (not plan_limits).
|
|
_PLAN_BILLING_FIELDS = (
|
|
"display_name",
|
|
"description",
|
|
"monthly_price_cents",
|
|
"annual_price_cents",
|
|
"stripe_product_id",
|
|
"stripe_monthly_price_id",
|
|
"stripe_annual_price_id",
|
|
"is_public",
|
|
"is_archived",
|
|
"sort_order",
|
|
)
|
|
|
|
# Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model.
|
|
# These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit
|
|
# null for any of them would otherwise trigger a NOT NULL violation at commit.
|
|
_PLAN_BILLING_NOT_NULL_FIELDS = frozenset({
|
|
"display_name",
|
|
"is_public",
|
|
"is_archived",
|
|
"sort_order",
|
|
})
|
|
|
|
|
|
def _merge_plan_with_billing(
|
|
plan: PlanLimits, billing: PlanBilling | None
|
|
) -> PlanLimitWithBillingResponse:
|
|
"""Build a merged response. Billing fields are None when no plan_billing row
|
|
exists for the plan."""
|
|
payload = {
|
|
"plan": plan.plan,
|
|
"max_trees": plan.max_trees,
|
|
"max_sessions_per_month": plan.max_sessions_per_month,
|
|
"max_users": plan.max_users,
|
|
"custom_branding": plan.custom_branding,
|
|
"priority_support": plan.priority_support,
|
|
"export_formats": plan.export_formats or [],
|
|
}
|
|
if billing is not None:
|
|
payload.update({
|
|
"display_name": billing.display_name,
|
|
"description": billing.description,
|
|
"monthly_price_cents": billing.monthly_price_cents,
|
|
"annual_price_cents": billing.annual_price_cents,
|
|
"stripe_product_id": billing.stripe_product_id,
|
|
"stripe_monthly_price_id": billing.stripe_monthly_price_id,
|
|
"stripe_annual_price_id": billing.stripe_annual_price_id,
|
|
"is_public": billing.is_public,
|
|
"is_archived": billing.is_archived,
|
|
"sort_order": billing.sort_order,
|
|
})
|
|
return PlanLimitWithBillingResponse(**payload)
|
|
|
|
|
|
@router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse])
|
|
async def list_plan_limits(
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(require_admin)],
|
|
):
|
|
"""List all plan limit configurations, merged with plan_billing fields
|
|
where present. Plans without a plan_billing row return None for the
|
|
billing fields."""
|
|
rows = (await db.execute(
|
|
select(PlanLimits, PlanBilling)
|
|
.outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan)
|
|
)).all()
|
|
return [_merge_plan_with_billing(pl, pb) for pl, pb in rows]
|
|
|
|
|
|
@router.put("/plan-limits", response_model=PlanLimitWithBillingResponse)
|
|
async def update_plan_limits(
|
|
data: PlanLimitUpdate,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(require_admin)],
|
|
):
|
|
"""Update a plan's limits and (if any plan_billing field is included)
|
|
upsert the matching plan_billing row in the same transaction. After
|
|
commit, invalidates the in-process billing cache for accounts on this
|
|
plan (currently a no-op — see BillingService.invalidate_billing_cache).
|
|
"""
|
|
result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan))
|
|
plan = result.scalar_one_or_none()
|
|
if not plan:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plan not found")
|
|
|
|
plan.max_trees = data.max_trees
|
|
plan.max_sessions_per_month = data.max_sessions_per_month
|
|
plan.max_users = data.max_users
|
|
plan.custom_branding = data.custom_branding
|
|
plan.priority_support = data.priority_support
|
|
plan.export_formats = data.export_formats
|
|
|
|
# Did the request include any plan_billing field? (Pydantic gives us
|
|
# `model_fields_set` to distinguish "user passed null" from "field omitted".)
|
|
billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS)
|
|
billing: PlanBilling | None = None
|
|
if billing_fields_set:
|
|
billing = (await db.execute(
|
|
select(PlanBilling).where(PlanBilling.plan == data.plan)
|
|
)).scalar_one_or_none()
|
|
|
|
if billing is None:
|
|
# Create. display_name is required on the model — derive from the
|
|
# plan name when the caller didn't supply one (e.g. "pro" → "Pro").
|
|
display_name = data.display_name or data.plan.capitalize()
|
|
billing = PlanBilling(plan=data.plan, display_name=display_name)
|
|
db.add(billing)
|
|
|
|
# Apply only the fields the caller actually included. Allows partial
|
|
# updates without clobbering existing values.
|
|
for field in billing_fields_set:
|
|
value = getattr(data, field)
|
|
if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS:
|
|
# Don't NULL out a NOT NULL column on update.
|
|
continue
|
|
setattr(billing, field, value)
|
|
|
|
await log_audit(
|
|
db, current_user.id, "plan_limits.update", "plan_limits",
|
|
details={"plan": data.plan, "updated_billing": bool(billing_fields_set)},
|
|
)
|
|
await db.commit()
|
|
await db.refresh(plan)
|
|
if billing is not None:
|
|
await db.refresh(billing)
|
|
|
|
# Invalidate any in-process billing cache for accounts on this plan.
|
|
# TODO: invalidate app.state.billing_cache when added.
|
|
account_ids = [
|
|
row[0] for row in (await db.execute(
|
|
select(Subscription.account_id).where(Subscription.plan == data.plan)
|
|
)).all()
|
|
]
|
|
await BillingService.invalidate_billing_cache(account_ids)
|
|
|
|
return _merge_plan_with_billing(plan, billing)
|
|
|
|
|
|
@router.get("/account-overrides", response_model=list[AccountOverrideResponse])
|
|
async def list_account_overrides(
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(require_admin)],
|
|
):
|
|
"""List all account limit overrides."""
|
|
query = (
|
|
select(
|
|
AccountLimitOverride,
|
|
Account.name.label("account_name"),
|
|
Account.display_code.label("account_display_code"),
|
|
)
|
|
.outerjoin(Account, AccountLimitOverride.account_id == Account.id)
|
|
.order_by(AccountLimitOverride.created_at.desc())
|
|
)
|
|
result = await db.execute(query)
|
|
rows = result.all()
|
|
|
|
return [
|
|
AccountOverrideResponse(
|
|
id=row.AccountLimitOverride.id,
|
|
account_id=row.AccountLimitOverride.account_id,
|
|
account_name=row.account_name,
|
|
account_display_code=row.account_display_code,
|
|
override_max_trees=row.AccountLimitOverride.override_max_trees,
|
|
override_max_sessions_per_month=row.AccountLimitOverride.override_max_sessions_per_month,
|
|
override_max_users=row.AccountLimitOverride.override_max_users,
|
|
note=row.AccountLimitOverride.note,
|
|
created_at=row.AccountLimitOverride.created_at,
|
|
updated_at=row.AccountLimitOverride.updated_at,
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
@router.post("/account-overrides", response_model=AccountOverrideResponse, status_code=status.HTTP_201_CREATED)
|
|
async def create_account_override(
|
|
data: AccountOverrideCreate,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(require_admin)],
|
|
):
|
|
"""Create an account limit override."""
|
|
# Look up account by display_code
|
|
result = await db.execute(select(Account).where(Account.display_code == data.account_display_code))
|
|
account = result.scalar_one_or_none()
|
|
if not account:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
|
|
|
|
# Check for existing override
|
|
existing = await db.execute(
|
|
select(AccountLimitOverride).where(AccountLimitOverride.account_id == account.id)
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Override already exists for this account")
|
|
|
|
override = AccountLimitOverride(
|
|
account_id=account.id,
|
|
override_max_trees=data.override_max_trees,
|
|
override_max_sessions_per_month=data.override_max_sessions_per_month,
|
|
override_max_users=data.override_max_users,
|
|
note=data.note,
|
|
created_by_id=current_user.id,
|
|
)
|
|
db.add(override)
|
|
await log_audit(db, current_user.id, "account_override.create", "account", account.id,
|
|
{"display_code": data.account_display_code})
|
|
await db.commit()
|
|
await db.refresh(override)
|
|
|
|
return AccountOverrideResponse(
|
|
id=override.id,
|
|
account_id=override.account_id,
|
|
account_name=account.name,
|
|
account_display_code=account.display_code,
|
|
override_max_trees=override.override_max_trees,
|
|
override_max_sessions_per_month=override.override_max_sessions_per_month,
|
|
override_max_users=override.override_max_users,
|
|
note=override.note,
|
|
created_at=override.created_at,
|
|
updated_at=override.updated_at,
|
|
)
|
|
|
|
|
|
@router.put("/account-overrides/{override_id}", response_model=AccountOverrideResponse)
|
|
async def update_account_override(
|
|
override_id: UUID,
|
|
data: AccountOverrideUpdate,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(require_admin)],
|
|
):
|
|
"""Update an account limit override."""
|
|
result = await db.execute(select(AccountLimitOverride).where(AccountLimitOverride.id == override_id))
|
|
override = result.scalar_one_or_none()
|
|
if not override:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Override not found")
|
|
|
|
if data.override_max_trees is not None:
|
|
override.override_max_trees = data.override_max_trees
|
|
if data.override_max_sessions_per_month is not None:
|
|
override.override_max_sessions_per_month = data.override_max_sessions_per_month
|
|
if data.override_max_users is not None:
|
|
override.override_max_users = data.override_max_users
|
|
if data.note is not None:
|
|
override.note = data.note
|
|
|
|
await log_audit(db, current_user.id, "account_override.update", "account", override.account_id)
|
|
await db.commit()
|
|
await db.refresh(override)
|
|
|
|
# Fetch account info
|
|
acct = await db.execute(select(Account).where(Account.id == override.account_id))
|
|
account = acct.scalar_one_or_none()
|
|
|
|
return AccountOverrideResponse(
|
|
id=override.id,
|
|
account_id=override.account_id,
|
|
account_name=account.name if account else None,
|
|
account_display_code=account.display_code if account else None,
|
|
override_max_trees=override.override_max_trees,
|
|
override_max_sessions_per_month=override.override_max_sessions_per_month,
|
|
override_max_users=override.override_max_users,
|
|
note=override.note,
|
|
created_at=override.created_at,
|
|
updated_at=override.updated_at,
|
|
)
|
|
|
|
|
|
@router.delete("/account-overrides/{override_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_account_override(
|
|
override_id: UUID,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(require_admin)],
|
|
):
|
|
"""Delete an account limit override."""
|
|
result = await db.execute(select(AccountLimitOverride).where(AccountLimitOverride.id == override_id))
|
|
override = result.scalar_one_or_none()
|
|
if not override:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Override not found")
|
|
|
|
await log_audit(db, current_user.id, "account_override.delete", "account", override.account_id)
|
|
await db.delete(override)
|
|
await db.commit()
|