Merge pull request #134 from resolutionflow/feat/tenant-isolation-phase-2

feat: Phase 2 tenant isolation — RLS on 11 session tables
This commit was merged in pull request #134.
This commit is contained in:
chihlasm
2026-04-11 03:02:25 -04:00
committed by GitHub
28 changed files with 652 additions and 13 deletions

View File

@@ -31,6 +31,8 @@ jobs:
SECRET_KEY: ci-test-secret-key-not-for-production
DEBUG: "true"
APP_NAME: ResolutionFlow
TEST_DB_NAME: resolutionflow_test
DB_APP_ROLE_PASSWORD: app_secret_ci
steps:
- uses: actions/checkout@v5
@@ -47,6 +49,9 @@ jobs:
- name: Install dependencies
run: pip install -r backend/requirements.txt -r backend/requirements-dev.txt
- name: Run Alembic migrations
run: cd backend && alembic upgrade head
- name: Check tenant filter enforcement
run: cd backend && python scripts/check_tenant_filters.py
# Warn mode only (exits 0). Switch to --fail after Phase 1 backlog clears.

View File

@@ -29,13 +29,37 @@ from app.models.session_branch import SessionBranch # noqa: F401
from app.models.fork_point import ForkPoint # noqa: F401
from app.models.session_handoff import SessionHandoff # noqa: F401
from app.models.session_resolution_output import SessionResolutionOutput # noqa: F401
from app.core.config import settings
def _alembic_sync_url() -> str:
"""Return a psycopg2-compatible sync URL for Alembic.
Priority order:
1. DATABASE_URL_SYNC — in Railway this is set as a reference variable
(${{pgvector.DATABASE_URL}}) that resolves to the correct postgres
superuser credentials for the current environment (production, PR preview,
etc.). This always works even on fresh databases before any custom roles
have been created, because it uses the postgres superuser.
2. ADMIN_DATABASE_URL (resolutionflow_admin, BYPASSRLS) converted to a sync
driver — fallback for local dev where DATABASE_URL_SYNC may not be set.
"""
if settings.DATABASE_URL_SYNC:
return settings.DATABASE_URL_SYNC
admin_url = settings.ADMIN_DATABASE_URL
if admin_url and "+asyncpg" in admin_url:
return admin_url.replace("postgresql+asyncpg://", "postgresql://")
return settings.DATABASE_URL_SYNC
# this is the Alembic Config object
config = context.config
# Override sqlalchemy.url with the sync version for migrations
config.set_main_option("sqlalchemy.url", settings.DATABASE_URL_SYNC)
config.set_main_option("sqlalchemy.url", _alembic_sync_url())
# Interpret the config file for Python logging.
if config.config_file_name is not None:
@@ -86,7 +110,7 @@ def run_migrations_online() -> None:
from sqlalchemy import create_engine
connectable = create_engine(
settings.DATABASE_URL_SYNC,
_alembic_sync_url(),
poolclass=pool.NullPool,
)

View File

