fix: add auth guards, optimize refresh loop, and improve batch session tests

- Move mid-file pydantic/uuid imports to top of sessions.py
- Add can_access_tree, is_active, and draft status guards to batch_launch_sessions
- Remove notes field from _BatchTarget to keep API clean
- Add max_length=100 cap to targets list in _BatchLaunchRequest
- Hoist tree_snapshot computation above the session creation loop
- Replace N db.refresh() calls with a single bulk select after flush
- Add test_batch_launch_requires_auth and test_batch_launch_rejects_draft_tree tests
- Fix trailing slash on /api/v1/trees/ URL in new test (caused 307 redirect)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-02-17 12:29:04 -05:00
parent b78a50c8c5
commit 5a3af9c87e
2 changed files with 67 additions and 16 deletions

View File

@@ -1,8 +1,10 @@
from datetime import datetime, timezone
from typing import Annotated, Optional
from uuid import UUID
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, Field as PydanticField
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
@@ -489,18 +491,14 @@ async def save_session_as_tree(
# ── Batch Launch (Maintenance Flows) ──────────────────────────────────────
from pydantic import BaseModel, Field as PydanticField
import uuid as uuid_mod
class _BatchTarget(BaseModel):
label: str = PydanticField(..., min_length=1, max_length=255)
notes: Optional[str] = None
class _BatchLaunchRequest(BaseModel):
tree_id: UUID
targets: list[_BatchTarget] = PydanticField(..., min_length=1)
targets: list[_BatchTarget] = PydanticField(..., min_length=1, max_length=100)
class _BatchLaunchResponse(BaseModel):
@@ -520,19 +518,31 @@ async def batch_launch_sessions(
tree = tree_result.scalar_one_or_none()
if not tree:
raise HTTPException(status_code=404, detail="Tree not found")
if not can_access_tree(current_user, tree):
raise HTTPException(status_code=403, detail="Access denied")
if not tree.is_active:
raise HTTPException(status_code=400, detail="Cannot batch-launch an inactive flow")
if tree.status == 'draft':
raise HTTPException(status_code=400, detail="Cannot batch-launch a draft flow")
if tree.tree_type != "maintenance":
raise HTTPException(status_code=400, detail="Batch launch is only for maintenance flows")
batch_id = uuid_mod.uuid4()
batch_id = uuid.uuid4()
created_sessions = []
# Hoist snapshot computation out of the loop — same tree for all targets
tree_snapshot = {
**tree.tree_structure,
"name": tree.name,
"description": tree.description,
"tree_type": tree.tree_type,
}
for target in data.targets:
tree_snapshot = {
**tree.tree_structure,
"name": tree.name,
"description": tree.description,
"tree_type": tree.tree_type,
}
session = Session(
tree_id=tree.id,
user_id=current_user.id,
@@ -548,8 +558,9 @@ async def batch_launch_sessions(
created_sessions.append(session)
await db.flush()
for s in created_sessions:
await db.refresh(s)
session_ids = [s.id for s in created_sessions]
result = await db.execute(select(Session).where(Session.id.in_(session_ids)))
created_sessions = result.scalars().all()
await db.commit()
return _BatchLaunchResponse(

View File

@@ -32,8 +32,8 @@ async def test_batch_launch_creates_one_session_per_target(client: AsyncClient,
json={
"tree_id": tree_id,
"targets": [
{"label": "RDS-01", "notes": "192.168.1.10"},
{"label": "RDS-02", "notes": "192.168.1.11"},
{"label": "RDS-01"},
{"label": "RDS-02"},
{"label": "RDS-03"},
],
},
@@ -90,3 +90,43 @@ async def test_batch_launch_rejects_non_maintenance_tree(client: AsyncClient, au
headers=auth_headers,
)
assert batch_resp.status_code == 400
@pytest.mark.asyncio
async def test_batch_launch_requires_auth(client: AsyncClient):
"""Unauthenticated batch launch returns 401."""
resp = await client.post(
"/api/v1/sessions/batch",
json={"tree_id": "00000000-0000-0000-0000-000000000000", "targets": [{"label": "SRV-01"}]},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_batch_launch_rejects_draft_tree(client: AsyncClient, auth_headers: dict):
"""Batch launch against a draft maintenance tree returns 400."""
# Create a maintenance tree — trees default to 'published', so we explicitly set draft
resp = await client.post(
"/api/v1/trees",
json={
"name": "Draft Maintenance",
"tree_type": "maintenance",
"status": "draft",
"tree_structure": {
"steps": [
{"id": "s1", "type": "procedure_step", "title": "Step",
"description": "Do it", "content_type": "action"},
{"id": "end", "type": "procedure_end", "title": "Done"},
]
},
},
headers=auth_headers,
)
assert resp.status_code == 201, resp.text
tree_id = resp.json()["id"]
batch_resp = await client.post(
"/api/v1/sessions/batch",
json={"tree_id": tree_id, "targets": [{"label": "SRV-01"}]},
headers=auth_headers,
)
assert batch_resp.status_code == 400