Merge PR #136: feat: tenant isolation Phase 4 — RLS on all remaining tables
This commit was merged in pull request #136.
This commit is contained in:
@@ -375,6 +375,12 @@ gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi
|
|||||||
|
|
||||||
**106. Guard async "select item → load data → apply state" flows with a ref:** When a component lets the user switch between items (chat sessions, flows, scripts) and loads data asynchronously on each switch, the load for item A can complete *after* the user has already switched to item B — overwriting B's state with A's stale data. Fix pattern: keep a `currentSelectionRef = useRef(initialId)` and update it synchronously whenever the selection changes (in every creation/switch path). After every `await`, bail out if `currentSelectionRef.current !== thisItemId`. See `AssistantChatPage.tsx` `selectChat` for the reference implementation (`currentChatRef`).
|
**106. Guard async "select item → load data → apply state" flows with a ref:** When a component lets the user switch between items (chat sessions, flows, scripts) and loads data asynchronously on each switch, the load for item A can complete *after* the user has already switched to item B — overwriting B's state with A's stale data. Fix pattern: keep a `currentSelectionRef = useRef(initialId)` and update it synchronously whenever the selection changes (in every creation/switch path). After every `await`, bail out if `currentSelectionRef.current !== thisItemId`. See `AssistantChatPage.tsx` `selectChat` for the reference implementation (`currentChatRef`).
|
||||||
|
|
||||||
|
**107. Startup routines must use `_admin_session_factory()` after Phase 4 RLS:** Any code that runs at startup (lifespan, `ensure_service_account`, seed scripts) and touches tenant-isolated tables (`users`, etc.) must use `_admin_session_factory()` — not `get_db()`. Phase 4 enabled RLS on `users`; a tenant-scoped session has no `app.current_account_id` set at startup, so all queries return 0 rows or fail. `get_service_account_id` in `deps.py` is safe — it reads from `app.state` cached at startup, never hits the DB per-request.
|
||||||
|
|
||||||
|
**108. Tables with no `account_id` column (never add to RLS migrations):** `script_categories`, `platform_steps`, `template_trees`, `plan_feature_defaults`, `accounts` — global/platform tables documented with "No account_id. No RLS." in their model files. When writing RLS migrations, scan at the class level (check for `account_id: Mapped` within the class block), not the file level — multiple classes in one `.py` file can have different columns (e.g. `ScriptCategory` vs `ScriptTemplate` in `script_template.py`).
|
||||||
|
|
||||||
|
**109. `tree_shares.account_id` must equal `tree.account_id`, not the actor's account:** When creating a `TreeShare`, always use `account_id=tree.account_id` (tree owner's tenant). A super admin in tenant A sharing tenant B's tree must produce a share row in tenant B's RLS context — using `current_user.account_id` instead makes the share invisible to the tree owner after RLS is enforced.
|
||||||
|
|
||||||
## RBAC & Permissions
|
## RBAC & Permissions
|
||||||
|
|
||||||
- **Role hierarchy:** super_admin > team_admin > engineer > viewer
|
- **Role hierarchy:** super_admin > team_admin > engineer > viewer
|
||||||
@@ -522,7 +528,7 @@ When a feature, fix, or significant piece of work is finished and merged/committ
|
|||||||
<!-- gitnexus:start -->
|
<!-- gitnexus:start -->
|
||||||
# GitNexus — Code Intelligence
|
# GitNexus — Code Intelligence
|
||||||
|
|
||||||
This project is indexed by GitNexus as **resolutionflow** (14787 symbols, 31366 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
|
This project is indexed by GitNexus as **resolutionflow** (16703 symbols, 35922 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
|
||||||
|
|
||||||
> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first.
|
> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first.
|
||||||
|
|
||||||
|
|||||||
85
backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py
Normal file
85
backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Enable RLS on Phase 4 tables — all remaining tenant-scoped tables.
|
||||||
|
|
||||||
|
All tables in this migration already have account_id NOT NULL (enforced by
|
||||||
|
earlier migrations). This migration adds ENABLE ROW LEVEL SECURITY,
|
||||||
|
FORCE ROW LEVEL SECURITY, and the appropriate tenant isolation policy to each.
|
||||||
|
|
||||||
|
Policy variants used:
|
||||||
|
- Standard: account_id = current_setting(app.current_account_id)::uuid
|
||||||
|
- Platform: standard OR account_id = PLATFORM_ACCOUNT_ID
|
||||||
|
(for global content tables readable by all tenants)
|
||||||
|
|
||||||
|
Skipped intentionally:
|
||||||
|
- accounts — IS the root table; no account_id column
|
||||||
|
- plan_feature_defaults — platform config; no account_id column
|
||||||
|
- script_categories — global lookup table; no account_id column
|
||||||
|
- platform_steps — global content; no account_id column (readable by all)
|
||||||
|
- template_trees — global content; no account_id column (readable by all)
|
||||||
|
|
||||||
|
Revision ID: b3c7e9f2a1d8
|
||||||
|
Revises: 172ad76d7d20
|
||||||
|
Create Date: 2026-04-12
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "b3c7e9f2a1d8"
|
||||||
|
down_revision: Union[str, None] = "172ad76d7d20"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
# Standard policy — tenant sees only own rows.
|
||||||
|
_STANDARD_TABLES = [
|
||||||
|
"users",
|
||||||
|
"account_invites",
|
||||||
|
"account_limit_overrides",
|
||||||
|
"account_feature_overrides",
|
||||||
|
"subscriptions",
|
||||||
|
"ai_chat_sessions",
|
||||||
|
"ai_conversations",
|
||||||
|
"ai_session_steps",
|
||||||
|
"ai_session_embeddings",
|
||||||
|
"ai_suggestions",
|
||||||
|
"ai_usage",
|
||||||
|
"assistant_chats",
|
||||||
|
"attachments",
|
||||||
|
"copilot_conversations",
|
||||||
|
"feedback",
|
||||||
|
"file_uploads",
|
||||||
|
"fork_points",
|
||||||
|
"kb_imports",
|
||||||
|
"notifications",
|
||||||
|
"notification_configs",
|
||||||
|
"notification_logs",
|
||||||
|
"psa_activity_logs",
|
||||||
|
"psa_member_mappings",
|
||||||
|
"script_builder_sessions",
|
||||||
|
"session_ratings",
|
||||||
|
"tree_embeddings",
|
||||||
|
"user_folders",
|
||||||
|
"user_pinned_trees",
|
||||||
|
]
|
||||||
|
|
||||||
|
_POLICY_EXPR = (
|
||||||
|
"account_id = COALESCE("
|
||||||
|
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||||
|
"'00000000-0000-0000-0000-000000000000'"
|
||||||
|
")::uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
for table in _STANDARD_TABLES:
|
||||||
|
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
|
||||||
|
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
|
||||||
|
op.execute(f"""
|
||||||
|
CREATE POLICY tenant_isolation ON {table}
|
||||||
|
USING ({_POLICY_EXPR})
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
for table in _STANDARD_TABLES:
|
||||||
|
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||||
|
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||||
@@ -24,10 +24,14 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
|||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
token: Annotated[str, Depends(oauth2_scheme)]
|
token: Annotated[str, Depends(oauth2_scheme)]
|
||||||
) -> User:
|
) -> User:
|
||||||
"""Get current authenticated user from JWT token."""
|
"""Get current authenticated user from JWT token.
|
||||||
|
|
||||||
|
Must use get_admin_db (BYPASSRLS): this dep runs before require_tenant_context
|
||||||
|
sets app.current_account_id, so the users table RLS would block the lookup.
|
||||||
|
"""
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Could not validate credentials",
|
detail="Could not validate credentials",
|
||||||
@@ -77,10 +81,14 @@ async def get_refresh_token_payload(
|
|||||||
async def get_current_active_user(
|
async def get_current_active_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
) -> User:
|
) -> User:
|
||||||
"""Ensure user is active (not disabled). Auto-downgrades expired trials.
|
"""Ensure user is active (not disabled). Auto-downgrades expired trials.
|
||||||
Enforces must_change_password — blocks all routes except allowlist."""
|
Enforces must_change_password — blocks all routes except allowlist.
|
||||||
|
|
||||||
|
Uses get_admin_db: runs before require_tenant_context sets the ContextVar,
|
||||||
|
so tenant-scoped tables (subscriptions) would return 0 rows via app role.
|
||||||
|
"""
|
||||||
if not current_user.is_active:
|
if not current_user.is_active:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from sqlalchemy import select
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
|
from app.core.admin_database import get_admin_db
|
||||||
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
|
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
|
||||||
from app.core.audit import log_audit
|
from app.core.audit import log_audit
|
||||||
from app.models.refresh_token import RefreshToken
|
from app.models.refresh_token import RefreshToken
|
||||||
@@ -148,7 +149,7 @@ async def update_member_role(
|
|||||||
@router.post("/me/transfer-ownership", response_model=AccountResponse)
|
@router.post("/me/transfer-ownership", response_model=AccountResponse)
|
||||||
async def transfer_ownership(
|
async def transfer_ownership(
|
||||||
data: TransferOwnershipRequest,
|
data: TransferOwnershipRequest,
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
current_user: Annotated[User, Depends(require_account_owner)]
|
current_user: Annotated[User, Depends(require_account_owner)]
|
||||||
):
|
):
|
||||||
"""Transfer account ownership to another member (owner only)."""
|
"""Transfer account ownership to another member (owner only)."""
|
||||||
@@ -377,7 +378,7 @@ async def list_invites(
|
|||||||
|
|
||||||
@router.post("/me/leave")
|
@router.post("/me/leave")
|
||||||
async def leave_account(
|
async def leave_account(
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||||
):
|
):
|
||||||
"""Leave the current account (non-owners only). Creates a personal account."""
|
"""Leave the current account (non-owners only). Creates a personal account."""
|
||||||
@@ -423,7 +424,7 @@ class DeleteAccountRequest(BaseModel):
|
|||||||
@router.delete("/me")
|
@router.delete("/me")
|
||||||
async def delete_account(
|
async def delete_account(
|
||||||
data: DeleteAccountRequest,
|
data: DeleteAccountRequest,
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
current_user: Annotated[User, Depends(require_account_owner)]
|
current_user: Annotated[User, Depends(require_account_owner)]
|
||||||
):
|
):
|
||||||
"""Delete the current account and soft-delete the user (owner only, no other members)."""
|
"""Delete the current account and soft-delete the user (owner only, no other members)."""
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy import select, update as sa_update
|
from sqlalchemy import select, update as sa_update
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.settings_manager import SettingsManager
|
from app.core.settings_manager import SettingsManager
|
||||||
from app.core.database import get_db
|
from app.core.admin_database import get_admin_db
|
||||||
from app.core.rate_limit import limiter
|
from app.core.rate_limit import limiter
|
||||||
from app.core.security import (
|
from app.core.security import (
|
||||||
verify_password,
|
verify_password,
|
||||||
@@ -67,7 +67,7 @@ def _generate_display_code() -> str:
|
|||||||
async def register(
|
async def register(
|
||||||
request: Request,
|
request: Request,
|
||||||
user_data: UserCreate,
|
user_data: UserCreate,
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Register a new user.
|
"""Register a new user.
|
||||||
|
|
||||||
@@ -232,7 +232,7 @@ async def register(
|
|||||||
async def login(
|
async def login(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Login and get access token."""
|
"""Login and get access token."""
|
||||||
# Find user by email
|
# Find user by email
|
||||||
@@ -270,7 +270,7 @@ async def login(
|
|||||||
async def login_json(
|
async def login_json(
|
||||||
request: Request,
|
request: Request,
|
||||||
credentials: UserLogin,
|
credentials: UserLogin,
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Login with JSON body (alternative to form data)."""
|
"""Login with JSON body (alternative to form data)."""
|
||||||
result = await db.execute(select(User).where(User.email == credentials.email))
|
result = await db.execute(select(User).where(User.email == credentials.email))
|
||||||
@@ -304,7 +304,7 @@ async def login_json(
|
|||||||
async def refresh_token(
|
async def refresh_token(
|
||||||
request: Request,
|
request: Request,
|
||||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Refresh access token using refresh token (rotation: old token is revoked)."""
|
"""Refresh access token using refresh token (rotation: old token is revoked)."""
|
||||||
user_id = payload.get("sub")
|
user_id = payload.get("sub")
|
||||||
@@ -368,7 +368,7 @@ async def get_me(
|
|||||||
async def update_me(
|
async def update_me(
|
||||||
data: UserUpdate,
|
data: UserUpdate,
|
||||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Update current user's profile (name, email)."""
|
"""Update current user's profile (name, email)."""
|
||||||
update_fields = data.model_fields_set - {"current_password"}
|
update_fields = data.model_fields_set - {"current_password"}
|
||||||
@@ -415,7 +415,7 @@ async def update_me(
|
|||||||
@router.post("/logout")
|
@router.post("/logout")
|
||||||
async def logout(
|
async def logout(
|
||||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Logout user by revoking the refresh token."""
|
"""Logout user by revoking the refresh token."""
|
||||||
jti = payload.get("jti")
|
jti = payload.get("jti")
|
||||||
@@ -438,7 +438,7 @@ async def change_password(
|
|||||||
request: Request,
|
request: Request,
|
||||||
data: ChangePasswordRequest,
|
data: ChangePasswordRequest,
|
||||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Change the current user's password."""
|
"""Change the current user's password."""
|
||||||
if not verify_password(data.current_password, current_user.password_hash):
|
if not verify_password(data.current_password, current_user.password_hash):
|
||||||
@@ -478,7 +478,7 @@ async def change_password(
|
|||||||
async def forgot_password(
|
async def forgot_password(
|
||||||
request: Request,
|
request: Request,
|
||||||
data: ForgotPasswordRequest,
|
data: ForgotPasswordRequest,
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Request a password reset email. Always returns success (anti-enumeration)."""
|
"""Request a password reset email. Always returns success (anti-enumeration)."""
|
||||||
result = await db.execute(select(User).where(User.email == data.email))
|
result = await db.execute(select(User).where(User.email == data.email))
|
||||||
@@ -513,7 +513,7 @@ async def forgot_password(
|
|||||||
@router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse)
|
@router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse)
|
||||||
async def verify_reset_token(
|
async def verify_reset_token(
|
||||||
data: VerifyResetTokenRequest,
|
data: VerifyResetTokenRequest,
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Verify a password reset token is valid."""
|
"""Verify a password reset token is valid."""
|
||||||
payload = decode_token(data.token)
|
payload = decode_token(data.token)
|
||||||
@@ -544,7 +544,7 @@ async def verify_reset_token(
|
|||||||
async def reset_password(
|
async def reset_password(
|
||||||
request: Request,
|
request: Request,
|
||||||
data: ResetPasswordRequest,
|
data: ResetPasswordRequest,
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Reset password using a valid reset token."""
|
"""Reset password using a valid reset token."""
|
||||||
payload = decode_token(data.token)
|
payload = decode_token(data.token)
|
||||||
@@ -611,7 +611,7 @@ async def reset_password(
|
|||||||
|
|
||||||
@router.get("/email/verification-status")
|
@router.get("/email/verification-status")
|
||||||
async def get_verification_status(
|
async def get_verification_status(
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Check if email verification is enabled on the platform."""
|
"""Check if email verification is enabled on the platform."""
|
||||||
enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
||||||
@@ -623,7 +623,7 @@ async def get_verification_status(
|
|||||||
async def send_verification_email(
|
async def send_verification_email(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Send an email verification link to the current user."""
|
"""Send an email verification link to the current user."""
|
||||||
verification_enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
verification_enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
||||||
@@ -662,7 +662,7 @@ async def send_verification_email(
|
|||||||
@router.post("/email/verify")
|
@router.post("/email/verify")
|
||||||
async def verify_email(
|
async def verify_email(
|
||||||
data: dict,
|
data: dict,
|
||||||
db: Annotated[AsyncSession, Depends(get_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Verify an email using a token. Public endpoint."""
|
"""Verify an email using a token. Public endpoint."""
|
||||||
token = data.get("token")
|
token = data.get("token")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.api.deps import get_current_active_user
|
from app.api.deps import get_current_active_user
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
|
from app.core.admin_database import get_admin_db
|
||||||
from app.models.assistant_chat import AssistantChat
|
from app.models.assistant_chat import AssistantChat
|
||||||
from app.models.psa_connection import PsaConnection
|
from app.models.psa_connection import PsaConnection
|
||||||
from app.models.session import Session
|
from app.models.session import Session
|
||||||
@@ -98,7 +99,7 @@ async def get_onboarding_status(
|
|||||||
|
|
||||||
@router.post("/onboarding-status/dismiss", response_model=OnboardingStatus)
|
@router.post("/onboarding-status/dismiss", response_model=OnboardingStatus)
|
||||||
async def dismiss_onboarding(
|
async def dismiss_onboarding(
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
) -> OnboardingStatus:
|
) -> OnboardingStatus:
|
||||||
"""Dismiss the onboarding checklist for the current user."""
|
"""Dismiss the onboarding checklist for the current user."""
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ async def create_session(
|
|||||||
session = await script_builder_service.create_session(
|
session = await script_builder_service.create_session(
|
||||||
db=db,
|
db=db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
account_id=current_user.account_id,
|
||||||
team_id=current_user.team_id,
|
team_id=current_user.team_id,
|
||||||
language=data.language,
|
language=data.language,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None:
|
|||||||
"""Create batch sessions for a scheduled maintenance run."""
|
"""Create batch sessions for a scheduled maintenance run."""
|
||||||
# Import all models first to ensure SQLAlchemy mapper relationships resolve
|
# Import all models first to ensure SQLAlchemy mapper relationships resolve
|
||||||
import app.models # noqa: F401
|
import app.models # noqa: F401
|
||||||
from app.core.database import async_session_maker
|
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||||
from app.models.maintenance_schedule import MaintenanceSchedule
|
from app.models.maintenance_schedule import MaintenanceSchedule
|
||||||
from app.models.session import Session
|
from app.models.session import Session
|
||||||
from app.models.target_list import TargetList
|
from app.models.target_list import TargetList
|
||||||
@@ -118,7 +118,7 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None:
|
|||||||
async def _cleanup_expired_ai_conversations() -> None:
|
async def _cleanup_expired_ai_conversations() -> None:
|
||||||
"""Delete expired AI wizard conversations."""
|
"""Delete expired AI wizard conversations."""
|
||||||
import app.models # noqa: F401
|
import app.models # noqa: F401
|
||||||
from app.core.database import async_session_maker
|
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||||
from app.models.ai_conversation import AIConversation
|
from app.models.ai_conversation import AIConversation
|
||||||
|
|
||||||
async with async_session_maker() as db:
|
async with async_session_maker() as db:
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import logging
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.admin_database import _admin_session_factory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SERVICE_ACCOUNT_EMAIL = "noreply@resolutionflow.com"
|
SERVICE_ACCOUNT_EMAIL = "noreply@resolutionflow.com"
|
||||||
@@ -52,40 +54,45 @@ async def _ensure_system_account(db: AsyncSession) -> uuid.UUID:
|
|||||||
async def ensure_service_account(db: AsyncSession) -> uuid.UUID:
|
async def ensure_service_account(db: AsyncSession) -> uuid.UUID:
|
||||||
"""Ensure the ResolutionFlow service account exists and return its ID.
|
"""Ensure the ResolutionFlow service account exists and return its ID.
|
||||||
|
|
||||||
Idempotent — safe to call on every startup. Creates the account if it
|
Idempotent — safe to call on every startup. This lookup must bypass RLS
|
||||||
does not exist. The account has no usable password and is_service_account=True
|
because startup runs before any request-scoped tenant context exists and
|
||||||
so it can never log in via normal auth flows.
|
the users table is tenant-isolated in Phase 4. The service account is
|
||||||
|
normally created by Alembic migration 1490781700bc; the runtime create path
|
||||||
|
remains as a self-healing fallback for environments that predate that seed.
|
||||||
"""
|
"""
|
||||||
|
_ = db # Retained for call-site compatibility in app lifespan startup.
|
||||||
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
result = await db.execute(
|
async with _admin_session_factory() as admin_db:
|
||||||
select(User).where(User.email == SERVICE_ACCOUNT_EMAIL)
|
result = await admin_db.execute(
|
||||||
)
|
select(User).where(User.email == SERVICE_ACCOUNT_EMAIL)
|
||||||
user = result.scalar_one_or_none()
|
)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
if user is not None:
|
if user is not None:
|
||||||
if not user.is_service_account:
|
if not user.is_service_account:
|
||||||
user.is_service_account = True
|
user.is_service_account = True
|
||||||
await db.commit()
|
await admin_db.commit()
|
||||||
return user.id
|
return user.id
|
||||||
|
|
||||||
account_id = await _ensure_system_account(db)
|
account_id = await _ensure_system_account(admin_db)
|
||||||
|
|
||||||
new_user = User(
|
new_user = User(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
email=SERVICE_ACCOUNT_EMAIL,
|
email=SERVICE_ACCOUNT_EMAIL,
|
||||||
name=SERVICE_ACCOUNT_NAME,
|
name=SERVICE_ACCOUNT_NAME,
|
||||||
password_hash="!service-account-no-login", # bcrypt can't produce this prefix
|
password_hash="!service-account-no-login", # bcrypt can't produce this prefix
|
||||||
role="engineer",
|
role="engineer",
|
||||||
is_super_admin=False,
|
is_super_admin=False,
|
||||||
is_team_admin=False,
|
is_team_admin=False,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
is_service_account=True,
|
is_service_account=True,
|
||||||
must_change_password=False,
|
must_change_password=False,
|
||||||
account_id=account_id,
|
account_id=account_id,
|
||||||
account_role="engineer",
|
account_role="engineer",
|
||||||
)
|
)
|
||||||
db.add(new_user)
|
admin_db.add(new_user)
|
||||||
await db.commit()
|
await admin_db.commit()
|
||||||
logger.info(f"[service_account] Created service account (id={new_user.id})")
|
logger.info(f"[service_account] Created service account (id={new_user.id})")
|
||||||
return new_user.id
|
return new_user.id
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ if settings.SENTRY_DSN:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.core.database import init_db, async_session_maker
|
from app.core.database import init_db
|
||||||
|
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||||
from app.core.logging_config import setup_logging
|
from app.core.logging_config import setup_logging
|
||||||
from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware
|
from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware
|
||||||
from app.core.security_headers import SecurityHeadersMiddleware
|
from app.core.security_headers import SecurityHeadersMiddleware
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import logging
|
|||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.core.database import async_session_maker
|
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||||
from app.models.ai_session import AISession
|
from app.models.ai_session import AISession
|
||||||
from app.services.knowledge_flywheel import analyze_session
|
from app.services.knowledge_flywheel import analyze_session
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from datetime import datetime, timezone
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.database import async_session_maker
|
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||||
from app.models.psa_post_log import PsaPostLog
|
from app.models.psa_post_log import PsaPostLog
|
||||||
from app.services.psa_documentation_service import retry_failed_push
|
from app.services.psa_documentation_service import retry_failed_push
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from datetime import datetime, timezone, timedelta
|
|||||||
|
|
||||||
from sqlalchemy import select, delete, func
|
from sqlalchemy import select, delete, func
|
||||||
|
|
||||||
from app.core.database import async_session_maker
|
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||||
from app.models.account import Account
|
from app.models.account import Account
|
||||||
from app.models.assistant_chat import AssistantChat
|
from app.models.assistant_chat import AssistantChat
|
||||||
|
|
||||||
|
|||||||
@@ -144,6 +144,7 @@ def _extract_script_from_response(content: str, language: str) -> tuple[str | No
|
|||||||
async def create_session(
|
async def create_session(
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
|
account_id: UUID,
|
||||||
team_id: UUID | None,
|
team_id: UUID | None,
|
||||||
language: str,
|
language: str,
|
||||||
initial_prompt: str | None = None,
|
initial_prompt: str | None = None,
|
||||||
@@ -151,6 +152,7 @@ async def create_session(
|
|||||||
"""Create a new Script Builder session."""
|
"""Create a new Script Builder session."""
|
||||||
session = ScriptBuilderSession(
|
session = ScriptBuilderSession(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
account_id=account_id,
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
language=language,
|
language=language,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -80,7 +80,10 @@ def _display_code() -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
engine = create_async_engine(settings.DATABASE_URL, echo=False)
|
# Must use ADMIN_DATABASE_URL (BYPASSRLS) — Phase 4 enabled RLS on users.
|
||||||
|
# The app-role connection has no tenant context at seed time and would see 0 rows.
|
||||||
|
admin_url = getattr(settings, "ADMIN_DATABASE_URL", None) or settings.DATABASE_URL
|
||||||
|
engine = create_async_engine(admin_url, echo=False)
|
||||||
password_hash = get_password_hash(SHARED_PASSWORD)
|
password_hash = get_password_hash(SHARED_PASSWORD)
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
team_account_id: uuid.UUID | None = None
|
team_account_id: uuid.UUID | None = None
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class TestAdminGlobalCategories:
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["name"] == "Test Category"
|
assert data["name"] == "Test Category"
|
||||||
assert data["slug"] == "test-category"
|
assert data["slug"] == "test-category"
|
||||||
assert data["account_id"] is None
|
assert data["account_id"] == "00000000-0000-0000-0000-000000000001" # PLATFORM_ACCOUNT_ID
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_global_category(
|
async def test_update_global_category(
|
||||||
|
|||||||
@@ -200,6 +200,7 @@ class TestAccountPermissions:
|
|||||||
})
|
})
|
||||||
outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"}
|
outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"}
|
||||||
|
|
||||||
# Outsider should NOT see the private tree
|
# Outsider should NOT see the private tree.
|
||||||
|
# With RLS, the tree is invisible to other tenants — 404 not 403.
|
||||||
response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers)
|
response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers)
|
||||||
assert response.status_code == 403
|
assert response.status_code == 404
|
||||||
|
|||||||
@@ -24,6 +24,12 @@ from pathlib import Path
|
|||||||
import asyncpg
|
import asyncpg
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
# All tests in this module use module-scoped async fixtures (admin_conn,
|
||||||
|
# seed_rls_test_data) which run on the module event loop. Without this marker,
|
||||||
|
# pytest-asyncio 0.23+ defaults tests to function-scoped loops, causing
|
||||||
|
# "Future attached to a different loop" errors on the asyncpg connections.
|
||||||
|
pytestmark = pytest.mark.asyncio(loop_scope="module")
|
||||||
|
|
||||||
_DB_HOST = os.getenv("TEST_DB_HOST", "localhost")
|
_DB_HOST = os.getenv("TEST_DB_HOST", "localhost")
|
||||||
_DB_PORT = int(os.getenv("TEST_DB_PORT", "5432"))
|
_DB_PORT = int(os.getenv("TEST_DB_PORT", "5432"))
|
||||||
_DB_NAME = os.getenv("TEST_DB_NAME", "patherly_test") # matches conftest.py
|
_DB_NAME = os.getenv("TEST_DB_NAME", "patherly_test") # matches conftest.py
|
||||||
@@ -191,7 +197,6 @@ async def conn_no_context():
|
|||||||
# trees
|
# trees
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
|
async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -199,7 +204,6 @@ async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
|
|||||||
assert len(rows) == 0, "Account A should not see Account B trees"
|
assert len(rows) == 0, "Account A should not see Account B trees"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trees_account_a_can_see_own_rows(conn_a):
|
async def test_trees_account_a_can_see_own_rows(conn_a):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}'"
|
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}'"
|
||||||
@@ -207,7 +211,6 @@ async def test_trees_account_a_can_see_own_rows(conn_a):
|
|||||||
assert len(rows) >= 1, "Account A should see its own trees"
|
assert len(rows) >= 1, "Account A should see its own trees"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trees_no_context_sees_no_private_trees(conn_no_context):
|
async def test_trees_no_context_sees_no_private_trees(conn_no_context):
|
||||||
rows = await conn_no_context.fetch(
|
rows = await conn_no_context.fetch(
|
||||||
"SELECT id FROM trees WHERE is_default = FALSE AND is_public = FALSE"
|
"SELECT id FROM trees WHERE is_default = FALSE AND is_public = FALSE"
|
||||||
@@ -219,7 +222,6 @@ async def test_trees_no_context_sees_no_private_trees(conn_no_context):
|
|||||||
# tree_tags — platform visibility
|
# tree_tags — platform visibility
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
|
async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM tree_tags WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM tree_tags WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -227,7 +229,6 @@ async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
|
|||||||
assert len(rows) == 0
|
assert len(rows) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
|
async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
|
||||||
rows_a = await conn_a.fetch(
|
rows_a = await conn_a.fetch(
|
||||||
f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'"
|
f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'"
|
||||||
@@ -243,7 +244,6 @@ async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
|
|||||||
# tree_categories — platform visibility
|
# tree_categories — platform visibility
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
|
async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM tree_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM tree_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -255,7 +255,6 @@ async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
|
|||||||
# step_categories — platform visibility
|
# step_categories — platform visibility
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_step_categories_account_a_cannot_see_account_b(conn_a):
|
async def test_step_categories_account_a_cannot_see_account_b(conn_a):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM step_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM step_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -267,7 +266,6 @@ async def test_step_categories_account_a_cannot_see_account_b(conn_a):
|
|||||||
# psa_connections — tenant-only
|
# psa_connections — tenant-only
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
|
async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM psa_connections WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM psa_connections WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -279,7 +277,6 @@ async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
|
|||||||
# flow_proposals — tenant-only
|
# flow_proposals — tenant-only
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_flow_proposals_account_a_cannot_see_account_b(conn_a):
|
async def test_flow_proposals_account_a_cannot_see_account_b(conn_a):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -513,7 +510,6 @@ async def session_row_ids(admin_conn):
|
|||||||
# sessions
|
# sessions
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_row_ids):
|
async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_b']}'"
|
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_b']}'"
|
||||||
@@ -521,7 +517,6 @@ async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_
|
|||||||
assert len(rows) == 0, "Account A should not see Account B sessions"
|
assert len(rows) == 0, "Account A should not see Account B sessions"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids):
|
async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_a']}'"
|
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_a']}'"
|
||||||
@@ -529,7 +524,6 @@ async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids):
|
|||||||
assert len(rows) == 1, "Account A should see its own sessions"
|
assert len(rows) == 1, "Account A should see its own sessions"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids):
|
async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids):
|
||||||
rows = await conn_no_context.fetch(
|
rows = await conn_no_context.fetch(
|
||||||
f"SELECT id FROM sessions WHERE id IN "
|
f"SELECT id FROM sessions WHERE id IN "
|
||||||
@@ -542,7 +536,6 @@ async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids
|
|||||||
# ai_sessions
|
# ai_sessions
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_b']}'"
|
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_b']}'"
|
||||||
@@ -550,7 +543,6 @@ async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_id
|
|||||||
assert len(rows) == 0, "Account A should not see Account B ai_sessions"
|
assert len(rows) == 0, "Account A should not see Account B ai_sessions"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids):
|
async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_a']}'"
|
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_a']}'"
|
||||||
@@ -562,7 +554,6 @@ async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids):
|
|||||||
# session_branches
|
# session_branches
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM session_branches WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM session_branches WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -574,7 +565,6 @@ async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_r
|
|||||||
# session_supporting_data
|
# session_supporting_data
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM session_supporting_data WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM session_supporting_data WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -586,7 +576,6 @@ async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, se
|
|||||||
# session_resolution_outputs
|
# session_resolution_outputs
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM session_resolution_outputs WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM session_resolution_outputs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -598,7 +587,6 @@ async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a,
|
|||||||
# session_handoffs
|
# session_handoffs
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM session_handoffs WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM session_handoffs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -610,7 +598,6 @@ async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_r
|
|||||||
# script_templates
|
# script_templates
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM script_templates WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM script_templates WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -622,7 +609,6 @@ async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_r
|
|||||||
# script_generations
|
# script_generations
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_script_generations_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
async def test_script_generations_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM script_generations WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM script_generations WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -634,7 +620,6 @@ async def test_script_generations_account_a_cannot_see_account_b(conn_a, session
|
|||||||
# maintenance_schedules
|
# maintenance_schedules
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM maintenance_schedules WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM maintenance_schedules WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -646,7 +631,6 @@ async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, sess
|
|||||||
# psa_post_log
|
# psa_post_log
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||||
rows = await conn_a.fetch(
|
rows = await conn_a.fetch(
|
||||||
f"SELECT id FROM psa_post_log WHERE account_id = '{ACCOUNT_B_ID}'"
|
f"SELECT id FROM psa_post_log WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
@@ -658,7 +642,6 @@ async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_i
|
|||||||
# step_library — visibility-aware policy
|
# step_library — visibility-aware policy
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_conn, conn_a):
|
async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_conn, conn_a):
|
||||||
"""Private/non-public steps owned by Account B must not be visible to Account A."""
|
"""Private/non-public steps owned by Account B must not be visible to Account A."""
|
||||||
private_step_id = str(uuid.uuid4())
|
private_step_id = str(uuid.uuid4())
|
||||||
@@ -683,7 +666,6 @@ async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_c
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn, conn_a):
|
async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn, conn_a):
|
||||||
"""Public steps owned by Account B MUST be visible to Account A (cross-tenant visibility)."""
|
"""Public steps owned by Account B MUST be visible to Account A (cross-tenant visibility)."""
|
||||||
public_step_id = str(uuid.uuid4())
|
public_step_id = str(uuid.uuid4())
|
||||||
@@ -738,7 +720,6 @@ async def _get_tree_b_id(admin_conn) -> str:
|
|||||||
# step_ratings
|
# step_ratings
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
|
async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
"""Account A must not see step ratings belonging to Account B."""
|
"""Account A must not see step ratings belonging to Account B."""
|
||||||
user_b_id = await _get_user_b_id(admin_conn)
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
@@ -779,7 +760,6 @@ async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
|
|||||||
# step_usage_log
|
# step_usage_log
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a):
|
async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
"""Account A must not see step usage logs belonging to Account B."""
|
"""Account A must not see step usage logs belonging to Account B."""
|
||||||
user_b_id = await _get_user_b_id(admin_conn)
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
@@ -832,7 +812,6 @@ async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a)
|
|||||||
# target_lists
|
# target_lists
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
|
async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
"""Account A must not see target lists belonging to Account B."""
|
"""Account A must not see target lists belonging to Account B."""
|
||||||
user_b_id = await _get_user_b_id(admin_conn)
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
@@ -859,7 +838,6 @@ async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
|
|||||||
# session_shares
|
# session_shares
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
"""Account A must not see session shares belonging to Account B."""
|
"""Account A must not see session shares belonging to Account B."""
|
||||||
user_b_id = await _get_user_b_id(admin_conn)
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
@@ -903,7 +881,6 @@ async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a)
|
|||||||
# audit_logs
|
# audit_logs
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
|
async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
"""Account A must not see audit logs belonging to Account B."""
|
"""Account A must not see audit logs belonging to Account B."""
|
||||||
user_b_id = await _get_user_b_id(admin_conn)
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
@@ -930,7 +907,6 @@ async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
|
|||||||
# tree_shares
|
# tree_shares
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
"""Account A must not see tree shares belonging to Account B."""
|
"""Account A must not see tree shares belonging to Account B."""
|
||||||
user_b_id = await _get_user_b_id(admin_conn)
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
@@ -954,3 +930,129 @@ async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
|||||||
assert len(rows) == 0, "Account A should not see Account B tree_shares"
|
assert len(rows) == 0, "Account A should not see Account B tree_shares"
|
||||||
finally:
|
finally:
|
||||||
await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'")
|
await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'")
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# Phase 4 RLS isolation tests
|
||||||
|
# Tables: users, script_builder_sessions, ai_session_steps, notifications
|
||||||
|
#
|
||||||
|
# Note: platform_steps and template_trees have no account_id column and no RLS —
|
||||||
|
# they are globally readable by all authenticated users.
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# users
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_users_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
|
"""Account A must not see users belonging to Account B."""
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) == 0, "Account A should not see Account B users"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_users_account_a_can_see_own(admin_conn, conn_a):
|
||||||
|
"""Account A must be able to see its own users."""
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_A_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) > 0, "Account A should see its own users"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# script_builder_sessions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_script_builder_sessions_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
|
"""Account A must not see script builder sessions belonging to Account B."""
|
||||||
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
|
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO script_builder_sessions (
|
||||||
|
id, user_id, account_id, language, created_at, updated_at
|
||||||
|
) VALUES (
|
||||||
|
'{session_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||||
|
'powershell', NOW(), NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try:
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM script_builder_sessions WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) == 0, "Account A should not see Account B script_builder_sessions"
|
||||||
|
finally:
|
||||||
|
await admin_conn.execute(
|
||||||
|
f"DELETE FROM script_builder_sessions WHERE id = '{session_id}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ai_session_steps
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_ai_session_steps_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
|
"""Account A must not see ai_session_steps belonging to Account B."""
|
||||||
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
|
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||||
|
|
||||||
|
# Need an ai_sessions row as FK
|
||||||
|
ai_session_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO ai_sessions (
|
||||||
|
id, user_id, account_id, flow_type, status, confidence_tier,
|
||||||
|
created_at, updated_at
|
||||||
|
) VALUES (
|
||||||
|
'{ai_session_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||||
|
'troubleshooting', 'active', 'guided', NOW(), NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
step_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO ai_session_steps (
|
||||||
|
id, session_id, account_id, step_type, content,
|
||||||
|
created_at
|
||||||
|
) VALUES (
|
||||||
|
'{step_id}', '{ai_session_id}', '{ACCOUNT_B_ID}',
|
||||||
|
'question', 'Phase4 RLS test step', NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try:
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM ai_session_steps WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) == 0, "Account A should not see Account B ai_session_steps"
|
||||||
|
finally:
|
||||||
|
await admin_conn.execute(f"DELETE FROM ai_session_steps WHERE id = '{step_id}'")
|
||||||
|
await admin_conn.execute(f"DELETE FROM ai_sessions WHERE id = '{ai_session_id}'")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# notifications
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_notifications_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
|
"""Account A must not see notifications belonging to Account B."""
|
||||||
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
|
|
||||||
|
notif_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO notifications (
|
||||||
|
id, user_id, account_id, type, title, message,
|
||||||
|
is_read, created_at
|
||||||
|
) VALUES (
|
||||||
|
'{notif_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||||
|
'info', 'Phase4 RLS Test', 'RLS isolation test notification',
|
||||||
|
FALSE, NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try:
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM notifications WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) == 0, "Account A should not see Account B notifications"
|
||||||
|
finally:
|
||||||
|
await admin_conn.execute(f"DELETE FROM notifications WHERE id = '{notif_id}'")
|
||||||
|
|
||||||
|
|||||||
89
backend/tests/test_service_account.py
Normal file
89
backend/tests/test_service_account.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core import service_account as service_account_module
|
||||||
|
from app.core.service_account import (
|
||||||
|
SERVICE_ACCOUNT_EMAIL,
|
||||||
|
SYSTEM_ACCOUNT_DISPLAY_CODE,
|
||||||
|
ensure_service_account,
|
||||||
|
)
|
||||||
|
from app.models.account import Account
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
class _SessionFactoryOverride:
|
||||||
|
def __init__(self, session):
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_service_account_creates_and_reuses_seeded_user(test_db, monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
service_account_module,
|
||||||
|
"_admin_session_factory",
|
||||||
|
_SessionFactoryOverride(test_db),
|
||||||
|
)
|
||||||
|
|
||||||
|
service_account_id = await ensure_service_account(test_db)
|
||||||
|
|
||||||
|
created_user = (
|
||||||
|
await test_db.execute(select(User).where(User.id == service_account_id))
|
||||||
|
).scalar_one()
|
||||||
|
assert created_user.email == SERVICE_ACCOUNT_EMAIL
|
||||||
|
assert created_user.is_service_account is True
|
||||||
|
|
||||||
|
system_account = (
|
||||||
|
await test_db.execute(
|
||||||
|
select(Account).where(Account.display_code == SYSTEM_ACCOUNT_DISPLAY_CODE)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert created_user.account_id == system_account.id
|
||||||
|
|
||||||
|
second_id = await ensure_service_account(test_db)
|
||||||
|
assert second_id == service_account_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_service_account_marks_existing_user_as_service_account(test_db, monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
service_account_module,
|
||||||
|
"_admin_session_factory",
|
||||||
|
_SessionFactoryOverride(test_db),
|
||||||
|
)
|
||||||
|
|
||||||
|
system_account = (
|
||||||
|
await test_db.execute(
|
||||||
|
select(Account).where(Account.display_code == SYSTEM_ACCOUNT_DISPLAY_CODE)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
|
||||||
|
existing_user = User(
|
||||||
|
email=SERVICE_ACCOUNT_EMAIL,
|
||||||
|
name="ResolutionFlow",
|
||||||
|
password_hash="!service-account-no-login",
|
||||||
|
role="engineer",
|
||||||
|
is_super_admin=False,
|
||||||
|
is_team_admin=False,
|
||||||
|
is_active=True,
|
||||||
|
is_service_account=False,
|
||||||
|
must_change_password=False,
|
||||||
|
account_id=system_account.id,
|
||||||
|
account_role="engineer",
|
||||||
|
)
|
||||||
|
test_db.add(existing_user)
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
resolved_id = await ensure_service_account(test_db)
|
||||||
|
await test_db.refresh(existing_user)
|
||||||
|
|
||||||
|
assert resolved_id == existing_user.id
|
||||||
|
assert existing_user.is_service_account is True
|
||||||
Reference in New Issue
Block a user