"""Tests for batch session launching (maintenance flows).""" import pytest from httpx import AsyncClient async def _create_maintenance_tree(client, headers): """Helper: create a published maintenance tree.""" resp = await client.post( "/api/v1/trees", json={ "name": "Patch RDS Servers", "tree_type": "maintenance", "tree_structure": { "steps": [ {"id": "s1", "type": "procedure_step", "title": "Install patch", "description": "Run installer", "content_type": "action"}, {"id": "end", "type": "procedure_end", "title": "Done"}, ] }, }, headers=headers, ) assert resp.status_code == 201, resp.text return resp.json()["id"] @pytest.mark.asyncio async def test_batch_launch_creates_one_session_per_target(client: AsyncClient, auth_headers: dict): tree_id = await _create_maintenance_tree(client, auth_headers) resp = await client.post( "/api/v1/sessions/batch", json={ "tree_id": tree_id, "targets": [ {"label": "RDS-01"}, {"label": "RDS-02"}, {"label": "RDS-03"}, ], }, headers=auth_headers, ) assert resp.status_code == 201, resp.text data = resp.json() assert data["count"] == 3 assert data["batch_id"] is not None sessions = data["sessions"] assert len(sessions) == 3 # All share the same batch_id batch_ids = {s["batch_id"] for s in sessions} assert len(batch_ids) == 1 # Each has a distinct target_label labels = {s["target_label"] for s in sessions} assert labels == {"RDS-01", "RDS-02", "RDS-03"} @pytest.mark.asyncio async def test_batch_launch_rejects_empty_targets(client: AsyncClient, auth_headers: dict): tree_id = await _create_maintenance_tree(client, auth_headers) resp = await client.post( "/api/v1/sessions/batch", json={"tree_id": tree_id, "targets": []}, headers=auth_headers, ) assert resp.status_code == 422 @pytest.mark.asyncio async def test_batch_launch_rejects_non_maintenance_tree(client: AsyncClient, auth_headers: dict): """Batch launch only works for maintenance flows.""" # Create a procedural tree resp = await client.post( "/api/v1/trees", json={ "name": "Regular Project", "tree_type": "procedural", "tree_structure": { "steps": [ {"id": "s1", "type": "procedure_step", "title": "Step", "description": "Do it", "content_type": "action"}, {"id": "end", "type": "procedure_end", "title": "Done"}, ] }, }, headers=auth_headers, ) tree_id = resp.json()["id"] batch_resp = await client.post( "/api/v1/sessions/batch", json={"tree_id": tree_id, "targets": [{"label": "SRV-01"}]}, headers=auth_headers, ) assert batch_resp.status_code == 400 @pytest.mark.asyncio async def test_batch_launch_requires_auth(client: AsyncClient): """Unauthenticated batch launch returns 401.""" resp = await client.post( "/api/v1/sessions/batch", json={"tree_id": "00000000-0000-0000-0000-000000000000", "targets": [{"label": "SRV-01"}]}, ) assert resp.status_code == 401 @pytest.mark.asyncio async def test_batch_launch_rejects_draft_tree(client: AsyncClient, auth_headers: dict): """Batch launch against a draft maintenance tree returns 400.""" # Create a maintenance tree — trees default to 'published', so we explicitly set draft resp = await client.post( "/api/v1/trees", json={ "name": "Draft Maintenance", "tree_type": "maintenance", "status": "draft", "tree_structure": { "steps": [ {"id": "s1", "type": "procedure_step", "title": "Step", "description": "Do it", "content_type": "action"}, {"id": "end", "type": "procedure_end", "title": "Done"}, ] }, }, headers=auth_headers, ) assert resp.status_code == 201, resp.text tree_id = resp.json()["id"] batch_resp = await client.post( "/api/v1/sessions/batch", json={"tree_id": tree_id, "targets": [{"label": "SRV-01"}]}, headers=auth_headers, ) assert batch_resp.status_code == 400