Merge pull request #135 from resolutionflow/feat/tenant-isolation-phase-3

feat: tenant isolation Phase 3 — audit_logs, tree_shares, remaining RLS
This commit was merged in pull request #135.
This commit is contained in:
chihlasm
2026-04-11 04:28:47 -04:00
committed by GitHub
15 changed files with 526 additions and 56 deletions

View File

@@ -0,0 +1,59 @@
"""Enable RLS on Phase 3 tables.
Tables covered:
- step_ratings (account_id NOT NULL since migration 7167e9374b0c)
- step_usage_log (account_id NOT NULL since migration 7167e9374b0c)
- target_lists (account_id NOT NULL since migration 2c6aabd89bc6)
- session_shares (account_id NOT NULL since session_share model)
- audit_logs (account_id NOT NULL since migration 2a9056eddd90)
- tree_shares (account_id NOT NULL since migration a05e1a1bea7c)
All use a standard intra-tenant isolation policy.
Token-based access to session_shares and tree_shares goes through
endpoints that use get_admin_db (BYPASSRLS), so a strict tenant
policy here is correct.
Revision ID: 04f013768235
Revises: a05e1a1bea7c
Create Date: 2026-04-11 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
revision: str = '04f013768235'
down_revision: Union[str, None] = 'a05e1a1bea7c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_CURRENT_ACCOUNT = (
"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
"'00000000-0000-0000-0000-000000000000')::uuid"
)
_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}"
_PHASE3_TABLES = [
"step_ratings",
"step_usage_log",
"target_lists",
"session_shares",
"audit_logs",
"tree_shares",
]
def upgrade() -> None:
for table in _PHASE3_TABLES:
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON {table}
USING ({_STANDARD_USING})
""")
def downgrade() -> None:
for table in _PHASE3_TABLES:
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")

View File

