feat: add batch_id/target_label to sessions and batch launch endpoint
- Add batch_id (UUID, nullable, indexed) and target_label (String 255, nullable) columns to the Session model - Manual Alembic migration 6e8128ef2aa8 applies both columns - POST /sessions/batch creates one session per target for maintenance flows; rejects empty targets (422) and non-maintenance trees (400) - SessionResponse schema exposes batch_id and target_label - 3 new integration tests, all 540 tests pass Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
"""add batch_id and target_label to sessions
|
||||
|
||||
Revision ID: 6e8128ef2aa8
|
||||
Revises: 5812e7df852f
|
||||
Create Date: 2026-02-17
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '6e8128ef2aa8'
|
||||
down_revision = '5812e7df852f'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('sessions', sa.Column('batch_id', postgresql.UUID(as_uuid=True), nullable=True))
|
||||
op.add_column('sessions', sa.Column('target_label', sa.String(255), nullable=True))
|
||||
op.create_index('ix_sessions_batch_id', 'sessions', ['batch_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_sessions_batch_id', table_name='sessions')
|
||||
op.drop_column('sessions', 'target_label')
|
||||
op.drop_column('sessions', 'batch_id')
|
||||
@@ -485,3 +485,83 @@ async def save_session_as_tree(
|
||||
tree_name=new_tree.name,
|
||||
message=f"Session saved as {'published' if request_data.status == 'published' else 'draft'} tree"
|
||||
)
|
||||
|
||||
|
||||
# ── Batch Launch (Maintenance Flows) ──────────────────────────────────────
|
||||
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
import uuid as uuid_mod
|
||||
|
||||
|
||||
class _BatchTarget(BaseModel):
|
||||
label: str = PydanticField(..., min_length=1, max_length=255)
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class _BatchLaunchRequest(BaseModel):
|
||||
tree_id: UUID
|
||||
targets: list[_BatchTarget] = PydanticField(..., min_length=1)
|
||||
|
||||
|
||||
class _BatchLaunchResponse(BaseModel):
|
||||
batch_id: str
|
||||
count: int
|
||||
sessions: list[dict]
|
||||
|
||||
|
||||
@router.post("/batch", status_code=201, response_model=_BatchLaunchResponse)
|
||||
async def batch_launch_sessions(
|
||||
data: _BatchLaunchRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Create one session per target for a maintenance flow batch run."""
|
||||
tree_result = await db.execute(select(Tree).where(Tree.id == data.tree_id))
|
||||
tree = tree_result.scalar_one_or_none()
|
||||
if not tree:
|
||||
raise HTTPException(status_code=404, detail="Tree not found")
|
||||
if tree.tree_type != "maintenance":
|
||||
raise HTTPException(status_code=400, detail="Batch launch is only for maintenance flows")
|
||||
|
||||
batch_id = uuid_mod.uuid4()
|
||||
created_sessions = []
|
||||
|
||||
for target in data.targets:
|
||||
tree_snapshot = {
|
||||
**tree.tree_structure,
|
||||
"name": tree.name,
|
||||
"description": tree.description,
|
||||
"tree_type": tree.tree_type,
|
||||
}
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=current_user.id,
|
||||
tree_snapshot=tree_snapshot,
|
||||
path_taken=[],
|
||||
decisions=[],
|
||||
custom_steps=[],
|
||||
session_variables={},
|
||||
batch_id=batch_id,
|
||||
target_label=target.label,
|
||||
)
|
||||
db.add(session)
|
||||
created_sessions.append(session)
|
||||
|
||||
await db.flush()
|
||||
for s in created_sessions:
|
||||
await db.refresh(s)
|
||||
await db.commit()
|
||||
|
||||
return _BatchLaunchResponse(
|
||||
batch_id=str(batch_id),
|
||||
count=len(created_sessions),
|
||||
sessions=[
|
||||
{
|
||||
"id": str(s.id),
|
||||
"batch_id": str(s.batch_id),
|
||||
"target_label": s.target_label,
|
||||
"tree_id": str(s.tree_id),
|
||||
}
|
||||
for s in created_sessions
|
||||
],
|
||||
)
|
||||
|
||||
@@ -65,3 +65,11 @@ class Session(Base):
|
||||
user: Mapped["User"] = relationship("User", back_populates="sessions")
|
||||
attachments: Mapped[list["Attachment"]] = relationship("Attachment", back_populates="session")
|
||||
shares: Mapped[list["SessionShare"]] = relationship("SessionShare", back_populates="session", cascade="all, delete-orphan")
|
||||
|
||||
# Batch tracking (maintenance flows)
|
||||
batch_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True), nullable=True, index=True
|
||||
)
|
||||
target_label: Mapped[Optional[str]] = mapped_column(
|
||||
String(255), nullable=True
|
||||
)
|
||||
|
||||
@@ -78,6 +78,9 @@ class SessionResponse(BaseModel):
|
||||
def normalize_text_fields(cls, v):
|
||||
return v or ""
|
||||
|
||||
batch_id: Optional[UUID] = None
|
||||
target_label: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
92
backend/tests/test_batch_sessions.py
Normal file
92
backend/tests/test_batch_sessions.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for batch session launching (maintenance flows)."""
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
async def _create_maintenance_tree(client, headers):
|
||||
"""Helper: create a published maintenance tree."""
|
||||
resp = await client.post(
|
||||
"/api/v1/trees",
|
||||
json={
|
||||
"name": "Patch RDS Servers",
|
||||
"tree_type": "maintenance",
|
||||
"tree_structure": {
|
||||
"steps": [
|
||||
{"id": "s1", "type": "procedure_step", "title": "Install patch",
|
||||
"description": "Run installer", "content_type": "action"},
|
||||
{"id": "end", "type": "procedure_end", "title": "Done"},
|
||||
]
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
return resp.json()["id"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_launch_creates_one_session_per_target(client: AsyncClient, auth_headers: dict):
|
||||
tree_id = await _create_maintenance_tree(client, auth_headers)
|
||||
resp = await client.post(
|
||||
"/api/v1/sessions/batch",
|
||||
json={
|
||||
"tree_id": tree_id,
|
||||
"targets": [
|
||||
{"label": "RDS-01", "notes": "192.168.1.10"},
|
||||
{"label": "RDS-02", "notes": "192.168.1.11"},
|
||||
{"label": "RDS-03"},
|
||||
],
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
data = resp.json()
|
||||
assert data["count"] == 3
|
||||
assert data["batch_id"] is not None
|
||||
sessions = data["sessions"]
|
||||
assert len(sessions) == 3
|
||||
# All share the same batch_id
|
||||
batch_ids = {s["batch_id"] for s in sessions}
|
||||
assert len(batch_ids) == 1
|
||||
# Each has a distinct target_label
|
||||
labels = {s["target_label"] for s in sessions}
|
||||
assert labels == {"RDS-01", "RDS-02", "RDS-03"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_launch_rejects_empty_targets(client: AsyncClient, auth_headers: dict):
|
||||
tree_id = await _create_maintenance_tree(client, auth_headers)
|
||||
resp = await client.post(
|
||||
"/api/v1/sessions/batch",
|
||||
json={"tree_id": tree_id, "targets": []},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_launch_rejects_non_maintenance_tree(client: AsyncClient, auth_headers: dict):
|
||||
"""Batch launch only works for maintenance flows."""
|
||||
# Create a procedural tree
|
||||
resp = await client.post(
|
||||
"/api/v1/trees",
|
||||
json={
|
||||
"name": "Regular Project",
|
||||
"tree_type": "procedural",
|
||||
"tree_structure": {
|
||||
"steps": [
|
||||
{"id": "s1", "type": "procedure_step", "title": "Step",
|
||||
"description": "Do it", "content_type": "action"},
|
||||
{"id": "end", "type": "procedure_end", "title": "Done"},
|
||||
]
|
||||
},
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
tree_id = resp.json()["id"]
|
||||
batch_resp = await client.post(
|
||||
"/api/v1/sessions/batch",
|
||||
json={"tree_id": tree_id, "targets": [{"label": "SRV-01"}]},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert batch_resp.status_code == 400
|
||||
Reference in New Issue
Block a user