Files
resolutionflow/backend/app/api/endpoints/auth.py
chihlasm ad59446332 feat: user management — admin create, password reset, archive/delete, quick invite
Phase 1: must_change_password enforcement + change password endpoint/page
Phase 2: Admin user creation (M365-style) with temp password
Phase 3: Password reset (self-service forgot + admin-triggered)
Phase 4: User archive (soft delete) + hard delete with precheck
Phase 5: Quick invite from admin Users page

Also fixes:
- Auto-create subscription for accounts missing one
- Hard delete precheck ignores sole-member personal accounts
- Seed script patches tree nodes for validation compliance

Migrations: 031 (must_change_password), 032 (password_reset_tokens), 033 (user soft delete)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 01:42:51 -05:00

546 lines
18 KiB
Python

import secrets
import string
from datetime import datetime, timezone, timedelta
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.config import settings
from app.core.database import get_db
from app.core.rate_limit import limiter
from app.core.security import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
create_password_reset_token,
decode_token,
hash_token,
)
from app.models.user import User
from app.models.invite_code import InviteCode
from app.models.refresh_token import RefreshToken
from app.models.account import Account
from app.models.subscription import Subscription
from app.models.account_invite import AccountInvite
from app.schemas.user import UserCreate, UserResponse, UserLogin
from app.schemas.token import Token
from app.schemas.auth_password import (
ChangePasswordRequest,
ForgotPasswordRequest,
VerifyResetTokenRequest,
VerifyResetTokenResponse,
ResetPasswordRequest,
)
from app.models.password_reset_token import PasswordResetToken
from app.core.email import EmailService
from app.api.deps import get_current_active_user, get_refresh_token_payload
from app.core.audit import log_audit
router = APIRouter(prefix="/auth", tags=["authentication"])
async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
"""Decode a refresh token JWT and store its hash in the database."""
payload = decode_token(refresh_token_str)
if payload and payload.get("jti"):
token_record = RefreshToken(
token_hash=hash_token(payload["jti"]),
user_id=user_id,
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
)
db.add(token_record)
def _generate_display_code() -> str:
"""Generate a random 8-character alphanumeric display code."""
chars = string.ascii_uppercase + string.digits
return ''.join(secrets.choice(chars) for _ in range(8))
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("3/minute")
async def register(
request: Request,
user_data: UserCreate,
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Register a new user.
Supports two flows:
- account_invite_code: Join an existing account (bypasses platform invite gate)
- invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled)
After user creation, if no account invite was used, a personal Account
and free Subscription are created automatically.
"""
# Check for account invite code FIRST — bypasses platform invite gate
account_invite_record = None
if user_data.account_invite_code:
result = await db.execute(
select(AccountInvite).where(
AccountInvite.code == user_data.account_invite_code
)
)
account_invite_record = result.scalar_one_or_none()
if not account_invite_record:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid account invite code"
)
if account_invite_record.is_used:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Account invite code has already been used"
)
if account_invite_record.is_expired:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Account invite code has expired"
)
# Validate platform invite code (skip if account invite was provided)
invite_code_record = None
if not account_invite_record:
if settings.REQUIRE_INVITE_CODE and not user_data.invite_code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code is required"
)
if user_data.invite_code:
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE
result = await db.execute(
select(InviteCode).where(InviteCode.code == user_data.invite_code.upper())
)
invite_code_record = result.scalar_one_or_none()
if not invite_code_record:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid invite code"
)
if invite_code_record.is_used:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code has already been used"
)
if invite_code_record.is_expired:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code has expired"
)
# Check if email already exists
result = await db.execute(select(User).where(User.email == user_data.email))
existing_user = result.scalar_one_or_none()
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
if account_invite_record:
# Join existing account via account invite
new_user = User(
email=user_data.email,
password_hash=get_password_hash(user_data.password),
name=user_data.name,
role="engineer",
invite_code_id=invite_code_record.id if invite_code_record else None,
account_id=account_invite_record.account_id,
account_role=account_invite_record.role,
)
db.add(new_user)
await db.flush()
# Mark account invite as used
account_invite_record.accepted_by_id = new_user.id
account_invite_record.used_at = datetime.now(timezone.utc)
else:
# Create personal Account first (user needs account_id for NOT NULL constraint)
new_account = Account(
name=f"{user_data.name}'s Account",
display_code=_generate_display_code(),
)
db.add(new_account)
await db.flush() # Get account ID
new_user = User(
email=user_data.email,
password_hash=get_password_hash(user_data.password),
name=user_data.name,
role="engineer",
invite_code_id=invite_code_record.id if invite_code_record else None,
account_id=new_account.id,
account_role="owner",
)
db.add(new_user)
await db.flush() # Get user ID
# Now set account owner and create subscription
new_account.owner_id = new_user.id
# Apply plan/trial from invite code if present
sub_plan = "free"
sub_status = "active"
period_start = None
period_end = None
if invite_code_record and invite_code_record.assigned_plan:
sub_plan = invite_code_record.assigned_plan
if invite_code_record.trial_duration_days:
sub_status = "trialing"
period_start = datetime.now(timezone.utc)
period_end = period_start + timedelta(days=invite_code_record.trial_duration_days)
new_subscription = Subscription(
account_id=new_account.id,
plan=sub_plan,
status=sub_status,
current_period_start=period_start,
current_period_end=period_end,
)
db.add(new_subscription)
# Mark platform invite code as used
if invite_code_record:
invite_code_record.used_by_id = new_user.id
invite_code_record.used_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(new_user)
return new_user
@router.post("/login", response_model=Token)
@limiter.limit("5/minute")
async def login(
request: Request,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Login and get access token."""
# Find user by email
result = await db.execute(select(User).where(User.email == form_data.username))
user = result.scalar_one_or_none()
if not user or not verify_password(form_data.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Update last login
user.last_login = datetime.now(timezone.utc)
# Create tokens
access_token = create_access_token(data={"sub": str(user.id)})
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await _store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
)
@router.post("/login/json", response_model=Token)
@limiter.limit("5/minute")
async def login_json(
request: Request,
credentials: UserLogin,
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Login with JSON body (alternative to form data)."""
result = await db.execute(select(User).where(User.email == credentials.email))
user = result.scalar_one_or_none()
if not user or not verify_password(credentials.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password"
)
user.last_login = datetime.now(timezone.utc)
access_token = create_access_token(data={"sub": str(user.id)})
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await _store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
)
@router.post("/refresh", response_model=Token)
@limiter.limit("10/minute")
async def refresh_token(
request: Request,
payload: Annotated[dict, Depends(get_refresh_token_payload)],
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Refresh access token using refresh token (rotation: old token is revoked)."""
user_id = payload.get("sub")
jti = payload.get("jti")
# Validate refresh token hasn't been revoked
if jti:
token_hash = hash_token(jti)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
stored_token = result.scalar_one_or_none()
if stored_token and stored_token.is_revoked:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has been revoked"
)
# Revoke the old refresh token (token rotation)
if stored_token:
stored_token.revoked_at = datetime.now(timezone.utc)
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found"
)
access_token = create_access_token(data={"sub": str(user.id)})
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store new refresh token
await _store_refresh_token(db, new_refresh_token_str, user.id)
await db.commit()
return Token(
access_token=access_token,
refresh_token=new_refresh_token_str,
token_type="bearer"
)
@router.get("/me", response_model=UserResponse)
async def get_me(
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get current authenticated user."""
return current_user
@router.post("/logout")
async def logout(
payload: Annotated[dict, Depends(get_refresh_token_payload)],
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Logout user by revoking the refresh token."""
jti = payload.get("jti")
if jti:
token_hash = hash_token(jti)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
stored_token = result.scalar_one_or_none()
if stored_token and not stored_token.is_revoked:
stored_token.revoked_at = datetime.now(timezone.utc)
await db.commit()
return {"message": "Successfully logged out"}
@router.post("/password/change")
@limiter.limit("5/minute")
async def change_password(
request: Request,
data: ChangePasswordRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Change the current user's password."""
if not verify_password(data.current_password, current_user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Current password is incorrect"
)
if data.current_password == data.new_password:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="New password must be different from current password"
)
current_user.password_hash = get_password_hash(data.new_password)
current_user.must_change_password = False
# Revoke all refresh tokens for this user
result = await db.execute(
select(RefreshToken).where(
RefreshToken.user_id == current_user.id,
RefreshToken.revoked_at.is_(None)
)
)
active_tokens = result.scalars().all()
for token in active_tokens:
token.revoked_at = datetime.now(timezone.utc)
await log_audit(db, current_user.id, "auth.password_change", "user", current_user.id)
await db.commit()
return {"message": "Password changed successfully"}
@router.post("/password/forgot")
@limiter.limit("3/minute")
async def forgot_password(
request: Request,
data: ForgotPasswordRequest,
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Request a password reset email. Always returns success (anti-enumeration)."""
result = await db.execute(select(User).where(User.email == data.email))
user = result.scalar_one_or_none()
if user:
# Create reset token JWT
raw_token = create_password_reset_token(str(user.id))
payload = decode_token(raw_token)
if payload and payload.get("jti"):
token_record = PasswordResetToken(
token_hash=hash_token(payload["jti"]),
user_id=user.id,
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
)
db.add(token_record)
await db.commit()
# Send email (best-effort)
reset_url = f"{settings.FRONTEND_URL}/reset-password?token={raw_token}"
await EmailService.send_password_reset_email(
to_email=user.email,
reset_url=reset_url,
)
await log_audit(db, user.id, "auth.password_reset.request", "user", user.id)
await db.commit()
return {"message": "If an account with that email exists, a reset link has been sent."}
@router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse)
async def verify_reset_token(
data: VerifyResetTokenRequest,
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Verify a password reset token is valid."""
payload = decode_token(data.token)
if not payload or payload.get("type") != "password_reset":
return VerifyResetTokenResponse(valid=False)
jti = payload.get("jti")
if not jti:
return VerifyResetTokenResponse(valid=False)
result = await db.execute(
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
)
token_record = result.scalar_one_or_none()
if not token_record or not token_record.is_valid:
return VerifyResetTokenResponse(valid=False)
# Get user email for display
user_result = await db.execute(select(User.email).where(User.id == token_record.user_id))
email = user_result.scalar_one_or_none()
return VerifyResetTokenResponse(valid=True, email=email)
@router.post("/password/reset")
@limiter.limit("5/minute")
async def reset_password(
request: Request,
data: ResetPasswordRequest,
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Reset password using a valid reset token."""
payload = decode_token(data.token)
if not payload or payload.get("type") != "password_reset":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired reset token"
)
jti = payload.get("jti")
user_id = payload.get("sub")
if not jti or not user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid reset token"
)
# Validate token in DB (single-use)
result = await db.execute(
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
)
token_record = result.scalar_one_or_none()
if not token_record or not token_record.is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Reset token has already been used or has expired"
)
# Get user
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid reset token"
)
# Update password
user.password_hash = get_password_hash(data.new_password)
user.must_change_password = False
# Mark token as used
token_record.used_at = datetime.now(timezone.utc)
# Revoke all refresh tokens
rt_result = await db.execute(
select(RefreshToken).where(
RefreshToken.user_id == user.id,
RefreshToken.revoked_at.is_(None)
)
)
for rt in rt_result.scalars().all():
rt.revoked_at = datetime.now(timezone.utc)
await log_audit(db, user.id, "auth.password_reset.complete", "user", user.id)
await db.commit()
return {"message": "Password has been reset successfully"}