@@ -0,0 +1,32 @@
"""Drop team_id from target_lists.
account_id (NOT NULL) is now the tenant isolation key; team_id is redundant.
All reads/writes use account_id via RLS + application filter.
Revision ID: 172ad76d7d20
Revises: 04f013768235
Create Date: 2026-04-11 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '172ad76d7d20'
down_revision: Union[str, None] = '04f013768235'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.drop_index('ix_target_lists_team_id', table_name='target_lists', if_exists=True)
op.drop_constraint('target_lists_team_id_fkey', 'target_lists', type_='foreignkey')
op.drop_column('target_lists', 'team_id')
def downgrade() -> None:
op.add_column('target_lists', sa.Column('team_id', sa.UUID(), nullable=True))
op.create_foreign_key(
'target_lists_team_id_fkey', 'target_lists', 'teams',
['team_id'], ['id'], ondelete='CASCADE',
)
op.create_index('ix_target_lists_team_id', 'target_lists', ['team_id'])

View File

@@ -0,0 +1,51 @@
"""Add account_id to audit_logs and backfill via user_id.
Revision ID: 2a9056eddd90
Revises: 70a5dd746e83
Create Date: 2026-04-11 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '2a9056eddd90'
down_revision: Union[str, None] = '70a5dd746e83'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('audit_logs', sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
'fk_audit_logs_account_id', 'audit_logs', 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
# Backfill: derive from the acting user's account
op.execute("""
UPDATE audit_logs al
SET account_id = u.account_id
FROM users u
WHERE al.user_id = u.id
AND u.account_id IS NOT NULL
AND al.account_id IS NULL
""")
result = op.get_bind().execute(
sa.text("SELECT COUNT(*) FROM audit_logs WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} audit_logs rows have NULL account_id after backfill. "
"All audit log entries must have an associated user with an account."
)
op.alter_column('audit_logs', 'account_id', nullable=False)
op.create_index('ix_audit_logs_account_id', 'audit_logs', ['account_id'])
def downgrade() -> None:
op.drop_index('ix_audit_logs_account_id', table_name='audit_logs')
op.drop_constraint('fk_audit_logs_account_id', 'audit_logs', type_='foreignkey')
op.drop_column('audit_logs', 'account_id')

View File

@@ -0,0 +1,57 @@
"""Add account_id to tree_shares and backfill via tree owner's account.
The share belongs to the tree's tenant, not the actor who created it.
A super admin in account A can share a tree owned by account B; that share
must land in account B so account B's RLS filter sees it.
Revision ID: a05e1a1bea7c
Revises: 2a9056eddd90
Create Date: 2026-04-11 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = 'a05e1a1bea7c'
down_revision: Union[str, None] = '2a9056eddd90'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('tree_shares', sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
'fk_tree_shares_account_id', 'tree_shares', 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
# Backfill: derive from the tree's account, not the creator's account.
# A share lives in the same tenant as its tree so that the tree owner's
# RLS context covers their own shares regardless of who created them.
op.execute("""
UPDATE tree_shares ts
SET account_id = t.account_id
FROM trees t
WHERE ts.tree_id = t.id
AND t.account_id IS NOT NULL
AND ts.account_id IS NULL
""")
result = op.get_bind().execute(
sa.text("SELECT COUNT(*) FROM tree_shares WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} tree_shares rows have NULL account_id after backfill. "
"All share entries must have a creating user with an account."
)
op.alter_column('tree_shares', 'account_id', nullable=False)
op.create_index('ix_tree_shares_account_id', 'tree_shares', ['account_id'])
def downgrade() -> None:
op.drop_index('ix_tree_shares_account_id', table_name='tree_shares')
op.drop_constraint('fk_tree_shares_account_id', 'tree_shares', type_='foreignkey')
op.drop_column('tree_shares', 'account_id')

View File

@@ -18,12 +18,10 @@ async def list_target_lists(
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""List all target lists for the current user's team."""
if not current_user.team_id:
return []
"""List all target lists for the current user's account."""
result = await db.execute(
select(TargetList)
.where(TargetList.team_id == current_user.team_id)
.where(TargetList.account_id == current_user.account_id)
.order_by(TargetList.name)
)
return result.scalars().all()
@@ -36,11 +34,9 @@ async def create_target_list(
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Create a new target list for the current team."""
if not current_user.team_id:
raise HTTPException(status_code=400, detail="User must belong to a team")
"""Create a new target list for the current account."""
target_list = TargetList(
team_id=current_user.team_id,
account_id=current_user.account_id,
created_by=current_user.id,
name=data.name,
description=data.description,
@@ -61,7 +57,7 @@ async def get_target_list(
result = await db.execute(
select(TargetList).where(
TargetList.id == list_id,
TargetList.team_id == current_user.team_id,
TargetList.account_id == current_user.account_id,
)
)
target_list = result.scalar_one_or_none()
@@ -81,7 +77,7 @@ async def update_target_list(
result = await db.execute(
select(TargetList).where(
TargetList.id == list_id,
TargetList.team_id == current_user.team_id,
TargetList.account_id == current_user.account_id,
)
)
target_list = result.scalar_one_or_none()
@@ -91,7 +87,7 @@ async def update_target_list(
if "name" in update_fields and data.name is not None:
target_list.name = data.name
if "description" in update_fields:
target_list.description = data.description # allow setting to None
target_list.description = data.description
if "targets" in update_fields and data.targets is not None:
target_list.targets = [t.model_dump() for t in data.targets]
await db.commit()
@@ -109,7 +105,7 @@ async def delete_target_list(
result = await db.execute(
select(TargetList).where(
TargetList.id == list_id,
TargetList.team_id == current_user.team_id,
TargetList.account_id == current_user.account_id,
)
)
target_list = result.scalar_one_or_none()

View File

@@ -1048,6 +1048,7 @@ async def create_tree_share(
# Create share
tree_share = TreeShare(
tree_id=tree.id,
account_id=tree.account_id, # share belongs to the tree's tenant, not the actor
share_token=share_token,
created_by=current_user.id,
allow_forking=share_data.allow_forking,

View File

@@ -12,10 +12,19 @@ async def log_audit(
resource_type: str,
resource_id: Optional[UUID] = None,
details: Optional[dict] = None,
account_id: Optional[UUID] = None,
) -> None:
"""Record an audit log entry. Does not commit — piggybacks on the caller's commit."""
if account_id is None:
# Derive from the acting user's account as a fallback (one extra query).
from sqlalchemy import select
from app.models.user import User
result = await db.execute(select(User.account_id).where(User.id == user_id))
account_id = result.scalar_one()
entry = AuditLog(
user_id=user_id,
account_id=account_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,

View File

@@ -21,6 +21,12 @@ class AuditLog(Base):
nullable=False,
index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True
)
action: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
resource_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
resource_id: Mapped[Optional[uuid.UUID]] = mapped_column(

View File

@@ -8,7 +8,6 @@ from app.core.database import Base
if TYPE_CHECKING:
from app.models.user import User
from app.models.team import Team
from app.models.account import Account
@@ -18,10 +17,6 @@ class TargetList(Base):
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
team_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"),
nullable=False, index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),

View File

@@ -25,6 +25,12 @@ class TreeShare(Base):
nullable=False,
index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True
)
share_token: Mapped[str] = mapped_column(
String(64),
unique=True,

View File

@@ -23,7 +23,7 @@ class TargetListUpdate(BaseModel):
class TargetListResponse(BaseModel):
id: UUID
team_id: UUID
account_id: UUID
created_by: Optional[UUID]
name: str
description: Optional[str]

View File

@@ -464,7 +464,6 @@ async def test_target_list_account_id_from_team_admin(test_db: AsyncSession):
await test_db.flush()
target_list = TargetList(
team_id=team.id,
account_id=account.id,
created_by=user.id,
name="Server Targets",

View File

@@ -708,3 +708,249 @@ async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn,
await admin_conn.execute(
f"DELETE FROM step_library WHERE id = '{public_step_id}'"
)
# ===========================================================================
# Phase 3 RLS isolation tests
# Tables: step_ratings, step_usage_log, target_lists,
# session_shares, audit_logs, tree_shares
# ===========================================================================
# ---------------------------------------------------------------------------
# Helpers shared by Phase 3 fixtures
# ---------------------------------------------------------------------------
async def _get_user_b_id(admin_conn) -> str:
row = await admin_conn.fetchrow(
"SELECT id FROM users WHERE email = 'rls-user-b@example.com'"
)
return str(row["id"])
async def _get_tree_b_id(admin_conn) -> str:
row = await admin_conn.fetchrow(
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
)
return str(row["id"])
# ---------------------------------------------------------------------------
# step_ratings
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see step ratings belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
# Need a step_library row as FK target
step_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO step_library (
id, account_id, title, step_type, content,
visibility, is_active, created_at, updated_at
) VALUES (
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 RLS Step', 'action',
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
)
""")
rating_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO step_ratings (
id, step_id, user_id, account_id, is_verified_use, is_visible,
created_at, updated_at
) VALUES (
'{rating_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
FALSE, TRUE, NOW(), NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM step_ratings WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B step_ratings"
finally:
await admin_conn.execute(f"DELETE FROM step_ratings WHERE id = '{rating_id}'")
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
# ---------------------------------------------------------------------------
# step_usage_log
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see step usage logs belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
tree_b_id = await _get_tree_b_id(admin_conn)
step_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO step_library (
id, account_id, title, step_type, content,
visibility, is_active, created_at, updated_at
) VALUES (
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 Usage Step', 'action',
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
)
""")
# Need a sessions row as FK for usage log
session_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO sessions (
id, tree_id, user_id, account_id, tree_snapshot,
path_taken, decisions, custom_steps, started_at
) VALUES (
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
)
""")
log_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO step_usage_log (
id, step_id, user_id, account_id, session_id, used_at
) VALUES (
'{log_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
'{session_id}', NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM step_usage_log WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B step_usage_log"
finally:
await admin_conn.execute(f"DELETE FROM step_usage_log WHERE id = '{log_id}'")
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
# ---------------------------------------------------------------------------
# target_lists
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see target lists belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
tl_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO target_lists (
id, account_id, created_by, name, targets, created_at, updated_at
) VALUES (
'{tl_id}', '{ACCOUNT_B_ID}', '{user_b_id}',
'Phase3 RLS Target List', '[]'::jsonb, NOW(), NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM target_lists WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B target_lists"
finally:
await admin_conn.execute(f"DELETE FROM target_lists WHERE id = '{tl_id}'")
# ---------------------------------------------------------------------------
# session_shares
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see session shares belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
tree_b_id = await _get_tree_b_id(admin_conn)
# Need a sessions row as FK
session_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO sessions (
id, tree_id, user_id, account_id, tree_snapshot,
path_taken, decisions, custom_steps, started_at
) VALUES (
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
)
""")
share_id = str(uuid.uuid4())
share_token = f"phase3-rls-test-{share_id[:8]}"
await admin_conn.execute(f"""
INSERT INTO session_shares (
id, session_id, account_id, share_token, visibility,
created_by, view_count, is_active, created_at, updated_at
) VALUES (
'{share_id}', '{session_id}', '{ACCOUNT_B_ID}',
'{share_token}', 'account', '{user_b_id}',
0, TRUE, NOW(), NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM session_shares WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B session_shares"
finally:
await admin_conn.execute(f"DELETE FROM session_shares WHERE id = '{share_id}'")
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
# ---------------------------------------------------------------------------
# audit_logs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see audit logs belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
log_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO audit_logs (
id, user_id, account_id, action, resource_type, created_at
) VALUES (
'{log_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
'test.action', 'test_resource', NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM audit_logs WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B audit_logs"
finally:
await admin_conn.execute(f"DELETE FROM audit_logs WHERE id = '{log_id}'")
# ---------------------------------------------------------------------------
# tree_shares
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see tree shares belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
tree_b_id = await _get_tree_b_id(admin_conn)
share_id = str(uuid.uuid4())
share_token = f"phase3-tree-rls-{share_id[:8]}"
await admin_conn.execute(f"""
INSERT INTO tree_shares (
id, tree_id, account_id, share_token, created_by,
allow_forking, created_at
) VALUES (
'{share_id}', '{tree_b_id}', '{ACCOUNT_B_ID}',
'{share_token}', '{user_b_id}', TRUE, NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM tree_shares WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B tree_shares"
finally:
await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'")

View File

@@ -3,37 +3,10 @@ import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.team import Team
from app.models.user import User
from sqlalchemy import select
@pytest.fixture
async def auth_headers(client: AsyncClient, test_db: AsyncSession, test_user: dict):
"""Override auth_headers to ensure the test user has a team_id assigned."""
# Fetch the user from DB and assign a team
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
user = result.scalar_one()
# Create a team and assign the user to it
team = Team(name="Test Team")
test_db.add(team)
await test_db.flush()
user.team_id = team.id
await test_db.commit()
# Re-login to get a fresh token
login_data = {
"email": test_user["email"],
"password": test_user["password"],
}
resp = await client.post("/api/v1/auth/login/json", json=login_data)
assert resp.status_code == 200
token_data = resp.json()
return {"Authorization": f"Bearer {token_data['access_token']}"}
@pytest.mark.asyncio
async def test_create_target_list(client: AsyncClient, auth_headers: dict):
resp = await client.post(
@@ -107,25 +80,28 @@ async def test_delete_target_list(client: AsyncClient, auth_headers: dict):
assert get.status_code == 404
@pytest.mark.asyncio
async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: dict, test_db):
"""User from team B cannot access team A's list."""
async def test_cannot_access_other_accounts_list(client: AsyncClient, auth_headers: dict, test_db):
"""User from account B cannot access account A's target list."""
import uuid
from app.models.team import Team
from app.models.account import Account
from app.models.user import User
from app.core.security import get_password_hash
# Create team A list using existing auth_headers
# Create account A list using existing auth_headers
create = await client.post(
"/api/v1/target-lists/",
json={"name": "Team A List", "targets": [{"label": "SRV-A"}]},
json={"name": "Account A List", "targets": [{"label": "SRV-A"}]},
headers=auth_headers,
)
assert create.status_code == 201
list_id = create.json()["id"]
# Create a separate team B with its own user
team_b = Team(name=f"Team B {uuid.uuid4()}")
test_db.add(team_b)
# Create a separate account B with its own user
account_b = Account(
name=f"Account B {uuid.uuid4()}",
display_code=f"AB{str(uuid.uuid4())[:6].upper()}",
)
test_db.add(account_b)
await test_db.flush()
user_b = User(
@@ -133,11 +109,13 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
password_hash=get_password_hash("password123"),
name="User B",
is_active=True,
team_id=team_b.id,
account_id=account_b.id,
account_role="engineer",
role="engineer",
)
test_db.add(user_b)
await test_db.flush()
await test_db.commit()
# Get auth token for user B
login = await client.post(
@@ -148,6 +126,6 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
token_b = login.json()["access_token"]
headers_b = {"Authorization": f"Bearer {token_b}"}
# Team B cannot access Team A's list
# Account B cannot access Account A's list
resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b)
assert resp.status_code == 404

View File

@@ -117,6 +117,7 @@ class TestTreeSharing:
for i in range(3):
share = TreeShare(
tree_id=sample_tree.id,
account_id=sample_tree.account_id,
share_token=f"token_{i}_" + "x" * 56,
created_by=sample_tree.author_id,
allow_forking=i % 2 == 0
@@ -162,6 +163,7 @@ class TestTreeSharing:
# Create a share
share = TreeShare(
tree_id=sample_tree.id,
account_id=sample_tree.account_id,
share_token="public_test_token" + "x" * 47,
created_by=UUID(test_user["user_data"]["id"]),
allow_forking=True
@@ -192,6 +194,7 @@ class TestTreeSharing:
# Create expired share
share = TreeShare(
tree_id=sample_tree.id,
account_id=sample_tree.account_id,
share_token="expired_token" + "x" * 50,
created_by=UUID(test_user["user_data"]["id"]),
allow_forking=True,
@@ -209,6 +212,7 @@ class TestTreeSharing:
from uuid import UUID
share = TreeShare(
tree_id=sample_tree.id,
account_id=sample_tree.account_id,
share_token="inactive_tree_token" + "x" * 44,
created_by=UUID(test_user["user_data"]["id"]),
allow_forking=True
@@ -248,6 +252,37 @@ class TestTreeSharing:
tokens.add(token)
assert len(tokens) == 5
async def test_share_account_id_matches_tree_not_actor(
self, client: AsyncClient, sample_tree, auth_headers, test_db
):
"""Share account_id must equal tree.account_id, not the actor's account_id.
A super admin in a different account can share any tree. The resulting
TreeShare row must live in the tree-owner's account so that the tree
owner's RLS context covers it. If account_id were derived from the
actor instead, the share would vanish from the tree owner's view once
RLS is enabled.
"""
from uuid import UUID
from sqlalchemy import select
response = await client.post(
f"/api/v1/trees/{sample_tree.id}/share",
json={"allow_forking": True},
headers=auth_headers,
)
assert response.status_code == 201
share_token = response.json()["share_token"]
result = await test_db.execute(
select(TreeShare).where(TreeShare.share_token == share_token)
)
share = result.scalar_one()
assert share.account_id == sample_tree.account_id, (
"TreeShare.account_id must equal tree.account_id, not the actor's account. "
"Shares must live in the tree owner's tenant for RLS to cover them."
)
@pytest.mark.asyncio
async def test_migration_defaults_visibility_to_team(test_db):