Merge pull request #135 from resolutionflow/feat/tenant-isolation-phase-3
feat: tenant isolation Phase 3 — audit_logs, tree_shares, remaining RLS
This commit was merged in pull request #135.
This commit is contained in:
59
backend/alembic/versions/04f013768235_enable_rls_phase3.py
Normal file
59
backend/alembic/versions/04f013768235_enable_rls_phase3.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Enable RLS on Phase 3 tables.
|
||||
|
||||
Tables covered:
|
||||
- step_ratings (account_id NOT NULL since migration 7167e9374b0c)
|
||||
- step_usage_log (account_id NOT NULL since migration 7167e9374b0c)
|
||||
- target_lists (account_id NOT NULL since migration 2c6aabd89bc6)
|
||||
- session_shares (account_id NOT NULL since session_share model)
|
||||
- audit_logs (account_id NOT NULL since migration 2a9056eddd90)
|
||||
- tree_shares (account_id NOT NULL since migration a05e1a1bea7c)
|
||||
|
||||
All use a standard intra-tenant isolation policy.
|
||||
Token-based access to session_shares and tree_shares goes through
|
||||
endpoints that use get_admin_db (BYPASSRLS), so a strict tenant
|
||||
policy here is correct.
|
||||
|
||||
Revision ID: 04f013768235
|
||||
Revises: a05e1a1bea7c
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
|
||||
revision: str = '04f013768235'
|
||||
down_revision: Union[str, None] = 'a05e1a1bea7c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_CURRENT_ACCOUNT = (
|
||||
"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000')::uuid"
|
||||
)
|
||||
|
||||
_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}"
|
||||
|
||||
_PHASE3_TABLES = [
|
||||
"step_ratings",
|
||||
"step_usage_log",
|
||||
"target_lists",
|
||||
"session_shares",
|
||||
"audit_logs",
|
||||
"tree_shares",
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table in _PHASE3_TABLES:
|
||||
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON {table}
|
||||
USING ({_STANDARD_USING})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in _PHASE3_TABLES:
|
||||
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Drop team_id from target_lists.
|
||||
|
||||
account_id (NOT NULL) is now the tenant isolation key; team_id is redundant.
|
||||
All reads/writes use account_id via RLS + application filter.
|
||||
|
||||
Revision ID: 172ad76d7d20
|
||||
Revises: 04f013768235
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '172ad76d7d20'
|
||||
down_revision: Union[str, None] = '04f013768235'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_index('ix_target_lists_team_id', table_name='target_lists', if_exists=True)
|
||||
op.drop_constraint('target_lists_team_id_fkey', 'target_lists', type_='foreignkey')
|
||||
op.drop_column('target_lists', 'team_id')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column('target_lists', sa.Column('team_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'target_lists_team_id_fkey', 'target_lists', 'teams',
|
||||
['team_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
op.create_index('ix_target_lists_team_id', 'target_lists', ['team_id'])
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Add account_id to audit_logs and backfill via user_id.
|
||||
|
||||
Revision ID: 2a9056eddd90
|
||||
Revises: 70a5dd746e83
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '2a9056eddd90'
|
||||
down_revision: Union[str, None] = '70a5dd746e83'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('audit_logs', sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_audit_logs_account_id', 'audit_logs', 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# Backfill: derive from the acting user's account
|
||||
op.execute("""
|
||||
UPDATE audit_logs al
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE al.user_id = u.id
|
||||
AND u.account_id IS NOT NULL
|
||||
AND al.account_id IS NULL
|
||||
""")
|
||||
|
||||
result = op.get_bind().execute(
|
||||
sa.text("SELECT COUNT(*) FROM audit_logs WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} audit_logs rows have NULL account_id after backfill. "
|
||||
"All audit log entries must have an associated user with an account."
|
||||
)
|
||||
|
||||
op.alter_column('audit_logs', 'account_id', nullable=False)
|
||||
op.create_index('ix_audit_logs_account_id', 'audit_logs', ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_audit_logs_account_id', table_name='audit_logs')
|
||||
op.drop_constraint('fk_audit_logs_account_id', 'audit_logs', type_='foreignkey')
|
||||
op.drop_column('audit_logs', 'account_id')
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Add account_id to tree_shares and backfill via tree owner's account.
|
||||
|
||||
The share belongs to the tree's tenant, not the actor who created it.
|
||||
A super admin in account A can share a tree owned by account B; that share
|
||||
must land in account B so account B's RLS filter sees it.
|
||||
|
||||
Revision ID: a05e1a1bea7c
|
||||
Revises: 2a9056eddd90
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = 'a05e1a1bea7c'
|
||||
down_revision: Union[str, None] = '2a9056eddd90'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('tree_shares', sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_tree_shares_account_id', 'tree_shares', 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# Backfill: derive from the tree's account, not the creator's account.
|
||||
# A share lives in the same tenant as its tree so that the tree owner's
|
||||
# RLS context covers their own shares regardless of who created them.
|
||||
op.execute("""
|
||||
UPDATE tree_shares ts
|
||||
SET account_id = t.account_id
|
||||
FROM trees t
|
||||
WHERE ts.tree_id = t.id
|
||||
AND t.account_id IS NOT NULL
|
||||
AND ts.account_id IS NULL
|
||||
""")
|
||||
|
||||
result = op.get_bind().execute(
|
||||
sa.text("SELECT COUNT(*) FROM tree_shares WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} tree_shares rows have NULL account_id after backfill. "
|
||||
"All share entries must have a creating user with an account."
|
||||
)
|
||||
|
||||
op.alter_column('tree_shares', 'account_id', nullable=False)
|
||||
op.create_index('ix_tree_shares_account_id', 'tree_shares', ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_tree_shares_account_id', table_name='tree_shares')
|
||||
op.drop_constraint('fk_tree_shares_account_id', 'tree_shares', type_='foreignkey')
|
||||
op.drop_column('tree_shares', 'account_id')
|
||||
@@ -18,12 +18,10 @@ async def list_target_lists(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""List all target lists for the current user's team."""
|
||||
if not current_user.team_id:
|
||||
return []
|
||||
"""List all target lists for the current user's account."""
|
||||
result = await db.execute(
|
||||
select(TargetList)
|
||||
.where(TargetList.team_id == current_user.team_id)
|
||||
.where(TargetList.account_id == current_user.account_id)
|
||||
.order_by(TargetList.name)
|
||||
)
|
||||
return result.scalars().all()
|
||||
@@ -36,11 +34,9 @@ async def create_target_list(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
_: None = Depends(require_engineer_or_admin),
|
||||
):
|
||||
"""Create a new target list for the current team."""
|
||||
if not current_user.team_id:
|
||||
raise HTTPException(status_code=400, detail="User must belong to a team")
|
||||
"""Create a new target list for the current account."""
|
||||
target_list = TargetList(
|
||||
team_id=current_user.team_id,
|
||||
account_id=current_user.account_id,
|
||||
created_by=current_user.id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
@@ -61,7 +57,7 @@ async def get_target_list(
|
||||
result = await db.execute(
|
||||
select(TargetList).where(
|
||||
TargetList.id == list_id,
|
||||
TargetList.team_id == current_user.team_id,
|
||||
TargetList.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
target_list = result.scalar_one_or_none()
|
||||
@@ -81,7 +77,7 @@ async def update_target_list(
|
||||
result = await db.execute(
|
||||
select(TargetList).where(
|
||||
TargetList.id == list_id,
|
||||
TargetList.team_id == current_user.team_id,
|
||||
TargetList.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
target_list = result.scalar_one_or_none()
|
||||
@@ -91,7 +87,7 @@ async def update_target_list(
|
||||
if "name" in update_fields and data.name is not None:
|
||||
target_list.name = data.name
|
||||
if "description" in update_fields:
|
||||
target_list.description = data.description # allow setting to None
|
||||
target_list.description = data.description
|
||||
if "targets" in update_fields and data.targets is not None:
|
||||
target_list.targets = [t.model_dump() for t in data.targets]
|
||||
await db.commit()
|
||||
@@ -109,7 +105,7 @@ async def delete_target_list(
|
||||
result = await db.execute(
|
||||
select(TargetList).where(
|
||||
TargetList.id == list_id,
|
||||
TargetList.team_id == current_user.team_id,
|
||||
TargetList.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
target_list = result.scalar_one_or_none()
|
||||
|
||||
@@ -1048,6 +1048,7 @@ async def create_tree_share(
|
||||
# Create share
|
||||
tree_share = TreeShare(
|
||||
tree_id=tree.id,
|
||||
account_id=tree.account_id, # share belongs to the tree's tenant, not the actor
|
||||
share_token=share_token,
|
||||
created_by=current_user.id,
|
||||
allow_forking=share_data.allow_forking,
|
||||
|
||||
@@ -12,10 +12,19 @@ async def log_audit(
|
||||
resource_type: str,
|
||||
resource_id: Optional[UUID] = None,
|
||||
details: Optional[dict] = None,
|
||||
account_id: Optional[UUID] = None,
|
||||
) -> None:
|
||||
"""Record an audit log entry. Does not commit — piggybacks on the caller's commit."""
|
||||
if account_id is None:
|
||||
# Derive from the acting user's account as a fallback (one extra query).
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
result = await db.execute(select(User.account_id).where(User.id == user_id))
|
||||
account_id = result.scalar_one()
|
||||
|
||||
entry = AuditLog(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
|
||||
@@ -21,6 +21,12 @@ class AuditLog(Base):
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||
resource_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||
resource_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
|
||||
@@ -8,7 +8,6 @@ from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
|
||||
|
||||
@@ -18,10 +17,6 @@ class TargetList(Base):
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
team_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
|
||||
@@ -25,6 +25,12 @@ class TreeShare(Base):
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
share_token: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
unique=True,
|
||||
|
||||
@@ -23,7 +23,7 @@ class TargetListUpdate(BaseModel):
|
||||
|
||||
class TargetListResponse(BaseModel):
|
||||
id: UUID
|
||||
team_id: UUID
|
||||
account_id: UUID
|
||||
created_by: Optional[UUID]
|
||||
name: str
|
||||
description: Optional[str]
|
||||
|
||||
@@ -464,7 +464,6 @@ async def test_target_list_account_id_from_team_admin(test_db: AsyncSession):
|
||||
await test_db.flush()
|
||||
|
||||
target_list = TargetList(
|
||||
team_id=team.id,
|
||||
account_id=account.id,
|
||||
created_by=user.id,
|
||||
name="Server Targets",
|
||||
|
||||
@@ -708,3 +708,249 @@ async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn,
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM step_library WHERE id = '{public_step_id}'"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Phase 3 RLS isolation tests
|
||||
# Tables: step_ratings, step_usage_log, target_lists,
|
||||
# session_shares, audit_logs, tree_shares
|
||||
# ===========================================================================
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers shared by Phase 3 fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _get_user_b_id(admin_conn) -> str:
|
||||
row = await admin_conn.fetchrow(
|
||||
"SELECT id FROM users WHERE email = 'rls-user-b@example.com'"
|
||||
)
|
||||
return str(row["id"])
|
||||
|
||||
|
||||
async def _get_tree_b_id(admin_conn) -> str:
|
||||
row = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
|
||||
)
|
||||
return str(row["id"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# step_ratings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see step ratings belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
# Need a step_library row as FK target
|
||||
step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_library (
|
||||
id, account_id, title, step_type, content,
|
||||
visibility, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 RLS Step', 'action',
|
||||
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
rating_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_ratings (
|
||||
id, step_id, user_id, account_id, is_verified_use, is_visible,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
'{rating_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
FALSE, TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_ratings WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B step_ratings"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM step_ratings WHERE id = '{rating_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# step_usage_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see step usage logs belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_library (
|
||||
id, account_id, title, step_type, content,
|
||||
visibility, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 Usage Step', 'action',
|
||||
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
# Need a sessions row as FK for usage log
|
||||
session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO sessions (
|
||||
id, tree_id, user_id, account_id, tree_snapshot,
|
||||
path_taken, decisions, custom_steps, started_at
|
||||
) VALUES (
|
||||
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
log_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_usage_log (
|
||||
id, step_id, user_id, account_id, session_id, used_at
|
||||
) VALUES (
|
||||
'{log_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'{session_id}', NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_usage_log WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B step_usage_log"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM step_usage_log WHERE id = '{log_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# target_lists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see target lists belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
tl_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO target_lists (
|
||||
id, account_id, created_by, name, targets, created_at, updated_at
|
||||
) VALUES (
|
||||
'{tl_id}', '{ACCOUNT_B_ID}', '{user_b_id}',
|
||||
'Phase3 RLS Target List', '[]'::jsonb, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM target_lists WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B target_lists"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM target_lists WHERE id = '{tl_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_shares
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see session shares belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
# Need a sessions row as FK
|
||||
session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO sessions (
|
||||
id, tree_id, user_id, account_id, tree_snapshot,
|
||||
path_taken, decisions, custom_steps, started_at
|
||||
) VALUES (
|
||||
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
share_id = str(uuid.uuid4())
|
||||
share_token = f"phase3-rls-test-{share_id[:8]}"
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO session_shares (
|
||||
id, session_id, account_id, share_token, visibility,
|
||||
created_by, view_count, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{share_id}', '{session_id}', '{ACCOUNT_B_ID}',
|
||||
'{share_token}', 'account', '{user_b_id}',
|
||||
0, TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_shares WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_shares"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM session_shares WHERE id = '{share_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# audit_logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see audit logs belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
log_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO audit_logs (
|
||||
id, user_id, account_id, action, resource_type, created_at
|
||||
) VALUES (
|
||||
'{log_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'test.action', 'test_resource', NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM audit_logs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B audit_logs"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM audit_logs WHERE id = '{log_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tree_shares
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see tree shares belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
share_id = str(uuid.uuid4())
|
||||
share_token = f"phase3-tree-rls-{share_id[:8]}"
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO tree_shares (
|
||||
id, tree_id, account_id, share_token, created_by,
|
||||
allow_forking, created_at
|
||||
) VALUES (
|
||||
'{share_id}', '{tree_b_id}', '{ACCOUNT_B_ID}',
|
||||
'{share_token}', '{user_b_id}', TRUE, NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_shares WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B tree_shares"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'")
|
||||
|
||||
@@ -3,37 +3,10 @@ import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.team import Team
|
||||
from app.models.user import User
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(client: AsyncClient, test_db: AsyncSession, test_user: dict):
|
||||
"""Override auth_headers to ensure the test user has a team_id assigned."""
|
||||
# Fetch the user from DB and assign a team
|
||||
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
|
||||
user = result.scalar_one()
|
||||
|
||||
# Create a team and assign the user to it
|
||||
team = Team(name="Test Team")
|
||||
test_db.add(team)
|
||||
await test_db.flush()
|
||||
|
||||
user.team_id = team.id
|
||||
await test_db.commit()
|
||||
|
||||
# Re-login to get a fresh token
|
||||
login_data = {
|
||||
"email": test_user["email"],
|
||||
"password": test_user["password"],
|
||||
}
|
||||
resp = await client.post("/api/v1/auth/login/json", json=login_data)
|
||||
assert resp.status_code == 200
|
||||
token_data = resp.json()
|
||||
return {"Authorization": f"Bearer {token_data['access_token']}"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_target_list(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post(
|
||||
@@ -107,25 +80,28 @@ async def test_delete_target_list(client: AsyncClient, auth_headers: dict):
|
||||
assert get.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""User from team B cannot access team A's list."""
|
||||
async def test_cannot_access_other_accounts_list(client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""User from account B cannot access account A's target list."""
|
||||
import uuid
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
# Create team A list using existing auth_headers
|
||||
# Create account A list using existing auth_headers
|
||||
create = await client.post(
|
||||
"/api/v1/target-lists/",
|
||||
json={"name": "Team A List", "targets": [{"label": "SRV-A"}]},
|
||||
json={"name": "Account A List", "targets": [{"label": "SRV-A"}]},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert create.status_code == 201
|
||||
list_id = create.json()["id"]
|
||||
|
||||
# Create a separate team B with its own user
|
||||
team_b = Team(name=f"Team B {uuid.uuid4()}")
|
||||
test_db.add(team_b)
|
||||
# Create a separate account B with its own user
|
||||
account_b = Account(
|
||||
name=f"Account B {uuid.uuid4()}",
|
||||
display_code=f"AB{str(uuid.uuid4())[:6].upper()}",
|
||||
)
|
||||
test_db.add(account_b)
|
||||
await test_db.flush()
|
||||
|
||||
user_b = User(
|
||||
@@ -133,11 +109,13 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
|
||||
password_hash=get_password_hash("password123"),
|
||||
name="User B",
|
||||
is_active=True,
|
||||
team_id=team_b.id,
|
||||
account_id=account_b.id,
|
||||
account_role="engineer",
|
||||
role="engineer",
|
||||
)
|
||||
test_db.add(user_b)
|
||||
await test_db.flush()
|
||||
await test_db.commit()
|
||||
|
||||
# Get auth token for user B
|
||||
login = await client.post(
|
||||
@@ -148,6 +126,6 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
|
||||
token_b = login.json()["access_token"]
|
||||
headers_b = {"Authorization": f"Bearer {token_b}"}
|
||||
|
||||
# Team B cannot access Team A's list
|
||||
# Account B cannot access Account A's list
|
||||
resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@@ -117,6 +117,7 @@ class TestTreeSharing:
|
||||
for i in range(3):
|
||||
share = TreeShare(
|
||||
tree_id=sample_tree.id,
|
||||
account_id=sample_tree.account_id,
|
||||
share_token=f"token_{i}_" + "x" * 56,
|
||||
created_by=sample_tree.author_id,
|
||||
allow_forking=i % 2 == 0
|
||||
@@ -162,6 +163,7 @@ class TestTreeSharing:
|
||||
# Create a share
|
||||
share = TreeShare(
|
||||
tree_id=sample_tree.id,
|
||||
account_id=sample_tree.account_id,
|
||||
share_token="public_test_token" + "x" * 47,
|
||||
created_by=UUID(test_user["user_data"]["id"]),
|
||||
allow_forking=True
|
||||
@@ -192,6 +194,7 @@ class TestTreeSharing:
|
||||
# Create expired share
|
||||
share = TreeShare(
|
||||
tree_id=sample_tree.id,
|
||||
account_id=sample_tree.account_id,
|
||||
share_token="expired_token" + "x" * 50,
|
||||
created_by=UUID(test_user["user_data"]["id"]),
|
||||
allow_forking=True,
|
||||
@@ -209,6 +212,7 @@ class TestTreeSharing:
|
||||
from uuid import UUID
|
||||
share = TreeShare(
|
||||
tree_id=sample_tree.id,
|
||||
account_id=sample_tree.account_id,
|
||||
share_token="inactive_tree_token" + "x" * 44,
|
||||
created_by=UUID(test_user["user_data"]["id"]),
|
||||
allow_forking=True
|
||||
@@ -248,6 +252,37 @@ class TestTreeSharing:
|
||||
tokens.add(token)
|
||||
assert len(tokens) == 5
|
||||
|
||||
async def test_share_account_id_matches_tree_not_actor(
|
||||
self, client: AsyncClient, sample_tree, auth_headers, test_db
|
||||
):
|
||||
"""Share account_id must equal tree.account_id, not the actor's account_id.
|
||||
|
||||
A super admin in a different account can share any tree. The resulting
|
||||
TreeShare row must live in the tree-owner's account so that the tree
|
||||
owner's RLS context covers it. If account_id were derived from the
|
||||
actor instead, the share would vanish from the tree owner's view once
|
||||
RLS is enabled.
|
||||
"""
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/trees/{sample_tree.id}/share",
|
||||
json={"allow_forking": True},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
share_token = response.json()["share_token"]
|
||||
|
||||
result = await test_db.execute(
|
||||
select(TreeShare).where(TreeShare.share_token == share_token)
|
||||
)
|
||||
share = result.scalar_one()
|
||||
assert share.account_id == sample_tree.account_id, (
|
||||
"TreeShare.account_id must equal tree.account_id, not the actor's account. "
|
||||
"Shares must live in the tree owner's tenant for RLS to cover them."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration_defaults_visibility_to_team(test_db):
|
||||
|
||||
Reference in New Issue
Block a user