Files
resolutionflow/backend/app/api/endpoints/auth.py
Michael Chihlas 47ff8ad2b5 feat(l1): enforce seat limits on invite, accept-invite, role-change
For engineer + l1_tech roles, check_seat_available is called at each
mutation point. Returns 402 Payment Required with structured detail
{code: 'seat_limit_exceeded', role, current, limit, upgrade_url} when
seats are full. Grandfathering: existing over-seated accounts keep
existing users; only new mutations are blocked.

Also updates AccountInviteCreate and AccountRoleUpdate schemas to
accept l1_tech as a valid role value.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-28 12:49:59 -04:00

928 lines
34 KiB
Python

import logging
import secrets
import string
from datetime import datetime, timezone, timedelta
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update as sa_update
from app.core.config import settings
from app.core.settings_manager import SettingsManager
from app.core.admin_database import get_admin_db
from app.core.rate_limit import limiter
from app.core.security import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
create_password_reset_token,
create_email_verification_token,
decode_token,
hash_token,
resolve_session_policy,
)
from app.models.user import User
from app.models.invite_code import InviteCode
from app.models.refresh_token import RefreshToken
from app.models.account import Account
from app.models.subscription import Subscription
from app.models.account_invite import AccountInvite
from app.schemas.user import UserCreate, UserResponse, UserLogin, UserUpdate
from app.schemas.token import Token
from app.schemas.auth_password import (
ChangePasswordRequest,
ForgotPasswordRequest,
VerifyResetTokenRequest,
VerifyResetTokenResponse,
ResetPasswordRequest,
)
from app.models.password_reset_token import PasswordResetToken
from app.models.email_verification_token import EmailVerificationToken
from app.core.email import EmailService
from app.api.deps import get_current_active_user, get_refresh_token_payload
from app.core.audit import log_audit
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["authentication"])
async def store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
"""Decode a refresh token JWT and store its hash in the database.
Module-public so OAuth callback endpoints (and any future token-issuing
surface) can register the JTI in the ``refresh_tokens`` table the same
way ``/auth/login`` does. Without this the first ``/auth/refresh`` call
will reject the token as "revoked" because no row exists.
Caller is responsible for committing the session.
"""
payload = decode_token(refresh_token_str)
if payload and payload.get("jti"):
token_record = RefreshToken(
token_hash=hash_token(payload["jti"]),
user_id=user_id,
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
)
db.add(token_record)
async def _mint_session_tokens(user: User, db: AsyncSession) -> Token:
"""Mint a fresh refresh+access pair for a new login.
Snapshots the account's current session policy into the refresh JWT
(auth_time/idle_max/abs_max) and registers the JTI in refresh_tokens.
Caller is responsible for committing the session. Use this for every
NEW login (password, OAuth, etc.) — for /auth/refresh use
_refresh_session_tokens instead, which carries claims forward.
See docs/plans/2026-05-13-session-expiration-policy.md §4.6.
"""
account = (
await db.execute(select(Account).where(Account.id == user.account_id))
).scalar_one()
idle_minutes, abs_minutes = resolve_session_policy(account)
idle_max_seconds = idle_minutes * 60
abs_max_seconds = abs_minutes * 60
now = datetime.now(timezone.utc)
auth_time_unix = int(now.timestamp())
refresh_token_str = create_refresh_token(
user_id=str(user.id),
auth_time=auth_time_unix,
idle_max_seconds=idle_max_seconds,
abs_max_seconds=abs_max_seconds,
)
access_token = create_access_token(data={"sub": str(user.id)})
await store_refresh_token(db, refresh_token_str, user.id)
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
idle_expires_at=now + timedelta(seconds=idle_max_seconds),
absolute_expires_at=datetime.fromtimestamp(
auth_time_unix + abs_max_seconds, tz=timezone.utc
),
)
async def _resolve_refresh_claims(
payload: dict, user: User, db: AsyncSession
) -> tuple[int, int, int]:
"""Return (auth_time, idle_max_seconds, abs_max_seconds) for a refresh.
Grandfathers legacy tokens issued before the session-policy PR: tokens
missing any of auth_time/idle_max/abs_max get treated as if just minted
under the account's current policy. One free rotation under the new
rules — see plan §5.1. Callers that have the claims use them as-is.
"""
auth_time = payload.get("auth_time")
idle_max_seconds = payload.get("idle_max")
abs_max_seconds = payload.get("abs_max")
if auth_time is None or idle_max_seconds is None or abs_max_seconds is None:
account = (
await db.execute(select(Account).where(Account.id == user.account_id))
).scalar_one()
idle_minutes, abs_minutes = resolve_session_policy(account)
auth_time = int(datetime.now(timezone.utc).timestamp())
idle_max_seconds = idle_minutes * 60
abs_max_seconds = abs_minutes * 60
return auth_time, idle_max_seconds, abs_max_seconds
async def _mint_with_claims(
user: User,
auth_time: int,
idle_max_seconds: int,
abs_max_seconds: int,
db: AsyncSession,
) -> Token:
"""Mint a refresh+access pair carrying explicit session-policy claims.
Used by /auth/refresh after the grandfather + absolute-cap checks
have already produced the effective claim values. Caller commits.
"""
now = datetime.now(timezone.utc)
refresh_token_str = create_refresh_token(
user_id=str(user.id),
auth_time=auth_time,
idle_max_seconds=idle_max_seconds,
abs_max_seconds=abs_max_seconds,
)
access_token = create_access_token(data={"sub": str(user.id)})
await store_refresh_token(db, refresh_token_str, user.id)
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
idle_expires_at=now + timedelta(seconds=idle_max_seconds),
absolute_expires_at=datetime.fromtimestamp(
auth_time + abs_max_seconds, tz=timezone.utc
),
)
def _generate_display_code() -> str:
"""Generate a random 8-character alphanumeric display code."""
chars = string.ascii_uppercase + string.digits
return ''.join(secrets.choice(chars) for _ in range(8))
async def _reject_if_oauth_only(db: AsyncSession, user) -> None:
"""If the user has no password_hash, raise 400 with a list of linked
providers so the client can redirect them to the right OAuth flow."""
if user is None or user.password_hash is not None:
return
from app.models.oauth_identity import OAuthIdentity
result = await db.execute(
select(OAuthIdentity.provider).where(OAuthIdentity.user_id == user.id)
)
providers = [row for row in result.scalars().all()]
raise HTTPException(
status_code=400,
detail={"error": "use_oauth_provider", "providers": providers},
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("3/minute")
async def register(
request: Request,
user_data: UserCreate,
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Register a new user.
Supports two flows:
- account_invite_code: Join an existing account (bypasses platform invite gate)
- invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled)
After user creation, if no account invite was used, a personal Account
and free Subscription are created automatically.
"""
# Check for account invite code FIRST — bypasses platform invite gate.
# SELECT FOR UPDATE prevents two concurrent registrations from both
# reading the same invite as unused and double-spending it.
account_invite_record = None
if user_data.account_invite_code:
result = await db.execute(
select(AccountInvite)
.where(AccountInvite.code == user_data.account_invite_code)
.with_for_update()
)
account_invite_record = result.scalar_one_or_none()
if not account_invite_record:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid account invite code"
)
if account_invite_record.is_used:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Account invite code has already been used"
)
if account_invite_record.is_expired:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Account invite code has expired"
)
if account_invite_record.email.lower() != user_data.email.lower():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "invite_email_mismatch"},
)
# Validate platform invite code (skip if account invite was provided)
invite_code_record = None
if not account_invite_record:
# When SELF_SERVE_ENABLED is on, the platform invite gate is bypassed
# entirely — public self-serve signup is the whole point. The
# invite_code field stays in the schema for backward compatibility
# and so paid/trial-bearing codes still apply when supplied.
if (
settings.REQUIRE_INVITE_CODE
and not settings.is_self_serve_active_for(user_data.email)
and not user_data.invite_code
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code is required"
)
if user_data.invite_code:
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE.
# FOR UPDATE prevents double-spend by concurrent registrations.
result = await db.execute(
select(InviteCode)
.where(InviteCode.code == user_data.invite_code.upper())
.with_for_update()
)
invite_code_record = result.scalar_one_or_none()
if not invite_code_record:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid invite code"
)
if invite_code_record.is_used:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code has already been used"
)
if invite_code_record.is_expired:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invite code has expired"
)
# Seat enforcement: re-check at accept time (race-condition guard).
# Fires only when an account invite is being accepted and the target role
# is seat-counted (engineer, l1_tech). Accounts without a subscription
# (free / pre-billing) are not blocked.
if account_invite_record and account_invite_record.role in ("engineer", "l1_tech"):
from app.core.subscriptions import get_account_subscription
from app.services.seat_enforcement import check_seat_available
from app.models.account import Account as _Account
sub = await get_account_subscription(account_invite_record.account_id, db)
if sub is not None:
acct_result = await db.execute(
select(_Account).where(_Account.id == account_invite_record.account_id)
)
acct = acct_result.scalar_one()
seat_result = await check_seat_available(acct, sub, account_invite_record.role, db)
if not seat_result.available:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"code": "seat_limit_exceeded",
"role": seat_result.role,
"current": seat_result.current,
"limit": seat_result.limit,
"upgrade_url": "/account/billing",
},
)
# Check if email already exists
result = await db.execute(select(User).where(User.email == user_data.email))
existing_user = result.scalar_one_or_none()
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
if account_invite_record:
# Join existing account via account invite
new_user = User(
email=user_data.email,
password_hash=get_password_hash(user_data.password),
name=user_data.name,
role="engineer",
invite_code_id=invite_code_record.id if invite_code_record else None,
account_id=account_invite_record.account_id,
account_role=account_invite_record.role,
)
db.add(new_user)
await db.flush()
# Mark account invite as used
account_invite_record.accepted_by_id = new_user.id
account_invite_record.used_at = datetime.now(timezone.utc)
else:
# Create personal Account first (user needs account_id for NOT NULL constraint)
new_account = Account(
name=f"{user_data.name}'s Account",
display_code=_generate_display_code(),
)
db.add(new_account)
await db.flush() # Get account ID
new_user = User(
email=user_data.email,
password_hash=get_password_hash(user_data.password),
name=user_data.name,
role="engineer",
invite_code_id=invite_code_record.id if invite_code_record else None,
account_id=new_account.id,
account_role="owner",
)
db.add(new_user)
await db.flush() # Get user ID
# Now set account owner and create subscription
new_account.owner_id = new_user.id
if invite_code_record and invite_code_record.assigned_plan:
# Plan/trial driven by platform invite code (existing pilot flow)
sub_plan = invite_code_record.assigned_plan
sub_status = "active"
period_start = None
period_end = None
if invite_code_record.trial_duration_days:
sub_status = "trialing"
period_start = datetime.now(timezone.utc)
period_end = period_start + timedelta(days=invite_code_record.trial_duration_days)
db.add(Subscription(
account_id=new_account.id,
plan=sub_plan,
status=sub_status,
current_period_start=period_start,
current_period_end=period_end,
))
else:
# New self-serve shop — start the standard Pro trial.
# start_trial commits internally; flush our pending User/Account changes
# first so the FK is satisfied.
await db.flush()
from app.services.billing import BillingService
await BillingService.start_trial(db, new_account.id)
# Mark platform invite code as used
if invite_code_record:
invite_code_record.used_by_id = new_user.id
invite_code_record.used_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(new_user)
# Auto-send verification email for newly-registered users.
# Skip silently if verification already done (shouldn't happen for fresh
# users, but defensive).
if new_user.email_verified_at is None:
verification_enabled = await SettingsManager.get(
"email_verification_enabled", db, default=True
)
if verification_enabled:
try:
raw_token = create_email_verification_token(str(new_user.id))
payload = decode_token(raw_token)
if payload and payload.get("jti"):
token_record = EmailVerificationToken(
token_hash=hash_token(payload["jti"]),
user_id=new_user.id,
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
)
db.add(token_record)
await db.commit()
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={raw_token}"
await EmailService.send_email_verification_email(
to_email=new_user.email,
verification_url=verification_url,
)
except Exception as e:
logger.warning("verification email send failed for %s: %s", new_user.email, e)
return new_user
@router.post("/login", response_model=Token)
@limiter.limit("5/minute")
async def login(
request: Request,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Login and get access token."""
# Find user by email
result = await db.execute(select(User).where(User.email == form_data.username))
user = result.scalar_one_or_none()
await _reject_if_oauth_only(db, user)
if not user or not verify_password(form_data.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Update last login
user.last_login = datetime.now(timezone.utc)
token = await _mint_session_tokens(user, db)
await db.commit()
return token
@router.post("/login/json", response_model=Token)
@limiter.limit("5/minute")
async def login_json(
request: Request,
credentials: UserLogin,
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Login with JSON body (alternative to form data)."""
result = await db.execute(select(User).where(User.email == credentials.email))
user = result.scalar_one_or_none()
await _reject_if_oauth_only(db, user)
if not user or not verify_password(credentials.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password"
)
user.last_login = datetime.now(timezone.utc)
token = await _mint_session_tokens(user, db)
await db.commit()
return token
@router.post("/refresh", response_model=Token)
@limiter.limit("10/minute")
async def refresh_token(
request: Request,
payload: Annotated[dict, Depends(get_refresh_token_payload)],
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Refresh access token, enforcing both idle and absolute session windows.
Algorithm (see plan §4.5):
1. Decode refresh JWT (the dep already rejects idle-expired tokens with
session_expired_idle).
2. Load the user. If missing or inactive, 401 invalid_refresh_token.
3. Resolve effective auth_time/idle_max/abs_max (grandfather legacy
tokens that pre-date this PR).
4. Atomically revoke the JTI regardless of outcome — so an absolute-
expired token cannot be replayed; the second attempt finds it
already revoked and gets invalid_refresh_token instead.
5. If the atomic UPDATE matched zero rows, 401 invalid_refresh_token.
6. If now >= auth_time + abs_max, 401 session_expired_absolute.
7. Otherwise mint new tokens carrying the claims forward.
"""
user_id = payload.get("sub")
jti = payload.get("jti")
user = (await db.execute(select(User).where(User.id == user_id))).scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_refresh_token",
)
auth_time, idle_max_seconds, abs_max_seconds = await _resolve_refresh_claims(
payload, user, db
)
# Atomically revoke the old refresh token first — this consumes the
# token regardless of whether the absolute check passes, so an absolute-
# expired token cannot be replayed.
if jti:
token_hash = hash_token(jti)
result = await db.execute(
sa_update(RefreshToken)
.where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
)
.values(revoked_at=datetime.now(timezone.utc))
.returning(RefreshToken.id, RefreshToken.user_id)
)
revoked_row = result.fetchone()
if not revoked_row:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_refresh_token",
)
# Absolute-window check. Boundary is `>=`, not `>` — a deadline equal to
# now is expired. The token row has already been revoked above, so the
# client cannot retry this token even though we're raising after the
# consume.
now_unix = int(datetime.now(timezone.utc).timestamp())
if now_unix >= auth_time + abs_max_seconds:
# Commit the revoke so the consumed-on-failure invariant survives
# any subsequent rollback in the request lifecycle.
await db.commit()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="session_expired_absolute",
)
token = await _mint_with_claims(
user, auth_time, idle_max_seconds, abs_max_seconds, db
)
await db.commit()
return token
@router.get("/me", response_model=UserResponse)
async def get_me(
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get current authenticated user."""
return current_user
@router.patch("/me", response_model=UserResponse)
async def update_me(
data: UserUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Update current user's profile (name, email)."""
update_fields = data.model_fields_set - {"current_password"}
if not update_fields:
return current_user
# Email change requires current_password
if "email" in data.model_fields_set:
if not data.current_password:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is required to change email"
)
if not verify_password(data.current_password, current_user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Current password is incorrect"
)
# Check uniqueness
result = await db.execute(
select(User).where(User.email == data.email, User.id != current_user.id)
)
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
current_user.email = data.email
if "name" in data.model_fields_set and data.name is not None:
current_user.name = data.name
# Handle simple string profile fields
for field in ("phone", "job_title", "timezone"):
if field in data.model_fields_set:
setattr(current_user, field, getattr(data, field))
await log_audit(db, current_user.id, "auth.profile_update", "user", current_user.id)
await db.commit()
await db.refresh(current_user)
return current_user
@router.post("/logout")
async def logout(
payload: Annotated[dict, Depends(get_refresh_token_payload)],
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Logout user by revoking the refresh token."""
jti = payload.get("jti")
if jti:
token_hash = hash_token(jti)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
stored_token = result.scalar_one_or_none()
if stored_token and not stored_token.is_revoked:
stored_token.revoked_at = datetime.now(timezone.utc)
await db.commit()
return {"message": "Successfully logged out"}
@router.post("/password/change")
@limiter.limit("5/minute")
async def change_password(
request: Request,
data: ChangePasswordRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Change the current user's password."""
await _reject_if_oauth_only(db, current_user)
if not verify_password(data.current_password, current_user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Current password is incorrect"
)
if data.current_password == data.new_password:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="New password must be different from current password"
)
current_user.password_hash = get_password_hash(data.new_password)
current_user.must_change_password = False
# Revoke all refresh tokens for this user
result = await db.execute(
select(RefreshToken).where(
RefreshToken.user_id == current_user.id,
RefreshToken.revoked_at.is_(None)
)
)
active_tokens = result.scalars().all()
for token in active_tokens:
token.revoked_at = datetime.now(timezone.utc)
await log_audit(db, current_user.id, "auth.password_change", "user", current_user.id)
await db.commit()
return {"message": "Password changed successfully"}
@router.post("/password/forgot")
@limiter.limit("3/minute")
async def forgot_password(
request: Request,
data: ForgotPasswordRequest,
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Request a password reset email. Always returns success (anti-enumeration)."""
result = await db.execute(select(User).where(User.email == data.email))
user = result.scalar_one_or_none()
if user and user.password_hash is not None:
# Create reset token JWT
raw_token = create_password_reset_token(str(user.id))
payload = decode_token(raw_token)
if payload and payload.get("jti"):
token_record = PasswordResetToken(
token_hash=hash_token(payload["jti"]),
user_id=user.id,
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
)
db.add(token_record)
await db.commit()
# Send email (best-effort)
reset_url = f"{settings.FRONTEND_URL}/reset-password?token={raw_token}"
await EmailService.send_password_reset_email(
to_email=user.email,
reset_url=reset_url,
)
await log_audit(db, user.id, "auth.password_reset.request", "user", user.id)
await db.commit()
return {"message": "If an account with that email exists, a reset link has been sent."}
@router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse)
async def verify_reset_token(
data: VerifyResetTokenRequest,
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Verify a password reset token is valid."""
payload = decode_token(data.token)
if not payload or payload.get("type") != "password_reset":
return VerifyResetTokenResponse(valid=False)
jti = payload.get("jti")
if not jti:
return VerifyResetTokenResponse(valid=False)
result = await db.execute(
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
)
token_record = result.scalar_one_or_none()
if not token_record or not token_record.is_valid:
return VerifyResetTokenResponse(valid=False)
# Get user email for display
user_result = await db.execute(select(User.email).where(User.id == token_record.user_id))
email = user_result.scalar_one_or_none()
return VerifyResetTokenResponse(valid=True, email=email)
@router.post("/password/reset")
@limiter.limit("5/minute")
async def reset_password(
request: Request,
data: ResetPasswordRequest,
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Reset password using a valid reset token."""
payload = decode_token(data.token)
if not payload or payload.get("type") != "password_reset":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired reset token"
)
jti = payload.get("jti")
user_id = payload.get("sub")
if not jti or not user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid reset token"
)
# Validate token in DB (single-use).
# FOR UPDATE prevents two concurrent reset requests from both succeeding.
result = await db.execute(
select(PasswordResetToken)
.where(PasswordResetToken.token_hash == hash_token(jti))
.with_for_update()
)
token_record = result.scalar_one_or_none()
if not token_record or not token_record.is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Reset token has already been used or has expired"
)
# Get user
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid reset token"
)
# Update password
user.password_hash = get_password_hash(data.new_password)
user.must_change_password = False
# Mark token as used
token_record.used_at = datetime.now(timezone.utc)
# Revoke all refresh tokens
rt_result = await db.execute(
select(RefreshToken).where(
RefreshToken.user_id == user.id,
RefreshToken.revoked_at.is_(None)
)
)
for rt in rt_result.scalars().all():
rt.revoked_at = datetime.now(timezone.utc)
await log_audit(db, user.id, "auth.password_reset.complete", "user", user.id)
await db.commit()
return {"message": "Password has been reset successfully"}
@router.get("/email/verification-status")
async def get_verification_status(
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Check if email verification is enabled on the platform."""
enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
return {"enabled": enabled}
@router.post("/email/send-verification")
@limiter.limit("3/minute")
async def send_verification_email(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Send an email verification link to the current user."""
verification_enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
if not verification_enabled:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Email verification is currently disabled"
)
if current_user.email_verified_at is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email is already verified"
)
raw_token = create_email_verification_token(str(current_user.id))
payload = decode_token(raw_token)
if payload and payload.get("jti"):
token_record = EmailVerificationToken(
token_hash=hash_token(payload["jti"]),
user_id=current_user.id,
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
)
db.add(token_record)
await db.commit()
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={raw_token}"
await EmailService.send_email_verification_email(
to_email=current_user.email,
verification_url=verification_url,
)
return {"message": "Verification email sent"}
@router.post("/email/verify")
async def verify_email(
data: dict,
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Verify an email using a token. Public endpoint."""
token = data.get("token")
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Token is required"
)
payload = decode_token(token)
if not payload or payload.get("type") != "email_verification":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired verification token"
)
jti = payload.get("jti")
user_id = payload.get("sub")
if not jti or not user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid verification token"
)
# FOR UPDATE prevents two concurrent verification requests from both succeeding.
result = await db.execute(
select(EmailVerificationToken)
.where(EmailVerificationToken.token_hash == hash_token(jti))
.with_for_update()
)
token_record = result.scalar_one_or_none()
if not token_record or not token_record.is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Verification token has already been used or has expired"
)
# Mark token as used
token_record.used_at = datetime.now(timezone.utc)
# Mark user email as verified
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid verification token"
)
user.email_verified_at = datetime.now(timezone.utc)
await log_audit(db, user.id, "auth.email_verified", "user", user.id)
await db.commit()
return {"message": "Email verified successfully"}