feat: add require_tenant_context and require_admin_db dependencies
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,8 @@ from app.core.database import get_db
|
|||||||
from app.core.security import decode_token
|
from app.core.security import decode_token
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.plan_limits import PlanLimits
|
from app.models.plan_limits import PlanLimits
|
||||||
|
from app.core.tenant_context import set_current_account_id, clear_current_account_id
|
||||||
|
from app.core.admin_database import get_admin_db # noqa: F401 — re-exported for use in endpoints
|
||||||
|
|
||||||
# Routes that are allowed even when must_change_password is True
|
# Routes that are allowed even when must_change_password is True
|
||||||
_PASSWORD_CHANGE_ALLOWLIST = {
|
_PASSWORD_CHANGE_ALLOWLIST = {
|
||||||
@@ -190,3 +192,44 @@ async def get_plan_limits_for_user(
|
|||||||
"""Get plan limits for the current user's account."""
|
"""Get plan limits for the current user's account."""
|
||||||
from app.core.subscriptions import get_user_plan_limits
|
from app.core.subscriptions import get_user_plan_limits
|
||||||
return await get_user_plan_limits(current_user.account_id, db)
|
return await get_user_plan_limits(current_user.account_id, db)
|
||||||
|
|
||||||
|
|
||||||
|
async def require_tenant_context(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
):
|
||||||
|
"""Set per-request tenant context for RLS.
|
||||||
|
|
||||||
|
Raises 403 if the authenticated user has no account_id — never falls back
|
||||||
|
to PLATFORM_ACCOUNT_ID (that would grant platform-scope access to a
|
||||||
|
malformed account).
|
||||||
|
|
||||||
|
Sets the ContextVar that the SQLAlchemy transaction-begin listener reads to
|
||||||
|
issue set_config('app.current_account_id', …, true) on every transaction.
|
||||||
|
|
||||||
|
Applied to every user-facing router. NOT applied to /admin/* routers or
|
||||||
|
public endpoints (auth, shared, webhooks).
|
||||||
|
"""
|
||||||
|
if current_user.account_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="User account required",
|
||||||
|
)
|
||||||
|
token = set_current_account_id(current_user.account_id)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
clear_current_account_id(token)
|
||||||
|
|
||||||
|
|
||||||
|
async def require_admin_db(
|
||||||
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
|
current_user: Annotated[User, Depends(require_admin)],
|
||||||
|
) -> AsyncSession:
|
||||||
|
"""Return a BYPASSRLS admin DB session after verifying super_admin role.
|
||||||
|
|
||||||
|
Use on /admin/* endpoints that query RLS-protected tables. Replaces
|
||||||
|
Depends(get_db) on the db parameter of those endpoints.
|
||||||
|
The current_user dep is still declared separately on the endpoint if
|
||||||
|
the user object is needed in the handler.
|
||||||
|
"""
|
||||||
|
return db
|
||||||
|
|||||||
@@ -41,3 +41,18 @@ def test_tasks_are_isolated():
|
|||||||
asyncio.run(run())
|
asyncio.run(run())
|
||||||
assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001"
|
assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001"
|
||||||
assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002"
|
assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_tenant_context_raises_403_when_no_account():
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from app.api.deps import require_tenant_context
|
||||||
|
|
||||||
|
user = MagicMock()
|
||||||
|
user.account_id = None
|
||||||
|
|
||||||
|
gen = require_tenant_context(current_user=user)
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await gen.__anext__()
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
assert "account required" in exc_info.value.detail.lower()
|
||||||
|
|||||||
Reference in New Issue
Block a user