fix: atomic counters, plan limit re-check, and double-submit guard

Backend:
- Tree usage_count: use SQL-level UPDATE (Tree.usage_count + 1) instead
  of Python-level increment to prevent lost updates under concurrency
- Tag usage_count: same SQL-level atomic increment/decrement in both
  create_tree and update_tree (delete_tree already used this pattern)
- Plan tree limit: re-check count after db.flush() to close the TOCTOU
  window where two concurrent creates could both pass the pre-check

Frontend:
- TreeEditorPage: add isSaving early-return guard inside handleSaveDraft
  and handlePublish callbacks so Ctrl+S can't bypass the button disabled
  prop and fire duplicate save requests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Michael Chihlas
2026-03-09 17:27:09 -04:00
parent c724ad8062
commit 7e3b383a65
3 changed files with 46 additions and 11 deletions

View File

@@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, Field as PydanticField
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy import select, update as sa_update
from app.core.database import get_db
from app.models.tree import Tree
@@ -189,8 +189,10 @@ async def start_session(
session_variables=session_variables,
)
# Increment tree usage count
tree.usage_count += 1
# Atomically increment tree usage count (SQL-level to avoid lost updates)
await db.execute(
sa_update(Tree).where(Tree.id == tree.id).values(usage_count=Tree.usage_count + 1)
)
db.add(new_session)
await db.commit()

View File

@@ -25,7 +25,7 @@ from app.models.user_pinned_tree import UserPinnedTree
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin, get_service_account_id
from app.core.permissions import can_edit_tree, can_access_tree
from app.core.filters import build_tree_access_filter
from app.core.subscriptions import check_tree_limit
from app.core.subscriptions import check_tree_limit, get_account_subscription, get_plan_limits
from app.core.audit import log_audit
from app.core.config import settings
from app.core.tree_validation import can_publish_tree
@@ -487,6 +487,26 @@ async def create_tree(
db.add(new_tree)
await db.flush() # Get the ID
# Re-check tree limit after flush to close the TOCTOU race window:
# two concurrent creates could both pass the pre-check, but only one
# should succeed when the limit is exactly reached.
if not is_default and current_user.account_id:
post_count = await db.scalar(
select(func.count(Tree.id)).where(
Tree.account_id == current_user.account_id,
Tree.deleted_at.is_(None),
)
)
sub = await get_account_subscription(current_user.account_id, db)
if sub:
limits = await get_plan_limits(sub.plan, db)
if limits and limits.max_trees and (post_count or 0) > limits.max_trees:
await db.rollback()
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Tree limit reached ({limits.max_trees}/{limits.max_trees}). Upgrade your plan to create more trees."
)
# Handle tags
if tree_data.tags:
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
@@ -519,7 +539,6 @@ async def create_tree(
await db.flush()
tags_to_add.append(tag)
tag.usage_count += 1
# Use direct SQL insert for the junction table to avoid lazy load issues
from app.models.tag import tree_tag_assignments
@@ -531,6 +550,10 @@ async def create_tree(
assigned_by=current_user.id
)
)
# Atomically increment (SQL-level to avoid lost updates from concurrent requests)
await db.execute(
update(TreeTag).where(TreeTag.id == tag.id).values(usage_count=TreeTag.usage_count + 1)
)
await db.commit()
@@ -673,9 +696,14 @@ async def update_tree(
if tags_data is not None:
from app.models.tag import tree_tag_assignments
# Decrement usage count for old tags (already eagerly loaded)
for tag in tree.tags:
tag.usage_count = max(0, tag.usage_count - 1)
# Atomically decrement usage count for old tags
old_tag_ids = [tag.id for tag in tree.tags]
if old_tag_ids:
await db.execute(
update(TreeTag)
.where(TreeTag.id.in_(old_tag_ids))
.values(usage_count=func.greatest(TreeTag.usage_count - 1, 0))
)
# Delete existing tag assignments using direct SQL
await db.execute(
@@ -720,7 +748,10 @@ async def update_tree(
)
)
added_tag_ids.add(tag.id)
tag.usage_count += 1
# Atomically increment (SQL-level to avoid lost updates)
await db.execute(
update(TreeTag).where(TreeTag.id == tag.id).values(usage_count=TreeTag.usage_count + 1)
)
await db.commit()

View File

@@ -330,6 +330,7 @@ export function TreeEditorPage() {
}, [updateNode, selectNode])
const handleSaveDraft = useCallback(async () => {
if (isSaving) return
setSaving(true)
try {
// In Code Mode, run fresh validation on current markdown before saving
@@ -388,9 +389,10 @@ export function TreeEditorPage() {
} finally {
setSaving(false)
}
}, [isEditMode, id, editorMode, getTreeForSave, markSaved, navigate])
}, [isSaving, isEditMode, id, editorMode, getTreeForSave, markSaved, navigate])
const handlePublish = useCallback(async () => {
if (isSaving) return
setSaving(true)
try {
// In Code Mode, run fresh validation on current markdown before publishing
@@ -467,7 +469,7 @@ export function TreeEditorPage() {
} finally {
setSaving(false)
}
}, [isEditMode, id, editorMode, validate, getTreeForSave, markSaved, navigate])
}, [isSaving, isEditMode, id, editorMode, validate, getTreeForSave, markSaved, navigate])
// Keep handleSave for backward compatibility (Ctrl+S shortcut)
const handleSave = useCallback(async () => {