feat: add require_tenant_context and require_admin_db dependencies

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-04-10 03:50:59 +00:00
parent b0e5f12897
commit df9ecf2d29
2 changed files with 58 additions and 0 deletions

View File

@@ -10,6 +10,8 @@ from app.core.database import get_db
from app.core.security import decode_token
from app.models.user import User
from app.models.plan_limits import PlanLimits
from app.core.tenant_context import set_current_account_id, clear_current_account_id
from app.core.admin_database import get_admin_db # noqa: F401 — re-exported for use in endpoints
# Routes that are allowed even when must_change_password is True
_PASSWORD_CHANGE_ALLOWLIST = {
@@ -190,3 +192,44 @@ async def get_plan_limits_for_user(
"""Get plan limits for the current user's account."""
from app.core.subscriptions import get_user_plan_limits
return await get_user_plan_limits(current_user.account_id, db)
async def require_tenant_context(
current_user: Annotated[User, Depends(get_current_active_user)],
):
"""Set per-request tenant context for RLS.
Raises 403 if the authenticated user has no account_id — never falls back
to PLATFORM_ACCOUNT_ID (that would grant platform-scope access to a
malformed account).
Sets the ContextVar that the SQLAlchemy transaction-begin listener reads to
issue set_config('app.current_account_id', …, true) on every transaction.
Applied to every user-facing router. NOT applied to /admin/* routers or
public endpoints (auth, shared, webhooks).
"""
if current_user.account_id is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account required",
)
token = set_current_account_id(current_user.account_id)
try:
yield
finally:
clear_current_account_id(token)
async def require_admin_db(
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
) -> AsyncSession:
"""Return a BYPASSRLS admin DB session after verifying super_admin role.
Use on /admin/* endpoints that query RLS-protected tables. Replaces
Depends(get_db) on the db parameter of those endpoints.
The current_user dep is still declared separately on the endpoint if
the user object is needed in the handler.
"""
return db

View File

@@ -41,3 +41,18 @@ def test_tasks_are_isolated():
asyncio.run(run())
assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001"
assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002"
@pytest.mark.asyncio
async def test_require_tenant_context_raises_403_when_no_account():
from fastapi import HTTPException
from app.api.deps import require_tenant_context
user = MagicMock()
user.account_id = None
gen = require_tenant_context(current_user=user)
with pytest.raises(HTTPException) as exc_info:
await gen.__anext__()
assert exc_info.value.status_code == 403
assert "account required" in exc_info.value.detail.lower()