diff --git a/backend/alembic/versions/6e8128ef2aa8_add_batch_id_and_target_label_to_.py b/backend/alembic/versions/6e8128ef2aa8_add_batch_id_and_target_label_to_.py new file mode 100644 index 00000000..ee1b9fc2 --- /dev/null +++ b/backend/alembic/versions/6e8128ef2aa8_add_batch_id_and_target_label_to_.py @@ -0,0 +1,28 @@ +"""add batch_id and target_label to sessions + +Revision ID: 6e8128ef2aa8 +Revises: 5812e7df852f +Create Date: 2026-02-17 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '6e8128ef2aa8' +down_revision = '5812e7df852f' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('sessions', sa.Column('batch_id', postgresql.UUID(as_uuid=True), nullable=True)) + op.add_column('sessions', sa.Column('target_label', sa.String(255), nullable=True)) + op.create_index('ix_sessions_batch_id', 'sessions', ['batch_id']) + + +def downgrade() -> None: + op.drop_index('ix_sessions_batch_id', table_name='sessions') + op.drop_column('sessions', 'target_label') + op.drop_column('sessions', 'batch_id') diff --git a/backend/app/api/endpoints/sessions.py b/backend/app/api/endpoints/sessions.py index 08b45875..912c03cc 100644 --- a/backend/app/api/endpoints/sessions.py +++ b/backend/app/api/endpoints/sessions.py @@ -485,3 +485,83 @@ async def save_session_as_tree( tree_name=new_tree.name, message=f"Session saved as {'published' if request_data.status == 'published' else 'draft'} tree" ) + + +# ── Batch Launch (Maintenance Flows) ────────────────────────────────────── + +from pydantic import BaseModel, Field as PydanticField +import uuid as uuid_mod + + +class _BatchTarget(BaseModel): + label: str = PydanticField(..., min_length=1, max_length=255) + notes: Optional[str] = None + + +class _BatchLaunchRequest(BaseModel): + tree_id: UUID + targets: list[_BatchTarget] = PydanticField(..., min_length=1) + + +class _BatchLaunchResponse(BaseModel): + batch_id: str + count: int + sessions: list[dict] + + +@router.post("/batch", status_code=201, response_model=_BatchLaunchResponse) +async def batch_launch_sessions( + data: _BatchLaunchRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Create one session per target for a maintenance flow batch run.""" + tree_result = await db.execute(select(Tree).where(Tree.id == data.tree_id)) + tree = tree_result.scalar_one_or_none() + if not tree: + raise HTTPException(status_code=404, detail="Tree not found") + if tree.tree_type != "maintenance": + raise HTTPException(status_code=400, detail="Batch launch is only for maintenance flows") + + batch_id = uuid_mod.uuid4() + created_sessions = [] + + for target in data.targets: + tree_snapshot = { + **tree.tree_structure, + "name": tree.name, + "description": tree.description, + "tree_type": tree.tree_type, + } + session = Session( + tree_id=tree.id, + user_id=current_user.id, + tree_snapshot=tree_snapshot, + path_taken=[], + decisions=[], + custom_steps=[], + session_variables={}, + batch_id=batch_id, + target_label=target.label, + ) + db.add(session) + created_sessions.append(session) + + await db.flush() + for s in created_sessions: + await db.refresh(s) + await db.commit() + + return _BatchLaunchResponse( + batch_id=str(batch_id), + count=len(created_sessions), + sessions=[ + { + "id": str(s.id), + "batch_id": str(s.batch_id), + "target_label": s.target_label, + "tree_id": str(s.tree_id), + } + for s in created_sessions + ], + ) diff --git a/backend/app/models/session.py b/backend/app/models/session.py index 91e40a50..f5085c81 100644 --- a/backend/app/models/session.py +++ b/backend/app/models/session.py @@ -65,3 +65,11 @@ class Session(Base): user: Mapped["User"] = relationship("User", back_populates="sessions") attachments: Mapped[list["Attachment"]] = relationship("Attachment", back_populates="session") shares: Mapped[list["SessionShare"]] = relationship("SessionShare", back_populates="session", cascade="all, delete-orphan") + + # Batch tracking (maintenance flows) + batch_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), nullable=True, index=True + ) + target_label: Mapped[Optional[str]] = mapped_column( + String(255), nullable=True + ) diff --git a/backend/app/schemas/session.py b/backend/app/schemas/session.py index 4f00f618..3aefa3b7 100644 --- a/backend/app/schemas/session.py +++ b/backend/app/schemas/session.py @@ -78,6 +78,9 @@ class SessionResponse(BaseModel): def normalize_text_fields(cls, v): return v or "" + batch_id: Optional[UUID] = None + target_label: Optional[str] = None + class Config: from_attributes = True diff --git a/backend/tests/test_batch_sessions.py b/backend/tests/test_batch_sessions.py new file mode 100644 index 00000000..41938b58 --- /dev/null +++ b/backend/tests/test_batch_sessions.py @@ -0,0 +1,92 @@ +"""Tests for batch session launching (maintenance flows).""" +import pytest +from httpx import AsyncClient + + +async def _create_maintenance_tree(client, headers): + """Helper: create a published maintenance tree.""" + resp = await client.post( + "/api/v1/trees", + json={ + "name": "Patch RDS Servers", + "tree_type": "maintenance", + "tree_structure": { + "steps": [ + {"id": "s1", "type": "procedure_step", "title": "Install patch", + "description": "Run installer", "content_type": "action"}, + {"id": "end", "type": "procedure_end", "title": "Done"}, + ] + }, + }, + headers=headers, + ) + assert resp.status_code == 201, resp.text + return resp.json()["id"] + + +@pytest.mark.asyncio +async def test_batch_launch_creates_one_session_per_target(client: AsyncClient, auth_headers: dict): + tree_id = await _create_maintenance_tree(client, auth_headers) + resp = await client.post( + "/api/v1/sessions/batch", + json={ + "tree_id": tree_id, + "targets": [ + {"label": "RDS-01", "notes": "192.168.1.10"}, + {"label": "RDS-02", "notes": "192.168.1.11"}, + {"label": "RDS-03"}, + ], + }, + headers=auth_headers, + ) + assert resp.status_code == 201, resp.text + data = resp.json() + assert data["count"] == 3 + assert data["batch_id"] is not None + sessions = data["sessions"] + assert len(sessions) == 3 + # All share the same batch_id + batch_ids = {s["batch_id"] for s in sessions} + assert len(batch_ids) == 1 + # Each has a distinct target_label + labels = {s["target_label"] for s in sessions} + assert labels == {"RDS-01", "RDS-02", "RDS-03"} + + +@pytest.mark.asyncio +async def test_batch_launch_rejects_empty_targets(client: AsyncClient, auth_headers: dict): + tree_id = await _create_maintenance_tree(client, auth_headers) + resp = await client.post( + "/api/v1/sessions/batch", + json={"tree_id": tree_id, "targets": []}, + headers=auth_headers, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_batch_launch_rejects_non_maintenance_tree(client: AsyncClient, auth_headers: dict): + """Batch launch only works for maintenance flows.""" + # Create a procedural tree + resp = await client.post( + "/api/v1/trees", + json={ + "name": "Regular Project", + "tree_type": "procedural", + "tree_structure": { + "steps": [ + {"id": "s1", "type": "procedure_step", "title": "Step", + "description": "Do it", "content_type": "action"}, + {"id": "end", "type": "procedure_end", "title": "Done"}, + ] + }, + }, + headers=auth_headers, + ) + tree_id = resp.json()["id"] + batch_resp = await client.post( + "/api/v1/sessions/batch", + json={"tree_id": tree_id, "targets": [{"label": "SRV-01"}]}, + headers=auth_headers, + ) + assert batch_resp.status_code == 400