fix: prevent race conditions in token operations and auth flows

Backend:
- Refresh token rotation: use atomic UPDATE...WHERE revoked_at IS NULL
  to prevent concurrent refresh requests from both succeeding
- Account invite codes: SELECT FOR UPDATE to prevent double-spend
- Platform invite codes: SELECT FOR UPDATE to prevent double-spend
- Password reset tokens: SELECT FOR UPDATE to prevent double-use
- Email verification tokens: SELECT FOR UPDATE to prevent double-use

Frontend:
- Token refresh subscriber arrays: swap before iterating so a throwing
  callback doesn't leave the queue in a dirty state

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Michael Chihlas
2026-03-09 17:20:38 -04:00
parent 5095b0d8df
commit c724ad8062
2 changed files with 40 additions and 22 deletions

View File

@@ -5,7 +5,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy import select, update as sa_update
from app.core.config import settings
from app.core.settings_manager import SettingsManager
from app.core.database import get_db
@@ -78,13 +78,15 @@ async def register(
After user creation, if no account invite was used, a personal Account
and free Subscription are created automatically.
"""
# Check for account invite code FIRST — bypasses platform invite gate
# Check for account invite code FIRST — bypasses platform invite gate.
# SELECT FOR UPDATE prevents two concurrent registrations from both
# reading the same invite as unused and double-spending it.
account_invite_record = None
if user_data.account_invite_code:
result = await db.execute(
select(AccountInvite).where(
AccountInvite.code == user_data.account_invite_code
)
select(AccountInvite)
.where(AccountInvite.code == user_data.account_invite_code)
.with_for_update()
)
account_invite_record = result.scalar_one_or_none()
@@ -116,9 +118,12 @@ async def register(
)
if user_data.invite_code:
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE.
# FOR UPDATE prevents double-spend by concurrent registrations.
result = await db.execute(
select(InviteCode).where(InviteCode.code == user_data.invite_code.upper())
select(InviteCode)
.where(InviteCode.code == user_data.invite_code.upper())
.with_for_update()
)
invite_code_record = result.scalar_one_or_none()
@@ -305,24 +310,29 @@ async def refresh_token(
user_id = payload.get("sub")
jti = payload.get("jti")
# Validate refresh token hasn't been revoked
# Atomically revoke the old refresh token (token rotation).
# Using a conditional UPDATE prevents the race where two concurrent
# refresh requests both read revoked_at=NULL and both succeed.
if jti:
token_hash = hash_token(jti)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
sa_update(RefreshToken)
.where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
)
.values(revoked_at=datetime.now(timezone.utc))
.returning(RefreshToken.id, RefreshToken.user_id)
)
stored_token = result.scalar_one_or_none()
revoked_row = result.fetchone()
if stored_token and stored_token.is_revoked:
if not revoked_row:
# Either the token doesn't exist or was already revoked/used
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has been revoked"
)
# Revoke the old refresh token (token rotation)
if stored_token:
stored_token.revoked_at = datetime.now(timezone.utc)
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
@@ -552,9 +562,12 @@ async def reset_password(
detail="Invalid reset token"
)
# Validate token in DB (single-use)
# Validate token in DB (single-use).
# FOR UPDATE prevents two concurrent reset requests from both succeeding.
result = await db.execute(
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
select(PasswordResetToken)
.where(PasswordResetToken.token_hash == hash_token(jti))
.with_for_update()
)
token_record = result.scalar_one_or_none()
@@ -674,10 +687,11 @@ async def verify_email(
detail="Invalid verification token"
)
# FOR UPDATE prevents two concurrent verification requests from both succeeding.
result = await db.execute(
select(EmailVerificationToken).where(
EmailVerificationToken.token_hash == hash_token(jti)
)
select(EmailVerificationToken)
.where(EmailVerificationToken.token_hash == hash_token(jti))
.with_for_update()
)
token_record = result.scalar_one_or_none()

View File

@@ -75,15 +75,19 @@ let refreshSubscribers: ((token: string) => void)[] = []
let refreshFailSubscribers: ((error: unknown) => void)[] = []
function onRefreshed(token: string) {
refreshSubscribers.forEach(cb => cb(token))
// Swap arrays before iterating — if a callback throws, the arrays
// are already cleared so the next refresh cycle starts clean.
const subscribers = refreshSubscribers
refreshSubscribers = []
refreshFailSubscribers = []
subscribers.forEach(cb => cb(token))
}
function onRefreshFailed(error: unknown) {
refreshFailSubscribers.forEach(cb => cb(error))
const failSubscribers = refreshFailSubscribers
refreshSubscribers = []
refreshFailSubscribers = []
failSubscribers.forEach(cb => cb(error))
}
// Response interceptor - handle token refresh