feat: add require_tenant_context and require_admin_db dependencies
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,8 @@ from app.core.database import get_db
|
||||
from app.core.security import decode_token
|
||||
from app.models.user import User
|
||||
from app.models.plan_limits import PlanLimits
|
||||
from app.core.tenant_context import set_current_account_id, clear_current_account_id
|
||||
from app.core.admin_database import get_admin_db # noqa: F401 — re-exported for use in endpoints
|
||||
|
||||
# Routes that are allowed even when must_change_password is True
|
||||
_PASSWORD_CHANGE_ALLOWLIST = {
|
||||
@@ -190,3 +192,44 @@ async def get_plan_limits_for_user(
|
||||
"""Get plan limits for the current user's account."""
|
||||
from app.core.subscriptions import get_user_plan_limits
|
||||
return await get_user_plan_limits(current_user.account_id, db)
|
||||
|
||||
|
||||
async def require_tenant_context(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
"""Set per-request tenant context for RLS.
|
||||
|
||||
Raises 403 if the authenticated user has no account_id — never falls back
|
||||
to PLATFORM_ACCOUNT_ID (that would grant platform-scope access to a
|
||||
malformed account).
|
||||
|
||||
Sets the ContextVar that the SQLAlchemy transaction-begin listener reads to
|
||||
issue set_config('app.current_account_id', …, true) on every transaction.
|
||||
|
||||
Applied to every user-facing router. NOT applied to /admin/* routers or
|
||||
public endpoints (auth, shared, webhooks).
|
||||
"""
|
||||
if current_user.account_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User account required",
|
||||
)
|
||||
token = set_current_account_id(current_user.account_id)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
clear_current_account_id(token)
|
||||
|
||||
|
||||
async def require_admin_db(
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
) -> AsyncSession:
|
||||
"""Return a BYPASSRLS admin DB session after verifying super_admin role.
|
||||
|
||||
Use on /admin/* endpoints that query RLS-protected tables. Replaces
|
||||
Depends(get_db) on the db parameter of those endpoints.
|
||||
The current_user dep is still declared separately on the endpoint if
|
||||
the user object is needed in the handler.
|
||||
"""
|
||||
return db
|
||||
|
||||
@@ -41,3 +41,18 @@ def test_tasks_are_isolated():
|
||||
asyncio.run(run())
|
||||
assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001"
|
||||
assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_tenant_context_raises_403_when_no_account():
|
||||
from fastapi import HTTPException
|
||||
from app.api.deps import require_tenant_context
|
||||
|
||||
user = MagicMock()
|
||||
user.account_id = None
|
||||
|
||||
gen = require_tenant_context(current_user=user)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await gen.__anext__()
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "account required" in exc_info.value.detail.lower()
|
||||
|
||||
Reference in New Issue
Block a user