chore: resolve merge conflicts with main
- deps.py: keep require_tenant_context + require_admin_db (RLS deps); drop unused get_tenant_context stub from Phase 0 - categories.py: keep both PLATFORM_ACCOUNT_ID and tenant_filter imports (body uses both) - tenant-isolation spec: keep main's resolved TargetList/teams audit answers Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -519,11 +519,15 @@ async def save_task_lane(
|
||||
_: None = Depends(require_engineer_or_admin),
|
||||
):
|
||||
"""Save the current task lane state including user's in-progress responses."""
|
||||
session = await db.get(AISession, session_id)
|
||||
result = await db.execute(
|
||||
select(AISession).where(
|
||||
AISession.id == session_id,
|
||||
AISession.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
session = result.scalar_one_or_none()
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
if session.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not your session")
|
||||
|
||||
payload = {
|
||||
"questions": [q.model_dump() for q in body.questions],
|
||||
@@ -762,13 +766,13 @@ async def search_sessions(
|
||||
limit: int = Query(5, ge=1, le=20),
|
||||
):
|
||||
"""Search AI sessions by content using full-text search. Used by Command Palette."""
|
||||
# Sessions are user-scoped. The list endpoint uses user_id only;
|
||||
# search must be consistent. Cross-user access requires explicit
|
||||
# escalation or session sharing — not ambient account membership.
|
||||
result = await db.execute(
|
||||
select(AISession)
|
||||
.where(
|
||||
or_(
|
||||
AISession.user_id == current_user.id,
|
||||
AISession.account_id == current_user.account_id,
|
||||
),
|
||||
AISession.user_id == current_user.id,
|
||||
text("ai_sessions.search_vector @@ plainto_tsquery('english', :q)"),
|
||||
)
|
||||
.params(q=q)
|
||||
@@ -901,7 +905,7 @@ async def get_session(
|
||||
pkg = session.escalation_package or {}
|
||||
is_handler = pkg.get("picked_up_by") == str(current_user.id)
|
||||
if session.user_id != current_user.id and session.escalated_to_id != current_user.id and not is_handler:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
||||
|
||||
return _build_session_detail(session)
|
||||
|
||||
@@ -917,6 +921,18 @@ async def get_documentation(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Get auto-generated documentation for a session."""
|
||||
# Verify session ownership — owner only. Documentation endpoints require direct
|
||||
# ownership; escalated_to_id / picked_up_by handlers use get_session (read-only).
|
||||
# This is consistent with stream_documentation which has the same owner-only check.
|
||||
result = await db.execute(
|
||||
select(AISession).where(
|
||||
AISession.id == session_id,
|
||||
AISession.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
if not result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
try:
|
||||
return await flowpilot_engine.get_session_documentation(
|
||||
session_id=session_id,
|
||||
@@ -942,13 +958,14 @@ async def stream_documentation(
|
||||
|
||||
# Verify session ownership
|
||||
result = await db.execute(
|
||||
select(AISession).where(AISession.id == session_id)
|
||||
select(AISession).where(
|
||||
AISession.id == session_id,
|
||||
AISession.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
session = result.scalar_one_or_none()
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
if session.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
@@ -1043,6 +1060,19 @@ async def retry_psa_push_endpoint(
|
||||
"""Manually retry a failed PSA documentation push."""
|
||||
from app.models.psa_post_log import PsaPostLog
|
||||
|
||||
# Verify the session belongs to the current user
|
||||
session_result = await db.execute(
|
||||
select(AISession).where(
|
||||
AISession.id == session_id,
|
||||
AISession.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
if not session_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found",
|
||||
)
|
||||
|
||||
# Find the latest failed push log for this session
|
||||
result = await db.execute(
|
||||
select(PsaPostLog)
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.filters import tenant_filter
|
||||
from app.models import User, Session, Tree, SessionRating
|
||||
from app.schemas.analytics import (
|
||||
TeamAnalyticsResponse, PersonalAnalyticsResponse, FlowAnalyticsResponse,
|
||||
@@ -290,8 +291,13 @@ async def get_flow_analytics(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Analytics for a specific flow."""
|
||||
# Verify tree exists
|
||||
result = await db.execute(select(Tree).where(Tree.id == tree_id))
|
||||
# Verify tree exists and belongs to the requesting user's account.
|
||||
result = await db.execute(
|
||||
select(Tree).where(
|
||||
Tree.id == tree_id,
|
||||
tenant_filter(Tree, current_user.account_id),
|
||||
)
|
||||
)
|
||||
tree = result.scalar_one_or_none()
|
||||
if not tree:
|
||||
raise HTTPException(status_code=404, detail="Flow not found")
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.schemas.category import CategoryCreate, CategoryUpdate, CategoryRespons
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.permissions import can_manage_category, can_create_category
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
from app.core.filters import tenant_filter
|
||||
|
||||
router = APIRouter(prefix="/categories", tags=["categories"])
|
||||
|
||||
@@ -109,10 +110,12 @@ async def get_category(
|
||||
detail="You don't have access to this category"
|
||||
)
|
||||
|
||||
# Get tree count
|
||||
# Get tree count — scoped to the requesting account so cross-account
|
||||
# trees in shared categories are not counted.
|
||||
count_query = select(func.count(Tree.id)).where(
|
||||
Tree.category_id == category.id,
|
||||
Tree.is_active == True
|
||||
Tree.is_active == True,
|
||||
tenant_filter(Tree, current_user.account_id),
|
||||
)
|
||||
count_result = await db.execute(count_query)
|
||||
tree_count = count_result.scalar() or 0
|
||||
|
||||
@@ -29,8 +29,8 @@ def _compute_next_run(cron_expression: str, tz_name: str) -> datetime:
|
||||
return cron.get_next(datetime).astimezone(timezone.utc)
|
||||
|
||||
|
||||
async def _get_tree_or_403(tree_id: UUID, current_user: User, db: AsyncSession) -> "Tree":
|
||||
"""Fetch tree and verify the current user's team owns it."""
|
||||
async def _get_tree_or_404(tree_id: UUID, current_user: User, db: AsyncSession) -> "Tree":
|
||||
"""Fetch tree and verify the current user's team owns it. Raises 404 if not found or access denied."""
|
||||
result = await db.execute(select(Tree).where(Tree.id == tree_id))
|
||||
tree = result.scalar_one_or_none()
|
||||
if not tree:
|
||||
@@ -38,7 +38,7 @@ async def _get_tree_or_403(tree_id: UUID, current_user: User, db: AsyncSession)
|
||||
# Super admins can access any tree; regular users must be on the same team
|
||||
if not getattr(current_user, 'is_super_admin', False):
|
||||
if tree.team_id != current_user.team_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
raise HTTPException(status_code=404, detail="Tree not found")
|
||||
return tree
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ async def create_schedule(
|
||||
):
|
||||
"""Create a cron schedule for a maintenance flow. One per flow."""
|
||||
# Verify user's team owns the tree
|
||||
tree = await _get_tree_or_403(data.tree_id, current_user, db)
|
||||
tree = await _get_tree_or_404(data.tree_id, current_user, db)
|
||||
if tree.tree_type != "maintenance":
|
||||
raise HTTPException(status_code=400, detail="Schedules are only supported for maintenance flows")
|
||||
|
||||
@@ -94,7 +94,7 @@ async def get_schedule_for_tree(
|
||||
):
|
||||
"""Get the schedule for a specific maintenance flow."""
|
||||
# Verify user's team owns the tree before returning schedule data
|
||||
await _get_tree_or_403(tree_id, current_user, db)
|
||||
await _get_tree_or_404(tree_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(MaintenanceSchedule).where(MaintenanceSchedule.tree_id == tree_id)
|
||||
@@ -122,7 +122,7 @@ async def update_schedule(
|
||||
raise HTTPException(status_code=404, detail="Schedule not found")
|
||||
|
||||
# Verify user's team owns the tree this schedule belongs to
|
||||
await _get_tree_or_403(schedule.tree_id, current_user, db)
|
||||
await _get_tree_or_404(schedule.tree_id, current_user, db)
|
||||
|
||||
update_fields = data.model_fields_set
|
||||
was_active = schedule.is_active
|
||||
|
||||
@@ -143,8 +143,8 @@ async def get_session(
|
||||
|
||||
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this session"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
return session
|
||||
@@ -234,8 +234,8 @@ async def update_session(
|
||||
|
||||
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this session"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
if session.completed_at:
|
||||
@@ -281,8 +281,8 @@ async def complete_session(
|
||||
|
||||
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this session"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
if session.completed_at:
|
||||
@@ -319,8 +319,8 @@ async def update_scratchpad(
|
||||
|
||||
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this session"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
session.scratchpad = data.scratchpad
|
||||
@@ -348,8 +348,8 @@ async def update_session_variables(
|
||||
|
||||
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this session"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
if session.completed_at:
|
||||
@@ -387,8 +387,8 @@ async def export_session(
|
||||
|
||||
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this session"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
# PDF export — separate path with binary response
|
||||
@@ -830,8 +830,8 @@ async def link_ticket(
|
||||
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
|
||||
if not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this session",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found",
|
||||
)
|
||||
|
||||
# Unlink
|
||||
|
||||
@@ -72,8 +72,8 @@ async def create_share(
|
||||
|
||||
if session.user_id != current_user.id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only the session owner can create share links"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
# Require account_id for account-scoped shares
|
||||
@@ -170,8 +170,8 @@ async def revoke_share(
|
||||
|
||||
if share.created_by != current_user.id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only the share creator can revoke it"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Share not found"
|
||||
)
|
||||
|
||||
share.is_active = False
|
||||
|
||||
@@ -95,8 +95,8 @@ async def get_step_category(
|
||||
# Check access: global categories visible to all, account categories only to account members
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this step category"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Step category not found"
|
||||
)
|
||||
|
||||
return StepCategoryResponse(
|
||||
|
||||
@@ -47,10 +47,10 @@ async def get_step_or_404(
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
if check_view and not can_view_step(current_user, step):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to view this step")
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
if check_edit and not can_edit_step(current_user, step):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to modify this step")
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
return step
|
||||
|
||||
|
||||
@@ -106,8 +106,8 @@ async def get_tag(
|
||||
# Check access: global tags visible to all, account tags only to account members
|
||||
if tag.account_id and tag.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this tag"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tag not found"
|
||||
)
|
||||
|
||||
return TagResponse.model_validate(tag)
|
||||
|
||||
@@ -612,9 +612,17 @@ async def update_tree(
|
||||
)
|
||||
|
||||
if not can_edit_tree(current_user, tree):
|
||||
# If the user can see this tree (same account, team visibility), give a 403 with
|
||||
# a clear message — returning 404 here would be confusing since GET returns 200.
|
||||
# For truly inaccessible trees (cross-account), return 404 to avoid confirming existence.
|
||||
if can_access_tree(current_user, tree):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have permission to edit this flow"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only edit your own trees"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tree not found"
|
||||
)
|
||||
|
||||
# Extract tags for separate handling
|
||||
@@ -1146,9 +1154,17 @@ async def update_tree_visibility(
|
||||
)
|
||||
|
||||
if not can_edit_tree(current_user, tree):
|
||||
# If the user can see this tree (same account, team visibility), give a 403 with
|
||||
# a clear message — returning 404 here would be confusing since GET returns 200.
|
||||
# For truly inaccessible trees (cross-account), return 404 to avoid confirming existence.
|
||||
if can_access_tree(current_user, tree):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have permission to edit this flow"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only edit your own trees"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tree not found"
|
||||
)
|
||||
|
||||
# Update visibility
|
||||
|
||||
@@ -255,9 +255,9 @@ async def get_upload_url(
|
||||
if upload is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found")
|
||||
|
||||
# Verify the upload belongs to the user's account
|
||||
# Verify the upload belongs to the user's account — 404 to avoid revealing existence
|
||||
if upload.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found")
|
||||
|
||||
url = storage_service.get_presigned_url(upload.storage_key)
|
||||
return {"url": url}
|
||||
@@ -311,9 +311,9 @@ async def delete_upload(
|
||||
if upload is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found")
|
||||
|
||||
# Verify ownership
|
||||
# Verify ownership — 404 to avoid revealing existence
|
||||
if upload.uploaded_by != current_user.id and not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found")
|
||||
|
||||
# Delete from S3
|
||||
await storage_service.delete_file(upload.storage_key)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""
|
||||
Centralized query filters for ResolutionFlow.
|
||||
|
||||
Provides reusable SQLAlchemy filter builders for tree access control
|
||||
and step visibility, used across multiple endpoint modules.
|
||||
Provides reusable SQLAlchemy filter builders for tree access control,
|
||||
step visibility, and the canonical tenant_filter used by all queries
|
||||
on tenant-scoped tables.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import or_, and_, true as sa_true
|
||||
@@ -13,6 +15,18 @@ if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
def tenant_filter(model, account_id: uuid.UUID):
|
||||
"""Primary app-layer tenant filter.
|
||||
|
||||
MUST be used in every SELECT/UPDATE/DELETE on tenant tables.
|
||||
RLS (Phase 2) is the safety net — this is the primary enforcement.
|
||||
|
||||
Usage:
|
||||
stmt = select(Tree).where(tenant_filter(Tree, current_user.account_id), ...)
|
||||
"""
|
||||
return model.account_id == account_id
|
||||
|
||||
|
||||
def build_tree_access_filter(current_user: User):
|
||||
"""Build the access filter for trees based on user permissions.
|
||||
|
||||
@@ -36,10 +50,11 @@ def build_tree_access_filter(current_user: User):
|
||||
Tree.author_id == current_user.id,
|
||||
]
|
||||
if current_user.account_id:
|
||||
# Team-visible trees: use tenant_filter as the account match
|
||||
conditions.append(
|
||||
and_(
|
||||
Tree.visibility == 'team',
|
||||
Tree.account_id == current_user.account_id
|
||||
tenant_filter(Tree, current_user.account_id),
|
||||
)
|
||||
)
|
||||
return or_(*conditions)
|
||||
@@ -58,11 +73,14 @@ def build_step_visibility_filter(current_user: User):
|
||||
if current_user.account_id:
|
||||
return or_(
|
||||
StepLibrary.visibility == 'public',
|
||||
and_(StepLibrary.visibility == 'team', StepLibrary.account_id == current_user.account_id),
|
||||
StepLibrary.created_by == current_user.id # Own private steps
|
||||
and_(
|
||||
StepLibrary.visibility == 'team',
|
||||
tenant_filter(StepLibrary, current_user.account_id),
|
||||
),
|
||||
StepLibrary.created_by == current_user.id,
|
||||
)
|
||||
else:
|
||||
return or_(
|
||||
StepLibrary.visibility == 'public',
|
||||
StepLibrary.created_by == current_user.id
|
||||
StepLibrary.created_by == current_user.id,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
@@ -103,13 +103,23 @@ async def start_conversation(
|
||||
|
||||
Returns (conversation, greeting_message).
|
||||
"""
|
||||
# Load tree
|
||||
# Load tree — must be accessible to this account.
|
||||
# Allows own account's trees, default trees, and public trees.
|
||||
# Raises ValueError (caught by endpoint as 404) if not found or not accessible.
|
||||
result = await db.execute(
|
||||
select(Tree).options(selectinload(Tree.tags)).where(Tree.id == tree_id)
|
||||
select(Tree).options(selectinload(Tree.tags)).where(
|
||||
Tree.id == tree_id,
|
||||
or_(
|
||||
Tree.account_id == account_id,
|
||||
Tree.author_id == user_id,
|
||||
Tree.is_default == True,
|
||||
Tree.is_public == True,
|
||||
),
|
||||
)
|
||||
)
|
||||
tree = result.scalar_one_or_none()
|
||||
if not tree:
|
||||
raise ValueError(f"Tree {tree_id} not found")
|
||||
raise ValueError(f"Tree {tree_id} not found or not accessible")
|
||||
|
||||
conversation = CopilotConversation(
|
||||
user_id=user_id,
|
||||
|
||||
91
backend/scripts/check_tenant_filters.py
Normal file
91
backend/scripts/check_tenant_filters.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Tenant filter enforcement check.
|
||||
|
||||
Scans endpoint and service files for SQLAlchemy select() calls on known
|
||||
tenant tables and warns when account_id or tenant_filter is not present
|
||||
in the surrounding 15 lines (the typical extent of a single query).
|
||||
|
||||
Usage:
|
||||
python scripts/check_tenant_filters.py # warn mode (exits 0)
|
||||
python scripts/check_tenant_filters.py --fail # block mode (exits 1 on findings)
|
||||
"""
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Tables that must always be filtered by account_id or tenant_filter.
|
||||
# Extend this list as new tenant tables are added.
|
||||
TENANT_MODELS = [
|
||||
"Tree", "AISession", "Session", "StepLibrary", "FlowProposal",
|
||||
"CopilotConversation", "AssistantChat", "FileUpload", "KBImport",
|
||||
"PsaConnection", "PsaPostLog", "PsaMemberMapping", "AIChatSession",
|
||||
"AIConversation", "AIUsage", "Subscription", "AccountInvite",
|
||||
"Notification", "NotificationConfig", "SessionShare", "UserFolder",
|
||||
"UserPinnedTree", "SessionBranch", "SessionHandoff",
|
||||
"SessionResolutionOutput", "ForkPoint", "AISessionStep",
|
||||
"AISuggestion", "StepCategory", "TreeCategory", "TreeTag",
|
||||
"Attachment", "SessionSupportingData", "MaintenanceSchedule",
|
||||
"AuditLog", "ScriptBuilderSession", "ScriptTemplate",
|
||||
"StepRating", "StepUsageLog", "TargetList",
|
||||
]
|
||||
|
||||
# Directories to scan
|
||||
SCAN_DIRS = [
|
||||
Path("app/api/endpoints"),
|
||||
Path("app/services"),
|
||||
]
|
||||
|
||||
# Patterns that indicate the query is correctly scoped.
|
||||
# NOTE: user_id scoping is accepted for user-owned resources (sessions, folders, notifications).
|
||||
# For account-shared resources (trees, steps, etc.) use tenant_filter or account_id.
|
||||
SAFE_PATTERNS = [
|
||||
r"tenant_filter",
|
||||
r"account_id",
|
||||
r"user_id", # User-scoped resources (sessions, folders, notifications, etc.)
|
||||
r"is_super_admin", # Super admin queries intentionally bypass tenant filter
|
||||
r"# cross-tenant: approved", # Explicit approval comment
|
||||
]
|
||||
|
||||
SKIP_FILES = {
|
||||
"admin.py", # Super admin endpoints intentionally bypass tenant filter
|
||||
"admin_gallery.py", # Gallery management — super admin only, no tenant scoping needed
|
||||
"public_templates.py",# Public template browser — intentionally cross-tenant
|
||||
"auth.py", # Auth/registration — no account context during login/register
|
||||
"ratings.py", # Session ratings — user-scoped via session lookup chain
|
||||
}
|
||||
|
||||
findings = []
|
||||
|
||||
for scan_dir in SCAN_DIRS:
|
||||
if not scan_dir.exists():
|
||||
continue
|
||||
for path in sorted(scan_dir.glob("*.py")):
|
||||
if path.name in SKIP_FILES:
|
||||
continue
|
||||
lines = path.read_text().splitlines()
|
||||
for i, line in enumerate(lines):
|
||||
for model in TENANT_MODELS:
|
||||
if re.search(rf"\bselect\s*\(\s*{model}\b", line):
|
||||
# Check surrounding 15 lines for a safe pattern
|
||||
start = max(0, i - 2)
|
||||
end = min(len(lines), i + 15)
|
||||
context = "\n".join(lines[start:end])
|
||||
if not any(re.search(p, context) for p in SAFE_PATTERNS):
|
||||
findings.append(
|
||||
f"{path}:{i + 1}: select({model}) — no tenant_filter or account_id found in context"
|
||||
)
|
||||
|
||||
if findings:
|
||||
print(f"\n⚠ Tenant filter check — {len(findings)} warning(s):\n")
|
||||
for f in findings:
|
||||
print(f" {f}")
|
||||
print()
|
||||
if "--fail" in sys.argv:
|
||||
print("Run with --fail: exiting 1")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("Run in warn mode — not blocking. Pass --fail to block.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("✓ Tenant filter check passed — no unscoped tenant table queries found.")
|
||||
sys.exit(0)
|
||||
578
backend/tests/test_tenant_isolation_p0.py
Normal file
578
backend/tests/test_tenant_isolation_p0.py
Normal file
@@ -0,0 +1,578 @@
|
||||
"""Phase 0 tenant-isolation tests.
|
||||
|
||||
Verifies that endpoints respect account boundaries and don't leak data
|
||||
across tenants. Each task group tests a specific endpoint fix.
|
||||
"""
|
||||
import uuid
|
||||
import datetime as dt
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.models.tree import Tree
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _create_account_and_user(db: AsyncSession, prefix: str):
|
||||
"""Create a fresh account + engineer user. Returns (account, user, plain_password)."""
|
||||
password = "TestPass123!"
|
||||
account = Account(
|
||||
name=f"{prefix}-corp",
|
||||
display_code=uuid.uuid4().hex[:8],
|
||||
)
|
||||
db.add(account)
|
||||
await db.flush()
|
||||
|
||||
user = User(
|
||||
email=f"{prefix}-{uuid.uuid4().hex[:6]}@example.com",
|
||||
name=f"{prefix} user",
|
||||
password_hash=get_password_hash(password),
|
||||
is_active=True,
|
||||
account_id=account.id,
|
||||
account_role="engineer",
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
return account, user, password
|
||||
|
||||
|
||||
async def _login(client: AsyncClient, email: str, password: str) -> dict:
|
||||
"""Log in and return Authorization headers."""
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
assert resp.status_code == 200, f"Login failed: {resp.text}"
|
||||
token = resp.json()["access_token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
async def _create_private_tree(db: AsyncSession, account: Account, user: User) -> Tree:
|
||||
"""Create a private tree owned by the given account/user."""
|
||||
tree = Tree(
|
||||
name=f"Private Tree {uuid.uuid4().hex[:6]}",
|
||||
account_id=account.id,
|
||||
author_id=user.id,
|
||||
visibility="private",
|
||||
tree_type="troubleshooting",
|
||||
tree_structure={"id": "root", "type": "start", "children": []},
|
||||
is_active=True,
|
||||
status="published",
|
||||
)
|
||||
db.add(tree)
|
||||
await db.flush()
|
||||
return tree
|
||||
|
||||
|
||||
# ── Task 3: Analytics flow endpoint ──────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analytics_flow_cannot_read_other_account_tree(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""Account A cannot read flow analytics for Account B's private tree."""
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "anl-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "anl-b")
|
||||
tree_b = await _create_private_tree(test_db, acct_b, user_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/v1/analytics/flows/{tree_b.id}",
|
||||
headers=headers_a,
|
||||
)
|
||||
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}: {resp.text}"
|
||||
|
||||
|
||||
# ── Task 4: Category tree count ───────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_category_tree_count_scoped_to_account(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""tree_count on a category must not include trees from other accounts."""
|
||||
from app.models.category import TreeCategory
|
||||
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "cat-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "cat-b")
|
||||
|
||||
# Shared category (account_id=None means global)
|
||||
category = TreeCategory(
|
||||
name="Shared Category",
|
||||
slug=f"shared-cat-{uuid.uuid4().hex[:6]}",
|
||||
account_id=None,
|
||||
is_active=True,
|
||||
)
|
||||
test_db.add(category)
|
||||
await test_db.flush()
|
||||
|
||||
# 3 trees for account_b under this category
|
||||
for i in range(3):
|
||||
tree = Tree(
|
||||
name=f"B Tree {i}",
|
||||
account_id=acct_b.id,
|
||||
author_id=user_b.id,
|
||||
category_id=category.id,
|
||||
visibility="team",
|
||||
tree_type="troubleshooting",
|
||||
tree_structure={"id": "root", "type": "start", "children": []},
|
||||
is_active=True,
|
||||
status="published",
|
||||
)
|
||||
test_db.add(tree)
|
||||
|
||||
# 1 tree for account_a under this category
|
||||
tree_a = Tree(
|
||||
name="A Tree",
|
||||
account_id=acct_a.id,
|
||||
author_id=user_a.id,
|
||||
category_id=category.id,
|
||||
visibility="team",
|
||||
tree_type="troubleshooting",
|
||||
tree_structure={"id": "root", "type": "start", "children": []},
|
||||
is_active=True,
|
||||
status="published",
|
||||
)
|
||||
test_db.add(tree_a)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/v1/categories/{category.id}",
|
||||
headers=headers_a,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
# account_a should only see their 1 tree, not account_b's 3
|
||||
assert resp.json()["tree_count"] == 1, (
|
||||
f"Expected tree_count=1 (own trees only), got {resp.json()['tree_count']}"
|
||||
)
|
||||
|
||||
|
||||
# ── Task 5: AI session search scope ──────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_session_search_cannot_see_other_users_sessions(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""User A cannot find User B's AI sessions via the search endpoint,
|
||||
even when both users are in the same account."""
|
||||
from app.models.ai_session import AISession
|
||||
|
||||
# Two users in the SAME account
|
||||
account = Account(name="Shared Corp", display_code=uuid.uuid4().hex[:8])
|
||||
test_db.add(account)
|
||||
await test_db.flush()
|
||||
|
||||
password = "TestPass123!"
|
||||
user_a = User(
|
||||
email=f"user-a-{uuid.uuid4().hex[:6]}@shared.com",
|
||||
name="User A",
|
||||
password_hash=get_password_hash(password),
|
||||
is_active=True,
|
||||
account_id=account.id,
|
||||
account_role="engineer",
|
||||
)
|
||||
user_b = User(
|
||||
email=f"user-b-{uuid.uuid4().hex[:6]}@shared.com",
|
||||
name="User B",
|
||||
password_hash=get_password_hash(password),
|
||||
is_active=True,
|
||||
account_id=account.id,
|
||||
account_role="engineer",
|
||||
)
|
||||
test_db.add_all([user_a, user_b])
|
||||
await test_db.flush()
|
||||
|
||||
# Session belonging to user_b with distinctive problem_summary
|
||||
session_b = AISession(
|
||||
user_id=user_b.id,
|
||||
account_id=account.id,
|
||||
problem_summary="CONFIDENTIAL: user_b's session",
|
||||
problem_domain="networking",
|
||||
status="resolved",
|
||||
)
|
||||
test_db.add(session_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, password)
|
||||
|
||||
resp = await client.get(
|
||||
"/api/v1/ai-sessions/search",
|
||||
params={"q": "CONFIDENTIAL"},
|
||||
headers=headers_a,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
results = resp.json()
|
||||
ids = [r["id"] for r in results]
|
||||
assert str(session_b.id) not in ids, (
|
||||
"User A can see User B's session via search — cross-user leak within account"
|
||||
)
|
||||
|
||||
|
||||
# ── Task 6: Cross-tenant UUID audit ─────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tree_returns_404_not_403_for_other_account(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""Account A gets 404 (not 403) when accessing Account B's private tree."""
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-tree-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-tree-b")
|
||||
tree_b = await _create_private_tree(test_db, acct_b, user_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.get(f"/api/v1/trees/{tree_b.id}", headers=headers_a)
|
||||
assert resp.status_code == 404, (
|
||||
f"Expected 404 for cross-tenant tree access, got {resp.status_code}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tree_returns_404_not_403_for_other_account(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""Account A gets 404 (not 403) when trying to update Account B's tree."""
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-upd-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-upd-b")
|
||||
tree_b = await _create_private_tree(test_db, acct_b, user_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.put(
|
||||
f"/api/v1/trees/{tree_b.id}",
|
||||
json={"name": "Hacked"},
|
||||
headers=headers_a,
|
||||
)
|
||||
assert resp.status_code == 404, (
|
||||
f"Expected 404 for cross-tenant tree update, got {resp.status_code}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_returns_404_not_403_for_other_user(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""User A gets 404 (not 403) when accessing User B's session."""
|
||||
from app.models.session import Session
|
||||
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-sess-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-sess-b")
|
||||
tree_b = await _create_private_tree(test_db, acct_b, user_b)
|
||||
|
||||
session_b = Session(
|
||||
tree_id=tree_b.id,
|
||||
user_id=user_b.id,
|
||||
tree_snapshot={"id": "root", "type": "start", "children": []},
|
||||
path_taken=[],
|
||||
decisions=[],
|
||||
)
|
||||
test_db.add(session_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.get(f"/api/v1/sessions/{session_b.id}", headers=headers_a)
|
||||
assert resp.status_code == 404, (
|
||||
f"Expected 404 for cross-user session access, got {resp.status_code}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_session_get_returns_404_not_403_for_other_user(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""User A gets 404 (not 403) when accessing User B's AI session."""
|
||||
from app.models.ai_session import AISession
|
||||
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-ais-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-ais-b")
|
||||
|
||||
ai_session_b = AISession(
|
||||
user_id=user_b.id,
|
||||
account_id=acct_b.id,
|
||||
problem_summary="Test session",
|
||||
problem_domain="networking",
|
||||
status="active",
|
||||
)
|
||||
test_db.add(ai_session_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.get(f"/api/v1/ai-sessions/{ai_session_b.id}", headers=headers_a)
|
||||
assert resp.status_code == 404, (
|
||||
f"Expected 404 for cross-user AI session access, got {resp.status_code}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_session_retry_psa_push_requires_ownership(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""User A cannot retry PSA push for User B's AI session."""
|
||||
from app.models.ai_session import AISession
|
||||
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-psa-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-psa-b")
|
||||
|
||||
ai_session_b = AISession(
|
||||
user_id=user_b.id,
|
||||
account_id=acct_b.id,
|
||||
problem_summary="PSA test",
|
||||
problem_domain="networking",
|
||||
status="resolved",
|
||||
)
|
||||
test_db.add(ai_session_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.post(
|
||||
f"/api/v1/ai-sessions/{ai_session_b.id}/retry-psa-push",
|
||||
headers=headers_a,
|
||||
)
|
||||
assert resp.status_code == 404, (
|
||||
f"Expected 404 for cross-user retry-psa-push, got {resp.status_code}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_url_returns_404_not_403_for_other_account(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""User A gets 404 (not 403) when accessing User B's upload URL."""
|
||||
from app.models.file_upload import FileUpload
|
||||
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-upl-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-upl-b")
|
||||
|
||||
upload_b = FileUpload(
|
||||
account_id=acct_b.id,
|
||||
uploaded_by=user_b.id,
|
||||
filename="secret.png",
|
||||
content_type="image/png",
|
||||
size_bytes=1024,
|
||||
storage_key="test/secret.png",
|
||||
)
|
||||
test_db.add(upload_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.get(f"/api/v1/uploads/{upload_b.id}/url", headers=headers_a)
|
||||
assert resp.status_code in (404, 503), (
|
||||
f"Expected 404 (or 503 if storage not configured) for cross-account upload, got {resp.status_code}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_share_revoke_returns_404_not_403_for_other_user(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""User A gets 404 (not 403) when revoking User B's share."""
|
||||
from app.models.session import Session
|
||||
from app.models.session_share import SessionShare
|
||||
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-shr-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-shr-b")
|
||||
tree_b = await _create_private_tree(test_db, acct_b, user_b)
|
||||
|
||||
session_b = Session(
|
||||
tree_id=tree_b.id,
|
||||
user_id=user_b.id,
|
||||
tree_snapshot={"id": "root", "type": "start", "children": []},
|
||||
path_taken=[],
|
||||
decisions=[],
|
||||
)
|
||||
test_db.add(session_b)
|
||||
await test_db.flush()
|
||||
|
||||
share_b = SessionShare(
|
||||
session_id=session_b.id,
|
||||
account_id=acct_b.id,
|
||||
share_token="test-token-unique-" + uuid.uuid4().hex[:8],
|
||||
share_name="Test",
|
||||
visibility="public",
|
||||
created_by=user_b.id,
|
||||
)
|
||||
test_db.add(share_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.delete(f"/api/v1/shares/{share_b.id}", headers=headers_a)
|
||||
assert resp.status_code == 404, (
|
||||
f"Expected 404 for cross-user share revoke, got {resp.status_code}"
|
||||
)
|
||||
|
||||
|
||||
# ── Task 6 (continued): steps, tags, step_categories, maintenance_schedules ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_other_account_step(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""User A gets 404 when reading a team-visibility step owned by Account B."""
|
||||
from app.models.step_library import StepLibrary
|
||||
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-step-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-step-b")
|
||||
|
||||
# Create a team-visibility step owned by account B
|
||||
step_b = StepLibrary(
|
||||
title="Account B Confidential Step",
|
||||
step_type="action",
|
||||
content={"description": "secret step"},
|
||||
created_by=user_b.id,
|
||||
account_id=acct_b.id,
|
||||
visibility="team",
|
||||
is_active=True,
|
||||
)
|
||||
test_db.add(step_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.get(f"/api/v1/steps/{step_b.id}", headers=headers_a)
|
||||
assert resp.status_code == 404, (
|
||||
f"Expected 404 for cross-account step access, got {resp.status_code}: {resp.text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_other_account_tag(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""User A gets 404 when reading a tag scoped to Account B."""
|
||||
from app.models.tag import TreeTag
|
||||
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-tag-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-tag-b")
|
||||
|
||||
# Create an account-scoped tag for account B
|
||||
tag_b = TreeTag(
|
||||
name=f"account-b-tag-{uuid.uuid4().hex[:6]}",
|
||||
slug=f"account-b-tag-{uuid.uuid4().hex[:6]}",
|
||||
account_id=acct_b.id,
|
||||
)
|
||||
test_db.add(tag_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.get(f"/api/v1/tags/{tag_b.id}", headers=headers_a)
|
||||
assert resp.status_code == 404, (
|
||||
f"Expected 404 for cross-account tag access, got {resp.status_code}: {resp.text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_other_account_step_category(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""User A gets 404 when reading a step category scoped to Account B."""
|
||||
from app.models.step_category import StepCategory
|
||||
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-scat-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-scat-b")
|
||||
|
||||
# Create an account-scoped step category for account B
|
||||
category_b = StepCategory(
|
||||
name=f"Account B Category {uuid.uuid4().hex[:6]}",
|
||||
slug=f"account-b-cat-{uuid.uuid4().hex[:6]}",
|
||||
account_id=acct_b.id,
|
||||
is_active=True,
|
||||
)
|
||||
test_db.add(category_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.get(f"/api/v1/step-categories/{category_b.id}", headers=headers_a)
|
||||
assert resp.status_code == 404, (
|
||||
f"Expected 404 for cross-account step category access, got {resp.status_code}: {resp.text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maintenance_schedule_returns_404_for_other_team(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""User A gets 404 when reading a maintenance schedule belonging to Team B's tree."""
|
||||
from app.models.team import Team
|
||||
from app.models.maintenance_schedule import MaintenanceSchedule
|
||||
|
||||
# Create two separate teams
|
||||
team_a = Team(name="Team A Corp")
|
||||
team_b = Team(name="Team B Corp")
|
||||
test_db.add_all([team_a, team_b])
|
||||
await test_db.flush()
|
||||
|
||||
# Create accounts and users, assign to respective teams
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-ms-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-ms-b")
|
||||
user_a.team_id = team_a.id
|
||||
user_b.team_id = team_b.id
|
||||
await test_db.flush()
|
||||
|
||||
# Create a maintenance tree owned by team B
|
||||
tree_b = Tree(
|
||||
name="Team B Maintenance Flow",
|
||||
account_id=acct_b.id,
|
||||
author_id=user_b.id,
|
||||
team_id=team_b.id,
|
||||
visibility="team",
|
||||
tree_type="maintenance",
|
||||
tree_structure={"id": "root", "type": "start", "children": []},
|
||||
is_active=True,
|
||||
status="published",
|
||||
)
|
||||
test_db.add(tree_b)
|
||||
await test_db.flush()
|
||||
|
||||
# Create a schedule for that tree
|
||||
schedule_b = MaintenanceSchedule(
|
||||
tree_id=tree_b.id,
|
||||
created_by=user_b.id,
|
||||
cron_expression="0 2 * * 0",
|
||||
timezone="UTC",
|
||||
is_active=True,
|
||||
next_run_at=dt.datetime(2026, 12, 31, tzinfo=dt.timezone.utc),
|
||||
)
|
||||
test_db.add(schedule_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.get(f"/api/v1/maintenance-schedules/tree/{tree_b.id}", headers=headers_a)
|
||||
assert resp.status_code == 404, (
|
||||
f"Expected 404 for cross-team maintenance schedule access, got {resp.status_code}: {resp.text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_documentation_returns_404_for_other_user_session(
|
||||
client: AsyncClient, test_db: AsyncSession
|
||||
):
|
||||
"""GET /ai-sessions/{id}/documentation must return 404 (not 403) for cross-user access."""
|
||||
from app.models.ai_session import AISession
|
||||
|
||||
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "doc-a")
|
||||
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "doc-b")
|
||||
|
||||
session_b = AISession(
|
||||
user_id=user_b.id,
|
||||
account_id=acct_b.id,
|
||||
problem_summary="B's confidential session",
|
||||
problem_domain="networking",
|
||||
status="resolved",
|
||||
)
|
||||
test_db.add(session_b)
|
||||
await test_db.commit()
|
||||
|
||||
headers_a = await _login(client, user_a.email, pass_a)
|
||||
resp = await client.get(
|
||||
f"/api/v1/ai-sessions/{session_b.id}/documentation",
|
||||
headers=headers_a,
|
||||
)
|
||||
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}: {resp.text}"
|
||||
Reference in New Issue
Block a user