fix: race condition hardening across auth, counters, and data fetching #102

Merged
chihlasm merged 4 commits from fix/race-conditions-critical into main 2026-03-10 05:57:22 +00:00
2 changed files with 40 additions and 22 deletions
Showing only changes of commit c724ad8062 - Show all commits

View File

@@ -5,7 +5,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy import select, update as sa_update
from app.core.config import settings
from app.core.settings_manager import SettingsManager
from app.core.database import get_db
@@ -78,13 +78,15 @@ async def register(
After user creation, if no account invite was used, a personal Account
and free Subscription are created automatically.
"""
# Check for account invite code FIRST — bypasses platform invite gate
# Check for account invite code FIRST — bypasses platform invite gate.
# SELECT FOR UPDATE prevents two concurrent registrations from both
# reading the same invite as unused and double-spending it.
account_invite_record = None
if user_data.account_invite_code:
result = await db.execute(
select(AccountInvite).where(
AccountInvite.code == user_data.account_invite_code
)
select(AccountInvite)
.where(AccountInvite.code == user_data.account_invite_code)
.with_for_update()
)
account_invite_record = result.scalar_one_or_none()
@@ -116,9 +118,12 @@ async def register(
)
if user_data.invite_code:
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE.
# FOR UPDATE prevents double-spend by concurrent registrations.
result = await db.execute(
select(InviteCode).where(InviteCode.code == user_data.invite_code.upper())
select(InviteCode)
.where(InviteCode.code == user_data.invite_code.upper())
.with_for_update()
)
invite_code_record = result.scalar_one_or_none()
@@ -305,24 +310,29 @@ async def refresh_token(
user_id = payload.get("sub")
jti = payload.get("jti")
# Validate refresh token hasn't been revoked
# Atomically revoke the old refresh token (token rotation).
# Using a conditional UPDATE prevents the race where two concurrent
# refresh requests both read revoked_at=NULL and both succeed.
if jti:
token_hash = hash_token(jti)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
sa_update(RefreshToken)
.where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
)
.values(revoked_at=datetime.now(timezone.utc))
.returning(RefreshToken.id, RefreshToken.user_id)
)
stored_token = result.scalar_one_or_none()
revoked_row = result.fetchone()
if stored_token and stored_token.is_revoked:
if not revoked_row:
# Either the token doesn't exist or was already revoked/used
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has been revoked"
)
# Revoke the old refresh token (token rotation)
if stored_token:
stored_token.revoked_at = datetime.now(timezone.utc)
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
@@ -552,9 +562,12 @@ async def reset_password(
detail="Invalid reset token"
)
# Validate token in DB (single-use)
# Validate token in DB (single-use).
# FOR UPDATE prevents two concurrent reset requests from both succeeding.
result = await db.execute(
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
select(PasswordResetToken)
.where(PasswordResetToken.token_hash == hash_token(jti))
.with_for_update()
)
token_record = result.scalar_one_or_none()
@@ -674,10 +687,11 @@ async def verify_email(
detail="Invalid verification token"
)
# FOR UPDATE prevents two concurrent verification requests from both succeeding.
result = await db.execute(
select(EmailVerificationToken).where(
EmailVerificationToken.token_hash == hash_token(jti)
)
select(EmailVerificationToken)
.where(EmailVerificationToken.token_hash == hash_token(jti))
.with_for_update()
)
token_record = result.scalar_one_or_none()

View File

@@ -75,15 +75,19 @@ let refreshSubscribers: ((token: string) => void)[] = []
let refreshFailSubscribers: ((error: unknown) => void)[] = []
function onRefreshed(token: string) {
refreshSubscribers.forEach(cb => cb(token))
// Swap arrays before iterating — if a callback throws, the arrays
// are already cleared so the next refresh cycle starts clean.
const subscribers = refreshSubscribers
refreshSubscribers = []
refreshFailSubscribers = []
subscribers.forEach(cb => cb(token))
}
function onRefreshFailed(error: unknown) {
refreshFailSubscribers.forEach(cb => cb(error))
const failSubscribers = refreshFailSubscribers
refreshSubscribers = []
refreshFailSubscribers = []
failSubscribers.forEach(cb => cb(error))
}
// Response interceptor - handle token refresh