diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 28536d68..4bd3fd3c 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -190,3 +190,20 @@ async def get_plan_limits_for_user( """Get plan limits for the current user's account.""" from app.core.subscriptions import get_user_plan_limits return await get_user_plan_limits(current_user.account_id, db) + + +async def get_tenant_context( + current_user: Annotated[User, Depends(get_current_active_user)], +) -> UUID: + """Return the current user's account_id. + + Use this dependency instead of reading current_user.account_id directly. + Raises 403 if the user has no account association (should not happen in + normal flows — users are always associated with an account on registration). + """ + if current_user.account_id is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User not associated with any account", + ) + return current_user.account_id diff --git a/backend/app/core/filters.py b/backend/app/core/filters.py index 005e269d..f6e629e3 100644 --- a/backend/app/core/filters.py +++ b/backend/app/core/filters.py @@ -1,10 +1,12 @@ """ Centralized query filters for ResolutionFlow. -Provides reusable SQLAlchemy filter builders for tree access control -and step visibility, used across multiple endpoint modules. +Provides reusable SQLAlchemy filter builders for tree access control, +step visibility, and the canonical tenant_filter used by all queries +on tenant-scoped tables. """ from __future__ import annotations +import uuid from typing import TYPE_CHECKING from sqlalchemy import or_, and_, true as sa_true @@ -13,6 +15,18 @@ if TYPE_CHECKING: from app.models.user import User +def tenant_filter(model, account_id: uuid.UUID): + """Primary app-layer tenant filter. + + MUST be used in every SELECT/UPDATE/DELETE on tenant tables. + RLS (Phase 2) is the safety net — this is the primary enforcement. + + Usage: + stmt = select(Tree).where(tenant_filter(Tree, current_user.account_id), ...) + """ + return model.account_id == account_id + + def build_tree_access_filter(current_user: User): """Build the access filter for trees based on user permissions. @@ -36,10 +50,11 @@ def build_tree_access_filter(current_user: User): Tree.author_id == current_user.id, ] if current_user.account_id: + # Team-visible trees: use tenant_filter as the account match conditions.append( and_( Tree.visibility == 'team', - Tree.account_id == current_user.account_id + tenant_filter(Tree, current_user.account_id), ) ) return or_(*conditions) @@ -58,11 +73,14 @@ def build_step_visibility_filter(current_user: User): if current_user.account_id: return or_( StepLibrary.visibility == 'public', - and_(StepLibrary.visibility == 'team', StepLibrary.account_id == current_user.account_id), - StepLibrary.created_by == current_user.id # Own private steps + and_( + StepLibrary.visibility == 'team', + tenant_filter(StepLibrary, current_user.account_id), + ), + StepLibrary.created_by == current_user.id, ) else: return or_( StepLibrary.visibility == 'public', - StepLibrary.created_by == current_user.id + StepLibrary.created_by == current_user.id, )