feat: expand admin customer account controls

This commit is contained in:
chihlasm
2026-04-02 04:17:29 +00:00
parent 70242ad037
commit 7cbc9fe224
6 changed files with 735 additions and 147 deletions

View File

@@ -6,7 +6,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, or_
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import selectinload, aliased
from app.core.admin_database import get_admin_db
from app.core.audit import log_audit
@@ -36,6 +36,9 @@ from app.schemas.admin import (
AdminAccountMember,
AdminAccountListItem,
AdminAccountListResponse,
AdminAccountOwnerSummary,
AdminAccountSubscriptionSummary,
AdminAccountUsageSummary,
)
from app.schemas.subscription import SubscriptionPlanUpdate, ExtendTrialRequest
from app.schemas.user_detail import (
@@ -43,6 +46,7 @@ from app.schemas.user_detail import (
SessionSummary, AuditLogSummary, InviteCodeUsedSummary,
)
from app.api.deps import require_admin
from app.core.subscriptions import get_account_usage
router = APIRouter(prefix="/admin", tags=["admin"])
@@ -149,22 +153,70 @@ async def list_accounts(
current_user: Annotated[User, Depends(require_admin)],
page: int = Query(1, ge=1),
size: int = Query(12, ge=1, le=100),
search: Optional[str] = Query(None, description="Search by account, display code, or owner"),
plan: Optional[str] = Query(None, description="Filter by subscription plan"),
status: Optional[str] = Query(None, description="Filter by subscription status"),
include_archived: bool = Query(False, description="Include archived users in account member lists"),
):
"""List accounts with embedded members for the admin panel."""
total_result = await db.execute(select(func.count()).select_from(Account))
owner_user = aliased(User)
count_query = (
select(func.count(func.distinct(Account.id)))
.select_from(Account)
.outerjoin(owner_user, Account.owner_id == owner_user.id)
.outerjoin(Subscription, Subscription.account_id == Account.id)
)
accounts_query = (
select(
Account,
owner_user.id.label("owner_user_id"),
owner_user.name.label("owner_name"),
owner_user.email.label("owner_email"),
Subscription.id.label("subscription_id"),
Subscription.plan.label("subscription_plan"),
Subscription.status.label("subscription_status"),
Subscription.billing_interval.label("subscription_billing_interval"),
Subscription.current_period_end.label("subscription_current_period_end"),
Subscription.cancel_at_period_end.label("subscription_cancel_at_period_end"),
)
.outerjoin(owner_user, Account.owner_id == owner_user.id)
.outerjoin(Subscription, Subscription.account_id == Account.id)
)
if search:
search_term = f"%{search.strip()}%"
search_filter = or_(
Account.name.ilike(search_term),
Account.display_code.ilike(search_term),
owner_user.name.ilike(search_term),
owner_user.email.ilike(search_term),
)
count_query = count_query.where(search_filter)
accounts_query = accounts_query.where(search_filter)
if plan:
count_query = count_query.where(Subscription.plan == plan)
accounts_query = accounts_query.where(Subscription.plan == plan)
if status:
count_query = count_query.where(Subscription.status == status)
accounts_query = accounts_query.where(Subscription.status == status)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
accounts_result = await db.execute(
select(Account)
accounts_query
.order_by(Account.created_at.desc())
.offset((page - 1) * size)
.limit(size)
)
accounts = accounts_result.scalars().all()
rows = accounts_result.all()
accounts = [row.Account for row in rows]
account_ids = [account.id for account in accounts]
members_by_account: dict[UUID, list[AdminAccountMember]] = {account_id: [] for account_id in account_ids}
pending_invites_by_account: dict[UUID, int] = {account_id: 0 for account_id in account_ids}
usage_by_account: dict[UUID, AdminAccountUsageSummary] = {}
if account_ids:
members_query = select(User).where(User.account_id.in_(account_ids))
@@ -189,18 +241,56 @@ async def list_accounts(
)
)
pending_invites_result = await db.execute(
select(AccountInvite.account_id, func.count(AccountInvite.id))
.where(
AccountInvite.account_id.in_(account_ids),
AccountInvite.used_at.is_(None),
)
.group_by(AccountInvite.account_id)
)
pending_invites_by_account.update({row[0]: row[1] for row in pending_invites_result.all()})
for account_id in account_ids:
usage = await get_account_usage(account_id, db)
usage_by_account[account_id] = AdminAccountUsageSummary(
tree_count=usage.get("tree_count", 0),
session_count_this_month=usage.get("session_count_this_month", 0),
)
items = [
AdminAccountListItem(
id=account.id,
name=account.name,
display_code=account.display_code,
created_at=account.created_at,
owner_id=account.owner_id,
member_count=len(members_by_account.get(account.id, [])),
active_member_count=sum(1 for member in members_by_account.get(account.id, []) if member.is_active),
members=members_by_account.get(account.id, []),
id=row.Account.id,
name=row.Account.name,
display_code=row.Account.display_code,
created_at=row.Account.created_at,
owner_id=row.Account.owner_id,
owner=(
AdminAccountOwnerSummary(
id=row.owner_user_id,
name=row.owner_name,
email=row.owner_email,
) if row.owner_user_id and row.owner_name and row.owner_email else None
),
subscription=(
AdminAccountSubscriptionSummary(
id=row.subscription_id,
plan=row.subscription_plan,
status=row.subscription_status,
billing_interval=row.subscription_billing_interval,
current_period_end=row.subscription_current_period_end,
cancel_at_period_end=row.subscription_cancel_at_period_end or False,
) if row.subscription_id and row.subscription_plan and row.subscription_status else None
),
usage=usage_by_account.get(row.Account.id, AdminAccountUsageSummary()),
member_count=len(members_by_account.get(row.Account.id, [])),
active_member_count=sum(1 for member in members_by_account.get(row.Account.id, []) if member.is_active),
pending_invite_count=pending_invites_by_account.get(row.Account.id, 0),
sso_enabled=row.Account.sso_enabled,
branding_company_name=row.Account.branding_company_name,
members=members_by_account.get(row.Account.id, []),
)
for account in accounts
for row in rows
]
return AdminAccountListResponse(
@@ -662,6 +752,28 @@ async def _get_user_subscription(user_id: UUID, db: AsyncSession) -> tuple[User,
return user, subscription
async def _get_account_subscription(account_id: UUID, db: AsyncSession) -> tuple[Account, Subscription]:
"""Helper to load account and its subscription."""
account_result = await db.execute(select(Account).where(Account.id == account_id))
account = account_result.scalar_one_or_none()
if not account:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
sub_result = await db.execute(
select(Subscription).where(Subscription.account_id == account.id)
)
subscription = sub_result.scalar_one_or_none()
if not subscription:
subscription = Subscription(
account_id=account.id,
plan="free",
status="active",
)
db.add(subscription)
await db.flush()
return account, subscription
@router.put("/users/{user_id}/subscription/plan")
async def update_user_plan(
user_id: UUID,
@@ -681,6 +793,31 @@ async def update_user_plan(
return {"plan": subscription.plan, "status": subscription.status}
@router.put("/accounts/{account_id}/subscription/plan")
async def update_account_plan(
account_id: UUID,
data: SubscriptionPlanUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Change an account subscription plan (super admin only)."""
if data.plan not in ("free", "pro", "team"):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
account, subscription = await _get_account_subscription(account_id, db)
old_plan = subscription.plan
subscription.plan = data.plan
await log_audit(
db,
current_user.id,
"subscription.plan_change",
"subscription",
subscription.id,
{"old_plan": old_plan, "new_plan": data.plan, "account_id": str(account_id)},
)
await db.commit()
return {"plan": subscription.plan, "status": subscription.status}
@router.put("/users/{user_id}/subscription/extend-trial")
async def extend_user_trial(
user_id: UUID,
@@ -711,6 +848,43 @@ async def extend_user_trial(
"current_period_end": subscription.current_period_end}
@router.put("/accounts/{account_id}/subscription/extend-trial")
async def extend_account_trial(
account_id: UUID,
data: ExtendTrialRequest,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Extend or start a trial for an account subscription (super admin only)."""
if data.days < 1 or data.days > 90:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Days must be 1-90")
account, subscription = await _get_account_subscription(account_id, db)
now = datetime.now(timezone.utc)
if subscription.status == "trialing" and subscription.current_period_end:
new_end = subscription.current_period_end + timedelta(days=data.days)
else:
subscription.status = "trialing"
subscription.current_period_start = now
new_end = now + timedelta(days=data.days)
subscription.current_period_end = new_end
await log_audit(
db,
current_user.id,
"subscription.extend_trial",
"subscription",
subscription.id,
{"days": data.days, "new_end": new_end.isoformat(), "account_id": str(account.id)},
)
await db.commit()
return {
"plan": subscription.plan,
"status": subscription.status,
"current_period_end": subscription.current_period_end,
}
@router.post("/users/{user_id}/password-reset", response_model=AdminPasswordResetResponse)
async def admin_reset_password(
user_id: UUID,

View File

@@ -66,14 +66,40 @@ class AdminAccountMember(BaseModel):
deleted_at: Optional[datetime] = None
class AdminAccountOwnerSummary(BaseModel):
id: UUID
name: str
email: EmailStr
class AdminAccountSubscriptionSummary(BaseModel):
id: UUID
plan: str
status: str
billing_interval: Optional[str] = None
current_period_end: Optional[datetime] = None
cancel_at_period_end: bool = False
class AdminAccountUsageSummary(BaseModel):
tree_count: int = 0
session_count_this_month: int = 0
class AdminAccountListItem(BaseModel):
id: UUID
name: str
display_code: str
created_at: datetime
owner_id: Optional[UUID] = None
owner: Optional[AdminAccountOwnerSummary] = None
subscription: Optional[AdminAccountSubscriptionSummary] = None
usage: AdminAccountUsageSummary = Field(default_factory=AdminAccountUsageSummary)
member_count: int = 0
active_member_count: int = 0
pending_invite_count: int = 0
sso_enabled: bool = False
branding_company_name: Optional[str] = None
members: list[AdminAccountMember] = Field(default_factory=list)

View File

@@ -51,6 +51,36 @@ class TestAdminEndpoints:
assert payload["total"] >= 1
assert len(payload["items"]) >= 1
assert "members" in payload["items"][0]
assert "subscription" in payload["items"][0]
@pytest.mark.asyncio
async def test_update_account_plan(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test changing an account's subscription plan."""
account_id = test_user["user_data"]["account_id"]
response = await client.put(
f"/api/v1/admin/accounts/{account_id}/subscription/plan",
json={"plan": "pro"},
headers=admin_auth_headers,
)
assert response.status_code == 200
assert response.json()["plan"] == "pro"
@pytest.mark.asyncio
async def test_extend_account_trial(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test starting or extending an account trial."""
account_id = test_user["user_data"]["account_id"]
response = await client.put(
f"/api/v1/admin/accounts/{account_id}/subscription/extend-trial",
json={"days": 14},
headers=admin_auth_headers,
)
assert response.status_code == 200
assert response.json()["status"] == "trialing"
assert response.json()["current_period_end"] is not None
@pytest.mark.asyncio
async def test_list_users_as_non_admin(