feat: expand admin customer account controls
This commit is contained in:
@@ -6,7 +6,7 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, or_
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import selectinload, aliased
|
||||
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.audit import log_audit
|
||||
@@ -36,6 +36,9 @@ from app.schemas.admin import (
|
||||
AdminAccountMember,
|
||||
AdminAccountListItem,
|
||||
AdminAccountListResponse,
|
||||
AdminAccountOwnerSummary,
|
||||
AdminAccountSubscriptionSummary,
|
||||
AdminAccountUsageSummary,
|
||||
)
|
||||
from app.schemas.subscription import SubscriptionPlanUpdate, ExtendTrialRequest
|
||||
from app.schemas.user_detail import (
|
||||
@@ -43,6 +46,7 @@ from app.schemas.user_detail import (
|
||||
SessionSummary, AuditLogSummary, InviteCodeUsedSummary,
|
||||
)
|
||||
from app.api.deps import require_admin
|
||||
from app.core.subscriptions import get_account_usage
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
@@ -149,22 +153,70 @@ async def list_accounts(
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(12, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="Search by account, display code, or owner"),
|
||||
plan: Optional[str] = Query(None, description="Filter by subscription plan"),
|
||||
status: Optional[str] = Query(None, description="Filter by subscription status"),
|
||||
include_archived: bool = Query(False, description="Include archived users in account member lists"),
|
||||
):
|
||||
"""List accounts with embedded members for the admin panel."""
|
||||
total_result = await db.execute(select(func.count()).select_from(Account))
|
||||
owner_user = aliased(User)
|
||||
|
||||
count_query = (
|
||||
select(func.count(func.distinct(Account.id)))
|
||||
.select_from(Account)
|
||||
.outerjoin(owner_user, Account.owner_id == owner_user.id)
|
||||
.outerjoin(Subscription, Subscription.account_id == Account.id)
|
||||
)
|
||||
accounts_query = (
|
||||
select(
|
||||
Account,
|
||||
owner_user.id.label("owner_user_id"),
|
||||
owner_user.name.label("owner_name"),
|
||||
owner_user.email.label("owner_email"),
|
||||
Subscription.id.label("subscription_id"),
|
||||
Subscription.plan.label("subscription_plan"),
|
||||
Subscription.status.label("subscription_status"),
|
||||
Subscription.billing_interval.label("subscription_billing_interval"),
|
||||
Subscription.current_period_end.label("subscription_current_period_end"),
|
||||
Subscription.cancel_at_period_end.label("subscription_cancel_at_period_end"),
|
||||
)
|
||||
.outerjoin(owner_user, Account.owner_id == owner_user.id)
|
||||
.outerjoin(Subscription, Subscription.account_id == Account.id)
|
||||
)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search.strip()}%"
|
||||
search_filter = or_(
|
||||
Account.name.ilike(search_term),
|
||||
Account.display_code.ilike(search_term),
|
||||
owner_user.name.ilike(search_term),
|
||||
owner_user.email.ilike(search_term),
|
||||
)
|
||||
count_query = count_query.where(search_filter)
|
||||
accounts_query = accounts_query.where(search_filter)
|
||||
if plan:
|
||||
count_query = count_query.where(Subscription.plan == plan)
|
||||
accounts_query = accounts_query.where(Subscription.plan == plan)
|
||||
if status:
|
||||
count_query = count_query.where(Subscription.status == status)
|
||||
accounts_query = accounts_query.where(Subscription.status == status)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
accounts_result = await db.execute(
|
||||
select(Account)
|
||||
accounts_query
|
||||
.order_by(Account.created_at.desc())
|
||||
.offset((page - 1) * size)
|
||||
.limit(size)
|
||||
)
|
||||
accounts = accounts_result.scalars().all()
|
||||
rows = accounts_result.all()
|
||||
accounts = [row.Account for row in rows]
|
||||
|
||||
account_ids = [account.id for account in accounts]
|
||||
members_by_account: dict[UUID, list[AdminAccountMember]] = {account_id: [] for account_id in account_ids}
|
||||
pending_invites_by_account: dict[UUID, int] = {account_id: 0 for account_id in account_ids}
|
||||
usage_by_account: dict[UUID, AdminAccountUsageSummary] = {}
|
||||
|
||||
if account_ids:
|
||||
members_query = select(User).where(User.account_id.in_(account_ids))
|
||||
@@ -189,18 +241,56 @@ async def list_accounts(
|
||||
)
|
||||
)
|
||||
|
||||
pending_invites_result = await db.execute(
|
||||
select(AccountInvite.account_id, func.count(AccountInvite.id))
|
||||
.where(
|
||||
AccountInvite.account_id.in_(account_ids),
|
||||
AccountInvite.used_at.is_(None),
|
||||
)
|
||||
.group_by(AccountInvite.account_id)
|
||||
)
|
||||
pending_invites_by_account.update({row[0]: row[1] for row in pending_invites_result.all()})
|
||||
|
||||
for account_id in account_ids:
|
||||
usage = await get_account_usage(account_id, db)
|
||||
usage_by_account[account_id] = AdminAccountUsageSummary(
|
||||
tree_count=usage.get("tree_count", 0),
|
||||
session_count_this_month=usage.get("session_count_this_month", 0),
|
||||
)
|
||||
|
||||
items = [
|
||||
AdminAccountListItem(
|
||||
id=account.id,
|
||||
name=account.name,
|
||||
display_code=account.display_code,
|
||||
created_at=account.created_at,
|
||||
owner_id=account.owner_id,
|
||||
member_count=len(members_by_account.get(account.id, [])),
|
||||
active_member_count=sum(1 for member in members_by_account.get(account.id, []) if member.is_active),
|
||||
members=members_by_account.get(account.id, []),
|
||||
id=row.Account.id,
|
||||
name=row.Account.name,
|
||||
display_code=row.Account.display_code,
|
||||
created_at=row.Account.created_at,
|
||||
owner_id=row.Account.owner_id,
|
||||
owner=(
|
||||
AdminAccountOwnerSummary(
|
||||
id=row.owner_user_id,
|
||||
name=row.owner_name,
|
||||
email=row.owner_email,
|
||||
) if row.owner_user_id and row.owner_name and row.owner_email else None
|
||||
),
|
||||
subscription=(
|
||||
AdminAccountSubscriptionSummary(
|
||||
id=row.subscription_id,
|
||||
plan=row.subscription_plan,
|
||||
status=row.subscription_status,
|
||||
billing_interval=row.subscription_billing_interval,
|
||||
current_period_end=row.subscription_current_period_end,
|
||||
cancel_at_period_end=row.subscription_cancel_at_period_end or False,
|
||||
) if row.subscription_id and row.subscription_plan and row.subscription_status else None
|
||||
),
|
||||
usage=usage_by_account.get(row.Account.id, AdminAccountUsageSummary()),
|
||||
member_count=len(members_by_account.get(row.Account.id, [])),
|
||||
active_member_count=sum(1 for member in members_by_account.get(row.Account.id, []) if member.is_active),
|
||||
pending_invite_count=pending_invites_by_account.get(row.Account.id, 0),
|
||||
sso_enabled=row.Account.sso_enabled,
|
||||
branding_company_name=row.Account.branding_company_name,
|
||||
members=members_by_account.get(row.Account.id, []),
|
||||
)
|
||||
for account in accounts
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return AdminAccountListResponse(
|
||||
@@ -662,6 +752,28 @@ async def _get_user_subscription(user_id: UUID, db: AsyncSession) -> tuple[User,
|
||||
return user, subscription
|
||||
|
||||
|
||||
async def _get_account_subscription(account_id: UUID, db: AsyncSession) -> tuple[Account, Subscription]:
|
||||
"""Helper to load account and its subscription."""
|
||||
account_result = await db.execute(select(Account).where(Account.id == account_id))
|
||||
account = account_result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
|
||||
|
||||
sub_result = await db.execute(
|
||||
select(Subscription).where(Subscription.account_id == account.id)
|
||||
)
|
||||
subscription = sub_result.scalar_one_or_none()
|
||||
if not subscription:
|
||||
subscription = Subscription(
|
||||
account_id=account.id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(subscription)
|
||||
await db.flush()
|
||||
return account, subscription
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/subscription/plan")
|
||||
async def update_user_plan(
|
||||
user_id: UUID,
|
||||
@@ -681,6 +793,31 @@ async def update_user_plan(
|
||||
return {"plan": subscription.plan, "status": subscription.status}
|
||||
|
||||
|
||||
@router.put("/accounts/{account_id}/subscription/plan")
|
||||
async def update_account_plan(
|
||||
account_id: UUID,
|
||||
data: SubscriptionPlanUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Change an account subscription plan (super admin only)."""
|
||||
if data.plan not in ("free", "pro", "team"):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
|
||||
account, subscription = await _get_account_subscription(account_id, db)
|
||||
old_plan = subscription.plan
|
||||
subscription.plan = data.plan
|
||||
await log_audit(
|
||||
db,
|
||||
current_user.id,
|
||||
"subscription.plan_change",
|
||||
"subscription",
|
||||
subscription.id,
|
||||
{"old_plan": old_plan, "new_plan": data.plan, "account_id": str(account_id)},
|
||||
)
|
||||
await db.commit()
|
||||
return {"plan": subscription.plan, "status": subscription.status}
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/subscription/extend-trial")
|
||||
async def extend_user_trial(
|
||||
user_id: UUID,
|
||||
@@ -711,6 +848,43 @@ async def extend_user_trial(
|
||||
"current_period_end": subscription.current_period_end}
|
||||
|
||||
|
||||
@router.put("/accounts/{account_id}/subscription/extend-trial")
|
||||
async def extend_account_trial(
|
||||
account_id: UUID,
|
||||
data: ExtendTrialRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Extend or start a trial for an account subscription (super admin only)."""
|
||||
if data.days < 1 or data.days > 90:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Days must be 1-90")
|
||||
account, subscription = await _get_account_subscription(account_id, db)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if subscription.status == "trialing" and subscription.current_period_end:
|
||||
new_end = subscription.current_period_end + timedelta(days=data.days)
|
||||
else:
|
||||
subscription.status = "trialing"
|
||||
subscription.current_period_start = now
|
||||
new_end = now + timedelta(days=data.days)
|
||||
|
||||
subscription.current_period_end = new_end
|
||||
await log_audit(
|
||||
db,
|
||||
current_user.id,
|
||||
"subscription.extend_trial",
|
||||
"subscription",
|
||||
subscription.id,
|
||||
{"days": data.days, "new_end": new_end.isoformat(), "account_id": str(account.id)},
|
||||
)
|
||||
await db.commit()
|
||||
return {
|
||||
"plan": subscription.plan,
|
||||
"status": subscription.status,
|
||||
"current_period_end": subscription.current_period_end,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/password-reset", response_model=AdminPasswordResetResponse)
|
||||
async def admin_reset_password(
|
||||
user_id: UUID,
|
||||
|
||||
@@ -66,14 +66,40 @@ class AdminAccountMember(BaseModel):
|
||||
deleted_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class AdminAccountOwnerSummary(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class AdminAccountSubscriptionSummary(BaseModel):
|
||||
id: UUID
|
||||
plan: str
|
||||
status: str
|
||||
billing_interval: Optional[str] = None
|
||||
current_period_end: Optional[datetime] = None
|
||||
cancel_at_period_end: bool = False
|
||||
|
||||
|
||||
class AdminAccountUsageSummary(BaseModel):
|
||||
tree_count: int = 0
|
||||
session_count_this_month: int = 0
|
||||
|
||||
|
||||
class AdminAccountListItem(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
display_code: str
|
||||
created_at: datetime
|
||||
owner_id: Optional[UUID] = None
|
||||
owner: Optional[AdminAccountOwnerSummary] = None
|
||||
subscription: Optional[AdminAccountSubscriptionSummary] = None
|
||||
usage: AdminAccountUsageSummary = Field(default_factory=AdminAccountUsageSummary)
|
||||
member_count: int = 0
|
||||
active_member_count: int = 0
|
||||
pending_invite_count: int = 0
|
||||
sso_enabled: bool = False
|
||||
branding_company_name: Optional[str] = None
|
||||
members: list[AdminAccountMember] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
@@ -51,6 +51,36 @@ class TestAdminEndpoints:
|
||||
assert payload["total"] >= 1
|
||||
assert len(payload["items"]) >= 1
|
||||
assert "members" in payload["items"][0]
|
||||
assert "subscription" in payload["items"][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_plan(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test changing an account's subscription plan."""
|
||||
account_id = test_user["user_data"]["account_id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/accounts/{account_id}/subscription/plan",
|
||||
json={"plan": "pro"},
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["plan"] == "pro"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extend_account_trial(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test starting or extending an account trial."""
|
||||
account_id = test_user["user_data"]["account_id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/accounts/{account_id}/subscription/extend-trial",
|
||||
json={"days": 14},
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "trialing"
|
||||
assert response.json()["current_period_end"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_as_non_admin(
|
||||
|
||||
Reference in New Issue
Block a user