diff --git a/CLAUDE.md b/CLAUDE.md index 4ce6373a..7ed9b66a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -375,6 +375,12 @@ gh run view --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi **106. Guard async "select item → load data → apply state" flows with a ref:** When a component lets the user switch between items (chat sessions, flows, scripts) and loads data asynchronously on each switch, the load for item A can complete *after* the user has already switched to item B — overwriting B's state with A's stale data. Fix pattern: keep a `currentSelectionRef = useRef(initialId)` and update it synchronously whenever the selection changes (in every creation/switch path). After every `await`, bail out if `currentSelectionRef.current !== thisItemId`. See `AssistantChatPage.tsx` `selectChat` for the reference implementation (`currentChatRef`). +**107. Startup routines must use `_admin_session_factory()` after Phase 4 RLS:** Any code that runs at startup (lifespan, `ensure_service_account`, seed scripts) and touches tenant-isolated tables (`users`, etc.) must use `_admin_session_factory()` — not `get_db()`. Phase 4 enabled RLS on `users`; a tenant-scoped session has no `app.current_account_id` set at startup, so all queries return 0 rows or fail. `get_service_account_id` in `deps.py` is safe — it reads from `app.state` cached at startup, never hits the DB per-request. + +**108. Tables with no `account_id` column (never add to RLS migrations):** `script_categories`, `platform_steps`, `template_trees`, `plan_feature_defaults`, `accounts` — global/platform tables documented with "No account_id. No RLS." in their model files. When writing RLS migrations, scan at the class level (check for `account_id: Mapped` within the class block), not the file level — multiple classes in one `.py` file can have different columns (e.g. `ScriptCategory` vs `ScriptTemplate` in `script_template.py`). + +**109. `tree_shares.account_id` must equal `tree.account_id`, not the actor's account:** When creating a `TreeShare`, always use `account_id=tree.account_id` (tree owner's tenant). A super admin in tenant A sharing tenant B's tree must produce a share row in tenant B's RLS context — using `current_user.account_id` instead makes the share invisible to the tree owner after RLS is enforced. + ## RBAC & Permissions - **Role hierarchy:** super_admin > team_admin > engineer > viewer @@ -522,7 +528,7 @@ When a feature, fix, or significant piece of work is finished and merged/committ # GitNexus — Code Intelligence -This project is indexed by GitNexus as **resolutionflow** (14787 symbols, 31366 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. +This project is indexed by GitNexus as **resolutionflow** (16703 symbols, 35922 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. > If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first. diff --git a/backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py b/backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py new file mode 100644 index 00000000..919c6904 --- /dev/null +++ b/backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py @@ -0,0 +1,85 @@ +"""Enable RLS on Phase 4 tables — all remaining tenant-scoped tables. + +All tables in this migration already have account_id NOT NULL (enforced by +earlier migrations). This migration adds ENABLE ROW LEVEL SECURITY, +FORCE ROW LEVEL SECURITY, and the appropriate tenant isolation policy to each. + +Policy variants used: + - Standard: account_id = current_setting(app.current_account_id)::uuid + - Platform: standard OR account_id = PLATFORM_ACCOUNT_ID + (for global content tables readable by all tenants) + +Skipped intentionally: + - accounts — IS the root table; no account_id column + - plan_feature_defaults — platform config; no account_id column + - script_categories — global lookup table; no account_id column + - platform_steps — global content; no account_id column (readable by all) + - template_trees — global content; no account_id column (readable by all) + +Revision ID: b3c7e9f2a1d8 +Revises: 172ad76d7d20 +Create Date: 2026-04-12 +""" + +from typing import Union +from alembic import op + +revision: str = "b3c7e9f2a1d8" +down_revision: Union[str, None] = "172ad76d7d20" +branch_labels = None +depends_on = None + +# Standard policy — tenant sees only own rows. +_STANDARD_TABLES = [ + "users", + "account_invites", + "account_limit_overrides", + "account_feature_overrides", + "subscriptions", + "ai_chat_sessions", + "ai_conversations", + "ai_session_steps", + "ai_session_embeddings", + "ai_suggestions", + "ai_usage", + "assistant_chats", + "attachments", + "copilot_conversations", + "feedback", + "file_uploads", + "fork_points", + "kb_imports", + "notifications", + "notification_configs", + "notification_logs", + "psa_activity_logs", + "psa_member_mappings", + "script_builder_sessions", + "session_ratings", + "tree_embeddings", + "user_folders", + "user_pinned_trees", +] + +_POLICY_EXPR = ( + "account_id = COALESCE(" + "NULLIF(current_setting('app.current_account_id', TRUE), ''), " + "'00000000-0000-0000-0000-000000000000'" + ")::uuid" +) + + +def upgrade() -> None: + for table in _STANDARD_TABLES: + op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON {table} + USING ({_POLICY_EXPR}) + """) + + +def downgrade() -> None: + for table in _STANDARD_TABLES: + op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}") + op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY") diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index bae3f935..79770ed9 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -24,10 +24,14 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") async def get_current_user( - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], token: Annotated[str, Depends(oauth2_scheme)] ) -> User: - """Get current authenticated user from JWT token.""" + """Get current authenticated user from JWT token. + + Must use get_admin_db (BYPASSRLS): this dep runs before require_tenant_context + sets app.current_account_id, so the users table RLS would block the lookup. + """ credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -77,10 +81,14 @@ async def get_refresh_token_payload( async def get_current_active_user( request: Request, current_user: Annotated[User, Depends(get_current_user)], - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], ) -> User: """Ensure user is active (not disabled). Auto-downgrades expired trials. - Enforces must_change_password — blocks all routes except allowlist.""" + Enforces must_change_password — blocks all routes except allowlist. + + Uses get_admin_db: runs before require_tenant_context sets the ContextVar, + so tenant-scoped tables (subscriptions) would return 0 rows via app role. + """ if not current_user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/backend/app/api/endpoints/accounts.py b/backend/app/api/endpoints/accounts.py index fd49ec48..66148952 100644 --- a/backend/app/api/endpoints/accounts.py +++ b/backend/app/api/endpoints/accounts.py @@ -9,6 +9,7 @@ from sqlalchemy import select from pydantic import BaseModel from app.core.database import get_db +from app.core.admin_database import get_admin_db from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage from app.core.audit import log_audit from app.models.refresh_token import RefreshToken @@ -148,7 +149,7 @@ async def update_member_role( @router.post("/me/transfer-ownership", response_model=AccountResponse) async def transfer_ownership( data: TransferOwnershipRequest, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_account_owner)] ): """Transfer account ownership to another member (owner only).""" @@ -377,7 +378,7 @@ async def list_invites( @router.post("/me/leave") async def leave_account( - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(get_current_active_user)] ): """Leave the current account (non-owners only). Creates a personal account.""" @@ -423,7 +424,7 @@ class DeleteAccountRequest(BaseModel): @router.delete("/me") async def delete_account( data: DeleteAccountRequest, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_account_owner)] ): """Delete the current account and soft-delete the user (owner only, no other members).""" diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index ed913441..2634a6ef 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update as sa_update from app.core.config import settings from app.core.settings_manager import SettingsManager -from app.core.database import get_db +from app.core.admin_database import get_admin_db from app.core.rate_limit import limiter from app.core.security import ( verify_password, @@ -67,7 +67,7 @@ def _generate_display_code() -> str: async def register( request: Request, user_data: UserCreate, - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Register a new user. @@ -232,7 +232,7 @@ async def register( async def login( request: Request, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Login and get access token.""" # Find user by email @@ -270,7 +270,7 @@ async def login( async def login_json( request: Request, credentials: UserLogin, - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Login with JSON body (alternative to form data).""" result = await db.execute(select(User).where(User.email == credentials.email)) @@ -304,7 +304,7 @@ async def login_json( async def refresh_token( request: Request, payload: Annotated[dict, Depends(get_refresh_token_payload)], - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Refresh access token using refresh token (rotation: old token is revoked).""" user_id = payload.get("sub") @@ -368,7 +368,7 @@ async def get_me( async def update_me( data: UserUpdate, current_user: Annotated[User, Depends(get_current_active_user)], - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Update current user's profile (name, email).""" update_fields = data.model_fields_set - {"current_password"} @@ -415,7 +415,7 @@ async def update_me( @router.post("/logout") async def logout( payload: Annotated[dict, Depends(get_refresh_token_payload)], - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Logout user by revoking the refresh token.""" jti = payload.get("jti") @@ -438,7 +438,7 @@ async def change_password( request: Request, data: ChangePasswordRequest, current_user: Annotated[User, Depends(get_current_active_user)], - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Change the current user's password.""" if not verify_password(data.current_password, current_user.password_hash): @@ -478,7 +478,7 @@ async def change_password( async def forgot_password( request: Request, data: ForgotPasswordRequest, - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Request a password reset email. Always returns success (anti-enumeration).""" result = await db.execute(select(User).where(User.email == data.email)) @@ -513,7 +513,7 @@ async def forgot_password( @router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse) async def verify_reset_token( data: VerifyResetTokenRequest, - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Verify a password reset token is valid.""" payload = decode_token(data.token) @@ -544,7 +544,7 @@ async def verify_reset_token( async def reset_password( request: Request, data: ResetPasswordRequest, - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Reset password using a valid reset token.""" payload = decode_token(data.token) @@ -611,7 +611,7 @@ async def reset_password( @router.get("/email/verification-status") async def get_verification_status( - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Check if email verification is enabled on the platform.""" enabled = await SettingsManager.get("email_verification_enabled", db, default=True) @@ -623,7 +623,7 @@ async def get_verification_status( async def send_verification_email( request: Request, current_user: Annotated[User, Depends(get_current_active_user)], - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Send an email verification link to the current user.""" verification_enabled = await SettingsManager.get("email_verification_enabled", db, default=True) @@ -662,7 +662,7 @@ async def send_verification_email( @router.post("/email/verify") async def verify_email( data: dict, - db: Annotated[AsyncSession, Depends(get_db)] + db: Annotated[AsyncSession, Depends(get_admin_db)] ): """Verify an email using a token. Public endpoint.""" token = data.get("token") diff --git a/backend/app/api/endpoints/onboarding.py b/backend/app/api/endpoints/onboarding.py index fdb07cd8..534f58a6 100644 --- a/backend/app/api/endpoints/onboarding.py +++ b/backend/app/api/endpoints/onboarding.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_active_user from app.core.database import get_db +from app.core.admin_database import get_admin_db from app.models.assistant_chat import AssistantChat from app.models.psa_connection import PsaConnection from app.models.session import Session @@ -98,7 +99,7 @@ async def get_onboarding_status( @router.post("/onboarding-status/dismiss", response_model=OnboardingStatus) async def dismiss_onboarding( - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> OnboardingStatus: """Dismiss the onboarding checklist for the current user.""" diff --git a/backend/app/api/endpoints/script_builder.py b/backend/app/api/endpoints/script_builder.py index f56c595c..1cb849f8 100644 --- a/backend/app/api/endpoints/script_builder.py +++ b/backend/app/api/endpoints/script_builder.py @@ -85,6 +85,7 @@ async def create_session( session = await script_builder_service.create_session( db=db, user_id=current_user.id, + account_id=current_user.account_id, team_id=current_user.team_id, language=data.language, ) diff --git a/backend/app/core/scheduler.py b/backend/app/core/scheduler.py index 3d5f5ff6..d25157c6 100644 --- a/backend/app/core/scheduler.py +++ b/backend/app/core/scheduler.py @@ -21,7 +21,7 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None: """Create batch sessions for a scheduled maintenance run.""" # Import all models first to ensure SQLAlchemy mapper relationships resolve import app.models # noqa: F401 - from app.core.database import async_session_maker + from app.core.admin_database import _admin_session_factory as async_session_maker from app.models.maintenance_schedule import MaintenanceSchedule from app.models.session import Session from app.models.target_list import TargetList @@ -118,7 +118,7 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None: async def _cleanup_expired_ai_conversations() -> None: """Delete expired AI wizard conversations.""" import app.models # noqa: F401 - from app.core.database import async_session_maker + from app.core.admin_database import _admin_session_factory as async_session_maker from app.models.ai_conversation import AIConversation async with async_session_maker() as db: diff --git a/backend/app/core/service_account.py b/backend/app/core/service_account.py index 9d00a1d9..d91c390f 100644 --- a/backend/app/core/service_account.py +++ b/backend/app/core/service_account.py @@ -14,6 +14,8 @@ import logging from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.core.admin_database import _admin_session_factory + logger = logging.getLogger(__name__) SERVICE_ACCOUNT_EMAIL = "noreply@resolutionflow.com" @@ -52,40 +54,45 @@ async def _ensure_system_account(db: AsyncSession) -> uuid.UUID: async def ensure_service_account(db: AsyncSession) -> uuid.UUID: """Ensure the ResolutionFlow service account exists and return its ID. - Idempotent — safe to call on every startup. Creates the account if it - does not exist. The account has no usable password and is_service_account=True - so it can never log in via normal auth flows. + Idempotent — safe to call on every startup. This lookup must bypass RLS + because startup runs before any request-scoped tenant context exists and + the users table is tenant-isolated in Phase 4. The service account is + normally created by Alembic migration 1490781700bc; the runtime create path + remains as a self-healing fallback for environments that predate that seed. """ + _ = db # Retained for call-site compatibility in app lifespan startup. + from app.models.user import User - result = await db.execute( - select(User).where(User.email == SERVICE_ACCOUNT_EMAIL) - ) - user = result.scalar_one_or_none() + async with _admin_session_factory() as admin_db: + result = await admin_db.execute( + select(User).where(User.email == SERVICE_ACCOUNT_EMAIL) + ) + user = result.scalar_one_or_none() - if user is not None: - if not user.is_service_account: - user.is_service_account = True - await db.commit() - return user.id + if user is not None: + if not user.is_service_account: + user.is_service_account = True + await admin_db.commit() + return user.id - account_id = await _ensure_system_account(db) + account_id = await _ensure_system_account(admin_db) - new_user = User( - id=uuid.uuid4(), - email=SERVICE_ACCOUNT_EMAIL, - name=SERVICE_ACCOUNT_NAME, - password_hash="!service-account-no-login", # bcrypt can't produce this prefix - role="engineer", - is_super_admin=False, - is_team_admin=False, - is_active=True, - is_service_account=True, - must_change_password=False, - account_id=account_id, - account_role="engineer", - ) - db.add(new_user) - await db.commit() - logger.info(f"[service_account] Created service account (id={new_user.id})") - return new_user.id + new_user = User( + id=uuid.uuid4(), + email=SERVICE_ACCOUNT_EMAIL, + name=SERVICE_ACCOUNT_NAME, + password_hash="!service-account-no-login", # bcrypt can't produce this prefix + role="engineer", + is_super_admin=False, + is_team_admin=False, + is_active=True, + is_service_account=True, + must_change_password=False, + account_id=account_id, + account_role="engineer", + ) + admin_db.add(new_user) + await admin_db.commit() + logger.info(f"[service_account] Created service account (id={new_user.id})") + return new_user.id diff --git a/backend/app/main.py b/backend/app/main.py index a9ade25e..795c0db9 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -25,7 +25,8 @@ if settings.SENTRY_DSN: ), ) -from app.core.database import init_db, async_session_maker +from app.core.database import init_db +from app.core.admin_database import _admin_session_factory as async_session_maker from app.core.logging_config import setup_logging from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware from app.core.security_headers import SecurityHeadersMiddleware diff --git a/backend/app/services/knowledge_flywheel_scheduler.py b/backend/app/services/knowledge_flywheel_scheduler.py index c1366b7d..65447f4b 100644 --- a/backend/app/services/knowledge_flywheel_scheduler.py +++ b/backend/app/services/knowledge_flywheel_scheduler.py @@ -10,7 +10,7 @@ import logging from sqlalchemy import select -from app.core.database import async_session_maker +from app.core.admin_database import _admin_session_factory as async_session_maker from app.models.ai_session import AISession from app.services.knowledge_flywheel import analyze_session diff --git a/backend/app/services/psa_retry_scheduler.py b/backend/app/services/psa_retry_scheduler.py index f3403eb2..03a32acd 100644 --- a/backend/app/services/psa_retry_scheduler.py +++ b/backend/app/services/psa_retry_scheduler.py @@ -9,7 +9,7 @@ from datetime import datetime, timezone from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.core.database import async_session_maker +from app.core.admin_database import _admin_session_factory as async_session_maker from app.models.psa_post_log import PsaPostLog from app.services.psa_documentation_service import retry_failed_push diff --git a/backend/app/services/retention_cleanup.py b/backend/app/services/retention_cleanup.py index c164240f..c8919451 100644 --- a/backend/app/services/retention_cleanup.py +++ b/backend/app/services/retention_cleanup.py @@ -9,7 +9,7 @@ from datetime import datetime, timezone, timedelta from sqlalchemy import select, delete, func -from app.core.database import async_session_maker +from app.core.admin_database import _admin_session_factory as async_session_maker from app.models.account import Account from app.models.assistant_chat import AssistantChat diff --git a/backend/app/services/script_builder_service.py b/backend/app/services/script_builder_service.py index a1a7647e..991d9a28 100644 --- a/backend/app/services/script_builder_service.py +++ b/backend/app/services/script_builder_service.py @@ -144,6 +144,7 @@ def _extract_script_from_response(content: str, language: str) -> tuple[str | No async def create_session( db: AsyncSession, user_id: UUID, + account_id: UUID, team_id: UUID | None, language: str, initial_prompt: str | None = None, @@ -151,6 +152,7 @@ async def create_session( """Create a new Script Builder session.""" session = ScriptBuilderSession( user_id=user_id, + account_id=account_id, team_id=team_id, language=language, ) diff --git a/backend/scripts/seed_test_users.py b/backend/scripts/seed_test_users.py index f8348d97..8526fea1 100644 --- a/backend/scripts/seed_test_users.py +++ b/backend/scripts/seed_test_users.py @@ -80,7 +80,10 @@ def _display_code() -> str: async def main() -> None: - engine = create_async_engine(settings.DATABASE_URL, echo=False) + # Must use ADMIN_DATABASE_URL (BYPASSRLS) — Phase 4 enabled RLS on users. + # The app-role connection has no tenant context at seed time and would see 0 rows. + admin_url = getattr(settings, "ADMIN_DATABASE_URL", None) or settings.DATABASE_URL + engine = create_async_engine(admin_url, echo=False) password_hash = get_password_hash(SHARED_PASSWORD) now = datetime.now(timezone.utc) team_account_id: uuid.UUID | None = None diff --git a/backend/tests/test_admin_categories_global.py b/backend/tests/test_admin_categories_global.py index 1ae6212a..7265771c 100644 --- a/backend/tests/test_admin_categories_global.py +++ b/backend/tests/test_admin_categories_global.py @@ -29,7 +29,7 @@ class TestAdminGlobalCategories: data = response.json() assert data["name"] == "Test Category" assert data["slug"] == "test-category" - assert data["account_id"] is None + assert data["account_id"] == "00000000-0000-0000-0000-000000000001" # PLATFORM_ACCOUNT_ID @pytest.mark.asyncio async def test_update_global_category( diff --git a/backend/tests/test_permissions_account.py b/backend/tests/test_permissions_account.py index a18067a9..b25e314b 100644 --- a/backend/tests/test_permissions_account.py +++ b/backend/tests/test_permissions_account.py @@ -200,6 +200,7 @@ class TestAccountPermissions: }) outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"} - # Outsider should NOT see the private tree + # Outsider should NOT see the private tree. + # With RLS, the tree is invisible to other tenants — 404 not 403. response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers) - assert response.status_code == 403 + assert response.status_code == 404 diff --git a/backend/tests/test_rls_isolation.py b/backend/tests/test_rls_isolation.py index b083eda8..4c69e3df 100644 --- a/backend/tests/test_rls_isolation.py +++ b/backend/tests/test_rls_isolation.py @@ -24,6 +24,12 @@ from pathlib import Path import asyncpg import pytest +# All tests in this module use module-scoped async fixtures (admin_conn, +# seed_rls_test_data) which run on the module event loop. Without this marker, +# pytest-asyncio 0.23+ defaults tests to function-scoped loops, causing +# "Future attached to a different loop" errors on the asyncpg connections. +pytestmark = pytest.mark.asyncio(loop_scope="module") + _DB_HOST = os.getenv("TEST_DB_HOST", "localhost") _DB_PORT = int(os.getenv("TEST_DB_PORT", "5432")) _DB_NAME = os.getenv("TEST_DB_NAME", "patherly_test") # matches conftest.py @@ -191,7 +197,6 @@ async def conn_no_context(): # trees # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_trees_account_a_cannot_see_account_b_rows(conn_a): rows = await conn_a.fetch( f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}'" @@ -199,7 +204,6 @@ async def test_trees_account_a_cannot_see_account_b_rows(conn_a): assert len(rows) == 0, "Account A should not see Account B trees" -@pytest.mark.asyncio async def test_trees_account_a_can_see_own_rows(conn_a): rows = await conn_a.fetch( f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}'" @@ -207,7 +211,6 @@ async def test_trees_account_a_can_see_own_rows(conn_a): assert len(rows) >= 1, "Account A should see its own trees" -@pytest.mark.asyncio async def test_trees_no_context_sees_no_private_trees(conn_no_context): rows = await conn_no_context.fetch( "SELECT id FROM trees WHERE is_default = FALSE AND is_public = FALSE" @@ -219,7 +222,6 @@ async def test_trees_no_context_sees_no_private_trees(conn_no_context): # tree_tags — platform visibility # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a): rows = await conn_a.fetch( f"SELECT id FROM tree_tags WHERE account_id = '{ACCOUNT_B_ID}'" @@ -227,7 +229,6 @@ async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a): assert len(rows) == 0 -@pytest.mark.asyncio async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b): rows_a = await conn_a.fetch( f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'" @@ -243,7 +244,6 @@ async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b): # tree_categories — platform visibility # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_tree_categories_account_a_cannot_see_account_b(conn_a): rows = await conn_a.fetch( f"SELECT id FROM tree_categories WHERE account_id = '{ACCOUNT_B_ID}'" @@ -255,7 +255,6 @@ async def test_tree_categories_account_a_cannot_see_account_b(conn_a): # step_categories — platform visibility # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_step_categories_account_a_cannot_see_account_b(conn_a): rows = await conn_a.fetch( f"SELECT id FROM step_categories WHERE account_id = '{ACCOUNT_B_ID}'" @@ -267,7 +266,6 @@ async def test_step_categories_account_a_cannot_see_account_b(conn_a): # psa_connections — tenant-only # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_psa_connections_account_a_cannot_see_account_b(conn_a): rows = await conn_a.fetch( f"SELECT id FROM psa_connections WHERE account_id = '{ACCOUNT_B_ID}'" @@ -279,7 +277,6 @@ async def test_psa_connections_account_a_cannot_see_account_b(conn_a): # flow_proposals — tenant-only # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_flow_proposals_account_a_cannot_see_account_b(conn_a): rows = await conn_a.fetch( f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'" @@ -513,7 +510,6 @@ async def session_row_ids(admin_conn): # sessions # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_b']}'" @@ -521,7 +517,6 @@ async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_ assert len(rows) == 0, "Account A should not see Account B sessions" -@pytest.mark.asyncio async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_a']}'" @@ -529,7 +524,6 @@ async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids): assert len(rows) == 1, "Account A should see its own sessions" -@pytest.mark.asyncio async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids): rows = await conn_no_context.fetch( f"SELECT id FROM sessions WHERE id IN " @@ -542,7 +536,6 @@ async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids # ai_sessions # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_b']}'" @@ -550,7 +543,6 @@ async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_id assert len(rows) == 0, "Account A should not see Account B ai_sessions" -@pytest.mark.asyncio async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_a']}'" @@ -562,7 +554,6 @@ async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids): # session_branches # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM session_branches WHERE account_id = '{ACCOUNT_B_ID}'" @@ -574,7 +565,6 @@ async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_r # session_supporting_data # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM session_supporting_data WHERE account_id = '{ACCOUNT_B_ID}'" @@ -586,7 +576,6 @@ async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, se # session_resolution_outputs # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM session_resolution_outputs WHERE account_id = '{ACCOUNT_B_ID}'" @@ -598,7 +587,6 @@ async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a, # session_handoffs # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM session_handoffs WHERE account_id = '{ACCOUNT_B_ID}'" @@ -610,7 +598,6 @@ async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_r # script_templates # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM script_templates WHERE account_id = '{ACCOUNT_B_ID}'" @@ -622,7 +609,6 @@ async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_r # script_generations # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_script_generations_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM script_generations WHERE account_id = '{ACCOUNT_B_ID}'" @@ -634,7 +620,6 @@ async def test_script_generations_account_a_cannot_see_account_b(conn_a, session # maintenance_schedules # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM maintenance_schedules WHERE account_id = '{ACCOUNT_B_ID}'" @@ -646,7 +631,6 @@ async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, sess # psa_post_log # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_ids): rows = await conn_a.fetch( f"SELECT id FROM psa_post_log WHERE account_id = '{ACCOUNT_B_ID}'" @@ -658,7 +642,6 @@ async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_i # step_library — visibility-aware policy # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_conn, conn_a): """Private/non-public steps owned by Account B must not be visible to Account A.""" private_step_id = str(uuid.uuid4()) @@ -683,7 +666,6 @@ async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_c ) -@pytest.mark.asyncio async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn, conn_a): """Public steps owned by Account B MUST be visible to Account A (cross-tenant visibility).""" public_step_id = str(uuid.uuid4()) @@ -738,7 +720,6 @@ async def _get_tree_b_id(admin_conn) -> str: # step_ratings # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a): """Account A must not see step ratings belonging to Account B.""" user_b_id = await _get_user_b_id(admin_conn) @@ -779,7 +760,6 @@ async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a): # step_usage_log # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a): """Account A must not see step usage logs belonging to Account B.""" user_b_id = await _get_user_b_id(admin_conn) @@ -832,7 +812,6 @@ async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a) # target_lists # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a): """Account A must not see target lists belonging to Account B.""" user_b_id = await _get_user_b_id(admin_conn) @@ -859,7 +838,6 @@ async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a): # session_shares # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a): """Account A must not see session shares belonging to Account B.""" user_b_id = await _get_user_b_id(admin_conn) @@ -903,7 +881,6 @@ async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a) # audit_logs # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a): """Account A must not see audit logs belonging to Account B.""" user_b_id = await _get_user_b_id(admin_conn) @@ -930,7 +907,6 @@ async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a): # tree_shares # --------------------------------------------------------------------------- -@pytest.mark.asyncio async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a): """Account A must not see tree shares belonging to Account B.""" user_b_id = await _get_user_b_id(admin_conn) @@ -954,3 +930,129 @@ async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a): assert len(rows) == 0, "Account A should not see Account B tree_shares" finally: await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'") + + +# =========================================================================== +# Phase 4 RLS isolation tests +# Tables: users, script_builder_sessions, ai_session_steps, notifications +# +# Note: platform_steps and template_trees have no account_id column and no RLS — +# they are globally readable by all authenticated users. +# =========================================================================== + +# --------------------------------------------------------------------------- +# users +# --------------------------------------------------------------------------- + +async def test_users_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see users belonging to Account B.""" + rows = await conn_a.fetch( + f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B users" + + +async def test_users_account_a_can_see_own(admin_conn, conn_a): + """Account A must be able to see its own users.""" + rows = await conn_a.fetch( + f"SELECT id FROM users WHERE account_id = '{ACCOUNT_A_ID}'" + ) + assert len(rows) > 0, "Account A should see its own users" + + +# --------------------------------------------------------------------------- +# script_builder_sessions +# --------------------------------------------------------------------------- + +async def test_script_builder_sessions_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see script builder sessions belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + + session_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO script_builder_sessions ( + id, user_id, account_id, language, created_at, updated_at + ) VALUES ( + '{session_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + 'powershell', NOW(), NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM script_builder_sessions WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B script_builder_sessions" + finally: + await admin_conn.execute( + f"DELETE FROM script_builder_sessions WHERE id = '{session_id}'" + ) + + +# --------------------------------------------------------------------------- +# ai_session_steps +# --------------------------------------------------------------------------- + +async def test_ai_session_steps_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see ai_session_steps belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + tree_b_id = await _get_tree_b_id(admin_conn) + + # Need an ai_sessions row as FK + ai_session_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO ai_sessions ( + id, user_id, account_id, flow_type, status, confidence_tier, + created_at, updated_at + ) VALUES ( + '{ai_session_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + 'troubleshooting', 'active', 'guided', NOW(), NOW() + ) + """) + + step_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO ai_session_steps ( + id, session_id, account_id, step_type, content, + created_at + ) VALUES ( + '{step_id}', '{ai_session_id}', '{ACCOUNT_B_ID}', + 'question', 'Phase4 RLS test step', NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM ai_session_steps WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B ai_session_steps" + finally: + await admin_conn.execute(f"DELETE FROM ai_session_steps WHERE id = '{step_id}'") + await admin_conn.execute(f"DELETE FROM ai_sessions WHERE id = '{ai_session_id}'") + + +# --------------------------------------------------------------------------- +# notifications +# --------------------------------------------------------------------------- + +async def test_notifications_account_a_cannot_see_account_b(admin_conn, conn_a): + """Account A must not see notifications belonging to Account B.""" + user_b_id = await _get_user_b_id(admin_conn) + + notif_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO notifications ( + id, user_id, account_id, type, title, message, + is_read, created_at + ) VALUES ( + '{notif_id}', '{user_b_id}', '{ACCOUNT_B_ID}', + 'info', 'Phase4 RLS Test', 'RLS isolation test notification', + FALSE, NOW() + ) + """) + try: + rows = await conn_a.fetch( + f"SELECT id FROM notifications WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B notifications" + finally: + await admin_conn.execute(f"DELETE FROM notifications WHERE id = '{notif_id}'") + diff --git a/backend/tests/test_service_account.py b/backend/tests/test_service_account.py new file mode 100644 index 00000000..9609cc60 --- /dev/null +++ b/backend/tests/test_service_account.py @@ -0,0 +1,89 @@ +import pytest +from sqlalchemy import select + +from app.core import service_account as service_account_module +from app.core.service_account import ( + SERVICE_ACCOUNT_EMAIL, + SYSTEM_ACCOUNT_DISPLAY_CODE, + ensure_service_account, +) +from app.models.account import Account +from app.models.user import User + + +class _SessionFactoryOverride: + def __init__(self, session): + self._session = session + + def __call__(self): + return self + + async def __aenter__(self): + return self._session + + async def __aexit__(self, exc_type, exc, tb): + return False + + +@pytest.mark.asyncio +async def test_ensure_service_account_creates_and_reuses_seeded_user(test_db, monkeypatch): + monkeypatch.setattr( + service_account_module, + "_admin_session_factory", + _SessionFactoryOverride(test_db), + ) + + service_account_id = await ensure_service_account(test_db) + + created_user = ( + await test_db.execute(select(User).where(User.id == service_account_id)) + ).scalar_one() + assert created_user.email == SERVICE_ACCOUNT_EMAIL + assert created_user.is_service_account is True + + system_account = ( + await test_db.execute( + select(Account).where(Account.display_code == SYSTEM_ACCOUNT_DISPLAY_CODE) + ) + ).scalar_one() + assert created_user.account_id == system_account.id + + second_id = await ensure_service_account(test_db) + assert second_id == service_account_id + + +@pytest.mark.asyncio +async def test_ensure_service_account_marks_existing_user_as_service_account(test_db, monkeypatch): + monkeypatch.setattr( + service_account_module, + "_admin_session_factory", + _SessionFactoryOverride(test_db), + ) + + system_account = ( + await test_db.execute( + select(Account).where(Account.display_code == SYSTEM_ACCOUNT_DISPLAY_CODE) + ) + ).scalar_one() + + existing_user = User( + email=SERVICE_ACCOUNT_EMAIL, + name="ResolutionFlow", + password_hash="!service-account-no-login", + role="engineer", + is_super_admin=False, + is_team_admin=False, + is_active=True, + is_service_account=False, + must_change_password=False, + account_id=system_account.id, + account_role="engineer", + ) + test_db.add(existing_user) + await test_db.commit() + + resolved_id = await ensure_service_account(test_db) + await test_db.refresh(existing_user) + + assert resolved_id == existing_user.id + assert existing_user.is_service_account is True