From c724ad80621cc11e9e1484c9c359f07ed7cbf9fb Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Mon, 9 Mar 2026 17:20:38 -0400 Subject: [PATCH] fix: prevent race conditions in token operations and auth flows Backend: - Refresh token rotation: use atomic UPDATE...WHERE revoked_at IS NULL to prevent concurrent refresh requests from both succeeding - Account invite codes: SELECT FOR UPDATE to prevent double-spend - Platform invite codes: SELECT FOR UPDATE to prevent double-spend - Password reset tokens: SELECT FOR UPDATE to prevent double-use - Email verification tokens: SELECT FOR UPDATE to prevent double-use Frontend: - Token refresh subscriber arrays: swap before iterating so a throwing callback doesn't leave the queue in a dirty state Co-Authored-By: Claude Opus 4.6 --- backend/app/api/endpoints/auth.py | 54 +++++++++++++++++++------------ frontend/src/api/client.ts | 8 +++-- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index 83de7f89..ed913441 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -5,7 +5,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select +from sqlalchemy import select, update as sa_update from app.core.config import settings from app.core.settings_manager import SettingsManager from app.core.database import get_db @@ -78,13 +78,15 @@ async def register( After user creation, if no account invite was used, a personal Account and free Subscription are created automatically. """ - # Check for account invite code FIRST — bypasses platform invite gate + # Check for account invite code FIRST — bypasses platform invite gate. + # SELECT FOR UPDATE prevents two concurrent registrations from both + # reading the same invite as unused and double-spending it. account_invite_record = None if user_data.account_invite_code: result = await db.execute( - select(AccountInvite).where( - AccountInvite.code == user_data.account_invite_code - ) + select(AccountInvite) + .where(AccountInvite.code == user_data.account_invite_code) + .with_for_update() ) account_invite_record = result.scalar_one_or_none() @@ -116,9 +118,12 @@ async def register( ) if user_data.invite_code: - # Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE + # Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE. + # FOR UPDATE prevents double-spend by concurrent registrations. result = await db.execute( - select(InviteCode).where(InviteCode.code == user_data.invite_code.upper()) + select(InviteCode) + .where(InviteCode.code == user_data.invite_code.upper()) + .with_for_update() ) invite_code_record = result.scalar_one_or_none() @@ -305,24 +310,29 @@ async def refresh_token( user_id = payload.get("sub") jti = payload.get("jti") - # Validate refresh token hasn't been revoked + # Atomically revoke the old refresh token (token rotation). + # Using a conditional UPDATE prevents the race where two concurrent + # refresh requests both read revoked_at=NULL and both succeed. if jti: token_hash = hash_token(jti) result = await db.execute( - select(RefreshToken).where(RefreshToken.token_hash == token_hash) + sa_update(RefreshToken) + .where( + RefreshToken.token_hash == token_hash, + RefreshToken.revoked_at.is_(None), + ) + .values(revoked_at=datetime.now(timezone.utc)) + .returning(RefreshToken.id, RefreshToken.user_id) ) - stored_token = result.scalar_one_or_none() + revoked_row = result.fetchone() - if stored_token and stored_token.is_revoked: + if not revoked_row: + # Either the token doesn't exist or was already revoked/used raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token has been revoked" ) - # Revoke the old refresh token (token rotation) - if stored_token: - stored_token.revoked_at = datetime.now(timezone.utc) - result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() @@ -552,9 +562,12 @@ async def reset_password( detail="Invalid reset token" ) - # Validate token in DB (single-use) + # Validate token in DB (single-use). + # FOR UPDATE prevents two concurrent reset requests from both succeeding. result = await db.execute( - select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti)) + select(PasswordResetToken) + .where(PasswordResetToken.token_hash == hash_token(jti)) + .with_for_update() ) token_record = result.scalar_one_or_none() @@ -674,10 +687,11 @@ async def verify_email( detail="Invalid verification token" ) + # FOR UPDATE prevents two concurrent verification requests from both succeeding. result = await db.execute( - select(EmailVerificationToken).where( - EmailVerificationToken.token_hash == hash_token(jti) - ) + select(EmailVerificationToken) + .where(EmailVerificationToken.token_hash == hash_token(jti)) + .with_for_update() ) token_record = result.scalar_one_or_none() diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index fbc5b251..71e80ba5 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -75,15 +75,19 @@ let refreshSubscribers: ((token: string) => void)[] = [] let refreshFailSubscribers: ((error: unknown) => void)[] = [] function onRefreshed(token: string) { - refreshSubscribers.forEach(cb => cb(token)) + // Swap arrays before iterating — if a callback throws, the arrays + // are already cleared so the next refresh cycle starts clean. + const subscribers = refreshSubscribers refreshSubscribers = [] refreshFailSubscribers = [] + subscribers.forEach(cb => cb(token)) } function onRefreshFailed(error: unknown) { - refreshFailSubscribers.forEach(cb => cb(error)) + const failSubscribers = refreshFailSubscribers refreshSubscribers = [] refreshFailSubscribers = [] + failSubscribers.forEach(cb => cb(error)) } // Response interceptor - handle token refresh