@@ -0,0 +1,90 @@
"""Enable RLS on Phase 2 session and supporting tables.
10 tables use a standard tenant-only policy.
step_library uses a visibility-aware policy — public steps visible to all tenants.
NOTE: session_messages does not exist in this codebase (removed from plan).
script_generations is the correct table name (not script_template_generations).
sessions and ai_sessions are two separate tables, both in scope.
Prerequisites:
- Phase 1 migration must have run (resolutionflow_app role exists, Phase 1 tables have RLS)
- NOT NULL write-path bugs fixed (P2-A commits b641ac6)
- shares.py cross-tenant session fix deployed (P2-B commit ac2b193)
Revision ID: 70a5dd746e83
Revises: c5f48b9890f9
Create Date: 2026-04-10 06:54:49.431817
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '70a5dd746e83'
down_revision: Union[str, None] = 'c5f48b9890f9'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_NULL_UUID = "00000000-0000-0000-0000-000000000000"
_CURRENT_ACCOUNT = (
f"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
f"'{_NULL_UUID}')::uuid"
)
# Standard tenant-only policy — account_id must match the current tenant.
# When no tenant context is set, COALESCE returns the nil UUID so zero rows
# are visible (fail-closed).
_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}"
# Visibility-aware policy for step_library — public steps (visibility='public')
# must be visible to ALL tenants regardless of account_id. This covers the
# visibility='public' arm of build_step_visibility_filter() in app/core/filters.py.
# The created_by arm (private steps visible to their author) is covered
# transitively: private steps share account_id with their creator, so the
# account_id match handles it. This relies on account_id NOT NULL on step_library.
_STEP_LIBRARY_USING = f"account_id = {_CURRENT_ACCOUNT} OR visibility = 'public'"
# Standard tables: strict tenant isolation, no cross-tenant visibility.
_STANDARD_TABLES = [
"sessions",
"ai_sessions",
"session_branches",
"session_supporting_data",
"session_resolution_outputs",
"session_handoffs",
"script_templates",
"script_generations",
"maintenance_schedules",
"psa_post_log",
]
def upgrade() -> None:
# ── Standard tenant-isolation tables ────────────────────────────────────
for table in _STANDARD_TABLES:
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON {table}
USING ({_STANDARD_USING})
""")
# ── step_library ────────────────────────────────────────────────────────
# Public steps (visibility='public') must be readable by all tenants so
# the Solutions Library browsing experience works without tenant context.
# Private/team steps remain tenant-scoped.
op.execute("ALTER TABLE step_library ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE step_library FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON step_library
USING ({_STEP_LIBRARY_USING})
""")
def downgrade() -> None:
for table in _STANDARD_TABLES + ["step_library"]:
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")

View File

@@ -43,6 +43,7 @@ async def create_suggestion(
suggestion = AISuggestion(
tree_id=data.tree_id,
user_id=current_user.id,
account_id=current_user.account_id,
session_id=data.session_id,
action_type=data.action_type,
target_node_id=data.target_node_id,

View File

@@ -69,6 +69,7 @@ async def create_schedule(
schedule = MaintenanceSchedule(
tree_id=data.tree_id,
account_id=current_user.account_id,
created_by=current_user.id,
cron_expression=data.cron_expression,
timezone=data.timezone,

View File

@@ -91,6 +91,7 @@ async def submit_step_feedback(
new_rating = StepRating(
step_id=step_id,
user_id=current_user.id,
account_id=current_user.account_id,
session_id=session_uuid,
was_helpful=data.was_helpful,
# rating is nullable now — thumbs-only mode

View File

@@ -196,6 +196,7 @@ async def start_session(
new_session = Session(
tree_id=tree.id,
user_id=current_user.id,
account_id=current_user.account_id,
tree_snapshot=tree_snapshot,
path_taken=[],
decisions=[],
@@ -693,6 +694,7 @@ async def prepare_session(
new_session = Session(
tree_id=tree.id,
user_id=data.assigned_to_id or current_user.id,
account_id=current_user.account_id,
tree_snapshot=tree_snapshot,
path_taken=[],
decisions=[],
@@ -770,6 +772,7 @@ async def batch_launch_sessions(
session = Session(
tree_id=tree.id,
user_id=current_user.id,
account_id=current_user.account_id,
tree_snapshot=tree_snapshot,
path_taken=[],
decisions=[],
@@ -1102,6 +1105,7 @@ async def psa_post_to_ticket(
# Log to audit trail
log_entry = PsaPostLog(
session_id=session.id,
account_id=session.account_id,
psa_connection_id=psa_connection.id if psa_connection else None,
ticket_id=session.psa_ticket_id,
note_type=data.note_type,

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import joinedload
from sqlalchemy.exc import IntegrityError
from app.core.database import get_db
from app.core.admin_database import get_admin_db
from app.models.session import Session
from app.models.session_share import SessionShare, SessionShareView
from app.models.user import User
@@ -210,7 +211,7 @@ async def _get_optional_user(request: Request, db: AsyncSession) -> Optional[Use
async def access_share(
share_token: str,
request: Request,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
):
"""Access a shared session via share token.

View File

@@ -460,6 +460,7 @@ async def rate_step(
rating = StepRating(
step_id=step_id,
user_id=current_user.id,
account_id=current_user.account_id,
rating=rating_data.rating,
was_helpful=rating_data.was_helpful,
review_text=rating_data.review_text,

View File

@@ -103,6 +103,7 @@ async def create_supporting_data(
item = SessionSupportingData(
session_id=session_id,
account_id=session.account_id,
label=data.label,
data_type=data.data_type,
content=data.content,

View File

@@ -2,8 +2,10 @@
"""
Admin database engine — connects as resolutionflow_admin (BYPASSRLS).
Use ONLY for /admin/* endpoints and internal tooling.
Never use this engine from user-facing endpoints.
Use ONLY where explicit application-level access control makes database-layer
tenant filtering unnecessary: /admin/* endpoints, internal tooling, and public
endpoints that enforce their own authorization before returning data (e.g.
share access via opaque token + visibility check).
"""
from collections.abc import AsyncGenerator
@@ -25,7 +27,7 @@ _admin_session_factory = async_sessionmaker(
async def get_admin_db() -> AsyncGenerator[AsyncSession, None]:
"""Yield an admin DB session (BYPASSRLS). Use only on /admin/* endpoints."""
"""Yield an admin DB session (BYPASSRLS). See module docstring for approved use cases."""
async with _admin_session_factory() as session:
try:
yield session

View File

@@ -34,6 +34,7 @@ class BranchManager:
root = SessionBranch(
id=uuid.uuid4(),
session_id=session_id,
account_id=session.account_id,
parent_branch_id=None,
branch_order=1,
label="Root",
@@ -68,9 +69,17 @@ class BranchManager:
"status": "untried",
})
# Load session to get account_id for FK constraints
session_result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
session = session_result.scalar_one_or_none()
account_id = session.account_id if session else None
fork_point = ForkPoint(
id=uuid.uuid4(),
session_id=session_id,
account_id=account_id,
parent_branch_id=parent_branch_id,
trigger_step_id=trigger_step_id,
fork_reason=fork_reason,
@@ -90,6 +99,7 @@ class BranchManager:
branch = SessionBranch(
id=branch_ids[i],
session_id=session_id,
account_id=account_id,
parent_branch_id=parent_branch_id,
fork_point_step_id=trigger_step_id,
branch_order=i + 1,

View File

@@ -56,6 +56,7 @@ class HandoffManager:
handoff = SessionHandoff(
session_id=session_id,
account_id=session.account_id,
handed_off_by=user_id,
intent=intent,
source_branch_id=session.active_branch_id,

View File

@@ -371,6 +371,7 @@ async def push_documentation(
# Log success
log_entry = PsaPostLog(
id=uuid.uuid4(),
account_id=session.account_id,
ai_session_id=session.id,
psa_connection_id=session.psa_connection_id,
ticket_id=session.psa_ticket_id,
@@ -394,6 +395,7 @@ async def push_documentation(
# Log failure with retry scheduling
log_entry = PsaPostLog(
id=uuid.uuid4(),
account_id=session.account_id,
ai_session_id=session.id,
psa_connection_id=session.psa_connection_id,
ticket_id=session.psa_ticket_id,

View File

@@ -45,6 +45,7 @@ class ResolutionOutputGenerator:
output = SessionResolutionOutput(
session_id=session_id,
account_id=session.account_id,
output_type=output_type,
generated_content=content,
status="draft",

View File

@@ -75,6 +75,19 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]:
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
"""))
# Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by
# global categories, gallery items, and other platform-owned content.
await conn.execute(sa.text("""
INSERT INTO accounts (id, name, display_code, created_at, updated_at)
VALUES (
'00000000-0000-0000-0000-000000000001',
'ResolutionFlow System',
'RF-SYS-1',
NOW(), NOW()
)
ON CONFLICT (id) DO NOTHING
"""))
# Create async session maker
async_session_maker = async_sessionmaker(
engine,

View File

@@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.models.tree import Tree
from app.models.script_template import ScriptTemplate, ScriptCategory
_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
# ---------------------------------------------------------------------------
# Helpers
@@ -22,6 +23,7 @@ async def _create_tree(db: AsyncSession, admin_user_id: str) -> Tree:
name="Gallery Test Flow",
tree_type="troubleshooting",
visibility="public",
account_id=_PLATFORM_ACCOUNT_ID,
is_gallery_featured=False,
gallery_sort_order=0,
tree_structure={
@@ -53,6 +55,7 @@ async def _create_script(db: AsyncSession, admin_user_id: str) -> ScriptTemplate
script = ScriptTemplate(
id=uuid.uuid4(),
category_id=category.id,
account_id=_PLATFORM_ACCOUNT_ID,
name="Gallery Test Script",
slug=f"gallery-test-script-{uuid.uuid4().hex[:6]}",
script_body="Write-Host 'Test'",

View File

@@ -594,6 +594,7 @@ class TestPsaMetrics:
post_log = PsaPostLog(
id=uuid.uuid4(),
ai_session_id=push_session_id,
account_id=account_id,
ticket_id="TICKET-123",
note_type="internal",
content_posted="Session summary",

View File

@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.security import get_password_hash
from app.models.account import Account
from app.models.team import Team
from app.models.user import User
@@ -23,6 +24,8 @@ async def _create_team_with_admin(
team_name: str = "Branding Test Team",
) -> tuple[dict, str, Team]:
"""Create a team + team admin user. Returns (auth_headers, team_id_str, team)."""
account = Account(name=team_name, display_code=uuid.uuid4().hex[:8].upper())
test_db.add(account)
team = Team(name=team_name)
test_db.add(team)
await test_db.flush()
@@ -36,6 +39,8 @@ async def _create_team_with_admin(
team_id=team.id,
is_team_admin=True,
role="engineer",
account_id=account.id,
account_role="engineer",
)
test_db.add(user)
await test_db.commit()
@@ -58,6 +63,15 @@ async def _create_team_member(
is_team_admin: bool = False,
) -> dict:
"""Create a regular team member. Returns auth_headers."""
# Look up the account associated with this team via an existing member
from sqlalchemy import select as _select
from app.models.user import User as _User
result = await test_db.execute(
_select(_User).where(_User.team_id == team.id).limit(1)
)
team_member = result.scalar_one_or_none()
member_account_id = team_member.account_id if team_member else None
email = f"member_{uuid.uuid4().hex[:8]}@test.com"
user = User(
email=email,
@@ -67,6 +81,8 @@ async def _create_team_member(
team_id=team.id,
is_team_admin=is_team_admin,
role="engineer",
account_id=member_account_id,
account_role="engineer",
)
test_db.add(user)
await test_db.commit()

View File

@@ -334,12 +334,13 @@ class TestDraftTreesAPI:
"""Test that migration defaults existing trees to published status."""
# Create a tree without specifying status (relies on DB default)
from uuid import UUID, uuid4
_platform_id = UUID("00000000-0000-0000-0000-000000000001")
tree = Tree(
name="Legacy Tree",
description="Created before status field",
tree_structure={"id": "root", "type": "solution", "title": "Fix"},
author_id=None,
account_id=None
account_id=_platform_id,
)
test_db.add(tree)
await test_db.commit()

View File

@@ -127,10 +127,12 @@ async def test_cannot_schedule_other_teams_tree(client: AsyncClient, auth_header
test_db.add(other_team)
await test_db.flush()
from uuid import UUID as _UUID
other_tree = Tree(
name="Other Team Tree",
tree_type="maintenance",
team_id=other_team.id,
account_id=_UUID("00000000-0000-0000-0000-000000000001"),
tree_structure={
"steps": [
{"id": "s1", "type": "procedure_step", "title": "Step",

View File

@@ -11,6 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.models.script_template import ScriptCategory, ScriptTemplate
from app.models.tree import Tree
_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
# ---------------------------------------------------------------------------
# Helpers
@@ -41,6 +43,7 @@ async def _create_featured_tree(db: AsyncSession, name: str = "Featured Flow", f
description="A featured flow for the gallery",
tree_type="troubleshooting",
tree_structure=_make_tree_structure(4),
account_id=_PLATFORM_ACCOUNT_ID,
is_gallery_featured=featured,
is_active=True,
usage_count=42,
@@ -74,6 +77,7 @@ async def _create_featured_script(
) -> ScriptTemplate:
script = ScriptTemplate(
category_id=category.id,
account_id=_PLATFORM_ACCOUNT_ID,
name=name,
slug=name.lower().replace(" ", "-"),
description="A gallery-featured script",
@@ -312,7 +316,7 @@ class TestCategoriesEndpoint:
from app.models.category import TreeCategory
# Create a category and a featured tree in that category
cat = TreeCategory(name="Networking", slug="networking", is_active=True)
cat = TreeCategory(name="Networking", slug="networking", is_active=True, account_id=_PLATFORM_ACCOUNT_ID)
test_db.add(cat)
await test_db.commit()
await test_db.refresh(cat)
@@ -321,6 +325,7 @@ class TestCategoriesEndpoint:
name="Router Diagnostics",
tree_type="troubleshooting",
tree_structure=_make_tree_structure(2),
account_id=_PLATFORM_ACCOUNT_ID,
is_gallery_featured=True,
is_active=True,
usage_count=5,

View File

@@ -62,6 +62,7 @@ async def test_edit_output(client: AsyncClient, test_user, auth_headers, test_db
output = SessionResolutionOutput(
session_id=session.id,
account_id=session.account_id,
output_type="psa_ticket_notes",
generated_content="Original notes",
status="draft",

View File

@@ -16,7 +16,10 @@ Run with:
The test DB is patherly_test (matches conftest.py default).
"""
import os
import subprocess
import sys
import uuid
from pathlib import Path
import asyncpg
import pytest
@@ -37,7 +40,25 @@ ACCOUNT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
async def admin_conn():
def _ensure_rls_schema():
"""Re-apply Alembic migrations before the module runs.
Function-scoped test_db fixtures in other modules drop and recreate the
public schema using Base.metadata.create_all, which does not enable RLS
or create DB roles. This fixture re-runs 'alembic upgrade head' so that
the full migration-managed schema (including RLS policies) is in place.
"""
backend_dir = Path(__file__).parent.parent
subprocess.run(
[sys.executable, "-m", "alembic", "upgrade", "head"],
cwd=backend_dir,
check=True,
capture_output=True,
)
@pytest.fixture(scope="module")
async def admin_conn(_ensure_rls_schema):
"""Superuser asyncpg connection for fixture setup and teardown."""
conn = await asyncpg.connect(_ADMIN_DSN)
yield conn
@@ -264,3 +285,426 @@ async def test_flow_proposals_account_a_cannot_see_account_b(conn_a):
f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0
# ---------------------------------------------------------------------------
# Phase 2 fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
async def session_row_ids(admin_conn):
"""
Insert one `sessions` row and one `ai_sessions` row for each of
ACCOUNT_A and ACCOUNT_B using the superuser connection (BYPASSRLS).
Returns a dict with the inserted IDs for use in tests.
Cleans up on exit.
"""
# Resolve a valid tree_id and user_id for each account
tree_a = await admin_conn.fetchrow(
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}' LIMIT 1"
)
tree_b = await admin_conn.fetchrow(
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
)
user_a = await admin_conn.fetchrow(
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_A_ID}' LIMIT 1"
)
user_b = await admin_conn.fetchrow(
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
)
assert tree_a is not None, f"No tree found for ACCOUNT_A ({ACCOUNT_A_ID}) — seed_rls_test_data must run first"
assert tree_b is not None, f"No tree found for ACCOUNT_B ({ACCOUNT_B_ID}) — seed_rls_test_data must run first"
assert user_a is not None, f"No user found for ACCOUNT_A ({ACCOUNT_A_ID}) — seed_rls_test_data must run first"
assert user_b is not None, f"No user found for ACCOUNT_B ({ACCOUNT_B_ID}) — seed_rls_test_data must run first"
tree_a_id = str(tree_a["id"])
tree_b_id = str(tree_b["id"])
user_a_id = str(user_a["id"])
user_b_id = str(user_b["id"])
session_a_id = str(uuid.uuid4())
session_b_id = str(uuid.uuid4())
ai_session_a_id = str(uuid.uuid4())
ai_session_b_id = str(uuid.uuid4())
# Insert sessions rows (sessions uses started_at not created_at)
await admin_conn.execute(f"""
INSERT INTO sessions (
id, tree_id, user_id, account_id, tree_snapshot,
path_taken, decisions, custom_steps, started_at
) VALUES
('{session_a_id}', '{tree_a_id}', '{user_a_id}', '{ACCOUNT_A_ID}',
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()),
('{session_b_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW())
""")
# Insert ai_sessions rows
# confidence_tier valid values: 'guided' | 'exploring' | 'discovery'
await admin_conn.execute(f"""
INSERT INTO ai_sessions (
id, user_id, account_id, session_type, intake_type,
intake_content, status, confidence_tier, confidence_score,
created_at, updated_at
) VALUES
('{ai_session_a_id}', '{user_a_id}', '{ACCOUNT_A_ID}',
'guided', 'free_text', '{{}}'::jsonb, 'active', 'guided', 0.0,
NOW(), NOW()),
('{ai_session_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
'guided', 'free_text', '{{}}'::jsonb, 'active', 'guided', 0.0,
NOW(), NOW())
""")
# -------------------------------------------------------------------------
# Seed Account B rows for every "cannot-see" table that would otherwise be
# empty. Without these, isolation tests pass vacuously even when RLS is off.
# -------------------------------------------------------------------------
# session_branches (FK: ai_sessions.id)
branch_b_row = await admin_conn.fetchrow("""
INSERT INTO session_branches (
id, session_id, account_id, branch_order, label, status,
conversation_messages, created_at, updated_at
) VALUES (
gen_random_uuid(), $1::uuid, $2::uuid, 1, 'test-branch', 'active',
'[]'::jsonb, NOW(), NOW()
) RETURNING id
""", ai_session_b_id, ACCOUNT_B_ID)
branch_b_id = str(branch_b_row["id"])
# session_supporting_data (FK: sessions.id)
supporting_data_b_row = await admin_conn.fetchrow("""
INSERT INTO session_supporting_data (
id, session_id, account_id, label, data_type, content,
sort_order, created_at, updated_at
) VALUES (
gen_random_uuid(), $1::uuid, $2::uuid, 'test-data', 'text_snippet',
'test content', 0, NOW(), NOW()
) RETURNING id
""", session_b_id, ACCOUNT_B_ID)
supporting_data_b_id = str(supporting_data_b_row["id"])
# session_resolution_outputs (FK: ai_sessions.id)
resolution_output_b_row = await admin_conn.fetchrow("""
INSERT INTO session_resolution_outputs (
id, session_id, account_id, output_type, generated_content,
status, generated_by_model, created_at, updated_at
) VALUES (
gen_random_uuid(), $1::uuid, $2::uuid, 'psa_ticket_notes',
'test content', 'draft', 'test-model', NOW(), NOW()
) RETURNING id
""", ai_session_b_id, ACCOUNT_B_ID)
resolution_output_b_id = str(resolution_output_b_row["id"])
# session_handoffs (FK: ai_sessions.id, users.id)
handoff_b_row = await admin_conn.fetchrow("""
INSERT INTO session_handoffs (
id, session_id, account_id, handed_off_by, intent, snapshot,
priority, psa_note_pushed, notification_sent, created_at
) VALUES (
gen_random_uuid(), $1::uuid, $2::uuid, $3::uuid, 'park',
'{}'::jsonb, 'normal', false, false, NOW()
) RETURNING id
""", ai_session_b_id, ACCOUNT_B_ID, user_b_id)
handoff_b_id = str(handoff_b_row["id"])
# maintenance_schedules (FK: trees.id)
maintenance_b_row = await admin_conn.fetchrow("""
INSERT INTO maintenance_schedules (
id, tree_id, account_id, cron_expression, timezone,
created_at, updated_at
) VALUES (
gen_random_uuid(), $1::uuid, $2::uuid, '0 9 * * 1', 'UTC',
NOW(), NOW()
) RETURNING id
""", tree_b_id, ACCOUNT_B_ID)
maintenance_b_id = str(maintenance_b_row["id"])
# psa_post_log (FK: ai_sessions.id, users.id)
psa_log_b_row = await admin_conn.fetchrow("""
INSERT INTO psa_post_log (
id, ai_session_id, account_id, ticket_id, note_type,
content_posted, status, posted_by, posted_at
) VALUES (
gen_random_uuid(), $1::uuid, $2::uuid, 'TEST-0001', 'internal',
'test note', 'success', $3::uuid, NOW()
) RETURNING id
""", ai_session_b_id, ACCOUNT_B_ID, user_b_id)
psa_log_b_id = str(psa_log_b_row["id"])
# script_templates requires a script_categories row — insert a temporary one
script_category_b_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO script_categories (id, name, slug, sort_order, is_active, created_at, updated_at)
VALUES ('{script_category_b_id}', 'RLS Test Category', 'rls-test-category-{script_category_b_id[:8]}',
0, true, NOW(), NOW())
""")
script_template_b_row = await admin_conn.fetchrow(f"""
INSERT INTO script_templates (
id, category_id, account_id, name, slug, script_body,
complexity, is_active, created_at, updated_at
) VALUES (
gen_random_uuid(), '{script_category_b_id}'::uuid, $1::uuid,
'RLS Test Template', 'rls-test-template-b-' || gen_random_uuid()::text,
'Write-Host "test"', 'beginner', true, NOW(), NOW()
) RETURNING id
""", ACCOUNT_B_ID)
script_template_b_id = str(script_template_b_row["id"])
# script_generations (FK: script_templates.id, users.id)
script_gen_b_row = await admin_conn.fetchrow("""
INSERT INTO script_generations (
id, template_id, user_id, account_id, parameters_used,
generated_script, created_at
) VALUES (
gen_random_uuid(), $1::uuid, $2::uuid, $3::uuid, '{}'::jsonb,
'test script', NOW()
) RETURNING id
""", script_template_b_id, user_b_id, ACCOUNT_B_ID)
script_gen_b_id = str(script_gen_b_row["id"])
try:
yield {
"session_a": session_a_id,
"session_b": session_b_id,
"ai_session_a": ai_session_a_id,
"ai_session_b": ai_session_b_id,
}
finally:
# Cleanup in reverse FK order (children before parents)
await admin_conn.execute(
f"DELETE FROM script_generations WHERE id = '{script_gen_b_id}'"
)
await admin_conn.execute(
f"DELETE FROM session_branches WHERE id = '{branch_b_id}'"
)
await admin_conn.execute(
f"DELETE FROM session_supporting_data WHERE id = '{supporting_data_b_id}'"
)
await admin_conn.execute(
f"DELETE FROM session_resolution_outputs WHERE id = '{resolution_output_b_id}'"
)
await admin_conn.execute(
f"DELETE FROM session_handoffs WHERE id = '{handoff_b_id}'"
)
await admin_conn.execute(
f"DELETE FROM maintenance_schedules WHERE id = '{maintenance_b_id}'"
)
await admin_conn.execute(
f"DELETE FROM psa_post_log WHERE id = '{psa_log_b_id}'"
)
await admin_conn.execute(
f"DELETE FROM script_templates WHERE id = '{script_template_b_id}'"
)
await admin_conn.execute(
f"DELETE FROM script_categories WHERE id = '{script_category_b_id}'"
)
await admin_conn.execute(
f"DELETE FROM sessions WHERE id IN ('{session_a_id}', '{session_b_id}')"
)
await admin_conn.execute(
f"DELETE FROM ai_sessions WHERE id IN ('{ai_session_a_id}', '{ai_session_b_id}')"
)
# ---------------------------------------------------------------------------
# sessions
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_b']}'"
)
assert len(rows) == 0, "Account A should not see Account B sessions"
@pytest.mark.asyncio
async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_a']}'"
)
assert len(rows) == 1, "Account A should see its own sessions"
@pytest.mark.asyncio
async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids):
rows = await conn_no_context.fetch(
f"SELECT id FROM sessions WHERE id IN "
f"('{session_row_ids['session_a']}', '{session_row_ids['session_b']}')"
)
assert len(rows) == 0, "No-context connection should see no sessions"
# ---------------------------------------------------------------------------
# ai_sessions
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_b']}'"
)
assert len(rows) == 0, "Account A should not see Account B ai_sessions"
@pytest.mark.asyncio
async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_a']}'"
)
assert len(rows) == 1, "Account A should see its own ai_sessions"
# ---------------------------------------------------------------------------
# session_branches
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM session_branches WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B session_branches"
# ---------------------------------------------------------------------------
# session_supporting_data
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM session_supporting_data WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B session_supporting_data"
# ---------------------------------------------------------------------------
# session_resolution_outputs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM session_resolution_outputs WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B session_resolution_outputs"
# ---------------------------------------------------------------------------
# session_handoffs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM session_handoffs WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B session_handoffs"
# ---------------------------------------------------------------------------
# script_templates
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM script_templates WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B script_templates"
# ---------------------------------------------------------------------------
# script_generations
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_script_generations_account_a_cannot_see_account_b(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM script_generations WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B script_generations"
# ---------------------------------------------------------------------------
# maintenance_schedules
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM maintenance_schedules WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B maintenance_schedules"
# ---------------------------------------------------------------------------
# psa_post_log
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_ids):
rows = await conn_a.fetch(
f"SELECT id FROM psa_post_log WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B psa_post_log"
# ---------------------------------------------------------------------------
# step_library — visibility-aware policy
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_conn, conn_a):
"""Private/non-public steps owned by Account B must not be visible to Account A."""
private_step_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO step_library (
id, account_id, title, step_type, content,
visibility, is_active, created_at, updated_at
) VALUES (
'{private_step_id}', '{ACCOUNT_B_ID}', 'RLS Private Step', 'action',
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM step_library "
f"WHERE id = '{private_step_id}' AND visibility != 'public'"
)
assert len(rows) == 0, "Account A should not see Account B's private step_library rows"
finally:
await admin_conn.execute(
f"DELETE FROM step_library WHERE id = '{private_step_id}'"
)
@pytest.mark.asyncio
async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn, conn_a):
"""Public steps owned by Account B MUST be visible to Account A (cross-tenant visibility)."""
public_step_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO step_library (
id, account_id, title, step_type, content,
visibility, is_active, created_at, updated_at
) VALUES (
'{public_step_id}', '{ACCOUNT_B_ID}', 'RLS Public Step', 'action',
'{{}}'::jsonb, 'public', TRUE, NOW(), NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM step_library WHERE id = '{public_step_id}'"
)
assert len(rows) == 1, (
"Account A should see public steps owned by Account B "
"(cross-tenant public visibility policy)"
)
finally:
await admin_conn.execute(
f"DELETE FROM step_library WHERE id = '{public_step_id}'"
)

View File

@@ -155,6 +155,7 @@ class TestSaveSessionAsTreeAPI:
session = Session(
tree_id=tree.id,
user_id=UUID(test_user["user_data"]["id"]),
account_id=UUID(test_user["user_data"]["account_id"]),
tree_snapshot=tree.tree_structure,
path_taken=["root"],
decisions=[{"node_id": "root", "timestamp": datetime.now(timezone.utc).isoformat()}],
@@ -199,6 +200,7 @@ class TestSaveSessionAsTreeAPI:
session = Session(
tree_id=tree.id,
user_id=UUID(test_user["user_data"]["id"]),
account_id=UUID(test_user["user_data"]["account_id"]),
tree_snapshot=tree.tree_structure,
path_taken=["root"],
decisions=[],
@@ -239,6 +241,7 @@ class TestSaveSessionAsTreeAPI:
session = Session(
tree_id=tree.id,
user_id=UUID(test_user["user_data"]["id"]),
account_id=UUID(test_user["user_data"]["account_id"]),
tree_snapshot=tree.tree_structure,
path_taken=["root"],
decisions=[],
@@ -279,6 +282,7 @@ class TestSaveSessionAsTreeAPI:
session = Session(
tree_id=tree.id,
user_id=UUID(test_user["user_data"]["id"]),
account_id=UUID(test_user["user_data"]["account_id"]),
tree_snapshot=tree.tree_structure,
path_taken=["root"],
decisions=[],
@@ -352,6 +356,7 @@ class TestSaveSessionAsTreeAPI:
session = Session(
tree_id=tree.id,
user_id=other_user.id,
account_id=UUID(test_user["user_data"]["account_id"]),
tree_snapshot=tree.tree_structure,
path_taken=["root"],
decisions=[],

View File

@@ -57,6 +57,7 @@ function loadTaskState(sessionId: string): TaskResponse[] | null {
} catch { return null }
}
// eslint-disable-next-line react-refresh/only-export-components
export function clearTaskState(sessionId: string) {
try { sessionStorage.removeItem(`${TASK_LANE_STORAGE_KEY}:${sessionId}`) } catch { /* ignore */ }
}

View File

@@ -9,10 +9,10 @@ export function TeamSummary() {
const { isAccountOwner } = usePermissions()
const navigate = useNavigate()
const [escalationCount, setEscalationCount] = useState(0)
const [loading, setLoading] = useState(true)
const [loading, setLoading] = useState(!!isAccountOwner)
useEffect(() => {
if (!isAccountOwner) { setLoading(false); return }
if (!isAccountOwner) return
aiSessionsApi.getEscalationQueue()
.then((esc) => setEscalationCount(esc.length))
.catch(() => {})

View File

@@ -1,4 +1,4 @@
import { useCallback, useRef } from 'react'
import { useCallback, useEffect, useRef } from 'react'
import Editor, { type BeforeMount } from '@monaco-editor/react'
import { resolutionFlowTheme, THEME_ID } from '@/components/tree-editor/code-mode/resolutionFlowTheme'
import { Spinner } from '@/components/common/Spinner'
@@ -11,7 +11,9 @@ interface Props {
export function ScriptBodyEditor({ value, onChange, disabled }: Props) {
const lastValueRef = useRef(value)
lastValueRef.current = value
useEffect(() => {
lastValueRef.current = value
}, [value])
const handleBeforeMount: BeforeMount = useCallback((monaco) => {
// Register our dark theme if not already defined