diff --git a/backend/app/core/tenant_context.py b/backend/app/core/tenant_context.py new file mode 100644 index 00000000..9cdb80c2 --- /dev/null +++ b/backend/app/core/tenant_context.py @@ -0,0 +1,92 @@ +# backend/app/core/tenant_context.py +""" +Per-request tenant context for row-level security. + +Flow: + 1. require_tenant_context (FastAPI dep) calls set_current_account_id(). + 2. The SQLAlchemy transaction-begin listener fires on every new transaction + and calls set_config('app.current_account_id', , true) automatically. + 3. PostgreSQL RLS policies read current_setting('app.current_account_id', TRUE) + to filter rows. + +The ContextVar is asyncio-task-scoped: each concurrent request has its own value. +set_config with is_local=true is transaction-scoped: it resets on COMMIT or +ROLLBACK, so the listener re-applies it at the start of every transaction. +""" +import contextvars +from typing import TYPE_CHECKING +from uuid import UUID + +from sqlalchemy import event, or_, text +from sqlalchemy.ext.asyncio import AsyncEngine + +if TYPE_CHECKING: + from app.models.user import User + +# One slot per async task — each concurrent request gets its own value. +_current_account_id: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "current_account_id", default=None +) + +# Platform account — global content visible to all tenants. +PLATFORM_ACCOUNT_ID = UUID("00000000-0000-0000-0000-000000000001") + + +def set_current_account_id(account_id: UUID) -> contextvars.Token: + """Set tenant context for the current request coroutine. + + Returns a token so the caller can reset it after the request. + """ + return _current_account_id.set(str(account_id)) + + +def clear_current_account_id(token: contextvars.Token) -> None: + """Reset the ContextVar to its previous value (call in finally block).""" + _current_account_id.reset(token) + + +def get_current_account_id() -> str | None: + """Return the account_id string for the current request, or None.""" + return _current_account_id.get() + + +def register_tenant_listener(engine: AsyncEngine) -> None: + """Register the transaction-begin listener on the given engine. + + Must be called once at application startup, AFTER the engine is created. + The listener issues set_config() at the start of every transaction so that + the setting is re-applied automatically even when a request commits + mid-flight and starts a new transaction. + + Do NOT call this on admin_engine — admin connections must never set tenant + context automatically. + """ + + @event.listens_for(engine.sync_engine, "begin") + def _on_transaction_begin(conn) -> None: # noqa: ANN001 + account_id = _current_account_id.get() + if account_id: + # set_config(name, value, is_local=true) ≡ SET LOCAL. + # Unlike SET LOCAL, set_config IS parameterisable. + conn.execute( + text("SELECT set_config('app.current_account_id', :id, true)"), + {"id": account_id}, + ) + # If no account_id is set, do nothing. The RLS policy falls back to a + # null-matching UUID and returns zero rows — fail-closed behaviour. + + +def tenant_filter(Model, current_user: "User"): # noqa: ANN001 + """SQLAlchemy filter clause for tables that contain platform-owned rows. + + Use for: tree_tags, tree_categories, step_categories, step_library, + template_trees, platform_steps. + + For tenant-only tables (trees, sessions, psa_connections, etc.) use: + Model.account_id == current_user.account_id + directly. + """ + return or_( + Model.account_id == current_user.account_id, + Model.account_id == PLATFORM_ACCOUNT_ID, + ) diff --git a/backend/tests/test_tenant_context.py b/backend/tests/test_tenant_context.py new file mode 100644 index 00000000..f3a2e89b --- /dev/null +++ b/backend/tests/test_tenant_context.py @@ -0,0 +1,43 @@ +import asyncio +from uuid import UUID +import pytest +from unittest.mock import MagicMock + +from app.core.tenant_context import ( + set_current_account_id, + clear_current_account_id, + get_current_account_id, +) + + +def test_contextvar_is_none_by_default(): + assert get_current_account_id() is None + + +def test_set_and_clear(): + account_id = UUID("aaaaaaaa-0000-0000-0000-000000000001") + token = set_current_account_id(account_id) + assert get_current_account_id() == str(account_id) + clear_current_account_id(token) + assert get_current_account_id() is None + + +def test_tasks_are_isolated(): + """Each asyncio task has its own ContextVar value.""" + results = {} + + async def set_in_task(name: str, value: str): + token = set_current_account_id(UUID(value)) + await asyncio.sleep(0) + results[name] = get_current_account_id() + clear_current_account_id(token) + + async def run(): + await asyncio.gather( + set_in_task("a", "aaaaaaaa-0000-0000-0000-000000000001"), + set_in_task("b", "bbbbbbbb-0000-0000-0000-000000000002"), + ) + + asyncio.run(run()) + assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001" + assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002"