feat: tenant isolation Phase 3 — audit_logs, tree_shares, remaining RLS #135

Merged
chihlasm merged 2 commits from feat/tenant-isolation-phase-3 into main 2026-04-11 08:28:47 +00:00
15 changed files with 526 additions and 56 deletions

View File

@@ -0,0 +1,59 @@
"""Enable RLS on Phase 3 tables.
Tables covered:
- step_ratings (account_id NOT NULL since migration 7167e9374b0c)
- step_usage_log (account_id NOT NULL since migration 7167e9374b0c)
- target_lists (account_id NOT NULL since migration 2c6aabd89bc6)
- session_shares (account_id NOT NULL since session_share model)
- audit_logs (account_id NOT NULL since migration 2a9056eddd90)
- tree_shares (account_id NOT NULL since migration a05e1a1bea7c)
All use a standard intra-tenant isolation policy.
Token-based access to session_shares and tree_shares goes through
endpoints that use get_admin_db (BYPASSRLS), so a strict tenant
policy here is correct.
Revision ID: 04f013768235
Revises: a05e1a1bea7c
Create Date: 2026-04-11 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
revision: str = '04f013768235'
down_revision: Union[str, None] = 'a05e1a1bea7c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_CURRENT_ACCOUNT = (
"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
"'00000000-0000-0000-0000-000000000000')::uuid"
)
_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}"
_PHASE3_TABLES = [
"step_ratings",
"step_usage_log",
"target_lists",
"session_shares",
"audit_logs",
"tree_shares",
]
def upgrade() -> None:
for table in _PHASE3_TABLES:
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON {table}
USING ({_STANDARD_USING})
""")
def downgrade() -> None:
for table in _PHASE3_TABLES:
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")

View File

