feat: add tenant_context module — ContextVar, transaction listener, tenant_filter
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
92
backend/app/core/tenant_context.py
Normal file
92
backend/app/core/tenant_context.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# backend/app/core/tenant_context.py
|
||||
"""
|
||||
Per-request tenant context for row-level security.
|
||||
|
||||
Flow:
|
||||
1. require_tenant_context (FastAPI dep) calls set_current_account_id().
|
||||
2. The SQLAlchemy transaction-begin listener fires on every new transaction
|
||||
and calls set_config('app.current_account_id', <id>, true) automatically.
|
||||
3. PostgreSQL RLS policies read current_setting('app.current_account_id', TRUE)
|
||||
to filter rows.
|
||||
|
||||
The ContextVar is asyncio-task-scoped: each concurrent request has its own value.
|
||||
set_config with is_local=true is transaction-scoped: it resets on COMMIT or
|
||||
ROLLBACK, so the listener re-applies it at the start of every transaction.
|
||||
"""
|
||||
import contextvars
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import event, or_, text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
|
||||
# One slot per async task — each concurrent request gets its own value.
|
||||
_current_account_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"current_account_id", default=None
|
||||
)
|
||||
|
||||
# Platform account — global content visible to all tenants.
|
||||
PLATFORM_ACCOUNT_ID = UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
def set_current_account_id(account_id: UUID) -> contextvars.Token:
|
||||
"""Set tenant context for the current request coroutine.
|
||||
|
||||
Returns a token so the caller can reset it after the request.
|
||||
"""
|
||||
return _current_account_id.set(str(account_id))
|
||||
|
||||
|
||||
def clear_current_account_id(token: contextvars.Token) -> None:
|
||||
"""Reset the ContextVar to its previous value (call in finally block)."""
|
||||
_current_account_id.reset(token)
|
||||
|
||||
|
||||
def get_current_account_id() -> str | None:
|
||||
"""Return the account_id string for the current request, or None."""
|
||||
return _current_account_id.get()
|
||||
|
||||
|
||||
def register_tenant_listener(engine: AsyncEngine) -> None:
|
||||
"""Register the transaction-begin listener on the given engine.
|
||||
|
||||
Must be called once at application startup, AFTER the engine is created.
|
||||
The listener issues set_config() at the start of every transaction so that
|
||||
the setting is re-applied automatically even when a request commits
|
||||
mid-flight and starts a new transaction.
|
||||
|
||||
Do NOT call this on admin_engine — admin connections must never set tenant
|
||||
context automatically.
|
||||
"""
|
||||
|
||||
@event.listens_for(engine.sync_engine, "begin")
|
||||
def _on_transaction_begin(conn) -> None: # noqa: ANN001
|
||||
account_id = _current_account_id.get()
|
||||
if account_id:
|
||||
# set_config(name, value, is_local=true) ≡ SET LOCAL.
|
||||
# Unlike SET LOCAL, set_config IS parameterisable.
|
||||
conn.execute(
|
||||
text("SELECT set_config('app.current_account_id', :id, true)"),
|
||||
{"id": account_id},
|
||||
)
|
||||
# If no account_id is set, do nothing. The RLS policy falls back to a
|
||||
# null-matching UUID and returns zero rows — fail-closed behaviour.
|
||||
|
||||
|
||||
def tenant_filter(Model, current_user: "User"): # noqa: ANN001
|
||||
"""SQLAlchemy filter clause for tables that contain platform-owned rows.
|
||||
|
||||
Use for: tree_tags, tree_categories, step_categories, step_library,
|
||||
template_trees, platform_steps.
|
||||
|
||||
For tenant-only tables (trees, sessions, psa_connections, etc.) use:
|
||||
Model.account_id == current_user.account_id
|
||||
directly.
|
||||
"""
|
||||
return or_(
|
||||
Model.account_id == current_user.account_id,
|
||||
Model.account_id == PLATFORM_ACCOUNT_ID,
|
||||
)
|
||||
43
backend/tests/test_tenant_context.py
Normal file
43
backend/tests/test_tenant_context.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.core.tenant_context import (
|
||||
set_current_account_id,
|
||||
clear_current_account_id,
|
||||
get_current_account_id,
|
||||
)
|
||||
|
||||
|
||||
def test_contextvar_is_none_by_default():
|
||||
assert get_current_account_id() is None
|
||||
|
||||
|
||||
def test_set_and_clear():
|
||||
account_id = UUID("aaaaaaaa-0000-0000-0000-000000000001")
|
||||
token = set_current_account_id(account_id)
|
||||
assert get_current_account_id() == str(account_id)
|
||||
clear_current_account_id(token)
|
||||
assert get_current_account_id() is None
|
||||
|
||||
|
||||
def test_tasks_are_isolated():
|
||||
"""Each asyncio task has its own ContextVar value."""
|
||||
results = {}
|
||||
|
||||
async def set_in_task(name: str, value: str):
|
||||
token = set_current_account_id(UUID(value))
|
||||
await asyncio.sleep(0)
|
||||
results[name] = get_current_account_id()
|
||||
clear_current_account_id(token)
|
||||
|
||||
async def run():
|
||||
await asyncio.gather(
|
||||
set_in_task("a", "aaaaaaaa-0000-0000-0000-000000000001"),
|
||||
set_in_task("b", "bbbbbbbb-0000-0000-0000-000000000002"),
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001"
|
||||
assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002"
|
||||
Reference in New Issue
Block a user