feat: admin invite codes with plan assignment + user detail page

- Migration 030: add email, assigned_plan, trial_duration_days, email_sent_at
  to invite_codes with CHECK constraints
- Resend email integration (graceful degradation when API key not set)
- Invite codes now support plan assignment (free/pro/team) and trial duration (1-90 days)
- Registration applies invite code plan/trial to new subscription
- Auto-downgrade expired trials on authenticated access
- Enriched GET /admin/users/{id} with account, subscription, sessions, audit logs
- New endpoints: PUT /admin/users/{id}/subscription/plan and extend-trial
- Frontend: enhanced invite codes page with email, plan, trial fields
- Frontend: new user detail page at /admin/users/:userId
- Fixed API path drift: /invite-codes -> /invites
- 11 new backend tests, 416 total passing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Michael Chihlas
2026-02-11 21:42:58 -05:00
parent a466400c5b
commit 50cb0fc7f0
24 changed files with 2522 additions and 1121 deletions

View File

@@ -0,0 +1,61 @@
"""enhance invite codes with plan assignment and email
Revision ID: 030
Revises: 029
Create Date: 2026-02-12
Adds email, assigned_plan, trial_duration_days, and email_sent_at columns
to invite_codes table for plan-aware invite code creation.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '030'
down_revision: Union[str, None] = '029'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('invite_codes', sa.Column('email', sa.String(255), nullable=True))
op.add_column('invite_codes', sa.Column('assigned_plan', sa.String(50), nullable=False, server_default='free'))
op.add_column('invite_codes', sa.Column('trial_duration_days', sa.Integer(), nullable=True))
op.add_column('invite_codes', sa.Column('email_sent_at', sa.DateTime(timezone=True), nullable=True))
op.create_index('ix_invite_codes_email', 'invite_codes', ['email'])
# Plan must be free/pro/team
op.create_check_constraint(
'ck_invite_codes_assigned_plan',
'invite_codes',
"assigned_plan IN ('free', 'pro', 'team')"
)
# Trial duration 1-90 days or NULL
op.create_check_constraint(
'ck_invite_codes_trial_duration',
'invite_codes',
"trial_duration_days IS NULL OR (trial_duration_days >= 1 AND trial_duration_days <= 90)"
)
# Free plan cannot have trial duration
op.create_check_constraint(
'ck_invite_codes_free_no_trial',
'invite_codes',
"assigned_plan != 'free' OR trial_duration_days IS NULL"
)
def downgrade() -> None:
op.drop_constraint('ck_invite_codes_free_no_trial', 'invite_codes', type_='check')
op.drop_constraint('ck_invite_codes_trial_duration', 'invite_codes', type_='check')
op.drop_constraint('ck_invite_codes_assigned_plan', 'invite_codes', type_='check')
op.drop_index('ix_invite_codes_email', table_name='invite_codes')
op.drop_column('invite_codes', 'email_sent_at')
op.drop_column('invite_codes', 'trial_duration_days')
op.drop_column('invite_codes', 'assigned_plan')
op.drop_column('invite_codes', 'email')

View File

@@ -65,14 +65,36 @@ async def get_refresh_token_payload(
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> User:
"""Ensure user is active (not disabled)."""
"""Ensure user is active (not disabled). Auto-downgrades expired trials."""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account has been deactivated"
)
# Lightweight trial expiry check
if current_user.account_id:
from app.models.subscription import Subscription
from datetime import datetime, timezone
result = await db.execute(
select(Subscription).where(Subscription.account_id == current_user.account_id)
)
subscription = result.scalar_one_or_none()
if (
subscription
and subscription.status == "trialing"
and subscription.current_period_end
and subscription.current_period_end < datetime.now(timezone.utc)
):
subscription.plan = "free"
subscription.status = "active"
subscription.current_period_end = None
subscription.current_period_start = None
await db.commit()
return current_user

View File

@@ -1,15 +1,26 @@
from datetime import datetime, timezone, timedelta
from typing import Annotated, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from sqlalchemy.orm import selectinload
from app.core.database import get_db
from app.core.audit import log_audit
from app.models.user import User
from app.models.account import Account
from app.models.subscription import Subscription
from app.models.session import Session
from app.models.audit_log import AuditLog
from app.models.invite_code import InviteCode
from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate
from app.schemas.admin import MoveUserAccount
from app.schemas.subscription import SubscriptionPlanUpdate, ExtendTrialRequest
from app.schemas.user_detail import (
UserDetailResponse, AccountSummary, SubscriptionSummary,
SessionSummary, AuditLogSummary, InviteCodeUsedSummary,
)
from app.api.deps import require_admin
router = APIRouter(prefix="/admin", tags=["admin"])
@@ -42,13 +53,13 @@ async def list_users(
return users
@router.get("/users/{user_id}", response_model=UserResponse)
@router.get("/users/{user_id}", response_model=UserDetailResponse)
async def get_user(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Get user details (super admin only)."""
"""Get enriched user details (super admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
@@ -58,7 +69,104 @@ async def get_user(
detail="User not found"
)
return user
# Account + subscription
account_summary = None
subscription_summary = None
if user.account_id:
acc_result = await db.execute(select(Account).where(Account.id == user.account_id))
account = acc_result.scalar_one_or_none()
if account:
account_summary = AccountSummary(
id=account.id, name=account.name,
display_code=getattr(account, "display_code", None),
)
sub_result = await db.execute(
select(Subscription).where(Subscription.account_id == user.account_id)
)
subscription = sub_result.scalar_one_or_none()
if subscription:
subscription_summary = SubscriptionSummary(
id=subscription.id, plan=subscription.plan, status=subscription.status,
current_period_start=subscription.current_period_start,
current_period_end=subscription.current_period_end,
)
# Recent sessions (latest 10 + total)
total_sessions_result = await db.execute(
select(func.count()).select_from(Session).where(Session.user_id == user_id)
)
total_sessions = total_sessions_result.scalar() or 0
sessions_result = await db.execute(
select(Session).options(selectinload(Session.tree))
.where(Session.user_id == user_id)
.order_by(Session.started_at.desc())
.limit(10)
)
sessions = sessions_result.scalars().all()
recent_sessions = [
SessionSummary(
id=s.id,
tree_name=s.tree.name if s.tree else None,
started_at=s.started_at,
completed_at=s.completed_at,
outcome=s.outcome,
)
for s in sessions
]
# Recent audit logs (latest 10 + total)
total_audits_result = await db.execute(
select(func.count()).select_from(AuditLog).where(AuditLog.user_id == user_id)
)
total_audit_logs = total_audits_result.scalar() or 0
audits_result = await db.execute(
select(AuditLog).where(AuditLog.user_id == user_id)
.order_by(AuditLog.created_at.desc())
.limit(10)
)
audits = audits_result.scalars().all()
recent_audit_logs = [
AuditLogSummary(
id=a.id, action=a.action, resource_type=a.resource_type,
resource_id=str(a.resource_id) if a.resource_id else None,
created_at=a.created_at, details=a.details,
)
for a in audits
]
# Invite code used
invite_code_used = None
if user.invite_code_id:
ic_result = await db.execute(
select(InviteCode).where(InviteCode.id == user.invite_code_id)
)
ic = ic_result.scalar_one_or_none()
if ic:
creator_email = None
if ic.created_by_id:
creator_result = await db.execute(
select(User.email).where(User.id == ic.created_by_id)
)
creator_email = creator_result.scalar_one_or_none()
invite_code_used = InviteCodeUsedSummary(
code=ic.code, assigned_plan=ic.assigned_plan,
trial_duration_days=ic.trial_duration_days,
created_by_email=creator_email,
)
return UserDetailResponse(
id=user.id, email=user.email, full_name=user.name,
role=user.role, is_active=user.is_active,
is_super_admin=user.is_super_admin,
is_team_admin=getattr(user, "is_team_admin", False),
created_at=user.created_at,
account=account_summary, subscription=subscription_summary,
invite_code_used=invite_code_used,
recent_sessions=recent_sessions, total_sessions=total_sessions,
recent_audit_logs=recent_audit_logs, total_audit_logs=total_audit_logs,
)
@router.put("/users/{user_id}/role", response_model=UserResponse)
@@ -198,3 +306,69 @@ async def move_user_account(
await db.commit()
await db.refresh(user)
return user
async def _get_user_subscription(user_id: UUID, db: AsyncSession) -> tuple[User, Subscription]:
"""Helper to load user and their subscription."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
if not user.account_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User has no account")
sub_result = await db.execute(
select(Subscription).where(Subscription.account_id == user.account_id)
)
subscription = sub_result.scalar_one_or_none()
if not subscription:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Subscription not found")
return user, subscription
@router.put("/users/{user_id}/subscription/plan")
async def update_user_plan(
user_id: UUID,
data: SubscriptionPlanUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Change a user's subscription plan (super admin only)."""
if data.plan not in ("free", "pro", "team"):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
user, subscription = await _get_user_subscription(user_id, db)
old_plan = subscription.plan
subscription.plan = data.plan
await log_audit(db, current_user.id, "subscription.plan_change", "subscription", subscription.id,
{"old_plan": old_plan, "new_plan": data.plan, "user_id": str(user_id)})
await db.commit()
return {"plan": subscription.plan, "status": subscription.status}
@router.put("/users/{user_id}/subscription/extend-trial")
async def extend_user_trial(
user_id: UUID,
data: ExtendTrialRequest,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Extend or start a trial for a user's subscription (super admin only)."""
if data.days < 1 or data.days > 90:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Days must be 1-90")
user, subscription = await _get_user_subscription(user_id, db)
now = datetime.now(timezone.utc)
if subscription.status == "trialing" and subscription.current_period_end:
# Extend existing trial
new_end = subscription.current_period_end + timedelta(days=data.days)
else:
# Start new trial
subscription.status = "trialing"
subscription.current_period_start = now
new_end = now + timedelta(days=data.days)
subscription.current_period_end = new_end
await log_audit(db, current_user.id, "subscription.extend_trial", "subscription", subscription.id,
{"days": data.days, "new_end": new_end.isoformat(), "user_id": str(user_id)})
await db.commit()
return {"plan": subscription.plan, "status": subscription.status,
"current_period_end": subscription.current_period_end}

View File

@@ -1,6 +1,6 @@
import secrets
import string
from datetime import datetime, timezone
from datetime import datetime, timezone, timedelta
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordRequestForm
@@ -92,38 +92,39 @@ async def register(
detail="Account invite code has expired"
)
# Validate platform invite code if required (skip if account invite was provided)
# Validate platform invite code (skip if account invite was provided)
invite_code_record = None
if not account_invite_record and settings.REQUIRE_INVITE_CODE:
if not user_data.invite_code:
if not account_invite_record:
if settings.REQUIRE_INVITE_CODE and not user_data.invite_code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code is required"
)
# Look up invite code (case-insensitive)
result = await db.execute(
select(InviteCode).where(InviteCode.code == user_data.invite_code.upper())
)
invite_code_record = result.scalar_one_or_none()
if not invite_code_record:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid invite code"
if user_data.invite_code:
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE
result = await db.execute(
select(InviteCode).where(InviteCode.code == user_data.invite_code.upper())
)
invite_code_record = result.scalar_one_or_none()
if invite_code_record.is_used:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code has already been used"
)
if not invite_code_record:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid invite code"
)
if invite_code_record.is_expired:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code has expired"
)
if invite_code_record.is_used:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code has already been used"
)
if invite_code_record.is_expired:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code has expired"
)
# Check if email already exists
result = await db.execute(select(User).where(User.email == user_data.email))
@@ -175,10 +176,24 @@ async def register(
# Now set account owner and create subscription
new_account.owner_id = new_user.id
# Apply plan/trial from invite code if present
sub_plan = "free"
sub_status = "active"
period_start = None
period_end = None
if invite_code_record and invite_code_record.assigned_plan:
sub_plan = invite_code_record.assigned_plan
if invite_code_record.trial_duration_days:
sub_status = "trialing"
period_start = datetime.now(timezone.utc)
period_end = period_start + timedelta(days=invite_code_record.trial_duration_days)
new_subscription = Subscription(
account_id=new_account.id,
plan="free",
status="active",
plan=sub_plan,
status=sub_status,
current_period_start=period_start,
current_period_end=period_end,
)
db.add(new_subscription)

View File

@@ -5,6 +5,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.database import get_db
from app.core.rate_limit import limiter
from app.core.audit import log_audit
from app.core.email import EmailService
from app.models.user import User
from app.models.invite_code import InviteCode
from app.schemas.invite_code import InviteCodeCreate, InviteCodeResponse, InviteCodeValidation
@@ -23,9 +25,35 @@ async def create_invite_code(
invite_code = InviteCode(
created_by_id=current_user.id,
expires_at=invite_data.expires_at,
note=invite_data.note
note=invite_data.note,
email=invite_data.email,
assigned_plan=invite_data.assigned_plan,
trial_duration_days=invite_data.trial_duration_days,
)
db.add(invite_code)
await db.flush()
# Send invite email if email provided
email_sent = False
if invite_data.email:
email_sent = await EmailService.send_invite_email(
to_email=invite_data.email,
code=invite_code.code,
plan=invite_data.assigned_plan,
trial_days=invite_data.trial_duration_days,
)
if email_sent:
invite_code.email_sent_at = datetime.now(timezone.utc)
await log_audit(
db, current_user.id, "invite.create", "invite_code", invite_code.id,
{
"code": invite_code.code,
"plan": invite_data.assigned_plan,
"email": invite_data.email,
"email_sent": email_sent,
},
)
await db.commit()
await db.refresh(invite_code)

View File

@@ -52,6 +52,15 @@ class Settings(BaseSettings):
# Registration
REQUIRE_INVITE_CODE: bool = True # Set to False to allow open registration
# Email (Resend)
RESEND_API_KEY: Optional[str] = None
FROM_EMAIL: str = "ResolutionFlow <invites@resolutionflow.com>"
@property
def email_enabled(self) -> bool:
"""Check if email sending is configured."""
return self.RESEND_API_KEY is not None
# Stripe
STRIPE_SECRET_KEY: Optional[str] = None
STRIPE_PUBLISHABLE_KEY: Optional[str] = None

105
backend/app/core/email.py Normal file
View File

@@ -0,0 +1,105 @@
import logging
from app.core.config import settings
logger = logging.getLogger(__name__)
class EmailService:
"""Best-effort email delivery via Resend. Never raises on failure."""
@staticmethod
async def send_invite_email(
to_email: str,
code: str,
plan: str,
trial_days: int | None = None,
signup_url: str = "https://resolutionflow.com/register",
) -> bool:
if not settings.email_enabled:
logger.warning("Email not sent — RESEND_API_KEY not configured")
return False
try:
import resend
resend.api_key = settings.RESEND_API_KEY
plan_label = plan.capitalize()
trial_text = f" with a {trial_days}-day free trial" if trial_days else ""
subject = f"You're invited to ResolutionFlow ({plan_label} plan{trial_text})"
html = _render_invite_html(
code=code,
plan_label=plan_label,
trial_days=trial_days,
signup_url=signup_url,
)
resend.Emails.send(
{
"from": settings.FROM_EMAIL,
"to": [to_email],
"subject": subject,
"html": html,
}
)
logger.info("Invite email sent to %s", to_email)
return True
except Exception:
logger.exception("Failed to send invite email to %s", to_email)
return False
def _render_invite_html(
code: str,
plan_label: str,
trial_days: int | None,
signup_url: str,
) -> str:
trial_section = ""
if trial_days:
trial_section = f"""
<tr><td style="padding:0 40px 20px;">
<p style="margin:0;color:#a0a0a0;font-size:14px;">
Your <strong style="color:#fff;">{trial_days}-day free trial</strong> starts when you register.
After your trial ends, your account will revert to the Free plan.
</p>
</td></tr>"""
return f"""<!DOCTYPE html>
<html><head><meta charset="utf-8"><meta name="viewport" content="width=device-width"></head>
<body style="margin:0;padding:0;background:#000;font-family:'Inter',Helvetica,Arial,sans-serif;">
<table width="100%" cellpadding="0" cellspacing="0" style="background:#000;padding:40px 0;">
<tr><td align="center">
<table width="560" cellpadding="0" cellspacing="0" style="background:#111;border:1px solid rgba(255,255,255,0.06);border-radius:16px;">
<tr><td style="padding:40px 40px 24px;text-align:center;">
<h1 style="margin:0;color:#fff;font-size:24px;font-weight:600;">ResolutionFlow</h1>
<p style="margin:8px 0 0;color:#a0a0a0;font-size:14px;">Decision Tree Platform for MSP Professionals</p>
</td></tr>
<tr><td style="padding:0 40px 24px;">
<p style="margin:0;color:#e0e0e0;font-size:16px;line-height:1.6;">
You've been invited to join ResolutionFlow on the <strong style="color:#fff;">{plan_label}</strong> plan.
</p>
</td></tr>
<tr><td style="padding:0 40px 24px;text-align:center;">
<div style="background:#000;border:1px solid rgba(255,255,255,0.1);border-radius:12px;padding:20px;">
<p style="margin:0 0 4px;color:#a0a0a0;font-size:12px;text-transform:uppercase;letter-spacing:1px;">Your Invite Code</p>
<p style="margin:0;color:#fff;font-size:28px;font-weight:700;letter-spacing:4px;">{code}</p>
</div>
</td></tr>
{trial_section}
<tr><td style="padding:0 40px 32px;text-align:center;">
<a href="{signup_url}" style="display:inline-block;background:#fff;color:#000;font-size:16px;font-weight:600;text-decoration:none;padding:14px 40px;border-radius:8px;">
Create Your Account
</a>
</td></tr>
<tr><td style="padding:0 40px 32px;">
<p style="margin:0;color:#666;font-size:12px;text-align:center;">
Enter the code above during registration, or click the button to get started.
</p>
</td></tr>
</table>
</td></tr>
</table>
</body></html>"""

View File

@@ -46,6 +46,13 @@ class InviteCode(Base):
DateTime(timezone=True),
nullable=True
)
email: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True)
assigned_plan: Mapped[str] = mapped_column(String(50), nullable=False, server_default="free")
trial_duration_days: Mapped[Optional[int]] = mapped_column(nullable=True)
email_sent_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),
nullable=True
)
note: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
@@ -84,3 +91,11 @@ class InviteCode(Base):
def is_valid(self) -> bool:
"""Check if the invite code is valid (not used and not expired)."""
return not self.is_used and not self.is_expired
@property
def has_trial(self) -> bool:
return self.trial_duration_days is not None and self.trial_duration_days > 0
@property
def email_sent(self) -> bool:
return self.email_sent_at is not None

