fix: prevent race conditions in token operations and auth flows
Backend: - Refresh token rotation: use atomic UPDATE...WHERE revoked_at IS NULL to prevent concurrent refresh requests from both succeeding - Account invite codes: SELECT FOR UPDATE to prevent double-spend - Platform invite codes: SELECT FOR UPDATE to prevent double-spend - Password reset tokens: SELECT FOR UPDATE to prevent double-use - Email verification tokens: SELECT FOR UPDATE to prevent double-use Frontend: - Token refresh subscriber arrays: swap before iterating so a throwing callback doesn't leave the queue in a dirty state Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Annotated
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, update as sa_update
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.settings_manager import SettingsManager
|
from app.core.settings_manager import SettingsManager
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
@@ -78,13 +78,15 @@ async def register(
|
|||||||
After user creation, if no account invite was used, a personal Account
|
After user creation, if no account invite was used, a personal Account
|
||||||
and free Subscription are created automatically.
|
and free Subscription are created automatically.
|
||||||
"""
|
"""
|
||||||
# Check for account invite code FIRST — bypasses platform invite gate
|
# Check for account invite code FIRST — bypasses platform invite gate.
|
||||||
|
# SELECT FOR UPDATE prevents two concurrent registrations from both
|
||||||
|
# reading the same invite as unused and double-spending it.
|
||||||
account_invite_record = None
|
account_invite_record = None
|
||||||
if user_data.account_invite_code:
|
if user_data.account_invite_code:
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(AccountInvite).where(
|
select(AccountInvite)
|
||||||
AccountInvite.code == user_data.account_invite_code
|
.where(AccountInvite.code == user_data.account_invite_code)
|
||||||
)
|
.with_for_update()
|
||||||
)
|
)
|
||||||
account_invite_record = result.scalar_one_or_none()
|
account_invite_record = result.scalar_one_or_none()
|
||||||
|
|
||||||
@@ -116,9 +118,12 @@ async def register(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if user_data.invite_code:
|
if user_data.invite_code:
|
||||||
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE
|
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE.
|
||||||
|
# FOR UPDATE prevents double-spend by concurrent registrations.
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(InviteCode).where(InviteCode.code == user_data.invite_code.upper())
|
select(InviteCode)
|
||||||
|
.where(InviteCode.code == user_data.invite_code.upper())
|
||||||
|
.with_for_update()
|
||||||
)
|
)
|
||||||
invite_code_record = result.scalar_one_or_none()
|
invite_code_record = result.scalar_one_or_none()
|
||||||
|
|
||||||
@@ -305,24 +310,29 @@ async def refresh_token(
|
|||||||
user_id = payload.get("sub")
|
user_id = payload.get("sub")
|
||||||
jti = payload.get("jti")
|
jti = payload.get("jti")
|
||||||
|
|
||||||
# Validate refresh token hasn't been revoked
|
# Atomically revoke the old refresh token (token rotation).
|
||||||
|
# Using a conditional UPDATE prevents the race where two concurrent
|
||||||
|
# refresh requests both read revoked_at=NULL and both succeed.
|
||||||
if jti:
|
if jti:
|
||||||
token_hash = hash_token(jti)
|
token_hash = hash_token(jti)
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
sa_update(RefreshToken)
|
||||||
|
.where(
|
||||||
|
RefreshToken.token_hash == token_hash,
|
||||||
|
RefreshToken.revoked_at.is_(None),
|
||||||
|
)
|
||||||
|
.values(revoked_at=datetime.now(timezone.utc))
|
||||||
|
.returning(RefreshToken.id, RefreshToken.user_id)
|
||||||
)
|
)
|
||||||
stored_token = result.scalar_one_or_none()
|
revoked_row = result.fetchone()
|
||||||
|
|
||||||
if stored_token and stored_token.is_revoked:
|
if not revoked_row:
|
||||||
|
# Either the token doesn't exist or was already revoked/used
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Refresh token has been revoked"
|
detail="Refresh token has been revoked"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Revoke the old refresh token (token rotation)
|
|
||||||
if stored_token:
|
|
||||||
stored_token.revoked_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
@@ -552,9 +562,12 @@ async def reset_password(
|
|||||||
detail="Invalid reset token"
|
detail="Invalid reset token"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate token in DB (single-use)
|
# Validate token in DB (single-use).
|
||||||
|
# FOR UPDATE prevents two concurrent reset requests from both succeeding.
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
|
select(PasswordResetToken)
|
||||||
|
.where(PasswordResetToken.token_hash == hash_token(jti))
|
||||||
|
.with_for_update()
|
||||||
)
|
)
|
||||||
token_record = result.scalar_one_or_none()
|
token_record = result.scalar_one_or_none()
|
||||||
|
|
||||||
@@ -674,10 +687,11 @@ async def verify_email(
|
|||||||
detail="Invalid verification token"
|
detail="Invalid verification token"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FOR UPDATE prevents two concurrent verification requests from both succeeding.
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(EmailVerificationToken).where(
|
select(EmailVerificationToken)
|
||||||
EmailVerificationToken.token_hash == hash_token(jti)
|
.where(EmailVerificationToken.token_hash == hash_token(jti))
|
||||||
)
|
.with_for_update()
|
||||||
)
|
)
|
||||||
token_record = result.scalar_one_or_none()
|
token_record = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|||||||
@@ -75,15 +75,19 @@ let refreshSubscribers: ((token: string) => void)[] = []
|
|||||||
let refreshFailSubscribers: ((error: unknown) => void)[] = []
|
let refreshFailSubscribers: ((error: unknown) => void)[] = []
|
||||||
|
|
||||||
function onRefreshed(token: string) {
|
function onRefreshed(token: string) {
|
||||||
refreshSubscribers.forEach(cb => cb(token))
|
// Swap arrays before iterating — if a callback throws, the arrays
|
||||||
|
// are already cleared so the next refresh cycle starts clean.
|
||||||
|
const subscribers = refreshSubscribers
|
||||||
refreshSubscribers = []
|
refreshSubscribers = []
|
||||||
refreshFailSubscribers = []
|
refreshFailSubscribers = []
|
||||||
|
subscribers.forEach(cb => cb(token))
|
||||||
}
|
}
|
||||||
|
|
||||||
function onRefreshFailed(error: unknown) {
|
function onRefreshFailed(error: unknown) {
|
||||||
refreshFailSubscribers.forEach(cb => cb(error))
|
const failSubscribers = refreshFailSubscribers
|
||||||
refreshSubscribers = []
|
refreshSubscribers = []
|
||||||
refreshFailSubscribers = []
|
refreshFailSubscribers = []
|
||||||
|
failSubscribers.forEach(cb => cb(error))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Response interceptor - handle token refresh
|
// Response interceptor - handle token refresh
|
||||||
|
|||||||
Reference in New Issue
Block a user