@@ -0,0 +1,32 @@
"""Drop team_id from target_lists.
account_id (NOT NULL) is now the tenant isolation key; team_id is redundant.
All reads/writes use account_id via RLS + application filter.
Revision ID: 172ad76d7d20
Revises: 04f013768235
Create Date: 2026-04-11 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '172ad76d7d20'
down_revision: Union[str, None] = '04f013768235'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.drop_index('ix_target_lists_team_id', table_name='target_lists', if_exists=True)
op.drop_constraint('target_lists_team_id_fkey', 'target_lists', type_='foreignkey')
op.drop_column('target_lists', 'team_id')
def downgrade() -> None:
op.add_column('target_lists', sa.Column('team_id', sa.UUID(), nullable=True))
op.create_foreign_key(
'target_lists_team_id_fkey', 'target_lists', 'teams',
['team_id'], ['id'], ondelete='CASCADE',
)
op.create_index('ix_target_lists_team_id', 'target_lists', ['team_id'])

View File

@@ -0,0 +1,51 @@
"""Add account_id to audit_logs and backfill via user_id.
Revision ID: 2a9056eddd90
Revises: 70a5dd746e83
Create Date: 2026-04-11 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '2a9056eddd90'
down_revision: Union[str, None] = '70a5dd746e83'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('audit_logs', sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
'fk_audit_logs_account_id', 'audit_logs', 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
# Backfill: derive from the acting user's account
op.execute("""
UPDATE audit_logs al
SET account_id = u.account_id
FROM users u
WHERE al.user_id = u.id
AND u.account_id IS NOT NULL
AND al.account_id IS NULL
""")
result = op.get_bind().execute(
sa.text("SELECT COUNT(*) FROM audit_logs WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} audit_logs rows have NULL account_id after backfill. "
"All audit log entries must have an associated user with an account."
)
op.alter_column('audit_logs', 'account_id', nullable=False)
op.create_index('ix_audit_logs_account_id', 'audit_logs', ['account_id'])
def downgrade() -> None:
op.drop_index('ix_audit_logs_account_id', table_name='audit_logs')
op.drop_constraint('fk_audit_logs_account_id', 'audit_logs', type_='foreignkey')
op.drop_column('audit_logs', 'account_id')

View File

@@ -0,0 +1,57 @@
"""Add account_id to tree_shares and backfill via tree owner's account.
The share belongs to the tree's tenant, not the actor who created it.
A super admin in account A can share a tree owned by account B; that share
must land in account B so account B's RLS filter sees it.
Revision ID: a05e1a1bea7c
Revises: 2a9056eddd90
Create Date: 2026-04-11 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = 'a05e1a1bea7c'
down_revision: Union[str, None] = '2a9056eddd90'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('tree_shares', sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
'fk_tree_shares_account_id', 'tree_shares', 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
# Backfill: derive from the tree's account, not the creator's account.
# A share lives in the same tenant as its tree so that the tree owner's
# RLS context covers their own shares regardless of who created them.
op.execute("""
UPDATE tree_shares ts
SET account_id = t.account_id
FROM trees t
WHERE ts.tree_id = t.id
AND t.account_id IS NOT NULL
AND ts.account_id IS NULL
""")
result = op.get_bind().execute(
sa.text("SELECT COUNT(*) FROM tree_shares WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} tree_shares rows have NULL account_id after backfill. "
"All share entries must have a creating user with an account."
)
op.alter_column('tree_shares', 'account_id', nullable=False)
op.create_index('ix_tree_shares_account_id', 'tree_shares', ['account_id'])
def downgrade() -> None:
op.drop_index('ix_tree_shares_account_id', table_name='tree_shares')
op.drop_constraint('fk_tree_shares_account_id', 'tree_shares', type_='foreignkey')
op.drop_column('tree_shares', 'account_id')

View File

@@ -18,12 +18,10 @@ async def list_target_lists(
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
): ):
"""List all target lists for the current user's team.""" """List all target lists for the current user's account."""
if not current_user.team_id:
return []
result = await db.execute( result = await db.execute(
select(TargetList) select(TargetList)
.where(TargetList.team_id == current_user.team_id) .where(TargetList.account_id == current_user.account_id)
.order_by(TargetList.name) .order_by(TargetList.name)
) )
return result.scalars().all() return result.scalars().all()
@@ -36,11 +34,9 @@ async def create_target_list(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin), _: None = Depends(require_engineer_or_admin),
): ):
"""Create a new target list for the current team.""" """Create a new target list for the current account."""
if not current_user.team_id:
raise HTTPException(status_code=400, detail="User must belong to a team")
target_list = TargetList( target_list = TargetList(
team_id=current_user.team_id, account_id=current_user.account_id,
created_by=current_user.id, created_by=current_user.id,
name=data.name, name=data.name,
description=data.description, description=data.description,
@@ -61,7 +57,7 @@ async def get_target_list(
result = await db.execute( result = await db.execute(
select(TargetList).where( select(TargetList).where(
TargetList.id == list_id, TargetList.id == list_id,
TargetList.team_id == current_user.team_id, TargetList.account_id == current_user.account_id,
) )
) )
target_list = result.scalar_one_or_none() target_list = result.scalar_one_or_none()
@@ -81,7 +77,7 @@ async def update_target_list(
result = await db.execute( result = await db.execute(
select(TargetList).where( select(TargetList).where(
TargetList.id == list_id, TargetList.id == list_id,
TargetList.team_id == current_user.team_id, TargetList.account_id == current_user.account_id,
) )
) )
target_list = result.scalar_one_or_none() target_list = result.scalar_one_or_none()
@@ -91,7 +87,7 @@ async def update_target_list(
if "name" in update_fields and data.name is not None: if "name" in update_fields and data.name is not None:
target_list.name = data.name target_list.name = data.name
if "description" in update_fields: if "description" in update_fields:
target_list.description = data.description # allow setting to None target_list.description = data.description
if "targets" in update_fields and data.targets is not None: if "targets" in update_fields and data.targets is not None:
target_list.targets = [t.model_dump() for t in data.targets] target_list.targets = [t.model_dump() for t in data.targets]
await db.commit() await db.commit()
@@ -109,7 +105,7 @@ async def delete_target_list(
result = await db.execute( result = await db.execute(
select(TargetList).where( select(TargetList).where(
TargetList.id == list_id, TargetList.id == list_id,
TargetList.team_id == current_user.team_id, TargetList.account_id == current_user.account_id,
) )
) )
target_list = result.scalar_one_or_none() target_list = result.scalar_one_or_none()

View File

@@ -1048,6 +1048,7 @@ async def create_tree_share(
# Create share # Create share
tree_share = TreeShare( tree_share = TreeShare(
tree_id=tree.id, tree_id=tree.id,
account_id=tree.account_id, # share belongs to the tree's tenant, not the actor
share_token=share_token, share_token=share_token,
created_by=current_user.id, created_by=current_user.id,
allow_forking=share_data.allow_forking, allow_forking=share_data.allow_forking,

View File

@@ -12,10 +12,19 @@ async def log_audit(
resource_type: str, resource_type: str,
resource_id: Optional[UUID] = None, resource_id: Optional[UUID] = None,
details: Optional[dict] = None, details: Optional[dict] = None,
account_id: Optional[UUID] = None,
) -> None: ) -> None:
"""Record an audit log entry. Does not commit — piggybacks on the caller's commit.""" """Record an audit log entry. Does not commit — piggybacks on the caller's commit."""
if account_id is None:
# Derive from the acting user's account as a fallback (one extra query).
from sqlalchemy import select
from app.models.user import User
result = await db.execute(select(User.account_id).where(User.id == user_id))
account_id = result.scalar_one()
entry = AuditLog( entry = AuditLog(
user_id=user_id, user_id=user_id,
account_id=account_id,
action=action, action=action,
resource_type=resource_type, resource_type=resource_type,
resource_id=resource_id, resource_id=resource_id,

View File

@@ -21,6 +21,12 @@ class AuditLog(Base):
nullable=False, nullable=False,
index=True index=True
) )
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True
)
action: Mapped[str] = mapped_column(String(50), nullable=False, index=True) action: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
resource_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True) resource_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
resource_id: Mapped[Optional[uuid.UUID]] = mapped_column( resource_id: Mapped[Optional[uuid.UUID]] = mapped_column(

View File

@@ -8,7 +8,6 @@ from app.core.database import Base
if TYPE_CHECKING: if TYPE_CHECKING:
from app.models.user import User from app.models.user import User
from app.models.team import Team
from app.models.account import Account from app.models.account import Account
@@ -18,10 +17,6 @@ class TargetList(Base):
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
) )
team_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"),
nullable=False, index=True
)
account_id: Mapped[uuid.UUID] = mapped_column( account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"), ForeignKey("accounts.id", ondelete="CASCADE"),

View File

@@ -25,6 +25,12 @@ class TreeShare(Base):
nullable=False, nullable=False,
index=True index=True
) )
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True
)
share_token: Mapped[str] = mapped_column( share_token: Mapped[str] = mapped_column(
String(64), String(64),
unique=True, unique=True,

View File

@@ -23,7 +23,7 @@ class TargetListUpdate(BaseModel):
class TargetListResponse(BaseModel): class TargetListResponse(BaseModel):
id: UUID id: UUID
team_id: UUID account_id: UUID
created_by: Optional[UUID] created_by: Optional[UUID]
name: str name: str
description: Optional[str] description: Optional[str]

View File

@@ -464,7 +464,6 @@ async def test_target_list_account_id_from_team_admin(test_db: AsyncSession):
await test_db.flush() await test_db.flush()
target_list = TargetList( target_list = TargetList(
team_id=team.id,
account_id=account.id, account_id=account.id,
created_by=user.id, created_by=user.id,
name="Server Targets", name="Server Targets",

View File

@@ -708,3 +708,249 @@ async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn,
await admin_conn.execute( await admin_conn.execute(
f"DELETE FROM step_library WHERE id = '{public_step_id}'" f"DELETE FROM step_library WHERE id = '{public_step_id}'"
) )
# ===========================================================================
# Phase 3 RLS isolation tests
# Tables: step_ratings, step_usage_log, target_lists,
# session_shares, audit_logs, tree_shares
# ===========================================================================
# ---------------------------------------------------------------------------
# Helpers shared by Phase 3 fixtures
# ---------------------------------------------------------------------------
async def _get_user_b_id(admin_conn) -> str:
row = await admin_conn.fetchrow(
"SELECT id FROM users WHERE email = 'rls-user-b@example.com'"
)
return str(row["id"])
async def _get_tree_b_id(admin_conn) -> str:
row = await admin_conn.fetchrow(
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
)
return str(row["id"])
# ---------------------------------------------------------------------------
# step_ratings
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see step ratings belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
# Need a step_library row as FK target
step_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO step_library (
id, account_id, title, step_type, content,
visibility, is_active, created_at, updated_at
) VALUES (
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 RLS Step', 'action',
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
)
""")
rating_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO step_ratings (
id, step_id, user_id, account_id, is_verified_use, is_visible,
created_at, updated_at
) VALUES (
'{rating_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
FALSE, TRUE, NOW(), NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM step_ratings WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B step_ratings"
finally:
await admin_conn.execute(f"DELETE FROM step_ratings WHERE id = '{rating_id}'")
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
# ---------------------------------------------------------------------------
# step_usage_log
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see step usage logs belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
tree_b_id = await _get_tree_b_id(admin_conn)
step_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO step_library (
id, account_id, title, step_type, content,
visibility, is_active, created_at, updated_at
) VALUES (
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 Usage Step', 'action',
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
)
""")
# Need a sessions row as FK for usage log
session_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO sessions (
id, tree_id, user_id, account_id, tree_snapshot,
path_taken, decisions, custom_steps, started_at
) VALUES (
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
)
""")
log_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO step_usage_log (
id, step_id, user_id, account_id, session_id, used_at
) VALUES (
'{log_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
'{session_id}', NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM step_usage_log WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B step_usage_log"
finally:
await admin_conn.execute(f"DELETE FROM step_usage_log WHERE id = '{log_id}'")
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
# ---------------------------------------------------------------------------
# target_lists
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see target lists belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
tl_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO target_lists (
id, account_id, created_by, name, targets, created_at, updated_at
) VALUES (
'{tl_id}', '{ACCOUNT_B_ID}', '{user_b_id}',
'Phase3 RLS Target List', '[]'::jsonb, NOW(), NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM target_lists WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B target_lists"
finally:
await admin_conn.execute(f"DELETE FROM target_lists WHERE id = '{tl_id}'")
# ---------------------------------------------------------------------------
# session_shares
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see session shares belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
tree_b_id = await _get_tree_b_id(admin_conn)
# Need a sessions row as FK
session_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO sessions (
id, tree_id, user_id, account_id, tree_snapshot,
path_taken, decisions, custom_steps, started_at
) VALUES (
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
)
""")
share_id = str(uuid.uuid4())
share_token = f"phase3-rls-test-{share_id[:8]}"
await admin_conn.execute(f"""
INSERT INTO session_shares (
id, session_id, account_id, share_token, visibility,
created_by, view_count, is_active, created_at, updated_at
) VALUES (
'{share_id}', '{session_id}', '{ACCOUNT_B_ID}',
'{share_token}', 'account', '{user_b_id}',
0, TRUE, NOW(), NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM session_shares WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B session_shares"
finally:
await admin_conn.execute(f"DELETE FROM session_shares WHERE id = '{share_id}'")
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
# ---------------------------------------------------------------------------
# audit_logs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see audit logs belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
log_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO audit_logs (
id, user_id, account_id, action, resource_type, created_at
) VALUES (
'{log_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
'test.action', 'test_resource', NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM audit_logs WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B audit_logs"
finally:
await admin_conn.execute(f"DELETE FROM audit_logs WHERE id = '{log_id}'")
# ---------------------------------------------------------------------------
# tree_shares
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
"""Account A must not see tree shares belonging to Account B."""
user_b_id = await _get_user_b_id(admin_conn)
tree_b_id = await _get_tree_b_id(admin_conn)
share_id = str(uuid.uuid4())
share_token = f"phase3-tree-rls-{share_id[:8]}"
await admin_conn.execute(f"""
INSERT INTO tree_shares (
id, tree_id, account_id, share_token, created_by,
allow_forking, created_at
) VALUES (
'{share_id}', '{tree_b_id}', '{ACCOUNT_B_ID}',
'{share_token}', '{user_b_id}', TRUE, NOW()
)
""")
try:
rows = await conn_a.fetch(
f"SELECT id FROM tree_shares WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B tree_shares"
finally:
await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'")

View File

@@ -3,37 +3,10 @@ import pytest
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.models.team import Team
from app.models.user import User from app.models.user import User
from sqlalchemy import select from sqlalchemy import select
@pytest.fixture
async def auth_headers(client: AsyncClient, test_db: AsyncSession, test_user: dict):
"""Override auth_headers to ensure the test user has a team_id assigned."""
# Fetch the user from DB and assign a team
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
user = result.scalar_one()
# Create a team and assign the user to it
team = Team(name="Test Team")
test_db.add(team)
await test_db.flush()
user.team_id = team.id
await test_db.commit()
# Re-login to get a fresh token
login_data = {
"email": test_user["email"],
"password": test_user["password"],
}
resp = await client.post("/api/v1/auth/login/json", json=login_data)
assert resp.status_code == 200
token_data = resp.json()
return {"Authorization": f"Bearer {token_data['access_token']}"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_target_list(client: AsyncClient, auth_headers: dict): async def test_create_target_list(client: AsyncClient, auth_headers: dict):
resp = await client.post( resp = await client.post(
@@ -107,25 +80,28 @@ async def test_delete_target_list(client: AsyncClient, auth_headers: dict):
assert get.status_code == 404 assert get.status_code == 404
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: dict, test_db): async def test_cannot_access_other_accounts_list(client: AsyncClient, auth_headers: dict, test_db):
"""User from team B cannot access team A's list.""" """User from account B cannot access account A's target list."""
import uuid import uuid
from app.models.team import Team from app.models.account import Account
from app.models.user import User from app.models.user import User
from app.core.security import get_password_hash from app.core.security import get_password_hash
# Create team A list using existing auth_headers # Create account A list using existing auth_headers
create = await client.post( create = await client.post(
"/api/v1/target-lists/", "/api/v1/target-lists/",
json={"name": "Team A List", "targets": [{"label": "SRV-A"}]}, json={"name": "Account A List", "targets": [{"label": "SRV-A"}]},
headers=auth_headers, headers=auth_headers,
) )
assert create.status_code == 201 assert create.status_code == 201
list_id = create.json()["id"] list_id = create.json()["id"]
# Create a separate team B with its own user # Create a separate account B with its own user
team_b = Team(name=f"Team B {uuid.uuid4()}") account_b = Account(
test_db.add(team_b) name=f"Account B {uuid.uuid4()}",
display_code=f"AB{str(uuid.uuid4())[:6].upper()}",
)
test_db.add(account_b)
await test_db.flush() await test_db.flush()
user_b = User( user_b = User(
@@ -133,11 +109,13 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
password_hash=get_password_hash("password123"), password_hash=get_password_hash("password123"),
name="User B", name="User B",
is_active=True, is_active=True,
team_id=team_b.id, account_id=account_b.id,
account_role="engineer",
role="engineer", role="engineer",
) )
test_db.add(user_b) test_db.add(user_b)
await test_db.flush() await test_db.flush()
await test_db.commit()
# Get auth token for user B # Get auth token for user B
login = await client.post( login = await client.post(
@@ -148,6 +126,6 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
token_b = login.json()["access_token"] token_b = login.json()["access_token"]
headers_b = {"Authorization": f"Bearer {token_b}"} headers_b = {"Authorization": f"Bearer {token_b}"}
# Team B cannot access Team A's list # Account B cannot access Account A's list
resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b) resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b)
assert resp.status_code == 404 assert resp.status_code == 404

View File

@@ -117,6 +117,7 @@ class TestTreeSharing:
for i in range(3): for i in range(3):
share = TreeShare( share = TreeShare(
tree_id=sample_tree.id, tree_id=sample_tree.id,
account_id=sample_tree.account_id,
share_token=f"token_{i}_" + "x" * 56, share_token=f"token_{i}_" + "x" * 56,
created_by=sample_tree.author_id, created_by=sample_tree.author_id,
allow_forking=i % 2 == 0 allow_forking=i % 2 == 0
@@ -162,6 +163,7 @@ class TestTreeSharing:
# Create a share # Create a share
share = TreeShare( share = TreeShare(
tree_id=sample_tree.id, tree_id=sample_tree.id,
account_id=sample_tree.account_id,
share_token="public_test_token" + "x" * 47, share_token="public_test_token" + "x" * 47,
created_by=UUID(test_user["user_data"]["id"]), created_by=UUID(test_user["user_data"]["id"]),
allow_forking=True allow_forking=True
@@ -192,6 +194,7 @@ class TestTreeSharing:
# Create expired share # Create expired share
share = TreeShare( share = TreeShare(
tree_id=sample_tree.id, tree_id=sample_tree.id,
account_id=sample_tree.account_id,
share_token="expired_token" + "x" * 50, share_token="expired_token" + "x" * 50,
created_by=UUID(test_user["user_data"]["id"]), created_by=UUID(test_user["user_data"]["id"]),
allow_forking=True, allow_forking=True,
@@ -209,6 +212,7 @@ class TestTreeSharing:
from uuid import UUID from uuid import UUID
share = TreeShare( share = TreeShare(
tree_id=sample_tree.id, tree_id=sample_tree.id,
account_id=sample_tree.account_id,
share_token="inactive_tree_token" + "x" * 44, share_token="inactive_tree_token" + "x" * 44,
created_by=UUID(test_user["user_data"]["id"]), created_by=UUID(test_user["user_data"]["id"]),
allow_forking=True allow_forking=True
@@ -248,6 +252,37 @@ class TestTreeSharing:
tokens.add(token) tokens.add(token)
assert len(tokens) == 5 assert len(tokens) == 5
async def test_share_account_id_matches_tree_not_actor(
self, client: AsyncClient, sample_tree, auth_headers, test_db
):
"""Share account_id must equal tree.account_id, not the actor's account_id.
A super admin in a different account can share any tree. The resulting
TreeShare row must live in the tree-owner's account so that the tree
owner's RLS context covers it. If account_id were derived from the
actor instead, the share would vanish from the tree owner's view once
RLS is enabled.
"""
from uuid import UUID
from sqlalchemy import select
response = await client.post(
f"/api/v1/trees/{sample_tree.id}/share",
json={"allow_forking": True},
headers=auth_headers,
)
assert response.status_code == 201
share_token = response.json()["share_token"]
result = await test_db.execute(
select(TreeShare).where(TreeShare.share_token == share_token)
)
share = result.scalar_one()
assert share.account_id == sample_tree.account_id, (
"TreeShare.account_id must equal tree.account_id, not the actor's account. "
"Shares must live in the tree owner's tenant for RLS to cover them."
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_migration_defaults_visibility_to_team(test_db): async def test_migration_defaults_visibility_to_team(test_db):