View File

@@ -1,13 +1,22 @@
from datetime import datetime
from typing import Optional
from typing import Literal, Optional
from uuid import UUID
from pydantic import BaseModel, Field
from pydantic import BaseModel, EmailStr, Field, model_validator
class InviteCodeCreate(BaseModel):
"""Schema for creating a new invite code."""
expires_at: Optional[datetime] = Field(None, description="Optional expiration time")
note: Optional[str] = Field(None, max_length=255, description="Note about who this code is for")
email: Optional[EmailStr] = Field(None, description="Recipient email for invite delivery")
assigned_plan: Literal["free", "pro", "team"] = Field("free", description="Plan to assign on registration")
trial_duration_days: Optional[int] = Field(None, ge=1, le=90, description="Trial duration in days (1-90)")
@model_validator(mode="after")
def free_plan_no_trial(self):
if self.assigned_plan == "free" and self.trial_duration_days is not None:
raise ValueError("Free plan cannot have a trial duration")
return self
class InviteCodeResponse(BaseModel):
@@ -23,6 +32,12 @@ class InviteCodeResponse(BaseModel):
is_used: bool
is_expired: bool
is_valid: bool
email: Optional[str] = None
assigned_plan: str = "free"
trial_duration_days: Optional[int] = None
email_sent_at: Optional[datetime] = None
has_trial: bool = False
email_sent: bool = False
class Config:
from_attributes = True

