fix: atomic counters, plan limit re-check, and double-submit guard
Backend: - Tree usage_count: use SQL-level UPDATE (Tree.usage_count + 1) instead of Python-level increment to prevent lost updates under concurrency - Tag usage_count: same SQL-level atomic increment/decrement in both create_tree and update_tree (delete_tree already used this pattern) - Plan tree limit: re-check count after db.flush() to close the TOCTOU window where two concurrent creates could both pass the pre-check Frontend: - TreeEditorPage: add isSaving early-return guard inside handleSaveDraft and handlePublish callbacks so Ctrl+S can't bypass the button disabled prop and fire duplicate save requests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
from pydantic import BaseModel, Field as PydanticField
|
from pydantic import BaseModel, Field as PydanticField
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, update as sa_update
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.models.tree import Tree
|
from app.models.tree import Tree
|
||||||
@@ -189,8 +189,10 @@ async def start_session(
|
|||||||
session_variables=session_variables,
|
session_variables=session_variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Increment tree usage count
|
# Atomically increment tree usage count (SQL-level to avoid lost updates)
|
||||||
tree.usage_count += 1
|
await db.execute(
|
||||||
|
sa_update(Tree).where(Tree.id == tree.id).values(usage_count=Tree.usage_count + 1)
|
||||||
|
)
|
||||||
|
|
||||||
db.add(new_session)
|
db.add(new_session)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from app.models.user_pinned_tree import UserPinnedTree
|
|||||||
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin, get_service_account_id
|
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin, get_service_account_id
|
||||||
from app.core.permissions import can_edit_tree, can_access_tree
|
from app.core.permissions import can_edit_tree, can_access_tree
|
||||||
from app.core.filters import build_tree_access_filter
|
from app.core.filters import build_tree_access_filter
|
||||||
from app.core.subscriptions import check_tree_limit
|
from app.core.subscriptions import check_tree_limit, get_account_subscription, get_plan_limits
|
||||||
from app.core.audit import log_audit
|
from app.core.audit import log_audit
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.tree_validation import can_publish_tree
|
from app.core.tree_validation import can_publish_tree
|
||||||
@@ -487,6 +487,26 @@ async def create_tree(
|
|||||||
db.add(new_tree)
|
db.add(new_tree)
|
||||||
await db.flush() # Get the ID
|
await db.flush() # Get the ID
|
||||||
|
|
||||||
|
# Re-check tree limit after flush to close the TOCTOU race window:
|
||||||
|
# two concurrent creates could both pass the pre-check, but only one
|
||||||
|
# should succeed when the limit is exactly reached.
|
||||||
|
if not is_default and current_user.account_id:
|
||||||
|
post_count = await db.scalar(
|
||||||
|
select(func.count(Tree.id)).where(
|
||||||
|
Tree.account_id == current_user.account_id,
|
||||||
|
Tree.deleted_at.is_(None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub = await get_account_subscription(current_user.account_id, db)
|
||||||
|
if sub:
|
||||||
|
limits = await get_plan_limits(sub.plan, db)
|
||||||
|
if limits and limits.max_trees and (post_count or 0) > limits.max_trees:
|
||||||
|
await db.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Tree limit reached ({limits.max_trees}/{limits.max_trees}). Upgrade your plan to create more trees."
|
||||||
|
)
|
||||||
|
|
||||||
# Handle tags
|
# Handle tags
|
||||||
if tree_data.tags:
|
if tree_data.tags:
|
||||||
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||||
@@ -519,7 +539,6 @@ async def create_tree(
|
|||||||
await db.flush()
|
await db.flush()
|
||||||
|
|
||||||
tags_to_add.append(tag)
|
tags_to_add.append(tag)
|
||||||
tag.usage_count += 1
|
|
||||||
|
|
||||||
# Use direct SQL insert for the junction table to avoid lazy load issues
|
# Use direct SQL insert for the junction table to avoid lazy load issues
|
||||||
from app.models.tag import tree_tag_assignments
|
from app.models.tag import tree_tag_assignments
|
||||||
@@ -531,6 +550,10 @@ async def create_tree(
|
|||||||
assigned_by=current_user.id
|
assigned_by=current_user.id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# Atomically increment (SQL-level to avoid lost updates from concurrent requests)
|
||||||
|
await db.execute(
|
||||||
|
update(TreeTag).where(TreeTag.id == tag.id).values(usage_count=TreeTag.usage_count + 1)
|
||||||
|
)
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
@@ -673,9 +696,14 @@ async def update_tree(
|
|||||||
if tags_data is not None:
|
if tags_data is not None:
|
||||||
from app.models.tag import tree_tag_assignments
|
from app.models.tag import tree_tag_assignments
|
||||||
|
|
||||||
# Decrement usage count for old tags (already eagerly loaded)
|
# Atomically decrement usage count for old tags
|
||||||
for tag in tree.tags:
|
old_tag_ids = [tag.id for tag in tree.tags]
|
||||||
tag.usage_count = max(0, tag.usage_count - 1)
|
if old_tag_ids:
|
||||||
|
await db.execute(
|
||||||
|
update(TreeTag)
|
||||||
|
.where(TreeTag.id.in_(old_tag_ids))
|
||||||
|
.values(usage_count=func.greatest(TreeTag.usage_count - 1, 0))
|
||||||
|
)
|
||||||
|
|
||||||
# Delete existing tag assignments using direct SQL
|
# Delete existing tag assignments using direct SQL
|
||||||
await db.execute(
|
await db.execute(
|
||||||
@@ -720,7 +748,10 @@ async def update_tree(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
added_tag_ids.add(tag.id)
|
added_tag_ids.add(tag.id)
|
||||||
tag.usage_count += 1
|
# Atomically increment (SQL-level to avoid lost updates)
|
||||||
|
await db.execute(
|
||||||
|
update(TreeTag).where(TreeTag.id == tag.id).values(usage_count=TreeTag.usage_count + 1)
|
||||||
|
)
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -330,6 +330,7 @@ export function TreeEditorPage() {
|
|||||||
}, [updateNode, selectNode])
|
}, [updateNode, selectNode])
|
||||||
|
|
||||||
const handleSaveDraft = useCallback(async () => {
|
const handleSaveDraft = useCallback(async () => {
|
||||||
|
if (isSaving) return
|
||||||
setSaving(true)
|
setSaving(true)
|
||||||
try {
|
try {
|
||||||
// In Code Mode, run fresh validation on current markdown before saving
|
// In Code Mode, run fresh validation on current markdown before saving
|
||||||
@@ -388,9 +389,10 @@ export function TreeEditorPage() {
|
|||||||
} finally {
|
} finally {
|
||||||
setSaving(false)
|
setSaving(false)
|
||||||
}
|
}
|
||||||
}, [isEditMode, id, editorMode, getTreeForSave, markSaved, navigate])
|
}, [isSaving, isEditMode, id, editorMode, getTreeForSave, markSaved, navigate])
|
||||||
|
|
||||||
const handlePublish = useCallback(async () => {
|
const handlePublish = useCallback(async () => {
|
||||||
|
if (isSaving) return
|
||||||
setSaving(true)
|
setSaving(true)
|
||||||
try {
|
try {
|
||||||
// In Code Mode, run fresh validation on current markdown before publishing
|
// In Code Mode, run fresh validation on current markdown before publishing
|
||||||
@@ -467,7 +469,7 @@ export function TreeEditorPage() {
|
|||||||
} finally {
|
} finally {
|
||||||
setSaving(false)
|
setSaving(false)
|
||||||
}
|
}
|
||||||
}, [isEditMode, id, editorMode, validate, getTreeForSave, markSaved, navigate])
|
}, [isSaving, isEditMode, id, editorMode, validate, getTreeForSave, markSaved, navigate])
|
||||||
|
|
||||||
// Keep handleSave for backward compatibility (Ctrl+S shortcut)
|
// Keep handleSave for backward compatibility (Ctrl+S shortcut)
|
||||||
const handleSave = useCallback(async () => {
|
const handleSave = useCallback(async () => {
|
||||||
|
|||||||
Reference in New Issue
Block a user