93 lines
3.4 KiB
Python
93 lines
3.4 KiB
Python
# backend/app/core/tenant_context.py
|
|
"""
|
|
Per-request tenant context for row-level security.
|
|
|
|
Flow:
|
|
1. require_tenant_context (FastAPI dep) calls set_current_account_id().
|
|
2. The SQLAlchemy transaction-begin listener fires on every new transaction
|
|
and calls set_config('app.current_account_id', <id>, true) automatically.
|
|
3. PostgreSQL RLS policies read current_setting('app.current_account_id', TRUE)
|
|
to filter rows.
|
|
|
|
The ContextVar is asyncio-task-scoped: each concurrent request has its own value.
|
|
set_config with is_local=true is transaction-scoped: it resets on COMMIT or
|
|
ROLLBACK, so the listener re-applies it at the start of every transaction.
|
|
"""
|
|
import contextvars
|
|
from typing import TYPE_CHECKING
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import event, or_, text
|
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
|
|
if TYPE_CHECKING:
|
|
from app.models.user import User
|
|
|
|
# One slot per async task — each concurrent request gets its own value.
|
|
_current_account_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
|
"current_account_id", default=None
|
|
)
|
|
|
|
# Platform account — global content visible to all tenants.
|
|
PLATFORM_ACCOUNT_ID = UUID("00000000-0000-0000-0000-000000000001")
|
|
|
|
|
|
def set_current_account_id(account_id: UUID) -> contextvars.Token:
|
|
"""Set tenant context for the current request coroutine.
|
|
|
|
Returns a token so the caller can reset it after the request.
|
|
"""
|
|
return _current_account_id.set(str(account_id))
|
|
|
|
|
|
def clear_current_account_id(token: contextvars.Token) -> None:
|
|
"""Reset the ContextVar to its previous value (call in finally block)."""
|
|
_current_account_id.reset(token)
|
|
|
|
|
|
def get_current_account_id() -> str | None:
|
|
"""Return the account_id string for the current request, or None."""
|
|
return _current_account_id.get()
|
|
|
|
|
|
def register_tenant_listener(engine: AsyncEngine) -> None:
|
|
"""Register the transaction-begin listener on the given engine.
|
|
|
|
Must be called once at application startup, AFTER the engine is created.
|
|
The listener issues set_config() at the start of every transaction so that
|
|
the setting is re-applied automatically even when a request commits
|
|
mid-flight and starts a new transaction.
|
|
|
|
Do NOT call this on admin_engine — admin connections must never set tenant
|
|
context automatically.
|
|
"""
|
|
|
|
@event.listens_for(engine.sync_engine, "begin")
|
|
def _on_transaction_begin(conn) -> None: # noqa: ANN001
|
|
account_id = _current_account_id.get()
|
|
if account_id:
|
|
# set_config(name, value, is_local=true) ≡ SET LOCAL.
|
|
# Unlike SET LOCAL, set_config IS parameterisable.
|
|
conn.execute(
|
|
text("SELECT set_config('app.current_account_id', :id, true)"),
|
|
{"id": account_id},
|
|
)
|
|
# If no account_id is set, do nothing. The RLS policy falls back to a
|
|
# null-matching UUID and returns zero rows — fail-closed behaviour.
|
|
|
|
|
|
def tenant_filter(Model, current_user: "User"): # noqa: ANN001
|
|
"""SQLAlchemy filter clause for tables that contain platform-owned rows.
|
|
|
|
Use for: tree_tags, tree_categories, step_categories, step_library,
|
|
template_trees, platform_steps.
|
|
|
|
For tenant-only tables (trees, sessions, psa_connections, etc.) use:
|
|
Model.account_id == current_user.account_id
|
|
directly.
|
|
"""
|
|
return or_(
|
|
Model.account_id == current_user.account_id,
|
|
Model.account_id == PLATFORM_ACCOUNT_ID,
|
|
)
|