# backend/app/core/tenant_context.py """ Per-request tenant context for row-level security. Flow: 1. require_tenant_context (FastAPI dep) calls set_current_account_id(). 2. The SQLAlchemy transaction-begin listener fires on every new transaction and calls set_config('app.current_account_id', , true) automatically. 3. PostgreSQL RLS policies read current_setting('app.current_account_id', TRUE) to filter rows. The ContextVar is asyncio-task-scoped: each concurrent request has its own value. set_config with is_local=true is transaction-scoped: it resets on COMMIT or ROLLBACK, so the listener re-applies it at the start of every transaction. """ import contextvars from typing import TYPE_CHECKING from uuid import UUID from sqlalchemy import event, or_, text from sqlalchemy.ext.asyncio import AsyncEngine if TYPE_CHECKING: from app.models.user import User # One slot per async task — each concurrent request gets its own value. _current_account_id: contextvars.ContextVar[str | None] = contextvars.ContextVar( "current_account_id", default=None ) # Platform account — global content visible to all tenants. PLATFORM_ACCOUNT_ID = UUID("00000000-0000-0000-0000-000000000001") def set_current_account_id(account_id: UUID) -> contextvars.Token: """Set tenant context for the current request coroutine. Returns a token so the caller can reset it after the request. """ return _current_account_id.set(str(account_id)) def clear_current_account_id(token: contextvars.Token) -> None: """Reset the ContextVar to its previous value (call in finally block).""" _current_account_id.reset(token) def get_current_account_id() -> str | None: """Return the account_id string for the current request, or None.""" return _current_account_id.get() def register_tenant_listener(engine: AsyncEngine) -> None: """Register the transaction-begin listener on the given engine. Must be called once at application startup, AFTER the engine is created. The listener issues set_config() at the start of every transaction so that the setting is re-applied automatically even when a request commits mid-flight and starts a new transaction. Do NOT call this on admin_engine — admin connections must never set tenant context automatically. """ @event.listens_for(engine.sync_engine, "begin") def _on_transaction_begin(conn) -> None: # noqa: ANN001 account_id = _current_account_id.get() if account_id: # set_config(name, value, is_local=true) ≡ SET LOCAL. # Unlike SET LOCAL, set_config IS parameterisable. conn.execute( text("SELECT set_config('app.current_account_id', :id, true)"), {"id": account_id}, ) # If no account_id is set, do nothing. The RLS policy falls back to a # null-matching UUID and returns zero rows — fail-closed behaviour. def tenant_filter(Model, current_user: "User"): # noqa: ANN001 """SQLAlchemy filter clause for tables that contain platform-owned rows. Use for: tree_tags, tree_categories, step_categories, step_library, template_trees, platform_steps. For tenant-only tables (trees, sessions, psa_connections, etc.) use: Model.account_id == current_user.account_id directly. """ return or_( Model.account_id == current_user.account_id, Model.account_id == PLATFORM_ACCOUNT_ID, )