fix: race condition hardening across auth, counters, and data fetching (#102)

* fix: prevent race conditions in token operations and auth flows

Backend:
- Refresh token rotation: use atomic UPDATE...WHERE revoked_at IS NULL
  to prevent concurrent refresh requests from both succeeding
- Account invite codes: SELECT FOR UPDATE to prevent double-spend
- Platform invite codes: SELECT FOR UPDATE to prevent double-spend
- Password reset tokens: SELECT FOR UPDATE to prevent double-use
- Email verification tokens: SELECT FOR UPDATE to prevent double-use

Frontend:
- Token refresh subscriber arrays: swap before iterating so a throwing
  callback doesn't leave the queue in a dirty state

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: atomic counters, plan limit re-check, and double-submit guard

Backend:
- Tree usage_count: use SQL-level UPDATE (Tree.usage_count + 1) instead
  of Python-level increment to prevent lost updates under concurrency
- Tag usage_count: same SQL-level atomic increment/decrement in both
  create_tree and update_tree (delete_tree already used this pattern)
- Plan tree limit: re-check count after db.flush() to close the TOCTOU
  window where two concurrent creates could both pass the pre-check

Frontend:
- TreeEditorPage: add isSaving early-return guard inside handleSaveDraft
  and handlePublish callbacks so Ctrl+S can't bypass the button disabled
  prop and fire duplicate save requests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: prevent stale API responses from overwriting newer data

- SessionHistoryPage: move loadSessions into effect with cancelled flag
  so rapid filter/tab changes discard outdated responses
- TreeLibraryPage: add request ID ref to loadTrees so stale responses
  from previous filter selections are discarded
- QuickStartPage: add request ID ref to debounced search so out-of-order
  responses don't overwrite newer search results

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* docs: add flexible intake design — deferred variables + prepared sessions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit was merged in pull request #102.
This commit is contained in:
chihlasm
2026-03-10 01:57:22 -04:00
committed by GitHub
parent 5095b0d8df
commit 4727106141
9 changed files with 305 additions and 98 deletions

View File

@@ -5,7 +5,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy import select, update as sa_update
from app.core.config import settings
from app.core.settings_manager import SettingsManager
from app.core.database import get_db
@@ -78,13 +78,15 @@ async def register(
After user creation, if no account invite was used, a personal Account
and free Subscription are created automatically.
"""
# Check for account invite code FIRST — bypasses platform invite gate
# Check for account invite code FIRST — bypasses platform invite gate.
# SELECT FOR UPDATE prevents two concurrent registrations from both
# reading the same invite as unused and double-spending it.
account_invite_record = None
if user_data.account_invite_code:
result = await db.execute(
select(AccountInvite).where(
AccountInvite.code == user_data.account_invite_code
)
select(AccountInvite)
.where(AccountInvite.code == user_data.account_invite_code)
.with_for_update()
)
account_invite_record = result.scalar_one_or_none()
@@ -116,9 +118,12 @@ async def register(
)
if user_data.invite_code:
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE.
# FOR UPDATE prevents double-spend by concurrent registrations.
result = await db.execute(
select(InviteCode).where(InviteCode.code == user_data.invite_code.upper())
select(InviteCode)
.where(InviteCode.code == user_data.invite_code.upper())
.with_for_update()
)
invite_code_record = result.scalar_one_or_none()
@@ -305,24 +310,29 @@ async def refresh_token(
user_id = payload.get("sub")
jti = payload.get("jti")
# Validate refresh token hasn't been revoked
# Atomically revoke the old refresh token (token rotation).
# Using a conditional UPDATE prevents the race where two concurrent
# refresh requests both read revoked_at=NULL and both succeed.
if jti:
token_hash = hash_token(jti)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
sa_update(RefreshToken)
.where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
)
.values(revoked_at=datetime.now(timezone.utc))
.returning(RefreshToken.id, RefreshToken.user_id)
)
stored_token = result.scalar_one_or_none()
revoked_row = result.fetchone()
if stored_token and stored_token.is_revoked:
if not revoked_row:
# Either the token doesn't exist or was already revoked/used
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has been revoked"
)
# Revoke the old refresh token (token rotation)
if stored_token:
stored_token.revoked_at = datetime.now(timezone.utc)
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
@@ -552,9 +562,12 @@ async def reset_password(
detail="Invalid reset token"
)
# Validate token in DB (single-use)
# Validate token in DB (single-use).
# FOR UPDATE prevents two concurrent reset requests from both succeeding.
result = await db.execute(
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
select(PasswordResetToken)
.where(PasswordResetToken.token_hash == hash_token(jti))
.with_for_update()
)
token_record = result.scalar_one_or_none()
@@ -674,10 +687,11 @@ async def verify_email(
detail="Invalid verification token"
)
# FOR UPDATE prevents two concurrent verification requests from both succeeding.
result = await db.execute(
select(EmailVerificationToken).where(
EmailVerificationToken.token_hash == hash_token(jti)
)
select(EmailVerificationToken)
.where(EmailVerificationToken.token_hash == hash_token(jti))
.with_for_update()
)
token_record = result.scalar_one_or_none()

View File

@@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, Field as PydanticField
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy import select, update as sa_update
from app.core.database import get_db
from app.models.tree import Tree
@@ -189,8 +189,10 @@ async def start_session(
session_variables=session_variables,
)
# Increment tree usage count
tree.usage_count += 1
# Atomically increment tree usage count (SQL-level to avoid lost updates)
await db.execute(
sa_update(Tree).where(Tree.id == tree.id).values(usage_count=Tree.usage_count + 1)
)
db.add(new_session)
await db.commit()

View File

@@ -25,7 +25,7 @@ from app.models.user_pinned_tree import UserPinnedTree
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin, get_service_account_id
from app.core.permissions import can_edit_tree, can_access_tree
from app.core.filters import build_tree_access_filter
from app.core.subscriptions import check_tree_limit
from app.core.subscriptions import check_tree_limit, get_account_subscription, get_plan_limits
from app.core.audit import log_audit
from app.core.config import settings
from app.core.tree_validation import can_publish_tree
@@ -487,6 +487,26 @@ async def create_tree(
db.add(new_tree)
await db.flush() # Get the ID
# Re-check tree limit after flush to close the TOCTOU race window:
# two concurrent creates could both pass the pre-check, but only one
# should succeed when the limit is exactly reached.
if not is_default and current_user.account_id:
post_count = await db.scalar(
select(func.count(Tree.id)).where(
Tree.account_id == current_user.account_id,
Tree.deleted_at.is_(None),
)
)
sub = await get_account_subscription(current_user.account_id, db)
if sub:
limits = await get_plan_limits(sub.plan, db)
if limits and limits.max_trees and (post_count or 0) > limits.max_trees:
await db.rollback()
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Tree limit reached ({limits.max_trees}/{limits.max_trees}). Upgrade your plan to create more trees."
)
# Handle tags
if tree_data.tags:
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
@@ -519,7 +539,6 @@ async def create_tree(
await db.flush()
tags_to_add.append(tag)
tag.usage_count += 1
# Use direct SQL insert for the junction table to avoid lazy load issues
from app.models.tag import tree_tag_assignments
@@ -531,6 +550,10 @@ async def create_tree(
assigned_by=current_user.id
)
)
# Atomically increment (SQL-level to avoid lost updates from concurrent requests)
await db.execute(
update(TreeTag).where(TreeTag.id == tag.id).values(usage_count=TreeTag.usage_count + 1)
)
await db.commit()
@@ -673,9 +696,14 @@ async def update_tree(
if tags_data is not None:
from app.models.tag import tree_tag_assignments
# Decrement usage count for old tags (already eagerly loaded)
for tag in tree.tags:
tag.usage_count = max(0, tag.usage_count - 1)
# Atomically decrement usage count for old tags
old_tag_ids = [tag.id for tag in tree.tags]
if old_tag_ids:
await db.execute(
update(TreeTag)
.where(TreeTag.id.in_(old_tag_ids))
.values(usage_count=func.greatest(TreeTag.usage_count - 1, 0))
)
# Delete existing tag assignments using direct SQL
await db.execute(
@@ -720,7 +748,10 @@ async def update_tree(
)
)
added_tag_ids.add(tag.id)
tag.usage_count += 1
# Atomically increment (SQL-level to avoid lost updates)
await db.execute(
update(TreeTag).where(TreeTag.id == tag.id).values(usage_count=TreeTag.usage_count + 1)
)
await db.commit()