from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.core.database import get_db from app.core.audit import log_audit from app.models.user import User from app.models.plan_limits import PlanLimits from app.models.plan_billing import PlanBilling from app.models.account import Account from app.models.account_limit_override import AccountLimitOverride from app.models.subscription import Subscription from app.schemas.admin import ( PlanLimitResponse, PlanLimitUpdate, PlanLimitWithBillingResponse, AccountOverrideCreate, AccountOverrideUpdate, AccountOverrideResponse, ) from app.api.deps import require_admin from app.services.billing import BillingService router = APIRouter(prefix="/admin", tags=["admin-plan-limits"]) # Fields on PlanLimitUpdate that map to plan_billing (not plan_limits). _PLAN_BILLING_FIELDS = ( "display_name", "description", "monthly_price_cents", "annual_price_cents", "stripe_product_id", "stripe_monthly_price_id", "stripe_annual_price_id", "is_public", "is_archived", "sort_order", ) # Subset of _PLAN_BILLING_FIELDS that are NOT NULL on the PlanBilling model. # These are Optional[...] on PlanLimitUpdate, so a caller sending an explicit # null for any of them would otherwise trigger a NOT NULL violation at commit. _PLAN_BILLING_NOT_NULL_FIELDS = frozenset({ "display_name", "is_public", "is_archived", "sort_order", }) def _merge_plan_with_billing( plan: PlanLimits, billing: PlanBilling | None ) -> PlanLimitWithBillingResponse: """Build a merged response. Billing fields are None when no plan_billing row exists for the plan.""" payload = { "plan": plan.plan, "max_trees": plan.max_trees, "max_sessions_per_month": plan.max_sessions_per_month, "max_users": plan.max_users, "custom_branding": plan.custom_branding, "priority_support": plan.priority_support, "export_formats": plan.export_formats or [], } if billing is not None: payload.update({ "display_name": billing.display_name, "description": billing.description, "monthly_price_cents": billing.monthly_price_cents, "annual_price_cents": billing.annual_price_cents, "stripe_product_id": billing.stripe_product_id, "stripe_monthly_price_id": billing.stripe_monthly_price_id, "stripe_annual_price_id": billing.stripe_annual_price_id, "is_public": billing.is_public, "is_archived": billing.is_archived, "sort_order": billing.sort_order, }) return PlanLimitWithBillingResponse(**payload) @router.get("/plan-limits", response_model=list[PlanLimitWithBillingResponse]) async def list_plan_limits( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(require_admin)], ): """List all plan limit configurations, merged with plan_billing fields where present. Plans without a plan_billing row return None for the billing fields.""" rows = (await db.execute( select(PlanLimits, PlanBilling) .outerjoin(PlanBilling, PlanLimits.plan == PlanBilling.plan) )).all() return [_merge_plan_with_billing(pl, pb) for pl, pb in rows] @router.put("/plan-limits", response_model=PlanLimitWithBillingResponse) async def update_plan_limits( data: PlanLimitUpdate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(require_admin)], ): """Update a plan's limits and (if any plan_billing field is included) upsert the matching plan_billing row in the same transaction. After commit, invalidates the in-process billing cache for accounts on this plan (currently a no-op — see BillingService.invalidate_billing_cache). """ result = await db.execute(select(PlanLimits).where(PlanLimits.plan == data.plan)) plan = result.scalar_one_or_none() if not plan: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plan not found") plan.max_trees = data.max_trees plan.max_sessions_per_month = data.max_sessions_per_month plan.max_users = data.max_users plan.custom_branding = data.custom_branding plan.priority_support = data.priority_support plan.export_formats = data.export_formats # Did the request include any plan_billing field? (Pydantic gives us # `model_fields_set` to distinguish "user passed null" from "field omitted".) billing_fields_set = data.model_fields_set & set(_PLAN_BILLING_FIELDS) billing: PlanBilling | None = None if billing_fields_set: billing = (await db.execute( select(PlanBilling).where(PlanBilling.plan == data.plan) )).scalar_one_or_none() if billing is None: # Create. display_name is required on the model — derive from the # plan name when the caller didn't supply one (e.g. "pro" → "Pro"). display_name = data.display_name or data.plan.capitalize() billing = PlanBilling(plan=data.plan, display_name=display_name) db.add(billing) # Apply only the fields the caller actually included. Allows partial # updates without clobbering existing values. for field in billing_fields_set: value = getattr(data, field) if value is None and field in _PLAN_BILLING_NOT_NULL_FIELDS: # Don't NULL out a NOT NULL column on update. continue setattr(billing, field, value) await log_audit( db, current_user.id, "plan_limits.update", "plan_limits", details={"plan": data.plan, "updated_billing": bool(billing_fields_set)}, ) await db.commit() await db.refresh(plan) if billing is not None: await db.refresh(billing) # Invalidate any in-process billing cache for accounts on this plan. # TODO: invalidate app.state.billing_cache when added. account_ids = [ row[0] for row in (await db.execute( select(Subscription.account_id).where(Subscription.plan == data.plan) )).all() ] await BillingService.invalidate_billing_cache(account_ids) return _merge_plan_with_billing(plan, billing) @router.get("/account-overrides", response_model=list[AccountOverrideResponse]) async def list_account_overrides( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(require_admin)], ): """List all account limit overrides.""" query = ( select( AccountLimitOverride, Account.name.label("account_name"), Account.display_code.label("account_display_code"), ) .outerjoin(Account, AccountLimitOverride.account_id == Account.id) .order_by(AccountLimitOverride.created_at.desc()) ) result = await db.execute(query) rows = result.all() return [ AccountOverrideResponse( id=row.AccountLimitOverride.id, account_id=row.AccountLimitOverride.account_id, account_name=row.account_name, account_display_code=row.account_display_code, override_max_trees=row.AccountLimitOverride.override_max_trees, override_max_sessions_per_month=row.AccountLimitOverride.override_max_sessions_per_month, override_max_users=row.AccountLimitOverride.override_max_users, note=row.AccountLimitOverride.note, created_at=row.AccountLimitOverride.created_at, updated_at=row.AccountLimitOverride.updated_at, ) for row in rows ] @router.post("/account-overrides", response_model=AccountOverrideResponse, status_code=status.HTTP_201_CREATED) async def create_account_override( data: AccountOverrideCreate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(require_admin)], ): """Create an account limit override.""" # Look up account by display_code result = await db.execute(select(Account).where(Account.display_code == data.account_display_code)) account = result.scalar_one_or_none() if not account: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found") # Check for existing override existing = await db.execute( select(AccountLimitOverride).where(AccountLimitOverride.account_id == account.id) ) if existing.scalar_one_or_none(): raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Override already exists for this account") override = AccountLimitOverride( account_id=account.id, override_max_trees=data.override_max_trees, override_max_sessions_per_month=data.override_max_sessions_per_month, override_max_users=data.override_max_users, note=data.note, created_by_id=current_user.id, ) db.add(override) await log_audit(db, current_user.id, "account_override.create", "account", account.id, {"display_code": data.account_display_code}) await db.commit() await db.refresh(override) return AccountOverrideResponse( id=override.id, account_id=override.account_id, account_name=account.name, account_display_code=account.display_code, override_max_trees=override.override_max_trees, override_max_sessions_per_month=override.override_max_sessions_per_month, override_max_users=override.override_max_users, note=override.note, created_at=override.created_at, updated_at=override.updated_at, ) @router.put("/account-overrides/{override_id}", response_model=AccountOverrideResponse) async def update_account_override( override_id: UUID, data: AccountOverrideUpdate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(require_admin)], ): """Update an account limit override.""" result = await db.execute(select(AccountLimitOverride).where(AccountLimitOverride.id == override_id)) override = result.scalar_one_or_none() if not override: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Override not found") if data.override_max_trees is not None: override.override_max_trees = data.override_max_trees if data.override_max_sessions_per_month is not None: override.override_max_sessions_per_month = data.override_max_sessions_per_month if data.override_max_users is not None: override.override_max_users = data.override_max_users if data.note is not None: override.note = data.note await log_audit(db, current_user.id, "account_override.update", "account", override.account_id) await db.commit() await db.refresh(override) # Fetch account info acct = await db.execute(select(Account).where(Account.id == override.account_id)) account = acct.scalar_one_or_none() return AccountOverrideResponse( id=override.id, account_id=override.account_id, account_name=account.name if account else None, account_display_code=account.display_code if account else None, override_max_trees=override.override_max_trees, override_max_sessions_per_month=override.override_max_sessions_per_month, override_max_users=override.override_max_users, note=override.note, created_at=override.created_at, updated_at=override.updated_at, ) @router.delete("/account-overrides/{override_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_account_override( override_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(require_admin)], ): """Delete an account limit override.""" result = await db.execute(select(AccountLimitOverride).where(AccountLimitOverride.id == override_id)) override = result.scalar_one_or_none() if not override: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Override not found") await log_audit(db, current_user.id, "account_override.delete", "account", override.account_id) await db.delete(override) await db.commit()