fix: race condition hardening across auth, counters, and data fetching (#102)
* fix: prevent race conditions in token operations and auth flows Backend: - Refresh token rotation: use atomic UPDATE...WHERE revoked_at IS NULL to prevent concurrent refresh requests from both succeeding - Account invite codes: SELECT FOR UPDATE to prevent double-spend - Platform invite codes: SELECT FOR UPDATE to prevent double-spend - Password reset tokens: SELECT FOR UPDATE to prevent double-use - Email verification tokens: SELECT FOR UPDATE to prevent double-use Frontend: - Token refresh subscriber arrays: swap before iterating so a throwing callback doesn't leave the queue in a dirty state Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: atomic counters, plan limit re-check, and double-submit guard Backend: - Tree usage_count: use SQL-level UPDATE (Tree.usage_count + 1) instead of Python-level increment to prevent lost updates under concurrency - Tag usage_count: same SQL-level atomic increment/decrement in both create_tree and update_tree (delete_tree already used this pattern) - Plan tree limit: re-check count after db.flush() to close the TOCTOU window where two concurrent creates could both pass the pre-check Frontend: - TreeEditorPage: add isSaving early-return guard inside handleSaveDraft and handlePublish callbacks so Ctrl+S can't bypass the button disabled prop and fire duplicate save requests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: prevent stale API responses from overwriting newer data - SessionHistoryPage: move loadSessions into effect with cancelled flag so rapid filter/tab changes discard outdated responses - TreeLibraryPage: add request ID ref to loadTrees so stale responses from previous filter selections are discarded - QuickStartPage: add request ID ref to debounced search so out-of-order responses don't overwrite newer search results Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * docs: add flexible intake design — deferred variables + prepared sessions Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit was merged in pull request #102.
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update as sa_update
|
||||
from app.core.config import settings
|
||||
from app.core.settings_manager import SettingsManager
|
||||
from app.core.database import get_db
|
||||
@@ -78,13 +78,15 @@ async def register(
|
||||
After user creation, if no account invite was used, a personal Account
|
||||
and free Subscription are created automatically.
|
||||
"""
|
||||
# Check for account invite code FIRST — bypasses platform invite gate
|
||||
# Check for account invite code FIRST — bypasses platform invite gate.
|
||||
# SELECT FOR UPDATE prevents two concurrent registrations from both
|
||||
# reading the same invite as unused and double-spending it.
|
||||
account_invite_record = None
|
||||
if user_data.account_invite_code:
|
||||
result = await db.execute(
|
||||
select(AccountInvite).where(
|
||||
AccountInvite.code == user_data.account_invite_code
|
||||
)
|
||||
select(AccountInvite)
|
||||
.where(AccountInvite.code == user_data.account_invite_code)
|
||||
.with_for_update()
|
||||
)
|
||||
account_invite_record = result.scalar_one_or_none()
|
||||
|
||||
@@ -116,9 +118,12 @@ async def register(
|
||||
)
|
||||
|
||||
if user_data.invite_code:
|
||||
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE
|
||||
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE.
|
||||
# FOR UPDATE prevents double-spend by concurrent registrations.
|
||||
result = await db.execute(
|
||||
select(InviteCode).where(InviteCode.code == user_data.invite_code.upper())
|
||||
select(InviteCode)
|
||||
.where(InviteCode.code == user_data.invite_code.upper())
|
||||
.with_for_update()
|
||||
)
|
||||
invite_code_record = result.scalar_one_or_none()
|
||||
|
||||
@@ -305,24 +310,29 @@ async def refresh_token(
|
||||
user_id = payload.get("sub")
|
||||
jti = payload.get("jti")
|
||||
|
||||
# Validate refresh token hasn't been revoked
|
||||
# Atomically revoke the old refresh token (token rotation).
|
||||
# Using a conditional UPDATE prevents the race where two concurrent
|
||||
# refresh requests both read revoked_at=NULL and both succeed.
|
||||
if jti:
|
||||
token_hash = hash_token(jti)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
sa_update(RefreshToken)
|
||||
.where(
|
||||
RefreshToken.token_hash == token_hash,
|
||||
RefreshToken.revoked_at.is_(None),
|
||||
)
|
||||
.values(revoked_at=datetime.now(timezone.utc))
|
||||
.returning(RefreshToken.id, RefreshToken.user_id)
|
||||
)
|
||||
stored_token = result.scalar_one_or_none()
|
||||
revoked_row = result.fetchone()
|
||||
|
||||
if stored_token and stored_token.is_revoked:
|
||||
if not revoked_row:
|
||||
# Either the token doesn't exist or was already revoked/used
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token has been revoked"
|
||||
)
|
||||
|
||||
# Revoke the old refresh token (token rotation)
|
||||
if stored_token:
|
||||
stored_token.revoked_at = datetime.now(timezone.utc)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
@@ -552,9 +562,12 @@ async def reset_password(
|
||||
detail="Invalid reset token"
|
||||
)
|
||||
|
||||
# Validate token in DB (single-use)
|
||||
# Validate token in DB (single-use).
|
||||
# FOR UPDATE prevents two concurrent reset requests from both succeeding.
|
||||
result = await db.execute(
|
||||
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
|
||||
select(PasswordResetToken)
|
||||
.where(PasswordResetToken.token_hash == hash_token(jti))
|
||||
.with_for_update()
|
||||
)
|
||||
token_record = result.scalar_one_or_none()
|
||||
|
||||
@@ -674,10 +687,11 @@ async def verify_email(
|
||||
detail="Invalid verification token"
|
||||
)
|
||||
|
||||
# FOR UPDATE prevents two concurrent verification requests from both succeeding.
|
||||
result = await db.execute(
|
||||
select(EmailVerificationToken).where(
|
||||
EmailVerificationToken.token_hash == hash_token(jti)
|
||||
)
|
||||
select(EmailVerificationToken)
|
||||
.where(EmailVerificationToken.token_hash == hash_token(jti))
|
||||
.with_for_update()
|
||||
)
|
||||
token_record = result.scalar_one_or_none()
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update as sa_update
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.tree import Tree
|
||||
@@ -189,8 +189,10 @@ async def start_session(
|
||||
session_variables=session_variables,
|
||||
)
|
||||
|
||||
# Increment tree usage count
|
||||
tree.usage_count += 1
|
||||
# Atomically increment tree usage count (SQL-level to avoid lost updates)
|
||||
await db.execute(
|
||||
sa_update(Tree).where(Tree.id == tree.id).values(usage_count=Tree.usage_count + 1)
|
||||
)
|
||||
|
||||
db.add(new_session)
|
||||
await db.commit()
|
||||
|
||||
@@ -25,7 +25,7 @@ from app.models.user_pinned_tree import UserPinnedTree
|
||||
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin, get_service_account_id
|
||||
from app.core.permissions import can_edit_tree, can_access_tree
|
||||
from app.core.filters import build_tree_access_filter
|
||||
from app.core.subscriptions import check_tree_limit
|
||||
from app.core.subscriptions import check_tree_limit, get_account_subscription, get_plan_limits
|
||||
from app.core.audit import log_audit
|
||||
from app.core.config import settings
|
||||
from app.core.tree_validation import can_publish_tree
|
||||
@@ -487,6 +487,26 @@ async def create_tree(
|
||||
db.add(new_tree)
|
||||
await db.flush() # Get the ID
|
||||
|
||||
# Re-check tree limit after flush to close the TOCTOU race window:
|
||||
# two concurrent creates could both pass the pre-check, but only one
|
||||
# should succeed when the limit is exactly reached.
|
||||
if not is_default and current_user.account_id:
|
||||
post_count = await db.scalar(
|
||||
select(func.count(Tree.id)).where(
|
||||
Tree.account_id == current_user.account_id,
|
||||
Tree.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
sub = await get_account_subscription(current_user.account_id, db)
|
||||
if sub:
|
||||
limits = await get_plan_limits(sub.plan, db)
|
||||
if limits and limits.max_trees and (post_count or 0) > limits.max_trees:
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Tree limit reached ({limits.max_trees}/{limits.max_trees}). Upgrade your plan to create more trees."
|
||||
)
|
||||
|
||||
# Handle tags
|
||||
if tree_data.tags:
|
||||
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
@@ -519,7 +539,6 @@ async def create_tree(
|
||||
await db.flush()
|
||||
|
||||
tags_to_add.append(tag)
|
||||
tag.usage_count += 1
|
||||
|
||||
# Use direct SQL insert for the junction table to avoid lazy load issues
|
||||
from app.models.tag import tree_tag_assignments
|
||||
@@ -531,6 +550,10 @@ async def create_tree(
|
||||
assigned_by=current_user.id
|
||||
)
|
||||
)
|
||||
# Atomically increment (SQL-level to avoid lost updates from concurrent requests)
|
||||
await db.execute(
|
||||
update(TreeTag).where(TreeTag.id == tag.id).values(usage_count=TreeTag.usage_count + 1)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -673,9 +696,14 @@ async def update_tree(
|
||||
if tags_data is not None:
|
||||
from app.models.tag import tree_tag_assignments
|
||||
|
||||
# Decrement usage count for old tags (already eagerly loaded)
|
||||
for tag in tree.tags:
|
||||
tag.usage_count = max(0, tag.usage_count - 1)
|
||||
# Atomically decrement usage count for old tags
|
||||
old_tag_ids = [tag.id for tag in tree.tags]
|
||||
if old_tag_ids:
|
||||
await db.execute(
|
||||
update(TreeTag)
|
||||
.where(TreeTag.id.in_(old_tag_ids))
|
||||
.values(usage_count=func.greatest(TreeTag.usage_count - 1, 0))
|
||||
)
|
||||
|
||||
# Delete existing tag assignments using direct SQL
|
||||
await db.execute(
|
||||
@@ -720,7 +748,10 @@ async def update_tree(
|
||||
)
|
||||
)
|
||||
added_tag_ids.add(tag.id)
|
||||
tag.usage_count += 1
|
||||
# Atomically increment (SQL-level to avoid lost updates)
|
||||
await db.execute(
|
||||
update(TreeTag).where(TreeTag.id == tag.id).values(usage_count=TreeTag.usage_count + 1)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user