From 5a3af9c87e1ee896cbb7e40fd82a5e4dfc20b598 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Tue, 17 Feb 2026 12:29:04 -0500 Subject: [PATCH] fix: add auth guards, optimize refresh loop, and improve batch session tests - Move mid-file pydantic/uuid imports to top of sessions.py - Add can_access_tree, is_active, and draft status guards to batch_launch_sessions - Remove notes field from _BatchTarget to keep API clean - Add max_length=100 cap to targets list in _BatchLaunchRequest - Hoist tree_snapshot computation above the session creation loop - Replace N db.refresh() calls with a single bulk select after flush - Add test_batch_launch_requires_auth and test_batch_launch_rejects_draft_tree tests - Fix trailing slash on /api/v1/trees/ URL in new test (caused 307 redirect) Co-Authored-By: Claude Sonnet 4.5 --- backend/app/api/endpoints/sessions.py | 39 +++++++++++++++--------- backend/tests/test_batch_sessions.py | 44 +++++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 16 deletions(-) diff --git a/backend/app/api/endpoints/sessions.py b/backend/app/api/endpoints/sessions.py index 912c03cc..0e4911b7 100644 --- a/backend/app/api/endpoints/sessions.py +++ b/backend/app/api/endpoints/sessions.py @@ -1,8 +1,10 @@ from datetime import datetime, timezone from typing import Annotated, Optional from uuid import UUID +import uuid from fastapi import APIRouter, Depends, HTTPException, status, Query from fastapi.responses import PlainTextResponse +from pydantic import BaseModel, Field as PydanticField from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select @@ -489,18 +491,14 @@ async def save_session_as_tree( # ── Batch Launch (Maintenance Flows) ────────────────────────────────────── -from pydantic import BaseModel, Field as PydanticField -import uuid as uuid_mod - class _BatchTarget(BaseModel): label: str = PydanticField(..., min_length=1, max_length=255) - notes: Optional[str] = None class _BatchLaunchRequest(BaseModel): tree_id: UUID - targets: list[_BatchTarget] = PydanticField(..., min_length=1) + targets: list[_BatchTarget] = PydanticField(..., min_length=1, max_length=100) class _BatchLaunchResponse(BaseModel): @@ -520,19 +518,31 @@ async def batch_launch_sessions( tree = tree_result.scalar_one_or_none() if not tree: raise HTTPException(status_code=404, detail="Tree not found") + + if not can_access_tree(current_user, tree): + raise HTTPException(status_code=403, detail="Access denied") + + if not tree.is_active: + raise HTTPException(status_code=400, detail="Cannot batch-launch an inactive flow") + + if tree.status == 'draft': + raise HTTPException(status_code=400, detail="Cannot batch-launch a draft flow") + if tree.tree_type != "maintenance": raise HTTPException(status_code=400, detail="Batch launch is only for maintenance flows") - batch_id = uuid_mod.uuid4() + batch_id = uuid.uuid4() created_sessions = [] + # Hoist snapshot computation out of the loop — same tree for all targets + tree_snapshot = { + **tree.tree_structure, + "name": tree.name, + "description": tree.description, + "tree_type": tree.tree_type, + } + for target in data.targets: - tree_snapshot = { - **tree.tree_structure, - "name": tree.name, - "description": tree.description, - "tree_type": tree.tree_type, - } session = Session( tree_id=tree.id, user_id=current_user.id, @@ -548,8 +558,9 @@ async def batch_launch_sessions( created_sessions.append(session) await db.flush() - for s in created_sessions: - await db.refresh(s) + session_ids = [s.id for s in created_sessions] + result = await db.execute(select(Session).where(Session.id.in_(session_ids))) + created_sessions = result.scalars().all() await db.commit() return _BatchLaunchResponse( diff --git a/backend/tests/test_batch_sessions.py b/backend/tests/test_batch_sessions.py index 41938b58..6598d71b 100644 --- a/backend/tests/test_batch_sessions.py +++ b/backend/tests/test_batch_sessions.py @@ -32,8 +32,8 @@ async def test_batch_launch_creates_one_session_per_target(client: AsyncClient, json={ "tree_id": tree_id, "targets": [ - {"label": "RDS-01", "notes": "192.168.1.10"}, - {"label": "RDS-02", "notes": "192.168.1.11"}, + {"label": "RDS-01"}, + {"label": "RDS-02"}, {"label": "RDS-03"}, ], }, @@ -90,3 +90,43 @@ async def test_batch_launch_rejects_non_maintenance_tree(client: AsyncClient, au headers=auth_headers, ) assert batch_resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_batch_launch_requires_auth(client: AsyncClient): + """Unauthenticated batch launch returns 401.""" + resp = await client.post( + "/api/v1/sessions/batch", + json={"tree_id": "00000000-0000-0000-0000-000000000000", "targets": [{"label": "SRV-01"}]}, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_batch_launch_rejects_draft_tree(client: AsyncClient, auth_headers: dict): + """Batch launch against a draft maintenance tree returns 400.""" + # Create a maintenance tree — trees default to 'published', so we explicitly set draft + resp = await client.post( + "/api/v1/trees", + json={ + "name": "Draft Maintenance", + "tree_type": "maintenance", + "status": "draft", + "tree_structure": { + "steps": [ + {"id": "s1", "type": "procedure_step", "title": "Step", + "description": "Do it", "content_type": "action"}, + {"id": "end", "type": "procedure_end", "title": "Done"}, + ] + }, + }, + headers=auth_headers, + ) + assert resp.status_code == 201, resp.text + tree_id = resp.json()["id"] + batch_resp = await client.post( + "/api/v1/sessions/batch", + json={"tree_id": tree_id, "targets": [{"label": "SRV-01"}]}, + headers=auth_headers, + ) + assert batch_resp.status_code == 400