View File

@@ -38,3 +38,15 @@ class SubscriptionDetails(BaseModel):
subscription: SubscriptionResponse
limits: PlanLimitsResponse
usage: UsageResponse
class SubscriptionPlanUpdate(BaseModel):
plan: str # free, pro, team
model_config = {"json_schema_extra": {"examples": [{"plan": "pro"}]}}
class ExtendTrialRequest(BaseModel):
days: int # 1-90
model_config = {"json_schema_extra": {"examples": [{"days": 14}]}}

View File

@@ -0,0 +1,72 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel
class AccountSummary(BaseModel):
id: UUID
name: str
display_code: Optional[str] = None
model_config = {"from_attributes": True}
class SubscriptionSummary(BaseModel):
id: UUID
plan: str
status: str
current_period_start: Optional[datetime] = None
current_period_end: Optional[datetime] = None
model_config = {"from_attributes": True}
class SessionSummary(BaseModel):
id: UUID
tree_name: Optional[str] = None
started_at: datetime
completed_at: Optional[datetime] = None
outcome: Optional[str] = None
model_config = {"from_attributes": True}
class AuditLogSummary(BaseModel):
id: UUID
action: str
resource_type: Optional[str] = None
resource_id: Optional[str] = None
created_at: datetime
details: Optional[dict] = None
model_config = {"from_attributes": True}
class InviteCodeUsedSummary(BaseModel):
code: str
assigned_plan: str
trial_duration_days: Optional[int] = None
created_by_email: Optional[str] = None
model_config = {"from_attributes": True}
class UserDetailResponse(BaseModel):
id: UUID
email: str
full_name: Optional[str] = None
role: str
is_active: bool
is_super_admin: bool
is_team_admin: bool
created_at: datetime
account: Optional[AccountSummary] = None
subscription: Optional[SubscriptionSummary] = None
invite_code_used: Optional[InviteCodeUsedSummary] = None
recent_sessions: list[SessionSummary] = []
total_sessions: int = 0
recent_audit_logs: list[AuditLogSummary] = []
total_audit_logs: int = 0

