From df9ecf2d294554ca97c4879076481dc70e5d28f4 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Fri, 10 Apr 2026 03:50:59 +0000 Subject: [PATCH] feat: add require_tenant_context and require_admin_db dependencies Co-Authored-By: Claude Sonnet 4.6 --- backend/app/api/deps.py | 43 ++++++++++++++++++++++++++++ backend/tests/test_tenant_context.py | 15 ++++++++++ 2 files changed, 58 insertions(+) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 28536d68..bae3f935 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -10,6 +10,8 @@ from app.core.database import get_db from app.core.security import decode_token from app.models.user import User from app.models.plan_limits import PlanLimits +from app.core.tenant_context import set_current_account_id, clear_current_account_id +from app.core.admin_database import get_admin_db # noqa: F401 — re-exported for use in endpoints # Routes that are allowed even when must_change_password is True _PASSWORD_CHANGE_ALLOWLIST = { @@ -190,3 +192,44 @@ async def get_plan_limits_for_user( """Get plan limits for the current user's account.""" from app.core.subscriptions import get_user_plan_limits return await get_user_plan_limits(current_user.account_id, db) + + +async def require_tenant_context( + current_user: Annotated[User, Depends(get_current_active_user)], +): + """Set per-request tenant context for RLS. + + Raises 403 if the authenticated user has no account_id — never falls back + to PLATFORM_ACCOUNT_ID (that would grant platform-scope access to a + malformed account). + + Sets the ContextVar that the SQLAlchemy transaction-begin listener reads to + issue set_config('app.current_account_id', …, true) on every transaction. + + Applied to every user-facing router. NOT applied to /admin/* routers or + public endpoints (auth, shared, webhooks). + """ + if current_user.account_id is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User account required", + ) + token = set_current_account_id(current_user.account_id) + try: + yield + finally: + clear_current_account_id(token) + + +async def require_admin_db( + db: Annotated[AsyncSession, Depends(get_admin_db)], + current_user: Annotated[User, Depends(require_admin)], +) -> AsyncSession: + """Return a BYPASSRLS admin DB session after verifying super_admin role. + + Use on /admin/* endpoints that query RLS-protected tables. Replaces + Depends(get_db) on the db parameter of those endpoints. + The current_user dep is still declared separately on the endpoint if + the user object is needed in the handler. + """ + return db diff --git a/backend/tests/test_tenant_context.py b/backend/tests/test_tenant_context.py index f3a2e89b..e4ad183e 100644 --- a/backend/tests/test_tenant_context.py +++ b/backend/tests/test_tenant_context.py @@ -41,3 +41,18 @@ def test_tasks_are_isolated(): asyncio.run(run()) assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001" assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002" + + +@pytest.mark.asyncio +async def test_require_tenant_context_raises_403_when_no_account(): + from fastapi import HTTPException + from app.api.deps import require_tenant_context + + user = MagicMock() + user.account_id = None + + gen = require_tenant_context(current_user=user) + with pytest.raises(HTTPException) as exc_info: + await gen.__anext__() + assert exc_info.value.status_code == 403 + assert "account required" in exc_info.value.detail.lower()