feat: tenant isolation Phase 3 — audit_logs, tree_shares, remaining RLS #135
59
backend/alembic/versions/04f013768235_enable_rls_phase3.py
Normal file
59
backend/alembic/versions/04f013768235_enable_rls_phase3.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Enable RLS on Phase 3 tables.
|
||||||
|
|
||||||
|
Tables covered:
|
||||||
|
- step_ratings (account_id NOT NULL since migration 7167e9374b0c)
|
||||||
|
- step_usage_log (account_id NOT NULL since migration 7167e9374b0c)
|
||||||
|
- target_lists (account_id NOT NULL since migration 2c6aabd89bc6)
|
||||||
|
- session_shares (account_id NOT NULL since session_share model)
|
||||||
|
- audit_logs (account_id NOT NULL since migration 2a9056eddd90)
|
||||||
|
- tree_shares (account_id NOT NULL since migration a05e1a1bea7c)
|
||||||
|
|
||||||
|
All use a standard intra-tenant isolation policy.
|
||||||
|
Token-based access to session_shares and tree_shares goes through
|
||||||
|
endpoints that use get_admin_db (BYPASSRLS), so a strict tenant
|
||||||
|
policy here is correct.
|
||||||
|
|
||||||
|
Revision ID: 04f013768235
|
||||||
|
Revises: a05e1a1bea7c
|
||||||
|
Create Date: 2026-04-11 00:00:00.000000
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = '04f013768235'
|
||||||
|
down_revision: Union[str, None] = 'a05e1a1bea7c'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
_CURRENT_ACCOUNT = (
|
||||||
|
"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||||
|
"'00000000-0000-0000-0000-000000000000')::uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}"
|
||||||
|
|
||||||
|
_PHASE3_TABLES = [
|
||||||
|
"step_ratings",
|
||||||
|
"step_usage_log",
|
||||||
|
"target_lists",
|
||||||
|
"session_shares",
|
||||||
|
"audit_logs",
|
||||||
|
"tree_shares",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
for table in _PHASE3_TABLES:
|
||||||
|
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
|
||||||
|
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
|
||||||
|
op.execute(f"""
|
||||||
|
CREATE POLICY tenant_isolation ON {table}
|
||||||
|
USING ({_STANDARD_USING})
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
for table in _PHASE3_TABLES:
|
||||||
|
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||||
|
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||||
|
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
"""Drop team_id from target_lists.
|
||||||
|
|
||||||
|
account_id (NOT NULL) is now the tenant isolation key; team_id is redundant.
|
||||||
|
All reads/writes use account_id via RLS + application filter.
|
||||||
|
|
||||||
|
Revision ID: 172ad76d7d20
|
||||||
|
Revises: 04f013768235
|
||||||
|
Create Date: 2026-04-11 00:00:00.000000
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision: str = '172ad76d7d20'
|
||||||
|
down_revision: Union[str, None] = '04f013768235'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.drop_index('ix_target_lists_team_id', table_name='target_lists', if_exists=True)
|
||||||
|
op.drop_constraint('target_lists_team_id_fkey', 'target_lists', type_='foreignkey')
|
||||||
|
op.drop_column('target_lists', 'team_id')
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.add_column('target_lists', sa.Column('team_id', sa.UUID(), nullable=True))
|
||||||
|
op.create_foreign_key(
|
||||||
|
'target_lists_team_id_fkey', 'target_lists', 'teams',
|
||||||
|
['team_id'], ['id'], ondelete='CASCADE',
|
||||||
|
)
|
||||||
|
op.create_index('ix_target_lists_team_id', 'target_lists', ['team_id'])
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
"""Add account_id to audit_logs and backfill via user_id.
|
||||||
|
|
||||||
|
Revision ID: 2a9056eddd90
|
||||||
|
Revises: 70a5dd746e83
|
||||||
|
Create Date: 2026-04-11 00:00:00.000000
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision: str = '2a9056eddd90'
|
||||||
|
down_revision: Union[str, None] = '70a5dd746e83'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column('audit_logs', sa.Column('account_id', sa.UUID(), nullable=True))
|
||||||
|
op.create_foreign_key(
|
||||||
|
'fk_audit_logs_account_id', 'audit_logs', 'accounts',
|
||||||
|
['account_id'], ['id'], ondelete='CASCADE',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backfill: derive from the acting user's account
|
||||||
|
op.execute("""
|
||||||
|
UPDATE audit_logs al
|
||||||
|
SET account_id = u.account_id
|
||||||
|
FROM users u
|
||||||
|
WHERE al.user_id = u.id
|
||||||
|
AND u.account_id IS NOT NULL
|
||||||
|
AND al.account_id IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
result = op.get_bind().execute(
|
||||||
|
sa.text("SELECT COUNT(*) FROM audit_logs WHERE account_id IS NULL")
|
||||||
|
)
|
||||||
|
count = result.scalar()
|
||||||
|
if count > 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"ROLLBACK: {count} audit_logs rows have NULL account_id after backfill. "
|
||||||
|
"All audit log entries must have an associated user with an account."
|
||||||
|
)
|
||||||
|
|
||||||
|
op.alter_column('audit_logs', 'account_id', nullable=False)
|
||||||
|
op.create_index('ix_audit_logs_account_id', 'audit_logs', ['account_id'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index('ix_audit_logs_account_id', table_name='audit_logs')
|
||||||
|
op.drop_constraint('fk_audit_logs_account_id', 'audit_logs', type_='foreignkey')
|
||||||
|
op.drop_column('audit_logs', 'account_id')
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
"""Add account_id to tree_shares and backfill via tree owner's account.
|
||||||
|
|
||||||
|
The share belongs to the tree's tenant, not the actor who created it.
|
||||||
|
A super admin in account A can share a tree owned by account B; that share
|
||||||
|
must land in account B so account B's RLS filter sees it.
|
||||||
|
|
||||||
|
Revision ID: a05e1a1bea7c
|
||||||
|
Revises: 2a9056eddd90
|
||||||
|
Create Date: 2026-04-11 00:00:00.000000
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision: str = 'a05e1a1bea7c'
|
||||||
|
down_revision: Union[str, None] = '2a9056eddd90'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column('tree_shares', sa.Column('account_id', sa.UUID(), nullable=True))
|
||||||
|
op.create_foreign_key(
|
||||||
|
'fk_tree_shares_account_id', 'tree_shares', 'accounts',
|
||||||
|
['account_id'], ['id'], ondelete='CASCADE',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backfill: derive from the tree's account, not the creator's account.
|
||||||
|
# A share lives in the same tenant as its tree so that the tree owner's
|
||||||
|
# RLS context covers their own shares regardless of who created them.
|
||||||
|
op.execute("""
|
||||||
|
UPDATE tree_shares ts
|
||||||
|
SET account_id = t.account_id
|
||||||
|
FROM trees t
|
||||||
|
WHERE ts.tree_id = t.id
|
||||||
|
AND t.account_id IS NOT NULL
|
||||||
|
AND ts.account_id IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
result = op.get_bind().execute(
|
||||||
|
sa.text("SELECT COUNT(*) FROM tree_shares WHERE account_id IS NULL")
|
||||||
|
)
|
||||||
|
count = result.scalar()
|
||||||
|
if count > 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"ROLLBACK: {count} tree_shares rows have NULL account_id after backfill. "
|
||||||
|
"All share entries must have a creating user with an account."
|
||||||
|
)
|
||||||
|
|
||||||
|
op.alter_column('tree_shares', 'account_id', nullable=False)
|
||||||
|
op.create_index('ix_tree_shares_account_id', 'tree_shares', ['account_id'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index('ix_tree_shares_account_id', table_name='tree_shares')
|
||||||
|
op.drop_constraint('fk_tree_shares_account_id', 'tree_shares', type_='foreignkey')
|
||||||
|
op.drop_column('tree_shares', 'account_id')
|
||||||
@@ -18,12 +18,10 @@ async def list_target_lists(
|
|||||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
):
|
):
|
||||||
"""List all target lists for the current user's team."""
|
"""List all target lists for the current user's account."""
|
||||||
if not current_user.team_id:
|
|
||||||
return []
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(TargetList)
|
select(TargetList)
|
||||||
.where(TargetList.team_id == current_user.team_id)
|
.where(TargetList.account_id == current_user.account_id)
|
||||||
.order_by(TargetList.name)
|
.order_by(TargetList.name)
|
||||||
)
|
)
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
@@ -36,11 +34,9 @@ async def create_target_list(
|
|||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
_: None = Depends(require_engineer_or_admin),
|
_: None = Depends(require_engineer_or_admin),
|
||||||
):
|
):
|
||||||
"""Create a new target list for the current team."""
|
"""Create a new target list for the current account."""
|
||||||
if not current_user.team_id:
|
|
||||||
raise HTTPException(status_code=400, detail="User must belong to a team")
|
|
||||||
target_list = TargetList(
|
target_list = TargetList(
|
||||||
team_id=current_user.team_id,
|
account_id=current_user.account_id,
|
||||||
created_by=current_user.id,
|
created_by=current_user.id,
|
||||||
name=data.name,
|
name=data.name,
|
||||||
description=data.description,
|
description=data.description,
|
||||||
@@ -61,7 +57,7 @@ async def get_target_list(
|
|||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(TargetList).where(
|
select(TargetList).where(
|
||||||
TargetList.id == list_id,
|
TargetList.id == list_id,
|
||||||
TargetList.team_id == current_user.team_id,
|
TargetList.account_id == current_user.account_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
target_list = result.scalar_one_or_none()
|
target_list = result.scalar_one_or_none()
|
||||||
@@ -81,7 +77,7 @@ async def update_target_list(
|
|||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(TargetList).where(
|
select(TargetList).where(
|
||||||
TargetList.id == list_id,
|
TargetList.id == list_id,
|
||||||
TargetList.team_id == current_user.team_id,
|
TargetList.account_id == current_user.account_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
target_list = result.scalar_one_or_none()
|
target_list = result.scalar_one_or_none()
|
||||||
@@ -91,7 +87,7 @@ async def update_target_list(
|
|||||||
if "name" in update_fields and data.name is not None:
|
if "name" in update_fields and data.name is not None:
|
||||||
target_list.name = data.name
|
target_list.name = data.name
|
||||||
if "description" in update_fields:
|
if "description" in update_fields:
|
||||||
target_list.description = data.description # allow setting to None
|
target_list.description = data.description
|
||||||
if "targets" in update_fields and data.targets is not None:
|
if "targets" in update_fields and data.targets is not None:
|
||||||
target_list.targets = [t.model_dump() for t in data.targets]
|
target_list.targets = [t.model_dump() for t in data.targets]
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -109,7 +105,7 @@ async def delete_target_list(
|
|||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(TargetList).where(
|
select(TargetList).where(
|
||||||
TargetList.id == list_id,
|
TargetList.id == list_id,
|
||||||
TargetList.team_id == current_user.team_id,
|
TargetList.account_id == current_user.account_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
target_list = result.scalar_one_or_none()
|
target_list = result.scalar_one_or_none()
|
||||||
|
|||||||
@@ -1048,6 +1048,7 @@ async def create_tree_share(
|
|||||||
# Create share
|
# Create share
|
||||||
tree_share = TreeShare(
|
tree_share = TreeShare(
|
||||||
tree_id=tree.id,
|
tree_id=tree.id,
|
||||||
|
account_id=tree.account_id, # share belongs to the tree's tenant, not the actor
|
||||||
share_token=share_token,
|
share_token=share_token,
|
||||||
created_by=current_user.id,
|
created_by=current_user.id,
|
||||||
allow_forking=share_data.allow_forking,
|
allow_forking=share_data.allow_forking,
|
||||||
|
|||||||
@@ -12,10 +12,19 @@ async def log_audit(
|
|||||||
resource_type: str,
|
resource_type: str,
|
||||||
resource_id: Optional[UUID] = None,
|
resource_id: Optional[UUID] = None,
|
||||||
details: Optional[dict] = None,
|
details: Optional[dict] = None,
|
||||||
|
account_id: Optional[UUID] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record an audit log entry. Does not commit — piggybacks on the caller's commit."""
|
"""Record an audit log entry. Does not commit — piggybacks on the caller's commit."""
|
||||||
|
if account_id is None:
|
||||||
|
# Derive from the acting user's account as a fallback (one extra query).
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models.user import User
|
||||||
|
result = await db.execute(select(User.account_id).where(User.id == user_id))
|
||||||
|
account_id = result.scalar_one()
|
||||||
|
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
account_id=account_id,
|
||||||
action=action,
|
action=action,
|
||||||
resource_type=resource_type,
|
resource_type=resource_type,
|
||||||
resource_id=resource_id,
|
resource_id=resource_id,
|
||||||
|
|||||||
@@ -21,6 +21,12 @@ class AuditLog(Base):
|
|||||||
nullable=False,
|
nullable=False,
|
||||||
index=True
|
index=True
|
||||||
)
|
)
|
||||||
|
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True
|
||||||
|
)
|
||||||
action: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
action: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||||
resource_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
resource_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||||
resource_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
resource_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from app.core.database import Base
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.team import Team
|
|
||||||
from app.models.account import Account
|
from app.models.account import Account
|
||||||
|
|
||||||
|
|
||||||
@@ -18,10 +17,6 @@ class TargetList(Base):
|
|||||||
id: Mapped[uuid.UUID] = mapped_column(
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||||
)
|
)
|
||||||
team_id: Mapped[uuid.UUID] = mapped_column(
|
|
||||||
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"),
|
|
||||||
nullable=False, index=True
|
|
||||||
)
|
|
||||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||||
|
|||||||
@@ -25,6 +25,12 @@ class TreeShare(Base):
|
|||||||
nullable=False,
|
nullable=False,
|
||||||
index=True
|
index=True
|
||||||
)
|
)
|
||||||
|
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True
|
||||||
|
)
|
||||||
share_token: Mapped[str] = mapped_column(
|
share_token: Mapped[str] = mapped_column(
|
||||||
String(64),
|
String(64),
|
||||||
unique=True,
|
unique=True,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class TargetListUpdate(BaseModel):
|
|||||||
|
|
||||||
class TargetListResponse(BaseModel):
|
class TargetListResponse(BaseModel):
|
||||||
id: UUID
|
id: UUID
|
||||||
team_id: UUID
|
account_id: UUID
|
||||||
created_by: Optional[UUID]
|
created_by: Optional[UUID]
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str]
|
description: Optional[str]
|
||||||
|
|||||||
@@ -464,7 +464,6 @@ async def test_target_list_account_id_from_team_admin(test_db: AsyncSession):
|
|||||||
await test_db.flush()
|
await test_db.flush()
|
||||||
|
|
||||||
target_list = TargetList(
|
target_list = TargetList(
|
||||||
team_id=team.id,
|
|
||||||
account_id=account.id,
|
account_id=account.id,
|
||||||
created_by=user.id,
|
created_by=user.id,
|
||||||
name="Server Targets",
|
name="Server Targets",
|
||||||
|
|||||||
@@ -708,3 +708,249 @@ async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn,
|
|||||||
await admin_conn.execute(
|
await admin_conn.execute(
|
||||||
f"DELETE FROM step_library WHERE id = '{public_step_id}'"
|
f"DELETE FROM step_library WHERE id = '{public_step_id}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# Phase 3 RLS isolation tests
|
||||||
|
# Tables: step_ratings, step_usage_log, target_lists,
|
||||||
|
# session_shares, audit_logs, tree_shares
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers shared by Phase 3 fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _get_user_b_id(admin_conn) -> str:
|
||||||
|
row = await admin_conn.fetchrow(
|
||||||
|
"SELECT id FROM users WHERE email = 'rls-user-b@example.com'"
|
||||||
|
)
|
||||||
|
return str(row["id"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_tree_b_id(admin_conn) -> str:
|
||||||
|
row = await admin_conn.fetchrow(
|
||||||
|
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
|
||||||
|
)
|
||||||
|
return str(row["id"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# step_ratings
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
|
"""Account A must not see step ratings belonging to Account B."""
|
||||||
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
|
|
||||||
|
# Need a step_library row as FK target
|
||||||
|
step_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO step_library (
|
||||||
|
id, account_id, title, step_type, content,
|
||||||
|
visibility, is_active, created_at, updated_at
|
||||||
|
) VALUES (
|
||||||
|
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 RLS Step', 'action',
|
||||||
|
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
rating_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO step_ratings (
|
||||||
|
id, step_id, user_id, account_id, is_verified_use, is_visible,
|
||||||
|
created_at, updated_at
|
||||||
|
) VALUES (
|
||||||
|
'{rating_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||||
|
FALSE, TRUE, NOW(), NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try:
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM step_ratings WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) == 0, "Account A should not see Account B step_ratings"
|
||||||
|
finally:
|
||||||
|
await admin_conn.execute(f"DELETE FROM step_ratings WHERE id = '{rating_id}'")
|
||||||
|
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# step_usage_log
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
|
"""Account A must not see step usage logs belonging to Account B."""
|
||||||
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
|
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||||
|
|
||||||
|
step_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO step_library (
|
||||||
|
id, account_id, title, step_type, content,
|
||||||
|
visibility, is_active, created_at, updated_at
|
||||||
|
) VALUES (
|
||||||
|
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 Usage Step', 'action',
|
||||||
|
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Need a sessions row as FK for usage log
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO sessions (
|
||||||
|
id, tree_id, user_id, account_id, tree_snapshot,
|
||||||
|
path_taken, decisions, custom_steps, started_at
|
||||||
|
) VALUES (
|
||||||
|
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||||
|
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
log_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO step_usage_log (
|
||||||
|
id, step_id, user_id, account_id, session_id, used_at
|
||||||
|
) VALUES (
|
||||||
|
'{log_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||||
|
'{session_id}', NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try:
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM step_usage_log WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) == 0, "Account A should not see Account B step_usage_log"
|
||||||
|
finally:
|
||||||
|
await admin_conn.execute(f"DELETE FROM step_usage_log WHERE id = '{log_id}'")
|
||||||
|
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
|
||||||
|
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# target_lists
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
|
"""Account A must not see target lists belonging to Account B."""
|
||||||
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
|
|
||||||
|
tl_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO target_lists (
|
||||||
|
id, account_id, created_by, name, targets, created_at, updated_at
|
||||||
|
) VALUES (
|
||||||
|
'{tl_id}', '{ACCOUNT_B_ID}', '{user_b_id}',
|
||||||
|
'Phase3 RLS Target List', '[]'::jsonb, NOW(), NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try:
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM target_lists WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) == 0, "Account A should not see Account B target_lists"
|
||||||
|
finally:
|
||||||
|
await admin_conn.execute(f"DELETE FROM target_lists WHERE id = '{tl_id}'")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# session_shares
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
|
"""Account A must not see session shares belonging to Account B."""
|
||||||
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
|
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||||
|
|
||||||
|
# Need a sessions row as FK
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO sessions (
|
||||||
|
id, tree_id, user_id, account_id, tree_snapshot,
|
||||||
|
path_taken, decisions, custom_steps, started_at
|
||||||
|
) VALUES (
|
||||||
|
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||||
|
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
share_id = str(uuid.uuid4())
|
||||||
|
share_token = f"phase3-rls-test-{share_id[:8]}"
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO session_shares (
|
||||||
|
id, session_id, account_id, share_token, visibility,
|
||||||
|
created_by, view_count, is_active, created_at, updated_at
|
||||||
|
) VALUES (
|
||||||
|
'{share_id}', '{session_id}', '{ACCOUNT_B_ID}',
|
||||||
|
'{share_token}', 'account', '{user_b_id}',
|
||||||
|
0, TRUE, NOW(), NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try:
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM session_shares WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) == 0, "Account A should not see Account B session_shares"
|
||||||
|
finally:
|
||||||
|
await admin_conn.execute(f"DELETE FROM session_shares WHERE id = '{share_id}'")
|
||||||
|
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# audit_logs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
|
"""Account A must not see audit logs belonging to Account B."""
|
||||||
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
|
|
||||||
|
log_id = str(uuid.uuid4())
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO audit_logs (
|
||||||
|
id, user_id, account_id, action, resource_type, created_at
|
||||||
|
) VALUES (
|
||||||
|
'{log_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||||
|
'test.action', 'test_resource', NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try:
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM audit_logs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) == 0, "Account A should not see Account B audit_logs"
|
||||||
|
finally:
|
||||||
|
await admin_conn.execute(f"DELETE FROM audit_logs WHERE id = '{log_id}'")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# tree_shares
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||||
|
"""Account A must not see tree shares belonging to Account B."""
|
||||||
|
user_b_id = await _get_user_b_id(admin_conn)
|
||||||
|
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||||
|
|
||||||
|
share_id = str(uuid.uuid4())
|
||||||
|
share_token = f"phase3-tree-rls-{share_id[:8]}"
|
||||||
|
await admin_conn.execute(f"""
|
||||||
|
INSERT INTO tree_shares (
|
||||||
|
id, tree_id, account_id, share_token, created_by,
|
||||||
|
allow_forking, created_at
|
||||||
|
) VALUES (
|
||||||
|
'{share_id}', '{tree_b_id}', '{ACCOUNT_B_ID}',
|
||||||
|
'{share_token}', '{user_b_id}', TRUE, NOW()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try:
|
||||||
|
rows = await conn_a.fetch(
|
||||||
|
f"SELECT id FROM tree_shares WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||||
|
)
|
||||||
|
assert len(rows) == 0, "Account A should not see Account B tree_shares"
|
||||||
|
finally:
|
||||||
|
await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'")
|
||||||
|
|||||||
@@ -3,37 +3,10 @@ import pytest
|
|||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.models.team import Team
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def auth_headers(client: AsyncClient, test_db: AsyncSession, test_user: dict):
|
|
||||||
"""Override auth_headers to ensure the test user has a team_id assigned."""
|
|
||||||
# Fetch the user from DB and assign a team
|
|
||||||
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
|
|
||||||
user = result.scalar_one()
|
|
||||||
|
|
||||||
# Create a team and assign the user to it
|
|
||||||
team = Team(name="Test Team")
|
|
||||||
test_db.add(team)
|
|
||||||
await test_db.flush()
|
|
||||||
|
|
||||||
user.team_id = team.id
|
|
||||||
await test_db.commit()
|
|
||||||
|
|
||||||
# Re-login to get a fresh token
|
|
||||||
login_data = {
|
|
||||||
"email": test_user["email"],
|
|
||||||
"password": test_user["password"],
|
|
||||||
}
|
|
||||||
resp = await client.post("/api/v1/auth/login/json", json=login_data)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
token_data = resp.json()
|
|
||||||
return {"Authorization": f"Bearer {token_data['access_token']}"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_target_list(client: AsyncClient, auth_headers: dict):
|
async def test_create_target_list(client: AsyncClient, auth_headers: dict):
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
@@ -107,25 +80,28 @@ async def test_delete_target_list(client: AsyncClient, auth_headers: dict):
|
|||||||
assert get.status_code == 404
|
assert get.status_code == 404
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: dict, test_db):
|
async def test_cannot_access_other_accounts_list(client: AsyncClient, auth_headers: dict, test_db):
|
||||||
"""User from team B cannot access team A's list."""
|
"""User from account B cannot access account A's target list."""
|
||||||
import uuid
|
import uuid
|
||||||
from app.models.team import Team
|
from app.models.account import Account
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.core.security import get_password_hash
|
from app.core.security import get_password_hash
|
||||||
|
|
||||||
# Create team A list using existing auth_headers
|
# Create account A list using existing auth_headers
|
||||||
create = await client.post(
|
create = await client.post(
|
||||||
"/api/v1/target-lists/",
|
"/api/v1/target-lists/",
|
||||||
json={"name": "Team A List", "targets": [{"label": "SRV-A"}]},
|
json={"name": "Account A List", "targets": [{"label": "SRV-A"}]},
|
||||||
headers=auth_headers,
|
headers=auth_headers,
|
||||||
)
|
)
|
||||||
assert create.status_code == 201
|
assert create.status_code == 201
|
||||||
list_id = create.json()["id"]
|
list_id = create.json()["id"]
|
||||||
|
|
||||||
# Create a separate team B with its own user
|
# Create a separate account B with its own user
|
||||||
team_b = Team(name=f"Team B {uuid.uuid4()}")
|
account_b = Account(
|
||||||
test_db.add(team_b)
|
name=f"Account B {uuid.uuid4()}",
|
||||||
|
display_code=f"AB{str(uuid.uuid4())[:6].upper()}",
|
||||||
|
)
|
||||||
|
test_db.add(account_b)
|
||||||
await test_db.flush()
|
await test_db.flush()
|
||||||
|
|
||||||
user_b = User(
|
user_b = User(
|
||||||
@@ -133,11 +109,13 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
|
|||||||
password_hash=get_password_hash("password123"),
|
password_hash=get_password_hash("password123"),
|
||||||
name="User B",
|
name="User B",
|
||||||
is_active=True,
|
is_active=True,
|
||||||
team_id=team_b.id,
|
account_id=account_b.id,
|
||||||
|
account_role="engineer",
|
||||||
role="engineer",
|
role="engineer",
|
||||||
)
|
)
|
||||||
test_db.add(user_b)
|
test_db.add(user_b)
|
||||||
await test_db.flush()
|
await test_db.flush()
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
# Get auth token for user B
|
# Get auth token for user B
|
||||||
login = await client.post(
|
login = await client.post(
|
||||||
@@ -148,6 +126,6 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
|
|||||||
token_b = login.json()["access_token"]
|
token_b = login.json()["access_token"]
|
||||||
headers_b = {"Authorization": f"Bearer {token_b}"}
|
headers_b = {"Authorization": f"Bearer {token_b}"}
|
||||||
|
|
||||||
# Team B cannot access Team A's list
|
# Account B cannot access Account A's list
|
||||||
resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b)
|
resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b)
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ class TestTreeSharing:
|
|||||||
for i in range(3):
|
for i in range(3):
|
||||||
share = TreeShare(
|
share = TreeShare(
|
||||||
tree_id=sample_tree.id,
|
tree_id=sample_tree.id,
|
||||||
|
account_id=sample_tree.account_id,
|
||||||
share_token=f"token_{i}_" + "x" * 56,
|
share_token=f"token_{i}_" + "x" * 56,
|
||||||
created_by=sample_tree.author_id,
|
created_by=sample_tree.author_id,
|
||||||
allow_forking=i % 2 == 0
|
allow_forking=i % 2 == 0
|
||||||
@@ -162,6 +163,7 @@ class TestTreeSharing:
|
|||||||
# Create a share
|
# Create a share
|
||||||
share = TreeShare(
|
share = TreeShare(
|
||||||
tree_id=sample_tree.id,
|
tree_id=sample_tree.id,
|
||||||
|
account_id=sample_tree.account_id,
|
||||||
share_token="public_test_token" + "x" * 47,
|
share_token="public_test_token" + "x" * 47,
|
||||||
created_by=UUID(test_user["user_data"]["id"]),
|
created_by=UUID(test_user["user_data"]["id"]),
|
||||||
allow_forking=True
|
allow_forking=True
|
||||||
@@ -192,6 +194,7 @@ class TestTreeSharing:
|
|||||||
# Create expired share
|
# Create expired share
|
||||||
share = TreeShare(
|
share = TreeShare(
|
||||||
tree_id=sample_tree.id,
|
tree_id=sample_tree.id,
|
||||||
|
account_id=sample_tree.account_id,
|
||||||
share_token="expired_token" + "x" * 50,
|
share_token="expired_token" + "x" * 50,
|
||||||
created_by=UUID(test_user["user_data"]["id"]),
|
created_by=UUID(test_user["user_data"]["id"]),
|
||||||
allow_forking=True,
|
allow_forking=True,
|
||||||
@@ -209,6 +212,7 @@ class TestTreeSharing:
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
share = TreeShare(
|
share = TreeShare(
|
||||||
tree_id=sample_tree.id,
|
tree_id=sample_tree.id,
|
||||||
|
account_id=sample_tree.account_id,
|
||||||
share_token="inactive_tree_token" + "x" * 44,
|
share_token="inactive_tree_token" + "x" * 44,
|
||||||
created_by=UUID(test_user["user_data"]["id"]),
|
created_by=UUID(test_user["user_data"]["id"]),
|
||||||
allow_forking=True
|
allow_forking=True
|
||||||
@@ -248,6 +252,37 @@ class TestTreeSharing:
|
|||||||
tokens.add(token)
|
tokens.add(token)
|
||||||
assert len(tokens) == 5
|
assert len(tokens) == 5
|
||||||
|
|
||||||
|
async def test_share_account_id_matches_tree_not_actor(
|
||||||
|
self, client: AsyncClient, sample_tree, auth_headers, test_db
|
||||||
|
):
|
||||||
|
"""Share account_id must equal tree.account_id, not the actor's account_id.
|
||||||
|
|
||||||
|
A super admin in a different account can share any tree. The resulting
|
||||||
|
TreeShare row must live in the tree-owner's account so that the tree
|
||||||
|
owner's RLS context covers it. If account_id were derived from the
|
||||||
|
actor instead, the share would vanish from the tree owner's view once
|
||||||
|
RLS is enabled.
|
||||||
|
"""
|
||||||
|
from uuid import UUID
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/trees/{sample_tree.id}/share",
|
||||||
|
json={"allow_forking": True},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 201
|
||||||
|
share_token = response.json()["share_token"]
|
||||||
|
|
||||||
|
result = await test_db.execute(
|
||||||
|
select(TreeShare).where(TreeShare.share_token == share_token)
|
||||||
|
)
|
||||||
|
share = result.scalar_one()
|
||||||
|
assert share.account_id == sample_tree.account_id, (
|
||||||
|
"TreeShare.account_id must equal tree.account_id, not the actor's account. "
|
||||||
|
"Shares must live in the tree owner's tenant for RLS to cover them."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_migration_defaults_visibility_to_team(test_db):
|
async def test_migration_defaults_visibility_to_team(test_db):
|
||||||
|
|||||||
Reference in New Issue
Block a user