feat: add tenant_context module — ContextVar, transaction listener, tenant_filter

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-04-10 03:48:34 +00:00
parent 6f1becf21f
commit b4f8694f6b
2 changed files with 135 additions and 0 deletions

View File

@@ -0,0 +1,92 @@
# backend/app/core/tenant_context.py
"""
Per-request tenant context for row-level security.
Flow:
1. require_tenant_context (FastAPI dep) calls set_current_account_id().
2. The SQLAlchemy transaction-begin listener fires on every new transaction
and calls set_config('app.current_account_id', <id>, true) automatically.
3. PostgreSQL RLS policies read current_setting('app.current_account_id', TRUE)
to filter rows.
The ContextVar is asyncio-task-scoped: each concurrent request has its own value.
set_config with is_local=true is transaction-scoped: it resets on COMMIT or
ROLLBACK, so the listener re-applies it at the start of every transaction.
"""
import contextvars
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import event, or_, text
from sqlalchemy.ext.asyncio import AsyncEngine
if TYPE_CHECKING:
from app.models.user import User
# One slot per async task — each concurrent request gets its own value.
_current_account_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"current_account_id", default=None
)
# Platform account — global content visible to all tenants.
PLATFORM_ACCOUNT_ID = UUID("00000000-0000-0000-0000-000000000001")
def set_current_account_id(account_id: UUID) -> contextvars.Token:
"""Set tenant context for the current request coroutine.
Returns a token so the caller can reset it after the request.
"""
return _current_account_id.set(str(account_id))
def clear_current_account_id(token: contextvars.Token) -> None:
"""Reset the ContextVar to its previous value (call in finally block)."""
_current_account_id.reset(token)
def get_current_account_id() -> str | None:
"""Return the account_id string for the current request, or None."""
return _current_account_id.get()
def register_tenant_listener(engine: AsyncEngine) -> None:
"""Register the transaction-begin listener on the given engine.
Must be called once at application startup, AFTER the engine is created.
The listener issues set_config() at the start of every transaction so that
the setting is re-applied automatically even when a request commits
mid-flight and starts a new transaction.
Do NOT call this on admin_engine — admin connections must never set tenant
context automatically.
"""
@event.listens_for(engine.sync_engine, "begin")
def _on_transaction_begin(conn) -> None: # noqa: ANN001
account_id = _current_account_id.get()
if account_id:
# set_config(name, value, is_local=true) ≡ SET LOCAL.
# Unlike SET LOCAL, set_config IS parameterisable.
conn.execute(
text("SELECT set_config('app.current_account_id', :id, true)"),
{"id": account_id},
)
# If no account_id is set, do nothing. The RLS policy falls back to a
# null-matching UUID and returns zero rows — fail-closed behaviour.
def tenant_filter(Model, current_user: "User"): # noqa: ANN001
"""SQLAlchemy filter clause for tables that contain platform-owned rows.
Use for: tree_tags, tree_categories, step_categories, step_library,
template_trees, platform_steps.
For tenant-only tables (trees, sessions, psa_connections, etc.) use:
Model.account_id == current_user.account_id
directly.
"""
return or_(
Model.account_id == current_user.account_id,
Model.account_id == PLATFORM_ACCOUNT_ID,
)

View File

@@ -0,0 +1,43 @@
import asyncio
from uuid import UUID
import pytest
from unittest.mock import MagicMock
from app.core.tenant_context import (
set_current_account_id,
clear_current_account_id,
get_current_account_id,
)
def test_contextvar_is_none_by_default():
assert get_current_account_id() is None
def test_set_and_clear():
account_id = UUID("aaaaaaaa-0000-0000-0000-000000000001")
token = set_current_account_id(account_id)
assert get_current_account_id() == str(account_id)
clear_current_account_id(token)
assert get_current_account_id() is None
def test_tasks_are_isolated():
"""Each asyncio task has its own ContextVar value."""
results = {}
async def set_in_task(name: str, value: str):
token = set_current_account_id(UUID(value))
await asyncio.sleep(0)
results[name] = get_current_account_id()
clear_current_account_id(token)
async def run():
await asyncio.gather(
set_in_task("a", "aaaaaaaa-0000-0000-0000-000000000001"),
set_in_task("b", "bbbbbbbb-0000-0000-0000-000000000002"),
)
asyncio.run(run())
assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001"
assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002"