fix: add cross-team authorization to maintenance schedule endpoints
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,7 @@ import pytz
|
|||||||
|
|
||||||
from app.api.deps import get_current_active_user, get_db
|
from app.api.deps import get_current_active_user, get_db
|
||||||
from app.models.maintenance_schedule import MaintenanceSchedule
|
from app.models.maintenance_schedule import MaintenanceSchedule
|
||||||
|
from app.models.tree import Tree
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.maintenance_schedule import (
|
from app.schemas.maintenance_schedule import (
|
||||||
MaintenanceScheduleCreate,
|
MaintenanceScheduleCreate,
|
||||||
@@ -28,6 +29,19 @@ def _compute_next_run(cron_expression: str, tz_name: str) -> datetime:
|
|||||||
return cron.get_next(datetime).astimezone(timezone.utc)
|
return cron.get_next(datetime).astimezone(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_tree_or_403(tree_id: UUID, current_user: User, db: AsyncSession) -> "Tree":
|
||||||
|
"""Fetch tree and verify the current user's team owns it."""
|
||||||
|
result = await db.execute(select(Tree).where(Tree.id == tree_id))
|
||||||
|
tree = result.scalar_one_or_none()
|
||||||
|
if not tree:
|
||||||
|
raise HTTPException(status_code=404, detail="Tree not found")
|
||||||
|
# Super admins can access any tree; regular users must be on the same team
|
||||||
|
if not getattr(current_user, 'is_super_admin', False):
|
||||||
|
if tree.team_id != current_user.team_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
return tree
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=MaintenanceScheduleResponse, status_code=201)
|
@router.post("", response_model=MaintenanceScheduleResponse, status_code=201)
|
||||||
async def create_schedule(
|
async def create_schedule(
|
||||||
data: MaintenanceScheduleCreate,
|
data: MaintenanceScheduleCreate,
|
||||||
@@ -35,6 +49,9 @@ async def create_schedule(
|
|||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
):
|
):
|
||||||
"""Create a cron schedule for a maintenance flow. One per flow."""
|
"""Create a cron schedule for a maintenance flow. One per flow."""
|
||||||
|
# Verify user's team owns the tree
|
||||||
|
await _get_tree_or_403(data.tree_id, current_user, db)
|
||||||
|
|
||||||
# Check no existing schedule for this tree
|
# Check no existing schedule for this tree
|
||||||
existing = await db.execute(
|
existing = await db.execute(
|
||||||
select(MaintenanceSchedule).where(MaintenanceSchedule.tree_id == data.tree_id)
|
select(MaintenanceSchedule).where(MaintenanceSchedule.tree_id == data.tree_id)
|
||||||
@@ -69,6 +86,9 @@ async def get_schedule_for_tree(
|
|||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
):
|
):
|
||||||
"""Get the schedule for a specific maintenance flow."""
|
"""Get the schedule for a specific maintenance flow."""
|
||||||
|
# Verify user's team owns the tree before returning schedule data
|
||||||
|
await _get_tree_or_403(tree_id, current_user, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(MaintenanceSchedule).where(MaintenanceSchedule.tree_id == tree_id)
|
select(MaintenanceSchedule).where(MaintenanceSchedule.tree_id == tree_id)
|
||||||
)
|
)
|
||||||
@@ -93,6 +113,9 @@ async def update_schedule(
|
|||||||
if not schedule:
|
if not schedule:
|
||||||
raise HTTPException(status_code=404, detail="Schedule not found")
|
raise HTTPException(status_code=404, detail="Schedule not found")
|
||||||
|
|
||||||
|
# Verify user's team owns the tree this schedule belongs to
|
||||||
|
await _get_tree_or_403(schedule.tree_id, current_user, db)
|
||||||
|
|
||||||
update_fields = data.model_fields_set
|
update_fields = data.model_fields_set
|
||||||
if "cron_expression" in update_fields and data.cron_expression is not None:
|
if "cron_expression" in update_fields and data.cron_expression is not None:
|
||||||
schedule.cron_expression = data.cron_expression
|
schedule.cron_expression = data.cron_expression
|
||||||
|
|||||||
@@ -96,3 +96,45 @@ async def test_get_schedule_not_found(client: AsyncClient, auth_headers: dict):
|
|||||||
tree_id = await _create_maintenance_tree(client, auth_headers)
|
tree_id = await _create_maintenance_tree(client, auth_headers)
|
||||||
resp = await client.get(f"/api/v1/maintenance-schedules/tree/{tree_id}", headers=auth_headers)
|
resp = await client.get(f"/api/v1/maintenance-schedules/tree/{tree_id}", headers=auth_headers)
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cannot_schedule_other_teams_tree(client: AsyncClient, auth_headers: dict, test_db):
|
||||||
|
"""User cannot create a schedule for a tree belonging to another team."""
|
||||||
|
import uuid as _uuid
|
||||||
|
from app.models.team import Team
|
||||||
|
from app.models.tree import Tree
|
||||||
|
|
||||||
|
# Create a tree belonging to a DIFFERENT team directly in DB
|
||||||
|
other_team = Team(name=f"Other Team {_uuid.uuid4()}")
|
||||||
|
test_db.add(other_team)
|
||||||
|
await test_db.flush()
|
||||||
|
|
||||||
|
other_tree = Tree(
|
||||||
|
name="Other Team Tree",
|
||||||
|
tree_type="maintenance",
|
||||||
|
team_id=other_team.id,
|
||||||
|
tree_structure={
|
||||||
|
"steps": [
|
||||||
|
{"id": "s1", "type": "procedure_step", "title": "Step",
|
||||||
|
"description": "Do it", "content_type": "action"},
|
||||||
|
{"id": "end", "type": "procedure_end", "title": "Done"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
status="published",
|
||||||
|
visibility="team",
|
||||||
|
)
|
||||||
|
test_db.add(other_tree)
|
||||||
|
await test_db.flush()
|
||||||
|
|
||||||
|
# Current user (from auth_headers) tries to schedule it
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/maintenance-schedules",
|
||||||
|
json={
|
||||||
|
"tree_id": str(other_tree.id),
|
||||||
|
"cron_expression": "0 9 1 * *",
|
||||||
|
"timezone": "UTC",
|
||||||
|
},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code in (403, 404) # either is acceptable
|
||||||
|
|||||||
Reference in New Issue
Block a user