feat: add batch_id/target_label to sessions and batch launch endpoint

- Add batch_id (UUID, nullable, indexed) and target_label (String 255,
  nullable) columns to the Session model
- Manual Alembic migration 6e8128ef2aa8 applies both columns
- POST /sessions/batch creates one session per target for maintenance
  flows; rejects empty targets (422) and non-maintenance trees (400)
- SessionResponse schema exposes batch_id and target_label
- 3 new integration tests, all 540 tests pass

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-02-17 11:48:08 -05:00
parent adcdf39d35
commit b78a50c8c5
5 changed files with 211 additions and 0 deletions

View File

@@ -0,0 +1,28 @@
"""add batch_id and target_label to sessions
Revision ID: 6e8128ef2aa8
Revises: 5812e7df852f
Create Date: 2026-02-17
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6e8128ef2aa8'
down_revision = '5812e7df852f'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('sessions', sa.Column('batch_id', postgresql.UUID(as_uuid=True), nullable=True))
op.add_column('sessions', sa.Column('target_label', sa.String(255), nullable=True))
op.create_index('ix_sessions_batch_id', 'sessions', ['batch_id'])
def downgrade() -> None:
op.drop_index('ix_sessions_batch_id', table_name='sessions')
op.drop_column('sessions', 'target_label')
op.drop_column('sessions', 'batch_id')

View File

@@ -485,3 +485,83 @@ async def save_session_as_tree(
tree_name=new_tree.name,
message=f"Session saved as {'published' if request_data.status == 'published' else 'draft'} tree"
)
# ── Batch Launch (Maintenance Flows) ──────────────────────────────────────
from pydantic import BaseModel, Field as PydanticField
import uuid as uuid_mod
class _BatchTarget(BaseModel):
label: str = PydanticField(..., min_length=1, max_length=255)
notes: Optional[str] = None
class _BatchLaunchRequest(BaseModel):
tree_id: UUID
targets: list[_BatchTarget] = PydanticField(..., min_length=1)
class _BatchLaunchResponse(BaseModel):
batch_id: str
count: int
sessions: list[dict]
@router.post("/batch", status_code=201, response_model=_BatchLaunchResponse)
async def batch_launch_sessions(
data: _BatchLaunchRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Create one session per target for a maintenance flow batch run."""
tree_result = await db.execute(select(Tree).where(Tree.id == data.tree_id))
tree = tree_result.scalar_one_or_none()
if not tree:
raise HTTPException(status_code=404, detail="Tree not found")
if tree.tree_type != "maintenance":
raise HTTPException(status_code=400, detail="Batch launch is only for maintenance flows")
batch_id = uuid_mod.uuid4()
created_sessions = []
for target in data.targets:
tree_snapshot = {
**tree.tree_structure,
"name": tree.name,
"description": tree.description,
"tree_type": tree.tree_type,
}
session = Session(
tree_id=tree.id,
user_id=current_user.id,
tree_snapshot=tree_snapshot,
path_taken=[],
decisions=[],
custom_steps=[],
session_variables={},
batch_id=batch_id,
target_label=target.label,
)
db.add(session)
created_sessions.append(session)
await db.flush()
for s in created_sessions:
await db.refresh(s)
await db.commit()
return _BatchLaunchResponse(
batch_id=str(batch_id),
count=len(created_sessions),
sessions=[
{
"id": str(s.id),
"batch_id": str(s.batch_id),
"target_label": s.target_label,
"tree_id": str(s.tree_id),
}
for s in created_sessions
],
)

View File

@@ -65,3 +65,11 @@ class Session(Base):
user: Mapped["User"] = relationship("User", back_populates="sessions")
attachments: Mapped[list["Attachment"]] = relationship("Attachment", back_populates="session")
shares: Mapped[list["SessionShare"]] = relationship("SessionShare", back_populates="session", cascade="all, delete-orphan")
# Batch tracking (maintenance flows)
batch_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True), nullable=True, index=True
)
target_label: Mapped[Optional[str]] = mapped_column(
String(255), nullable=True
)

View File

@@ -78,6 +78,9 @@ class SessionResponse(BaseModel):
def normalize_text_fields(cls, v):
return v or ""
batch_id: Optional[UUID] = None
target_label: Optional[str] = None
class Config:
from_attributes = True

View File

@@ -0,0 +1,92 @@
"""Tests for batch session launching (maintenance flows)."""
import pytest
from httpx import AsyncClient
async def _create_maintenance_tree(client, headers):
"""Helper: create a published maintenance tree."""
resp = await client.post(
"/api/v1/trees",
json={
"name": "Patch RDS Servers",
"tree_type": "maintenance",
"tree_structure": {
"steps": [
{"id": "s1", "type": "procedure_step", "title": "Install patch",
"description": "Run installer", "content_type": "action"},
{"id": "end", "type": "procedure_end", "title": "Done"},
]
},
},
headers=headers,
)
assert resp.status_code == 201, resp.text
return resp.json()["id"]
@pytest.mark.asyncio
async def test_batch_launch_creates_one_session_per_target(client: AsyncClient, auth_headers: dict):
tree_id = await _create_maintenance_tree(client, auth_headers)
resp = await client.post(
"/api/v1/sessions/batch",
json={
"tree_id": tree_id,
"targets": [
{"label": "RDS-01", "notes": "192.168.1.10"},
{"label": "RDS-02", "notes": "192.168.1.11"},
{"label": "RDS-03"},
],
},
headers=auth_headers,
)
assert resp.status_code == 201, resp.text
data = resp.json()
assert data["count"] == 3
assert data["batch_id"] is not None
sessions = data["sessions"]
assert len(sessions) == 3
# All share the same batch_id
batch_ids = {s["batch_id"] for s in sessions}
assert len(batch_ids) == 1
# Each has a distinct target_label
labels = {s["target_label"] for s in sessions}
assert labels == {"RDS-01", "RDS-02", "RDS-03"}
@pytest.mark.asyncio
async def test_batch_launch_rejects_empty_targets(client: AsyncClient, auth_headers: dict):
tree_id = await _create_maintenance_tree(client, auth_headers)
resp = await client.post(
"/api/v1/sessions/batch",
json={"tree_id": tree_id, "targets": []},
headers=auth_headers,
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_batch_launch_rejects_non_maintenance_tree(client: AsyncClient, auth_headers: dict):
"""Batch launch only works for maintenance flows."""
# Create a procedural tree
resp = await client.post(
"/api/v1/trees",
json={
"name": "Regular Project",
"tree_type": "procedural",
"tree_structure": {
"steps": [
{"id": "s1", "type": "procedure_step", "title": "Step",
"description": "Do it", "content_type": "action"},
{"id": "end", "type": "procedure_end", "title": "Done"},
]
},
},
headers=auth_headers,
)
tree_id = resp.json()["id"]
batch_resp = await client.post(
"/api/v1/sessions/batch",
json={"tree_id": tree_id, "targets": [{"label": "SRV-01"}]},
headers=auth_headers,
)
assert batch_resp.status_code == 400