fix: add auth guards, optimize refresh loop, and improve batch session tests
- Move mid-file pydantic/uuid imports to top of sessions.py - Add can_access_tree, is_active, and draft status guards to batch_launch_sessions - Remove notes field from _BatchTarget to keep API clean - Add max_length=100 cap to targets list in _BatchLaunchRequest - Hoist tree_snapshot computation above the session creation loop - Replace N db.refresh() calls with a single bulk select after flush - Add test_batch_launch_requires_auth and test_batch_launch_rejects_draft_tree tests - Fix trailing slash on /api/v1/trees/ URL in new test (caused 307 redirect) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
import uuid
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
from pydantic import BaseModel, Field as PydanticField
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@@ -489,18 +491,14 @@ async def save_session_as_tree(
|
|||||||
|
|
||||||
# ── Batch Launch (Maintenance Flows) ──────────────────────────────────────
|
# ── Batch Launch (Maintenance Flows) ──────────────────────────────────────
|
||||||
|
|
||||||
from pydantic import BaseModel, Field as PydanticField
|
|
||||||
import uuid as uuid_mod
|
|
||||||
|
|
||||||
|
|
||||||
class _BatchTarget(BaseModel):
|
class _BatchTarget(BaseModel):
|
||||||
label: str = PydanticField(..., min_length=1, max_length=255)
|
label: str = PydanticField(..., min_length=1, max_length=255)
|
||||||
notes: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class _BatchLaunchRequest(BaseModel):
|
class _BatchLaunchRequest(BaseModel):
|
||||||
tree_id: UUID
|
tree_id: UUID
|
||||||
targets: list[_BatchTarget] = PydanticField(..., min_length=1)
|
targets: list[_BatchTarget] = PydanticField(..., min_length=1, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
class _BatchLaunchResponse(BaseModel):
|
class _BatchLaunchResponse(BaseModel):
|
||||||
@@ -520,19 +518,31 @@ async def batch_launch_sessions(
|
|||||||
tree = tree_result.scalar_one_or_none()
|
tree = tree_result.scalar_one_or_none()
|
||||||
if not tree:
|
if not tree:
|
||||||
raise HTTPException(status_code=404, detail="Tree not found")
|
raise HTTPException(status_code=404, detail="Tree not found")
|
||||||
|
|
||||||
|
if not can_access_tree(current_user, tree):
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
if not tree.is_active:
|
||||||
|
raise HTTPException(status_code=400, detail="Cannot batch-launch an inactive flow")
|
||||||
|
|
||||||
|
if tree.status == 'draft':
|
||||||
|
raise HTTPException(status_code=400, detail="Cannot batch-launch a draft flow")
|
||||||
|
|
||||||
if tree.tree_type != "maintenance":
|
if tree.tree_type != "maintenance":
|
||||||
raise HTTPException(status_code=400, detail="Batch launch is only for maintenance flows")
|
raise HTTPException(status_code=400, detail="Batch launch is only for maintenance flows")
|
||||||
|
|
||||||
batch_id = uuid_mod.uuid4()
|
batch_id = uuid.uuid4()
|
||||||
created_sessions = []
|
created_sessions = []
|
||||||
|
|
||||||
|
# Hoist snapshot computation out of the loop — same tree for all targets
|
||||||
|
tree_snapshot = {
|
||||||
|
**tree.tree_structure,
|
||||||
|
"name": tree.name,
|
||||||
|
"description": tree.description,
|
||||||
|
"tree_type": tree.tree_type,
|
||||||
|
}
|
||||||
|
|
||||||
for target in data.targets:
|
for target in data.targets:
|
||||||
tree_snapshot = {
|
|
||||||
**tree.tree_structure,
|
|
||||||
"name": tree.name,
|
|
||||||
"description": tree.description,
|
|
||||||
"tree_type": tree.tree_type,
|
|
||||||
}
|
|
||||||
session = Session(
|
session = Session(
|
||||||
tree_id=tree.id,
|
tree_id=tree.id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -548,8 +558,9 @@ async def batch_launch_sessions(
|
|||||||
created_sessions.append(session)
|
created_sessions.append(session)
|
||||||
|
|
||||||
await db.flush()
|
await db.flush()
|
||||||
for s in created_sessions:
|
session_ids = [s.id for s in created_sessions]
|
||||||
await db.refresh(s)
|
result = await db.execute(select(Session).where(Session.id.in_(session_ids)))
|
||||||
|
created_sessions = result.scalars().all()
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
return _BatchLaunchResponse(
|
return _BatchLaunchResponse(
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ async def test_batch_launch_creates_one_session_per_target(client: AsyncClient,
|
|||||||
json={
|
json={
|
||||||
"tree_id": tree_id,
|
"tree_id": tree_id,
|
||||||
"targets": [
|
"targets": [
|
||||||
{"label": "RDS-01", "notes": "192.168.1.10"},
|
{"label": "RDS-01"},
|
||||||
{"label": "RDS-02", "notes": "192.168.1.11"},
|
{"label": "RDS-02"},
|
||||||
{"label": "RDS-03"},
|
{"label": "RDS-03"},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
@@ -90,3 +90,43 @@ async def test_batch_launch_rejects_non_maintenance_tree(client: AsyncClient, au
|
|||||||
headers=auth_headers,
|
headers=auth_headers,
|
||||||
)
|
)
|
||||||
assert batch_resp.status_code == 400
|
assert batch_resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_launch_requires_auth(client: AsyncClient):
|
||||||
|
"""Unauthenticated batch launch returns 401."""
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/sessions/batch",
|
||||||
|
json={"tree_id": "00000000-0000-0000-0000-000000000000", "targets": [{"label": "SRV-01"}]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_launch_rejects_draft_tree(client: AsyncClient, auth_headers: dict):
|
||||||
|
"""Batch launch against a draft maintenance tree returns 400."""
|
||||||
|
# Create a maintenance tree — trees default to 'published', so we explicitly set draft
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/trees",
|
||||||
|
json={
|
||||||
|
"name": "Draft Maintenance",
|
||||||
|
"tree_type": "maintenance",
|
||||||
|
"status": "draft",
|
||||||
|
"tree_structure": {
|
||||||
|
"steps": [
|
||||||
|
{"id": "s1", "type": "procedure_step", "title": "Step",
|
||||||
|
"description": "Do it", "content_type": "action"},
|
||||||
|
{"id": "end", "type": "procedure_end", "title": "Done"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201, resp.text
|
||||||
|
tree_id = resp.json()["id"]
|
||||||
|
batch_resp = await client.post(
|
||||||
|
"/api/v1/sessions/batch",
|
||||||
|
json={"tree_id": tree_id, "targets": [{"label": "SRV-01"}]},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert batch_resp.status_code == 400
|
||||||
|
|||||||
Reference in New Issue
Block a user