View File

@@ -25,5 +25,8 @@ slowapi==0.1.9
# Payments
stripe==14.3.0
# Email
resend==2.21.0
# Utilities
python-dotenv==1.0.1

View File

@@ -0,0 +1,227 @@
"""Tests for enhanced invite codes with plan assignment and trial durations."""
import pytest
from datetime import datetime, timezone, timedelta
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.invite_code import InviteCode
from app.models.subscription import Subscription
from app.models.user import User
class TestInviteCodeCreation:
"""Test invite code creation with plan/trial fields."""
@pytest.mark.asyncio
async def test_create_invite_with_plan(
self, client: AsyncClient, admin_auth_headers: dict
):
response = await client.post(
"/api/v1/invites",
json={"assigned_plan": "pro", "note": "Beta tester"},
headers=admin_auth_headers,
)
assert response.status_code == 201
data = response.json()
assert data["assigned_plan"] == "pro"
assert data["has_trial"] is False
assert data["trial_duration_days"] is None
@pytest.mark.asyncio
async def test_create_invite_with_trial(
self, client: AsyncClient, admin_auth_headers: dict
):
response = await client.post(
"/api/v1/invites",
json={"assigned_plan": "pro", "trial_duration_days": 14},
headers=admin_auth_headers,
)
assert response.status_code == 201
data = response.json()
assert data["assigned_plan"] == "pro"
assert data["trial_duration_days"] == 14
assert data["has_trial"] is True
@pytest.mark.asyncio
async def test_create_invite_with_email(
self, client: AsyncClient, admin_auth_headers: dict
):
response = await client.post(
"/api/v1/invites",
json={"assigned_plan": "team", "email": "beta@example.com"},
headers=admin_auth_headers,
)
assert response.status_code == 201
data = response.json()
assert data["email"] == "beta@example.com"
# Email not sent because RESEND_API_KEY not configured
assert data["email_sent"] is False
@pytest.mark.asyncio
async def test_free_plan_rejects_trial(
self, client: AsyncClient, admin_auth_headers: dict
):
response = await client.post(
"/api/v1/invites",
json={"assigned_plan": "free", "trial_duration_days": 14},
headers=admin_auth_headers,
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_trial_duration_bounds(
self, client: AsyncClient, admin_auth_headers: dict
):
# Too low
response = await client.post(
"/api/v1/invites",
json={"assigned_plan": "pro", "trial_duration_days": 0},
headers=admin_auth_headers,
)
assert response.status_code == 422
# Too high
response = await client.post(
"/api/v1/invites",
json={"assigned_plan": "pro", "trial_duration_days": 91},
headers=admin_auth_headers,
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_default_plan_is_free(
self, client: AsyncClient, admin_auth_headers: dict
):
response = await client.post(
"/api/v1/invites",
json={},
headers=admin_auth_headers,
)
assert response.status_code == 201
assert response.json()["assigned_plan"] == "free"
class TestRegistrationWithInvitePlan:
"""Test that registration applies invite code plan/trial to subscription."""
@pytest.mark.asyncio
async def test_register_with_pro_trial_invite(
self, client: AsyncClient, admin_auth_headers: dict, test_db: AsyncSession
):
# Create a pro trial invite
resp = await client.post(
"/api/v1/invites",
json={"assigned_plan": "pro", "trial_duration_days": 14},
headers=admin_auth_headers,
)
code = resp.json()["code"]
# Register with the invite code
reg_resp = await client.post(
"/api/v1/auth/register",
json={
"email": "trial_user@example.com",
"password": "SecurePass1",
"name": "Trial User",
"invite_code": code,
},
)
assert reg_resp.status_code == 201
user_id = reg_resp.json()["id"]
# Check subscription
user = (await test_db.execute(
select(User).where(User.id == user_id)
)).scalar_one()
sub = (await test_db.execute(
select(Subscription).where(Subscription.account_id == user.account_id)
)).scalar_one()
assert sub.plan == "pro"
assert sub.status == "trialing"
assert sub.current_period_end is not None
assert sub.current_period_end > datetime.now(timezone.utc)
@pytest.mark.asyncio
async def test_register_with_team_no_trial(
self, client: AsyncClient, admin_auth_headers: dict, test_db: AsyncSession
):
# Create team invite without trial
resp = await client.post(
"/api/v1/invites",
json={"assigned_plan": "team"},
headers=admin_auth_headers,
)
code = resp.json()["code"]
reg_resp = await client.post(
"/api/v1/auth/register",
json={
"email": "team_user@example.com",
"password": "SecurePass1",
"name": "Team User",
"invite_code": code,
},
)
assert reg_resp.status_code == 201
user_id = reg_resp.json()["id"]
user = (await test_db.execute(
select(User).where(User.id == user_id)
)).scalar_one()
sub = (await test_db.execute(
select(Subscription).where(Subscription.account_id == user.account_id)
)).scalar_one()
assert sub.plan == "team"
assert sub.status == "active"
class TestAdminSubscriptionManagement:
"""Test admin subscription plan change and trial extension endpoints."""
@pytest.mark.asyncio
async def test_change_user_plan(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
user_id = test_user["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{user_id}/subscription/plan",
json={"plan": "pro"},
headers=admin_auth_headers,
)
assert response.status_code == 200
assert response.json()["plan"] == "pro"
@pytest.mark.asyncio
async def test_extend_trial(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
user_id = test_user["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{user_id}/subscription/extend-trial",
json={"days": 14},
headers=admin_auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "trialing"
assert data["current_period_end"] is not None
@pytest.mark.asyncio
async def test_enriched_user_detail(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
user_id = test_user["user_data"]["id"]
response = await client.get(
f"/api/v1/admin/users/{user_id}",
headers=admin_auth_headers,
)
assert response.status_code == 200
data = response.json()
# Should have enriched fields
assert "subscription" in data
assert "account" in data
assert "recent_sessions" in data
assert "total_sessions" in data
assert "recent_audit_logs" in data
assert "total_audit_logs" in data