Merge PR #136: feat: tenant isolation Phase 4 — RLS on all remaining tables
This commit was merged in pull request #136.
This commit is contained in:
@@ -375,6 +375,12 @@ gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi
|
||||
|
||||
**106. Guard async "select item → load data → apply state" flows with a ref:** When a component lets the user switch between items (chat sessions, flows, scripts) and loads data asynchronously on each switch, the load for item A can complete *after* the user has already switched to item B — overwriting B's state with A's stale data. Fix pattern: keep a `currentSelectionRef = useRef(initialId)` and update it synchronously whenever the selection changes (in every creation/switch path). After every `await`, bail out if `currentSelectionRef.current !== thisItemId`. See `AssistantChatPage.tsx` `selectChat` for the reference implementation (`currentChatRef`).
|
||||
|
||||
**107. Startup routines must use `_admin_session_factory()` after Phase 4 RLS:** Any code that runs at startup (lifespan, `ensure_service_account`, seed scripts) and touches tenant-isolated tables (`users`, etc.) must use `_admin_session_factory()` — not `get_db()`. Phase 4 enabled RLS on `users`; a tenant-scoped session has no `app.current_account_id` set at startup, so all queries return 0 rows or fail. `get_service_account_id` in `deps.py` is safe — it reads from `app.state` cached at startup, never hits the DB per-request.
|
||||
|
||||
**108. Tables with no `account_id` column (never add to RLS migrations):** `script_categories`, `platform_steps`, `template_trees`, `plan_feature_defaults`, `accounts` — global/platform tables documented with "No account_id. No RLS." in their model files. When writing RLS migrations, scan at the class level (check for `account_id: Mapped` within the class block), not the file level — multiple classes in one `.py` file can have different columns (e.g. `ScriptCategory` vs `ScriptTemplate` in `script_template.py`).
|
||||
|
||||
**109. `tree_shares.account_id` must equal `tree.account_id`, not the actor's account:** When creating a `TreeShare`, always use `account_id=tree.account_id` (tree owner's tenant). A super admin in tenant A sharing tenant B's tree must produce a share row in tenant B's RLS context — using `current_user.account_id` instead makes the share invisible to the tree owner after RLS is enforced.
|
||||
|
||||
## RBAC & Permissions
|
||||
|
||||
- **Role hierarchy:** super_admin > team_admin > engineer > viewer
|
||||
@@ -522,7 +528,7 @@ When a feature, fix, or significant piece of work is finished and merged/committ
|
||||
<!-- gitnexus:start -->
|
||||
# GitNexus — Code Intelligence
|
||||
|
||||
This project is indexed by GitNexus as **resolutionflow** (14787 symbols, 31366 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
|
||||
This project is indexed by GitNexus as **resolutionflow** (16703 symbols, 35922 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
|
||||
|
||||
> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first.
|
||||
|
||||
|
||||
85
backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py
Normal file
85
backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Enable RLS on Phase 4 tables — all remaining tenant-scoped tables.
|
||||
|
||||
All tables in this migration already have account_id NOT NULL (enforced by
|
||||
earlier migrations). This migration adds ENABLE ROW LEVEL SECURITY,
|
||||
FORCE ROW LEVEL SECURITY, and the appropriate tenant isolation policy to each.
|
||||
|
||||
Policy variants used:
|
||||
- Standard: account_id = current_setting(app.current_account_id)::uuid
|
||||
- Platform: standard OR account_id = PLATFORM_ACCOUNT_ID
|
||||
(for global content tables readable by all tenants)
|
||||
|
||||
Skipped intentionally:
|
||||
- accounts — IS the root table; no account_id column
|
||||
- plan_feature_defaults — platform config; no account_id column
|
||||
- script_categories — global lookup table; no account_id column
|
||||
- platform_steps — global content; no account_id column (readable by all)
|
||||
- template_trees — global content; no account_id column (readable by all)
|
||||
|
||||
Revision ID: b3c7e9f2a1d8
|
||||
Revises: 172ad76d7d20
|
||||
Create Date: 2026-04-12
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b3c7e9f2a1d8"
|
||||
down_revision: Union[str, None] = "172ad76d7d20"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Standard policy — tenant sees only own rows.
|
||||
_STANDARD_TABLES = [
|
||||
"users",
|
||||
"account_invites",
|
||||
"account_limit_overrides",
|
||||
"account_feature_overrides",
|
||||
"subscriptions",
|
||||
"ai_chat_sessions",
|
||||
"ai_conversations",
|
||||
"ai_session_steps",
|
||||
"ai_session_embeddings",
|
||||
"ai_suggestions",
|
||||
"ai_usage",
|
||||
"assistant_chats",
|
||||
"attachments",
|
||||
"copilot_conversations",
|
||||
"feedback",
|
||||
"file_uploads",
|
||||
"fork_points",
|
||||
"kb_imports",
|
||||
"notifications",
|
||||
"notification_configs",
|
||||
"notification_logs",
|
||||
"psa_activity_logs",
|
||||
"psa_member_mappings",
|
||||
"script_builder_sessions",
|
||||
"session_ratings",
|
||||
"tree_embeddings",
|
||||
"user_folders",
|
||||
"user_pinned_trees",
|
||||
]
|
||||
|
||||
_POLICY_EXPR = (
|
||||
"account_id = COALESCE("
|
||||
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000'"
|
||||
")::uuid"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table in _STANDARD_TABLES:
|
||||
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON {table}
|
||||
USING ({_POLICY_EXPR})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in _STANDARD_TABLES:
|
||||
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||
@@ -24,10 +24,14 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
token: Annotated[str, Depends(oauth2_scheme)]
|
||||
) -> User:
|
||||
"""Get current authenticated user from JWT token."""
|
||||
"""Get current authenticated user from JWT token.
|
||||
|
||||
Must use get_admin_db (BYPASSRLS): this dep runs before require_tenant_context
|
||||
sets app.current_account_id, so the users table RLS would block the lookup.
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
@@ -77,10 +81,14 @@ async def get_refresh_token_payload(
|
||||
async def get_current_active_user(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> User:
|
||||
"""Ensure user is active (not disabled). Auto-downgrades expired trials.
|
||||
Enforces must_change_password — blocks all routes except allowlist."""
|
||||
Enforces must_change_password — blocks all routes except allowlist.
|
||||
|
||||
Uses get_admin_db: runs before require_tenant_context sets the ContextVar,
|
||||
so tenant-scoped tables (subscriptions) would return 0 rows via app role.
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
|
||||
from pydantic import BaseModel
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
|
||||
from app.core.audit import log_audit
|
||||
from app.models.refresh_token import RefreshToken
|
||||
@@ -148,7 +149,7 @@ async def update_member_role(
|
||||
@router.post("/me/transfer-ownership", response_model=AccountResponse)
|
||||
async def transfer_ownership(
|
||||
data: TransferOwnershipRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Transfer account ownership to another member (owner only)."""
|
||||
@@ -377,7 +378,7 @@ async def list_invites(
|
||||
|
||||
@router.post("/me/leave")
|
||||
async def leave_account(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Leave the current account (non-owners only). Creates a personal account."""
|
||||
@@ -423,7 +424,7 @@ class DeleteAccountRequest(BaseModel):
|
||||
@router.delete("/me")
|
||||
async def delete_account(
|
||||
data: DeleteAccountRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Delete the current account and soft-delete the user (owner only, no other members)."""
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update as sa_update
|
||||
from app.core.config import settings
|
||||
from app.core.settings_manager import SettingsManager
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.rate_limit import limiter
|
||||
from app.core.security import (
|
||||
verify_password,
|
||||
@@ -67,7 +67,7 @@ def _generate_display_code() -> str:
|
||||
async def register(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Register a new user.
|
||||
|
||||
@@ -232,7 +232,7 @@ async def register(
|
||||
async def login(
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Login and get access token."""
|
||||
# Find user by email
|
||||
@@ -270,7 +270,7 @@ async def login(
|
||||
async def login_json(
|
||||
request: Request,
|
||||
credentials: UserLogin,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Login with JSON body (alternative to form data)."""
|
||||
result = await db.execute(select(User).where(User.email == credentials.email))
|
||||
@@ -304,7 +304,7 @@ async def login_json(
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Refresh access token using refresh token (rotation: old token is revoked)."""
|
||||
user_id = payload.get("sub")
|
||||
@@ -368,7 +368,7 @@ async def get_me(
|
||||
async def update_me(
|
||||
data: UserUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Update current user's profile (name, email)."""
|
||||
update_fields = data.model_fields_set - {"current_password"}
|
||||
@@ -415,7 +415,7 @@ async def update_me(
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Logout user by revoking the refresh token."""
|
||||
jti = payload.get("jti")
|
||||
@@ -438,7 +438,7 @@ async def change_password(
|
||||
request: Request,
|
||||
data: ChangePasswordRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Change the current user's password."""
|
||||
if not verify_password(data.current_password, current_user.password_hash):
|
||||
@@ -478,7 +478,7 @@ async def change_password(
|
||||
async def forgot_password(
|
||||
request: Request,
|
||||
data: ForgotPasswordRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Request a password reset email. Always returns success (anti-enumeration)."""
|
||||
result = await db.execute(select(User).where(User.email == data.email))
|
||||
@@ -513,7 +513,7 @@ async def forgot_password(
|
||||
@router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse)
|
||||
async def verify_reset_token(
|
||||
data: VerifyResetTokenRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Verify a password reset token is valid."""
|
||||
payload = decode_token(data.token)
|
||||
@@ -544,7 +544,7 @@ async def verify_reset_token(
|
||||
async def reset_password(
|
||||
request: Request,
|
||||
data: ResetPasswordRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Reset password using a valid reset token."""
|
||||
payload = decode_token(data.token)
|
||||
@@ -611,7 +611,7 @@ async def reset_password(
|
||||
|
||||
@router.get("/email/verification-status")
|
||||
async def get_verification_status(
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Check if email verification is enabled on the platform."""
|
||||
enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
||||
@@ -623,7 +623,7 @@ async def get_verification_status(
|
||||
async def send_verification_email(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Send an email verification link to the current user."""
|
||||
verification_enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
||||
@@ -662,7 +662,7 @@ async def send_verification_email(
|
||||
@router.post("/email/verify")
|
||||
async def verify_email(
|
||||
data: dict,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Verify an email using a token. Public endpoint."""
|
||||
token = data.get("token")
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
from app.models.psa_connection import PsaConnection
|
||||
from app.models.session import Session
|
||||
@@ -98,7 +99,7 @@ async def get_onboarding_status(
|
||||
|
||||
@router.post("/onboarding-status/dismiss", response_model=OnboardingStatus)
|
||||
async def dismiss_onboarding(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> OnboardingStatus:
|
||||
"""Dismiss the onboarding checklist for the current user."""
|
||||
|
||||
@@ -85,6 +85,7 @@ async def create_session(
|
||||
session = await script_builder_service.create_session(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
team_id=current_user.team_id,
|
||||
language=data.language,
|
||||
)
|
||||
|
||||
@@ -21,7 +21,7 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None:
|
||||
"""Create batch sessions for a scheduled maintenance run."""
|
||||
# Import all models first to ensure SQLAlchemy mapper relationships resolve
|
||||
import app.models # noqa: F401
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.maintenance_schedule import MaintenanceSchedule
|
||||
from app.models.session import Session
|
||||
from app.models.target_list import TargetList
|
||||
@@ -118,7 +118,7 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None:
|
||||
async def _cleanup_expired_ai_conversations() -> None:
|
||||
"""Delete expired AI wizard conversations."""
|
||||
import app.models # noqa: F401
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.ai_conversation import AIConversation
|
||||
|
||||
async with async_session_maker() as db:
|
||||
|
||||
@@ -14,6 +14,8 @@ import logging
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.admin_database import _admin_session_factory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVICE_ACCOUNT_EMAIL = "noreply@resolutionflow.com"
|
||||
@@ -52,40 +54,45 @@ async def _ensure_system_account(db: AsyncSession) -> uuid.UUID:
|
||||
async def ensure_service_account(db: AsyncSession) -> uuid.UUID:
|
||||
"""Ensure the ResolutionFlow service account exists and return its ID.
|
||||
|
||||
Idempotent — safe to call on every startup. Creates the account if it
|
||||
does not exist. The account has no usable password and is_service_account=True
|
||||
so it can never log in via normal auth flows.
|
||||
Idempotent — safe to call on every startup. This lookup must bypass RLS
|
||||
because startup runs before any request-scoped tenant context exists and
|
||||
the users table is tenant-isolated in Phase 4. The service account is
|
||||
normally created by Alembic migration 1490781700bc; the runtime create path
|
||||
remains as a self-healing fallback for environments that predate that seed.
|
||||
"""
|
||||
_ = db # Retained for call-site compatibility in app lifespan startup.
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == SERVICE_ACCOUNT_EMAIL)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
async with _admin_session_factory() as admin_db:
|
||||
result = await admin_db.execute(
|
||||
select(User).where(User.email == SERVICE_ACCOUNT_EMAIL)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is not None:
|
||||
if not user.is_service_account:
|
||||
user.is_service_account = True
|
||||
await db.commit()
|
||||
return user.id
|
||||
if user is not None:
|
||||
if not user.is_service_account:
|
||||
user.is_service_account = True
|
||||
await admin_db.commit()
|
||||
return user.id
|
||||
|
||||
account_id = await _ensure_system_account(db)
|
||||
account_id = await _ensure_system_account(admin_db)
|
||||
|
||||
new_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=SERVICE_ACCOUNT_EMAIL,
|
||||
name=SERVICE_ACCOUNT_NAME,
|
||||
password_hash="!service-account-no-login", # bcrypt can't produce this prefix
|
||||
role="engineer",
|
||||
is_super_admin=False,
|
||||
is_team_admin=False,
|
||||
is_active=True,
|
||||
is_service_account=True,
|
||||
must_change_password=False,
|
||||
account_id=account_id,
|
||||
account_role="engineer",
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
logger.info(f"[service_account] Created service account (id={new_user.id})")
|
||||
return new_user.id
|
||||
new_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=SERVICE_ACCOUNT_EMAIL,
|
||||
name=SERVICE_ACCOUNT_NAME,
|
||||
password_hash="!service-account-no-login", # bcrypt can't produce this prefix
|
||||
role="engineer",
|
||||
is_super_admin=False,
|
||||
is_team_admin=False,
|
||||
is_active=True,
|
||||
is_service_account=True,
|
||||
must_change_password=False,
|
||||
account_id=account_id,
|
||||
account_role="engineer",
|
||||
)
|
||||
admin_db.add(new_user)
|
||||
await admin_db.commit()
|
||||
logger.info(f"[service_account] Created service account (id={new_user.id})")
|
||||
return new_user.id
|
||||
|
||||
@@ -25,7 +25,8 @@ if settings.SENTRY_DSN:
|
||||
),
|
||||
)
|
||||
|
||||
from app.core.database import init_db, async_session_maker
|
||||
from app.core.database import init_db
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.core.logging_config import setup_logging
|
||||
from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware
|
||||
from app.core.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
@@ -10,7 +10,7 @@ import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.ai_session import AISession
|
||||
from app.services.knowledge_flywheel import analyze_session
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime, timezone
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.psa_post_log import PsaPostLog
|
||||
from app.services.psa_documentation_service import retry_failed_push
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime, timezone, timedelta
|
||||
|
||||
from sqlalchemy import select, delete, func
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.account import Account
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
|
||||
|
||||
@@ -144,6 +144,7 @@ def _extract_script_from_response(content: str, language: str) -> tuple[str | No
|
||||
async def create_session(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
team_id: UUID | None,
|
||||
language: str,
|
||||
initial_prompt: str | None = None,
|
||||
@@ -151,6 +152,7 @@ async def create_session(
|
||||
"""Create a new Script Builder session."""
|
||||
session = ScriptBuilderSession(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
team_id=team_id,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@@ -80,7 +80,10 @@ def _display_code() -> str:
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
engine = create_async_engine(settings.DATABASE_URL, echo=False)
|
||||
# Must use ADMIN_DATABASE_URL (BYPASSRLS) — Phase 4 enabled RLS on users.
|
||||
# The app-role connection has no tenant context at seed time and would see 0 rows.
|
||||
admin_url = getattr(settings, "ADMIN_DATABASE_URL", None) or settings.DATABASE_URL
|
||||
engine = create_async_engine(admin_url, echo=False)
|
||||
password_hash = get_password_hash(SHARED_PASSWORD)
|
||||
now = datetime.now(timezone.utc)
|
||||
team_account_id: uuid.UUID | None = None
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestAdminGlobalCategories:
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Category"
|
||||
assert data["slug"] == "test-category"
|
||||
assert data["account_id"] is None
|
||||
assert data["account_id"] == "00000000-0000-0000-0000-000000000001" # PLATFORM_ACCOUNT_ID
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_global_category(
|
||||
|
||||
@@ -200,6 +200,7 @@ class TestAccountPermissions:
|
||||
})
|
||||
outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"}
|
||||
|
||||
# Outsider should NOT see the private tree
|
||||
# Outsider should NOT see the private tree.
|
||||
# With RLS, the tree is invisible to other tenants — 404 not 403.
|
||||
response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers)
|
||||
assert response.status_code == 403
|
||||
assert response.status_code == 404
|
||||
|
||||
@@ -24,6 +24,12 @@ from pathlib import Path
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
# All tests in this module use module-scoped async fixtures (admin_conn,
|
||||
# seed_rls_test_data) which run on the module event loop. Without this marker,
|
||||
# pytest-asyncio 0.23+ defaults tests to function-scoped loops, causing
|
||||
# "Future attached to a different loop" errors on the asyncpg connections.
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="module")
|
||||
|
||||
_DB_HOST = os.getenv("TEST_DB_HOST", "localhost")
|
||||
_DB_PORT = int(os.getenv("TEST_DB_PORT", "5432"))
|
||||
_DB_NAME = os.getenv("TEST_DB_NAME", "patherly_test") # matches conftest.py
|
||||
@@ -191,7 +197,6 @@ async def conn_no_context():
|
||||
# trees
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -199,7 +204,6 @@ async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
|
||||
assert len(rows) == 0, "Account A should not see Account B trees"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_account_a_can_see_own_rows(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}'"
|
||||
@@ -207,7 +211,6 @@ async def test_trees_account_a_can_see_own_rows(conn_a):
|
||||
assert len(rows) >= 1, "Account A should see its own trees"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_no_context_sees_no_private_trees(conn_no_context):
|
||||
rows = await conn_no_context.fetch(
|
||||
"SELECT id FROM trees WHERE is_default = FALSE AND is_public = FALSE"
|
||||
@@ -219,7 +222,6 @@ async def test_trees_no_context_sees_no_private_trees(conn_no_context):
|
||||
# tree_tags — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_tags WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -227,7 +229,6 @@ async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
|
||||
rows_a = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'"
|
||||
@@ -243,7 +244,6 @@ async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
|
||||
# tree_categories — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -255,7 +255,6 @@ async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
|
||||
# step_categories — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_categories_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -267,7 +266,6 @@ async def test_step_categories_account_a_cannot_see_account_b(conn_a):
|
||||
# psa_connections — tenant-only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM psa_connections WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -279,7 +277,6 @@ async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
|
||||
# flow_proposals — tenant-only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_proposals_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -513,7 +510,6 @@ async def session_row_ids(admin_conn):
|
||||
# sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_b']}'"
|
||||
@@ -521,7 +517,6 @@ async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_
|
||||
assert len(rows) == 0, "Account A should not see Account B sessions"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_a']}'"
|
||||
@@ -529,7 +524,6 @@ async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids):
|
||||
assert len(rows) == 1, "Account A should see its own sessions"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids):
|
||||
rows = await conn_no_context.fetch(
|
||||
f"SELECT id FROM sessions WHERE id IN "
|
||||
@@ -542,7 +536,6 @@ async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids
|
||||
# ai_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_b']}'"
|
||||
@@ -550,7 +543,6 @@ async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_id
|
||||
assert len(rows) == 0, "Account A should not see Account B ai_sessions"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_a']}'"
|
||||
@@ -562,7 +554,6 @@ async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids):
|
||||
# session_branches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_branches WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -574,7 +565,6 @@ async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_r
|
||||
# session_supporting_data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_supporting_data WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -586,7 +576,6 @@ async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, se
|
||||
# session_resolution_outputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_resolution_outputs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -598,7 +587,6 @@ async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a,
|
||||
# session_handoffs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_handoffs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -610,7 +598,6 @@ async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_r
|
||||
# script_templates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM script_templates WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -622,7 +609,6 @@ async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_r
|
||||
# script_generations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_script_generations_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM script_generations WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -634,7 +620,6 @@ async def test_script_generations_account_a_cannot_see_account_b(conn_a, session
|
||||
# maintenance_schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM maintenance_schedules WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -646,7 +631,6 @@ async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, sess
|
||||
# psa_post_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM psa_post_log WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -658,7 +642,6 @@ async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_i
|
||||
# step_library — visibility-aware policy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_conn, conn_a):
|
||||
"""Private/non-public steps owned by Account B must not be visible to Account A."""
|
||||
private_step_id = str(uuid.uuid4())
|
||||
@@ -683,7 +666,6 @@ async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_c
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn, conn_a):
|
||||
"""Public steps owned by Account B MUST be visible to Account A (cross-tenant visibility)."""
|
||||
public_step_id = str(uuid.uuid4())
|
||||
@@ -738,7 +720,6 @@ async def _get_tree_b_id(admin_conn) -> str:
|
||||
# step_ratings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see step ratings belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
@@ -779,7 +760,6 @@ async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
# step_usage_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see step usage logs belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
@@ -832,7 +812,6 @@ async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a)
|
||||
# target_lists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see target lists belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
@@ -859,7 +838,6 @@ async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
# session_shares
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see session shares belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
@@ -903,7 +881,6 @@ async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a)
|
||||
# audit_logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see audit logs belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
@@ -930,7 +907,6 @@ async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
# tree_shares
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see tree shares belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
@@ -954,3 +930,129 @@ async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
assert len(rows) == 0, "Account A should not see Account B tree_shares"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Phase 4 RLS isolation tests
|
||||
# Tables: users, script_builder_sessions, ai_session_steps, notifications
|
||||
#
|
||||
# Note: platform_steps and template_trees have no account_id column and no RLS —
|
||||
# they are globally readable by all authenticated users.
|
||||
# ===========================================================================
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# users
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_users_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see users belonging to Account B."""
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B users"
|
||||
|
||||
|
||||
async def test_users_account_a_can_see_own(admin_conn, conn_a):
|
||||
"""Account A must be able to see its own users."""
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_A_ID}'"
|
||||
)
|
||||
assert len(rows) > 0, "Account A should see its own users"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# script_builder_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_script_builder_sessions_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see script builder sessions belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO script_builder_sessions (
|
||||
id, user_id, account_id, language, created_at, updated_at
|
||||
) VALUES (
|
||||
'{session_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'powershell', NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM script_builder_sessions WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B script_builder_sessions"
|
||||
finally:
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM script_builder_sessions WHERE id = '{session_id}'"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ai_session_steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_ai_session_steps_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see ai_session_steps belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
# Need an ai_sessions row as FK
|
||||
ai_session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO ai_sessions (
|
||||
id, user_id, account_id, flow_type, status, confidence_tier,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
'{ai_session_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'troubleshooting', 'active', 'guided', NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO ai_session_steps (
|
||||
id, session_id, account_id, step_type, content,
|
||||
created_at
|
||||
) VALUES (
|
||||
'{step_id}', '{ai_session_id}', '{ACCOUNT_B_ID}',
|
||||
'question', 'Phase4 RLS test step', NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM ai_session_steps WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B ai_session_steps"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM ai_session_steps WHERE id = '{step_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM ai_sessions WHERE id = '{ai_session_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# notifications
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_notifications_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see notifications belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
notif_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO notifications (
|
||||
id, user_id, account_id, type, title, message,
|
||||
is_read, created_at
|
||||
) VALUES (
|
||||
'{notif_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'info', 'Phase4 RLS Test', 'RLS isolation test notification',
|
||||
FALSE, NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM notifications WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B notifications"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM notifications WHERE id = '{notif_id}'")
|
||||
|
||||
|
||||
89
backend/tests/test_service_account.py
Normal file
89
backend/tests/test_service_account.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core import service_account as service_account_module
|
||||
from app.core.service_account import (
|
||||
SERVICE_ACCOUNT_EMAIL,
|
||||
SYSTEM_ACCOUNT_DISPLAY_CODE,
|
||||
ensure_service_account,
|
||||
)
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class _SessionFactoryOverride:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __call__(self):
|
||||
return self
|
||||
|
||||
async def __aenter__(self):
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_service_account_creates_and_reuses_seeded_user(test_db, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
service_account_module,
|
||||
"_admin_session_factory",
|
||||
_SessionFactoryOverride(test_db),
|
||||
)
|
||||
|
||||
service_account_id = await ensure_service_account(test_db)
|
||||
|
||||
created_user = (
|
||||
await test_db.execute(select(User).where(User.id == service_account_id))
|
||||
).scalar_one()
|
||||
assert created_user.email == SERVICE_ACCOUNT_EMAIL
|
||||
assert created_user.is_service_account is True
|
||||
|
||||
system_account = (
|
||||
await test_db.execute(
|
||||
select(Account).where(Account.display_code == SYSTEM_ACCOUNT_DISPLAY_CODE)
|
||||
)
|
||||
).scalar_one()
|
||||
assert created_user.account_id == system_account.id
|
||||
|
||||
second_id = await ensure_service_account(test_db)
|
||||
assert second_id == service_account_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_service_account_marks_existing_user_as_service_account(test_db, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
service_account_module,
|
||||
"_admin_session_factory",
|
||||
_SessionFactoryOverride(test_db),
|
||||
)
|
||||
|
||||
system_account = (
|
||||
await test_db.execute(
|
||||
select(Account).where(Account.display_code == SYSTEM_ACCOUNT_DISPLAY_CODE)
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
existing_user = User(
|
||||
email=SERVICE_ACCOUNT_EMAIL,
|
||||
name="ResolutionFlow",
|
||||
password_hash="!service-account-no-login",
|
||||
role="engineer",
|
||||
is_super_admin=False,
|
||||
is_team_admin=False,
|
||||
is_active=True,
|
||||
is_service_account=False,
|
||||
must_change_password=False,
|
||||
account_id=system_account.id,
|
||||
account_role="engineer",
|
||||
)
|
||||
test_db.add(existing_user)
|
||||
await test_db.commit()
|
||||
|
||||
resolved_id = await ensure_service_account(test_db)
|
||||
await test_db.refresh(existing_user)
|
||||
|
||||
assert resolved_id == existing_user.id
|
||||
assert existing_user.is_service_account is True
|
||||
Reference in New Issue
Block a user