diff --git a/backend/alembic/versions/0f1ca2af3647_add_maintenance_tree_type.py b/backend/alembic/versions/0f1ca2af3647_add_maintenance_tree_type.py new file mode 100644 index 00000000..dd29ec95 --- /dev/null +++ b/backend/alembic/versions/0f1ca2af3647_add_maintenance_tree_type.py @@ -0,0 +1,33 @@ +"""add maintenance tree type + +Revision ID: 0f1ca2af3647 +Revises: 039 +Create Date: 2026-02-17 10:25:54.959861 + +""" +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '0f1ca2af3647' +down_revision: Union[str, None] = '039' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute("ALTER TABLE trees DROP CONSTRAINT ck_trees_tree_type") + op.execute( + "ALTER TABLE trees ADD CONSTRAINT ck_trees_tree_type " + "CHECK (tree_type IN ('troubleshooting', 'procedural', 'maintenance'))" + ) + + +def downgrade() -> None: + op.execute("UPDATE trees SET tree_type = 'procedural' WHERE tree_type = 'maintenance'") + op.execute("ALTER TABLE trees DROP CONSTRAINT ck_trees_tree_type") + op.execute( + "ALTER TABLE trees ADD CONSTRAINT ck_trees_tree_type " + "CHECK (tree_type IN ('troubleshooting', 'procedural'))" + ) diff --git a/backend/alembic/versions/0fd2a90a9c2c_add_maintenance_schedules_table.py b/backend/alembic/versions/0fd2a90a9c2c_add_maintenance_schedules_table.py new file mode 100644 index 00000000..14e518b3 --- /dev/null +++ b/backend/alembic/versions/0fd2a90a9c2c_add_maintenance_schedules_table.py @@ -0,0 +1,44 @@ +"""add maintenance_schedules table + +Revision ID: 0fd2a90a9c2c +Revises: 6e8128ef2aa8 +Create Date: 2026-02-17 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '0fd2a90a9c2c' +down_revision = '6e8128ef2aa8' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + 'maintenance_schedules', + sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column('tree_id', postgresql.UUID(as_uuid=True), + sa.ForeignKey('trees.id', ondelete='CASCADE'), nullable=False), + sa.Column('created_by', postgresql.UUID(as_uuid=True), + sa.ForeignKey('users.id', ondelete='SET NULL'), nullable=True), + sa.Column('cron_expression', sa.String(100), nullable=False), + sa.Column('timezone', sa.String(100), nullable=False, server_default='UTC'), + sa.Column('target_list_id', postgresql.UUID(as_uuid=True), + sa.ForeignKey('target_lists.id', ondelete='SET NULL'), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), + sa.Column('next_run_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('last_run_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), + ) + op.create_index('ix_maintenance_schedules_tree_id', 'maintenance_schedules', ['tree_id']) + op.create_unique_constraint('uq_maintenance_schedules_tree_id', 'maintenance_schedules', ['tree_id']) + + +def downgrade() -> None: + op.drop_constraint('uq_maintenance_schedules_tree_id', 'maintenance_schedules', type_='unique') + op.drop_index('ix_maintenance_schedules_tree_id', table_name='maintenance_schedules') + op.drop_table('maintenance_schedules') diff --git a/backend/alembic/versions/5812e7df852f_add_target_lists_table.py b/backend/alembic/versions/5812e7df852f_add_target_lists_table.py new file mode 100644 index 00000000..353496e9 --- /dev/null +++ b/backend/alembic/versions/5812e7df852f_add_target_lists_table.py @@ -0,0 +1,39 @@ +"""add target_lists table + +Revision ID: 5812e7df852f +Revises: 0f1ca2af3647 +Create Date: 2026-02-17 11:20:42.919564 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = '5812e7df852f' +down_revision: Union[str, None] = '0f1ca2af3647' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'target_lists', + sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column('team_id', postgresql.UUID(as_uuid=True), sa.ForeignKey('teams.id', ondelete='CASCADE'), nullable=False), + sa.Column('created_by', postgresql.UUID(as_uuid=True), sa.ForeignKey('users.id', ondelete='SET NULL'), nullable=True), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('targets', postgresql.JSONB(), nullable=False, server_default='[]'), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), + ) + op.create_index('ix_target_lists_team_id', 'target_lists', ['team_id']) + + +def downgrade() -> None: + op.drop_index('ix_target_lists_team_id', table_name='target_lists') + op.drop_table('target_lists') diff --git a/backend/alembic/versions/6e8128ef2aa8_add_batch_id_and_target_label_to_.py b/backend/alembic/versions/6e8128ef2aa8_add_batch_id_and_target_label_to_.py new file mode 100644 index 00000000..ee1b9fc2 --- /dev/null +++ b/backend/alembic/versions/6e8128ef2aa8_add_batch_id_and_target_label_to_.py @@ -0,0 +1,28 @@ +"""add batch_id and target_label to sessions + +Revision ID: 6e8128ef2aa8 +Revises: 5812e7df852f +Create Date: 2026-02-17 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '6e8128ef2aa8' +down_revision = '5812e7df852f' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('sessions', sa.Column('batch_id', postgresql.UUID(as_uuid=True), nullable=True)) + op.add_column('sessions', sa.Column('target_label', sa.String(255), nullable=True)) + op.create_index('ix_sessions_batch_id', 'sessions', ['batch_id']) + + +def downgrade() -> None: + op.drop_index('ix_sessions_batch_id', table_name='sessions') + op.drop_column('sessions', 'target_label') + op.drop_column('sessions', 'batch_id') diff --git a/backend/app/api/endpoints/maintenance_schedules.py b/backend/app/api/endpoints/maintenance_schedules.py new file mode 100644 index 00000000..581df52d --- /dev/null +++ b/backend/app/api/endpoints/maintenance_schedules.py @@ -0,0 +1,153 @@ +"""Maintenance schedule CRUD endpoints.""" +from typing import Annotated +from uuid import UUID +from datetime import datetime, timezone +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from croniter import croniter +import pytz + +from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin +from app.models.maintenance_schedule import MaintenanceSchedule +from app.models.tree import Tree +from app.models.user import User +from app.schemas.maintenance_schedule import ( + MaintenanceScheduleCreate, + MaintenanceScheduleUpdate, + MaintenanceScheduleResponse, +) + +router = APIRouter(prefix="/maintenance-schedules", tags=["maintenance-schedules"]) + + +def _compute_next_run(cron_expression: str, tz_name: str) -> datetime: + """Compute next run time from cron expression, returned as UTC.""" + tz = pytz.timezone(tz_name) + now = datetime.now(tz) + cron = croniter(cron_expression, now) + return cron.get_next(datetime).astimezone(timezone.utc) + + +async def _get_tree_or_403(tree_id: UUID, current_user: User, db: AsyncSession) -> "Tree": + """Fetch tree and verify the current user's team owns it.""" + result = await db.execute(select(Tree).where(Tree.id == tree_id)) + tree = result.scalar_one_or_none() + if not tree: + raise HTTPException(status_code=404, detail="Tree not found") + # Super admins can access any tree; regular users must be on the same team + if not getattr(current_user, 'is_super_admin', False): + if tree.team_id != current_user.team_id: + raise HTTPException(status_code=403, detail="Access denied") + return tree + + +@router.post("", response_model=MaintenanceScheduleResponse, status_code=201) +async def create_schedule( + data: MaintenanceScheduleCreate, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Create a cron schedule for a maintenance flow. One per flow.""" + # Verify user's team owns the tree + await _get_tree_or_403(data.tree_id, current_user, db) + + # Check no existing schedule for this tree + existing = await db.execute( + select(MaintenanceSchedule).where(MaintenanceSchedule.tree_id == data.tree_id) + ) + if existing.scalar_one_or_none(): + raise HTTPException(status_code=409, detail="Schedule already exists for this tree") + + try: + next_run = _compute_next_run(data.cron_expression, data.timezone) + except (ValueError, KeyError) as e: + raise HTTPException(status_code=422, detail=f"Invalid cron expression or timezone: {e}") + + schedule = MaintenanceSchedule( + tree_id=data.tree_id, + created_by=current_user.id, + cron_expression=data.cron_expression, + timezone=data.timezone, + target_list_id=data.target_list_id, + is_active=True, + next_run_at=next_run, + ) + db.add(schedule) + await db.commit() + await db.refresh(schedule) + + from app.core.scheduler import register_schedule + register_schedule(schedule) + + return schedule + + +@router.get("/tree/{tree_id}", response_model=MaintenanceScheduleResponse) +async def get_schedule_for_tree( + tree_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Get the schedule for a specific maintenance flow.""" + # Verify user's team owns the tree before returning schedule data + await _get_tree_or_403(tree_id, current_user, db) + + result = await db.execute( + select(MaintenanceSchedule).where(MaintenanceSchedule.tree_id == tree_id) + ) + schedule = result.scalar_one_or_none() + if not schedule: + raise HTTPException(status_code=404, detail="No schedule found for this tree") + return schedule + + +@router.patch("/{schedule_id}", response_model=MaintenanceScheduleResponse) +async def update_schedule( + schedule_id: UUID, + data: MaintenanceScheduleUpdate, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Update a schedule (disable, change cron, change timezone, change target list).""" + result = await db.execute( + select(MaintenanceSchedule).where(MaintenanceSchedule.id == schedule_id) + ) + schedule = result.scalar_one_or_none() + if not schedule: + raise HTTPException(status_code=404, detail="Schedule not found") + + # Verify user's team owns the tree this schedule belongs to + await _get_tree_or_403(schedule.tree_id, current_user, db) + + update_fields = data.model_fields_set + was_active = schedule.is_active + if "cron_expression" in update_fields and data.cron_expression is not None: + schedule.cron_expression = data.cron_expression + if "timezone" in update_fields and data.timezone is not None: + schedule.timezone = data.timezone + if "target_list_id" in update_fields: + schedule.target_list_id = data.target_list_id + if "is_active" in update_fields and data.is_active is not None: + schedule.is_active = data.is_active + + # Recompute next_run_at if schedule timing changed or schedule is being re-activated + reactivating = "is_active" in update_fields and data.is_active is True and not was_active + if "cron_expression" in update_fields or "timezone" in update_fields or reactivating: + try: + schedule.next_run_at = _compute_next_run(schedule.cron_expression, schedule.timezone) + except (ValueError, KeyError) as e: + raise HTTPException(status_code=422, detail=f"Invalid cron expression or timezone: {e}") + + await db.commit() + await db.refresh(schedule) + + from app.core.scheduler import register_schedule, unregister_schedule + if schedule.is_active: + register_schedule(schedule) + else: + unregister_schedule(str(schedule.id)) + + return schedule diff --git a/backend/app/api/endpoints/sessions.py b/backend/app/api/endpoints/sessions.py index 08b45875..38c3976a 100644 --- a/backend/app/api/endpoints/sessions.py +++ b/backend/app/api/endpoints/sessions.py @@ -1,8 +1,10 @@ from datetime import datetime, timezone from typing import Annotated, Optional from uuid import UUID +import uuid from fastapi import APIRouter, Depends, HTTPException, status, Query from fastapi.responses import PlainTextResponse +from pydantic import BaseModel, Field as PydanticField from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select @@ -485,3 +487,95 @@ async def save_session_as_tree( tree_name=new_tree.name, message=f"Session saved as {'published' if request_data.status == 'published' else 'draft'} tree" ) + + +# ── Batch Launch (Maintenance Flows) ────────────────────────────────────── + + +class _BatchTarget(BaseModel): + label: str = PydanticField(..., min_length=1, max_length=255) + + +class _BatchLaunchRequest(BaseModel): + tree_id: UUID + targets: list[_BatchTarget] = PydanticField(..., min_length=1, max_length=100) + + +class _BatchLaunchResponse(BaseModel): + batch_id: str + count: int + sessions: list[dict] + + +@router.post("/batch", status_code=201, response_model=_BatchLaunchResponse) +async def batch_launch_sessions( + data: _BatchLaunchRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Create one session per target for a maintenance flow batch run.""" + tree_result = await db.execute(select(Tree).where(Tree.id == data.tree_id)) + tree = tree_result.scalar_one_or_none() + if not tree: + raise HTTPException(status_code=404, detail="Tree not found") + + if not can_access_tree(current_user, tree): + raise HTTPException(status_code=403, detail="Access denied") + + if not tree.is_active: + raise HTTPException(status_code=400, detail="Cannot batch-launch an inactive flow") + + if tree.status == 'draft': + raise HTTPException(status_code=400, detail="Cannot batch-launch a draft flow") + + if not current_user.is_super_admin and tree.team_id != current_user.team_id: + raise HTTPException(status_code=403, detail="Access denied") + + if tree.tree_type != "maintenance": + raise HTTPException(status_code=400, detail="Batch launch is only for maintenance flows") + + batch_id = uuid.uuid4() + created_sessions = [] + + # Hoist snapshot computation out of the loop — same tree for all targets + tree_snapshot = { + **tree.tree_structure, + "name": tree.name, + "description": tree.description, + "tree_type": tree.tree_type, + } + + for target in data.targets: + session = Session( + tree_id=tree.id, + user_id=current_user.id, + tree_snapshot=tree_snapshot, + path_taken=[], + decisions=[], + custom_steps=[], + session_variables={}, + batch_id=batch_id, + target_label=target.label, + ) + db.add(session) + created_sessions.append(session) + + await db.flush() + session_ids = [s.id for s in created_sessions] + result = await db.execute(select(Session).where(Session.id.in_(session_ids))) + created_sessions = result.scalars().all() + await db.commit() + + return _BatchLaunchResponse( + batch_id=str(batch_id), + count=len(created_sessions), + sessions=[ + { + "id": str(s.id), + "batch_id": str(s.batch_id), + "target_label": s.target_label, + "tree_id": str(s.tree_id), + } + for s in created_sessions + ], + ) diff --git a/backend/app/api/endpoints/target_lists.py b/backend/app/api/endpoints/target_lists.py new file mode 100644 index 00000000..0bfac439 --- /dev/null +++ b/backend/app/api/endpoints/target_lists.py @@ -0,0 +1,119 @@ +"""Target lists CRUD endpoints.""" +from typing import Annotated +from uuid import UUID +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin +from app.models.target_list import TargetList +from app.models.user import User +from app.schemas.target_list import TargetListCreate, TargetListUpdate, TargetListResponse + +router = APIRouter(prefix="/target-lists", tags=["target-lists"]) + + +@router.get("/", response_model=list[TargetListResponse]) +async def list_target_lists( + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """List all target lists for the current user's team.""" + if not current_user.team_id: + return [] + result = await db.execute( + select(TargetList) + .where(TargetList.team_id == current_user.team_id) + .order_by(TargetList.name) + ) + return result.scalars().all() + + +@router.post("/", response_model=TargetListResponse, status_code=201) +async def create_target_list( + data: TargetListCreate, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Create a new target list for the current team.""" + if not current_user.team_id: + raise HTTPException(status_code=400, detail="User must belong to a team") + target_list = TargetList( + team_id=current_user.team_id, + created_by=current_user.id, + name=data.name, + description=data.description, + targets=[t.model_dump() for t in data.targets], + ) + db.add(target_list) + await db.commit() + await db.refresh(target_list) + return target_list + + +@router.get("/{list_id}", response_model=TargetListResponse) +async def get_target_list( + list_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +): + result = await db.execute( + select(TargetList).where( + TargetList.id == list_id, + TargetList.team_id == current_user.team_id, + ) + ) + target_list = result.scalar_one_or_none() + if not target_list: + raise HTTPException(status_code=404, detail="Target list not found") + return target_list + + +@router.put("/{list_id}", response_model=TargetListResponse) +async def update_target_list( + list_id: UUID, + data: TargetListUpdate, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + result = await db.execute( + select(TargetList).where( + TargetList.id == list_id, + TargetList.team_id == current_user.team_id, + ) + ) + target_list = result.scalar_one_or_none() + if not target_list: + raise HTTPException(status_code=404, detail="Target list not found") + update_fields = data.model_fields_set + if "name" in update_fields and data.name is not None: + target_list.name = data.name + if "description" in update_fields: + target_list.description = data.description # allow setting to None + if "targets" in update_fields and data.targets is not None: + target_list.targets = [t.model_dump() for t in data.targets] + await db.commit() + await db.refresh(target_list) + return target_list + + +@router.delete("/{list_id}", status_code=204) +async def delete_target_list( + list_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + result = await db.execute( + select(TargetList).where( + TargetList.id == list_id, + TargetList.team_id == current_user.team_id, + ) + ) + target_list = result.scalar_one_or_none() + if not target_list: + raise HTTPException(status_code=404, detail="Target list not found") + await db.delete(target_list) + await db.commit() diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 2e96ddc7..08580d80 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -2,6 +2,8 @@ from fastapi import APIRouter from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks, shares, shared, tree_markdown from app.api.endpoints import admin_dashboard, admin_audit, admin_plan_limits, admin_feature_flags, admin_settings, admin_categories from app.api.endpoints import ratings, analytics +from app.api.endpoints import target_lists +from app.api.endpoints import maintenance_schedules api_router = APIRouter() @@ -28,3 +30,5 @@ api_router.include_router(shared.router) # Public endpoints (no auth) api_router.include_router(tree_markdown.router) api_router.include_router(ratings.router) api_router.include_router(analytics.router) +api_router.include_router(target_lists.router) +api_router.include_router(maintenance_schedules.router) diff --git a/backend/app/core/scheduler.py b/backend/app/core/scheduler.py new file mode 100644 index 00000000..3370ce24 --- /dev/null +++ b/backend/app/core/scheduler.py @@ -0,0 +1,144 @@ +"""APScheduler integration for maintenance flow auto-session creation.""" +import logging +import uuid +from datetime import datetime, timezone + +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.schedulers.base import SchedulerNotRunningError +from apscheduler.jobstores.base import JobLookupError +from apscheduler.triggers.cron import CronTrigger +import pytz +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +logger = logging.getLogger(__name__) + +scheduler = AsyncIOScheduler() + + +async def _fire_maintenance_schedule(schedule_id: str) -> None: + """Create batch sessions for a scheduled maintenance run.""" + # Import all models first to ensure SQLAlchemy mapper relationships resolve + import app.models # noqa: F401 + from app.core.database import async_session_maker + from app.models.maintenance_schedule import MaintenanceSchedule + from app.models.session import Session + from app.models.target_list import TargetList + from app.models.tree import Tree + + async with async_session_maker() as db: + try: + result = await db.execute( + select(MaintenanceSchedule).where( + MaintenanceSchedule.id == uuid.UUID(schedule_id), + MaintenanceSchedule.is_active == True, + ) + ) + schedule = result.scalar_one_or_none() + if not schedule: + logger.warning(f"Schedule {schedule_id} not found or inactive") + return + + tree_result = await db.execute( + select(Tree).where(Tree.id == schedule.tree_id) + ) + tree = tree_result.scalar_one_or_none() + if not tree: + logger.error(f"Tree {schedule.tree_id} not found for schedule {schedule_id}") + return + + # Resolve targets + targets: list[dict] = [] + if schedule.target_list_id: + list_result = await db.execute( + select(TargetList).where(TargetList.id == schedule.target_list_id) + ) + target_list = list_result.scalar_one_or_none() + if target_list: + targets = list(target_list.targets) + + if not targets: + targets = [{"label": "Unassigned"}] + + batch_id = uuid.uuid4() + tree_snapshot = tree.tree_structure + + sessions_to_add = [] + for target in targets: + session = Session( + tree_id=tree.id, + user_id=schedule.created_by, + tree_snapshot=tree_snapshot, + path_taken=[], + decisions=[], + custom_steps=[], + session_variables={}, + batch_id=batch_id, + target_label=target.get("label", ""), + ) + sessions_to_add.append(session) + + for s in sessions_to_add: + db.add(s) + + # Update schedule tracking + schedule.last_run_at = datetime.now(timezone.utc) + from croniter import croniter + tz = pytz.timezone(schedule.timezone) + now = datetime.now(tz) + cron = croniter(schedule.cron_expression, now) + schedule.next_run_at = cron.get_next(datetime).astimezone(timezone.utc) + + await db.commit() + logger.info( + f"Schedule {schedule_id} fired: created {len(sessions_to_add)} sessions " + f"(batch {batch_id}) for tree '{tree.name}'" + ) + except Exception: + logger.exception(f"Error firing maintenance schedule {schedule_id}") + await db.rollback() + + +async def load_all_schedules(db: AsyncSession) -> None: + """Load all active schedules into APScheduler on startup.""" + # Import all models to ensure SQLAlchemy mapper relationships resolve + # before any ORM queries are executed. + import app.models # noqa: F401 + from app.models.maintenance_schedule import MaintenanceSchedule + result = await db.execute( + select(MaintenanceSchedule).where(MaintenanceSchedule.is_active == True) + ) + schedules = result.scalars().all() + for schedule in schedules: + register_schedule(schedule) + logger.info(f"Loaded {len(schedules)} active maintenance schedule(s)") + + +def register_schedule(schedule) -> None: + """Register or update a schedule in APScheduler.""" + job_id = f"maintenance_{schedule.id}" + try: + tz = pytz.timezone(schedule.timezone) + trigger = CronTrigger.from_crontab(schedule.cron_expression, timezone=tz) + scheduler.add_job( + _fire_maintenance_schedule, + trigger=trigger, + id=job_id, + args=[str(schedule.id)], + replace_existing=True, + misfire_grace_time=3600, + ) + logger.info(f"Registered schedule {schedule.id} ({schedule.cron_expression})") + except Exception: + logger.exception(f"Failed to register schedule {schedule.id}") + + +def unregister_schedule(schedule_id: str) -> None: + """Remove a schedule from APScheduler.""" + job_id = f"maintenance_{schedule_id}" + if scheduler.get_job(job_id): + try: + scheduler.remove_job(job_id) + logger.info(f"Unregistered schedule {schedule_id}") + except (SchedulerNotRunningError, JobLookupError): + logger.warning(f"Could not remove job {job_id}: scheduler not running or job already removed") diff --git a/backend/app/core/tree_validation.py b/backend/app/core/tree_validation.py index 8d079a1f..de562889 100644 --- a/backend/app/core/tree_validation.py +++ b/backend/app/core/tree_validation.py @@ -1,6 +1,8 @@ """Tree validation helper module for draft/published workflow.""" from typing import Any +PROCEDURAL_TREE_TYPES = {"procedural", "maintenance"} + class TreeValidationError(Exception): """Custom exception for tree validation errors.""" @@ -224,14 +226,14 @@ def can_publish_tree( errors.append({"field": "name", "message": "Tree must have a name to be published"}) # Validate structure based on tree type - if tree_type == "procedural": + if tree_type in PROCEDURAL_TREE_TYPES: structure_valid, structure_errors = validate_procedural_structure(tree_structure) else: structure_valid, structure_errors = validate_tree_structure(tree_structure) errors.extend(structure_errors) # Validate intake form if present (procedural only) - if intake_form and tree_type == "procedural": + if intake_form and tree_type in PROCEDURAL_TREE_TYPES: form_valid, form_errors = _validate_intake_form(intake_form) errors.extend(form_errors) diff --git a/backend/app/main.py b/backend/app/main.py index 0aab6815..5536e832 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -6,11 +6,12 @@ from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from app.core.config import settings -from app.core.database import init_db +from app.core.database import init_db, async_session_maker from app.core.logging_config import setup_logging from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware from app.core.rate_limit import limiter from app.api.router import api_router +from app.core.scheduler import scheduler, load_all_schedules # Initialize logging configuration setup_logging() @@ -26,8 +27,16 @@ async def lifespan(app: FastAPI): logger.info(f"ALLOW_RAILWAY_ORIGINS: {settings.ALLOW_RAILWAY_ORIGINS}") # Note: In production, use Alembic migrations instead of init_db # await init_db() + + # Start maintenance schedule runner + scheduler.start() + async with async_session_maker() as db: + await load_all_schedules(db) + yield + # Shutdown + scheduler.shutdown(wait=False) logger.info("Shutting down ResolutionFlow API server...") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 674f7627..a3c6f276 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -5,6 +5,7 @@ from .subscription import Subscription from .plan_limits import PlanLimits from .account_invite import AccountInvite from .tree import Tree +from .tree_share import TreeShare from .session import Session from .attachment import Attachment from .invite_code import InviteCode @@ -22,6 +23,8 @@ from .account_limit_override import AccountLimitOverride from .feature_flag import FeatureFlag, PlanFeatureDefault, AccountFeatureOverride from .platform_setting import PlatformSetting from .user_pinned_tree import UserPinnedTree +from .target_list import TargetList +from .maintenance_schedule import MaintenanceSchedule __all__ = [ "User", @@ -31,6 +34,7 @@ __all__ = [ "PlanLimits", "AccountInvite", "Tree", + "TreeShare", "Session", "Attachment", "InviteCode", @@ -55,4 +59,6 @@ __all__ = [ "AccountFeatureOverride", "PlatformSetting", "UserPinnedTree", + "TargetList", + "MaintenanceSchedule", ] diff --git a/backend/app/models/maintenance_schedule.py b/backend/app/models/maintenance_schedule.py new file mode 100644 index 00000000..91280eb4 --- /dev/null +++ b/backend/app/models/maintenance_schedule.py @@ -0,0 +1,45 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional +from sqlalchemy import String, DateTime, ForeignKey, Boolean, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import UUID +from app.core.database import Base + + +class MaintenanceSchedule(Base): + __tablename__ = "maintenance_schedules" + __table_args__ = ( + UniqueConstraint("tree_id", name="uq_maintenance_schedules_tree_id"), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + tree_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("trees.id", ondelete="CASCADE"), + nullable=False, index=True + ) + created_by: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) + cron_expression: Mapped[str] = mapped_column(String(100), nullable=False) + timezone: Mapped[str] = mapped_column(String(100), nullable=False, default="UTC") + target_list_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), ForeignKey("target_lists.id", ondelete="SET NULL"), nullable=True + ) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + next_run_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), nullable=True + ) + last_run_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), nullable=True + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) diff --git a/backend/app/models/session.py b/backend/app/models/session.py index 91e40a50..f5085c81 100644 --- a/backend/app/models/session.py +++ b/backend/app/models/session.py @@ -65,3 +65,11 @@ class Session(Base): user: Mapped["User"] = relationship("User", back_populates="sessions") attachments: Mapped[list["Attachment"]] = relationship("Attachment", back_populates="session") shares: Mapped[list["SessionShare"]] = relationship("SessionShare", back_populates="session", cascade="all, delete-orphan") + + # Batch tracking (maintenance flows) + batch_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), nullable=True, index=True + ) + target_label: Mapped[Optional[str]] = mapped_column( + String(255), nullable=True + ) diff --git a/backend/app/models/target_list.py b/backend/app/models/target_list.py new file mode 100644 index 00000000..f2dbd7ac --- /dev/null +++ b/backend/app/models/target_list.py @@ -0,0 +1,38 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional, TYPE_CHECKING +from sqlalchemy import String, Text, DateTime, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import UUID, JSONB +from app.core.database import Base + +if TYPE_CHECKING: + from app.models.user import User + from app.models.team import Team + + +class TargetList(Base): + __tablename__ = "target_lists" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + team_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), + nullable=False, index=True + ) + created_by: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + # targets: [{"label": "RDS-01", "notes": "optional notes"}, ...] + targets: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) diff --git a/backend/app/models/tree.py b/backend/app/models/tree.py index 49486a21..825e1d6c 100644 --- a/backend/app/models/tree.py +++ b/backend/app/models/tree.py @@ -29,7 +29,7 @@ class Tree(Base): name='ck_trees_status' ), CheckConstraint( - "tree_type IN ('troubleshooting', 'procedural')", + "tree_type IN ('troubleshooting', 'procedural', 'maintenance')", name='ck_trees_tree_type' ), ) diff --git a/backend/app/schemas/maintenance_schedule.py b/backend/app/schemas/maintenance_schedule.py new file mode 100644 index 00000000..5a191b2e --- /dev/null +++ b/backend/app/schemas/maintenance_schedule.py @@ -0,0 +1,34 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID +from pydantic import BaseModel, Field + + +class MaintenanceScheduleCreate(BaseModel): + tree_id: UUID + cron_expression: str = Field(..., min_length=9, max_length=100) + timezone: str = Field("UTC", max_length=100) + target_list_id: Optional[UUID] = None + + +class MaintenanceScheduleUpdate(BaseModel): + cron_expression: Optional[str] = Field(None, min_length=9, max_length=100) + timezone: Optional[str] = Field(None, max_length=100) + target_list_id: Optional[UUID] = None + is_active: Optional[bool] = None + + +class MaintenanceScheduleResponse(BaseModel): + id: UUID + tree_id: UUID + created_by: Optional[UUID] + cron_expression: str + timezone: str + target_list_id: Optional[UUID] + is_active: bool + next_run_at: Optional[datetime] + last_run_at: Optional[datetime] + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} diff --git a/backend/app/schemas/session.py b/backend/app/schemas/session.py index 4f00f618..3aefa3b7 100644 --- a/backend/app/schemas/session.py +++ b/backend/app/schemas/session.py @@ -78,6 +78,9 @@ class SessionResponse(BaseModel): def normalize_text_fields(cls, v): return v or "" + batch_id: Optional[UUID] = None + target_label: Optional[str] = None + class Config: from_attributes = True diff --git a/backend/app/schemas/target_list.py b/backend/app/schemas/target_list.py new file mode 100644 index 00000000..0016d393 --- /dev/null +++ b/backend/app/schemas/target_list.py @@ -0,0 +1,34 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID +from pydantic import BaseModel, Field + + +class TargetEntry(BaseModel): + label: str = Field(..., min_length=1, max_length=255) + notes: Optional[str] = Field(None, max_length=500) + + +class TargetListCreate(BaseModel): + name: str = Field(..., min_length=1, max_length=255) + description: Optional[str] = None + targets: list[TargetEntry] = Field(..., min_length=1) + + +class TargetListUpdate(BaseModel): + name: Optional[str] = Field(None, min_length=1, max_length=255) + description: Optional[str] = None + targets: Optional[list[TargetEntry]] = Field(None, min_length=1) + + +class TargetListResponse(BaseModel): + id: UUID + team_id: UUID + created_by: Optional[UUID] + name: str + description: Optional[str] + targets: list[TargetEntry] + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} diff --git a/backend/app/schemas/tree.py b/backend/app/schemas/tree.py index 236b6290..c19c5f70 100644 --- a/backend/app/schemas/tree.py +++ b/backend/app/schemas/tree.py @@ -7,7 +7,7 @@ import re # --- Tree Type --- -TreeType = Literal['troubleshooting', 'procedural'] +TreeType = Literal['troubleshooting', 'procedural', 'maintenance'] # --- Intake Form Schemas --- diff --git a/backend/requirements.txt b/backend/requirements.txt index 28975527..686c481e 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -30,3 +30,6 @@ resend==2.21.0 # Utilities python-dotenv==1.0.1 +croniter>=2.0.0 +pytz>=2024.1 +apscheduler>=3.10.4 diff --git a/backend/tests/test_batch_sessions.py b/backend/tests/test_batch_sessions.py new file mode 100644 index 00000000..6598d71b --- /dev/null +++ b/backend/tests/test_batch_sessions.py @@ -0,0 +1,132 @@ +"""Tests for batch session launching (maintenance flows).""" +import pytest +from httpx import AsyncClient + + +async def _create_maintenance_tree(client, headers): + """Helper: create a published maintenance tree.""" + resp = await client.post( + "/api/v1/trees", + json={ + "name": "Patch RDS Servers", + "tree_type": "maintenance", + "tree_structure": { + "steps": [ + {"id": "s1", "type": "procedure_step", "title": "Install patch", + "description": "Run installer", "content_type": "action"}, + {"id": "end", "type": "procedure_end", "title": "Done"}, + ] + }, + }, + headers=headers, + ) + assert resp.status_code == 201, resp.text + return resp.json()["id"] + + +@pytest.mark.asyncio +async def test_batch_launch_creates_one_session_per_target(client: AsyncClient, auth_headers: dict): + tree_id = await _create_maintenance_tree(client, auth_headers) + resp = await client.post( + "/api/v1/sessions/batch", + json={ + "tree_id": tree_id, + "targets": [ + {"label": "RDS-01"}, + {"label": "RDS-02"}, + {"label": "RDS-03"}, + ], + }, + headers=auth_headers, + ) + assert resp.status_code == 201, resp.text + data = resp.json() + assert data["count"] == 3 + assert data["batch_id"] is not None + sessions = data["sessions"] + assert len(sessions) == 3 + # All share the same batch_id + batch_ids = {s["batch_id"] for s in sessions} + assert len(batch_ids) == 1 + # Each has a distinct target_label + labels = {s["target_label"] for s in sessions} + assert labels == {"RDS-01", "RDS-02", "RDS-03"} + + +@pytest.mark.asyncio +async def test_batch_launch_rejects_empty_targets(client: AsyncClient, auth_headers: dict): + tree_id = await _create_maintenance_tree(client, auth_headers) + resp = await client.post( + "/api/v1/sessions/batch", + json={"tree_id": tree_id, "targets": []}, + headers=auth_headers, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_batch_launch_rejects_non_maintenance_tree(client: AsyncClient, auth_headers: dict): + """Batch launch only works for maintenance flows.""" + # Create a procedural tree + resp = await client.post( + "/api/v1/trees", + json={ + "name": "Regular Project", + "tree_type": "procedural", + "tree_structure": { + "steps": [ + {"id": "s1", "type": "procedure_step", "title": "Step", + "description": "Do it", "content_type": "action"}, + {"id": "end", "type": "procedure_end", "title": "Done"}, + ] + }, + }, + headers=auth_headers, + ) + tree_id = resp.json()["id"] + batch_resp = await client.post( + "/api/v1/sessions/batch", + json={"tree_id": tree_id, "targets": [{"label": "SRV-01"}]}, + headers=auth_headers, + ) + assert batch_resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_batch_launch_requires_auth(client: AsyncClient): + """Unauthenticated batch launch returns 401.""" + resp = await client.post( + "/api/v1/sessions/batch", + json={"tree_id": "00000000-0000-0000-0000-000000000000", "targets": [{"label": "SRV-01"}]}, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_batch_launch_rejects_draft_tree(client: AsyncClient, auth_headers: dict): + """Batch launch against a draft maintenance tree returns 400.""" + # Create a maintenance tree — trees default to 'published', so we explicitly set draft + resp = await client.post( + "/api/v1/trees", + json={ + "name": "Draft Maintenance", + "tree_type": "maintenance", + "status": "draft", + "tree_structure": { + "steps": [ + {"id": "s1", "type": "procedure_step", "title": "Step", + "description": "Do it", "content_type": "action"}, + {"id": "end", "type": "procedure_end", "title": "Done"}, + ] + }, + }, + headers=auth_headers, + ) + assert resp.status_code == 201, resp.text + tree_id = resp.json()["id"] + batch_resp = await client.post( + "/api/v1/sessions/batch", + json={"tree_id": tree_id, "targets": [{"label": "SRV-01"}]}, + headers=auth_headers, + ) + assert batch_resp.status_code == 400 diff --git a/backend/tests/test_maintenance_schedules.py b/backend/tests/test_maintenance_schedules.py new file mode 100644 index 00000000..32c5666d --- /dev/null +++ b/backend/tests/test_maintenance_schedules.py @@ -0,0 +1,140 @@ +"""Tests for maintenance schedule CRUD.""" +import pytest +from httpx import AsyncClient + + +async def _create_maintenance_tree(client, headers): + resp = await client.post( + "/api/v1/trees", + json={ + "name": "Scheduled Patch", + "tree_type": "maintenance", + "tree_structure": { + "steps": [ + {"id": "s1", "type": "procedure_step", "title": "Step", + "description": "Do it", "content_type": "action"}, + {"id": "end", "type": "procedure_end", "title": "Done"}, + ] + }, + }, + headers=headers, + ) + assert resp.status_code == 201, resp.text + return resp.json()["id"] + + +@pytest.mark.asyncio +async def test_create_schedule(client: AsyncClient, auth_headers: dict): + tree_id = await _create_maintenance_tree(client, auth_headers) + resp = await client.post( + "/api/v1/maintenance-schedules", + json={ + "tree_id": tree_id, + "cron_expression": "0 9 15 * *", + "timezone": "America/New_York", + }, + headers=auth_headers, + ) + assert resp.status_code == 201, resp.text + data = resp.json() + assert data["cron_expression"] == "0 9 15 * *" + assert data["timezone"] == "America/New_York" + assert data["is_active"] is True + assert data["next_run_at"] is not None + + +@pytest.mark.asyncio +async def test_duplicate_schedule_rejected(client: AsyncClient, auth_headers: dict): + """Cannot create two schedules for the same tree.""" + tree_id = await _create_maintenance_tree(client, auth_headers) + await client.post( + "/api/v1/maintenance-schedules", + json={"tree_id": tree_id, "cron_expression": "0 0 1 * *", "timezone": "UTC"}, + headers=auth_headers, + ) + resp = await client.post( + "/api/v1/maintenance-schedules", + json={"tree_id": tree_id, "cron_expression": "0 6 1 * *", "timezone": "UTC"}, + headers=auth_headers, + ) + assert resp.status_code == 409 + + +@pytest.mark.asyncio +async def test_get_schedule_for_tree(client: AsyncClient, auth_headers: dict): + tree_id = await _create_maintenance_tree(client, auth_headers) + await client.post( + "/api/v1/maintenance-schedules", + json={"tree_id": tree_id, "cron_expression": "0 0 1 * *", "timezone": "UTC"}, + headers=auth_headers, + ) + resp = await client.get(f"/api/v1/maintenance-schedules/tree/{tree_id}", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json()["cron_expression"] == "0 0 1 * *" + + +@pytest.mark.asyncio +async def test_disable_schedule(client: AsyncClient, auth_headers: dict): + tree_id = await _create_maintenance_tree(client, auth_headers) + create = await client.post( + "/api/v1/maintenance-schedules", + json={"tree_id": tree_id, "cron_expression": "0 6 * * 1", "timezone": "UTC"}, + headers=auth_headers, + ) + sched_id = create.json()["id"] + resp = await client.patch( + f"/api/v1/maintenance-schedules/{sched_id}", + json={"is_active": False}, + headers=auth_headers, + ) + assert resp.status_code == 200 + assert resp.json()["is_active"] is False + + +@pytest.mark.asyncio +async def test_get_schedule_not_found(client: AsyncClient, auth_headers: dict): + tree_id = await _create_maintenance_tree(client, auth_headers) + resp = await client.get(f"/api/v1/maintenance-schedules/tree/{tree_id}", headers=auth_headers) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_cannot_schedule_other_teams_tree(client: AsyncClient, auth_headers: dict, test_db): + """User cannot create a schedule for a tree belonging to another team.""" + import uuid as _uuid + from app.models.team import Team + from app.models.tree import Tree + + # Create a tree belonging to a DIFFERENT team directly in DB + other_team = Team(name=f"Other Team {_uuid.uuid4()}") + test_db.add(other_team) + await test_db.flush() + + other_tree = Tree( + name="Other Team Tree", + tree_type="maintenance", + team_id=other_team.id, + tree_structure={ + "steps": [ + {"id": "s1", "type": "procedure_step", "title": "Step", + "description": "Do it", "content_type": "action"}, + {"id": "end", "type": "procedure_end", "title": "Done"}, + ] + }, + status="published", + visibility="team", + ) + test_db.add(other_tree) + await test_db.flush() + + # Current user (from auth_headers) tries to schedule it + resp = await client.post( + "/api/v1/maintenance-schedules", + json={ + "tree_id": str(other_tree.id), + "cron_expression": "0 9 1 * *", + "timezone": "UTC", + }, + headers=auth_headers, + ) + assert resp.status_code in (403, 404) # either is acceptable diff --git a/backend/tests/test_maintenance_tree_type.py b/backend/tests/test_maintenance_tree_type.py new file mode 100644 index 00000000..0ba3866b --- /dev/null +++ b/backend/tests/test_maintenance_tree_type.py @@ -0,0 +1,73 @@ +"""Tests for maintenance tree type.""" +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_create_maintenance_tree(client: AsyncClient, auth_headers: dict): + """Maintenance tree type is accepted by the API.""" + resp = await client.post( + "/api/v1/trees", + json={ + "name": "Update FSLogix", + "description": "Monthly FSLogix update procedure", + "tree_type": "maintenance", + "tree_structure": { + "steps": [ + {"id": "step-1", "type": "procedure_step", "title": "Download installer", + "description": "Get latest FSLogix from Microsoft", "content_type": "action"}, + {"id": "step-end", "type": "procedure_end", "title": "Complete"}, + ] + }, + }, + headers=auth_headers, + ) + assert resp.status_code == 201, resp.text + data = resp.json() + assert data["tree_type"] == "maintenance" + + +@pytest.mark.asyncio +async def test_list_maintenance_trees_filter(client: AsyncClient, auth_headers: dict): + """Filtering by tree_type=maintenance returns only maintenance trees.""" + await client.post( + "/api/v1/trees", + json={ + "name": "Maintenance Only", + "tree_type": "maintenance", + "tree_structure": { + "steps": [ + {"id": "s1", "type": "procedure_step", "title": "Step", + "description": "Do it", "content_type": "action"}, + {"id": "end", "type": "procedure_end", "title": "Done"}, + ] + }, + }, + headers=auth_headers, + ) + resp = await client.get("/api/v1/trees?tree_type=maintenance", headers=auth_headers) + assert resp.status_code == 200 + trees = resp.json() + assert all(t["tree_type"] == "maintenance" for t in trees) + assert len(trees) >= 1 + + +@pytest.mark.asyncio +async def test_invalid_tree_type_rejected(client: AsyncClient, auth_headers: dict): + """An unrecognized tree_type value is rejected with 422.""" + resp = await client.post( + "/api/v1/trees", + json={ + "name": "Bad Type", + "tree_type": "garbage", + "tree_structure": { + "steps": [ + {"id": "s1", "type": "procedure_step", "title": "Step", + "description": "Do it", "content_type": "action"}, + {"id": "end", "type": "procedure_end", "title": "Done"}, + ] + }, + }, + headers=auth_headers, + ) + assert resp.status_code == 422 diff --git a/backend/tests/test_target_lists.py b/backend/tests/test_target_lists.py new file mode 100644 index 00000000..a40cfb48 --- /dev/null +++ b/backend/tests/test_target_lists.py @@ -0,0 +1,153 @@ +"""Tests for target lists CRUD.""" +import pytest +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.team import Team +from app.models.user import User +from sqlalchemy import select + + +@pytest.fixture +async def auth_headers(client: AsyncClient, test_db: AsyncSession, test_user: dict): + """Override auth_headers to ensure the test user has a team_id assigned.""" + # Fetch the user from DB and assign a team + result = await test_db.execute(select(User).where(User.email == test_user["email"])) + user = result.scalar_one() + + # Create a team and assign the user to it + team = Team(name="Test Team") + test_db.add(team) + await test_db.flush() + + user.team_id = team.id + await test_db.commit() + + # Re-login to get a fresh token + login_data = { + "email": test_user["email"], + "password": test_user["password"], + } + resp = await client.post("/api/v1/auth/login/json", json=login_data) + assert resp.status_code == 200 + token_data = resp.json() + return {"Authorization": f"Bearer {token_data['access_token']}"} + + +@pytest.mark.asyncio +async def test_create_target_list(client: AsyncClient, auth_headers: dict): + resp = await client.post( + "/api/v1/target-lists/", + json={ + "name": "RDS Farm A", + "description": "Production RDS servers", + "targets": [ + {"label": "RDS-01", "notes": "192.168.1.10"}, + {"label": "RDS-02", "notes": "192.168.1.11"}, + ], + }, + headers=auth_headers, + ) + assert resp.status_code == 201, resp.text + data = resp.json() + assert data["name"] == "RDS Farm A" + assert len(data["targets"]) == 2 + + +@pytest.mark.asyncio +async def test_list_target_lists(client: AsyncClient, auth_headers: dict): + resp = await client.get("/api/v1/target-lists/", headers=auth_headers) + assert resp.status_code == 200 + assert isinstance(resp.json(), list) + + +@pytest.mark.asyncio +async def test_get_target_list(client: AsyncClient, auth_headers: dict): + create = await client.post( + "/api/v1/target-lists/", + json={"name": "Get Test", "targets": [{"label": "SRV-01"}]}, + headers=auth_headers, + ) + list_id = create.json()["id"] + resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json()["name"] == "Get Test" + + +@pytest.mark.asyncio +async def test_update_target_list(client: AsyncClient, auth_headers: dict): + create = await client.post( + "/api/v1/target-lists/", + json={"name": "Old Name", "targets": [{"label": "SRV-01"}]}, + headers=auth_headers, + ) + list_id = create.json()["id"] + resp = await client.put( + f"/api/v1/target-lists/{list_id}", + json={"name": "New Name", "targets": [{"label": "SRV-01"}, {"label": "SRV-02"}]}, + headers=auth_headers, + ) + assert resp.status_code == 200 + assert resp.json()["name"] == "New Name" + assert len(resp.json()["targets"]) == 2 + + +@pytest.mark.asyncio +async def test_delete_target_list(client: AsyncClient, auth_headers: dict): + create = await client.post( + "/api/v1/target-lists/", + json={"name": "To Delete", "targets": [{"label": "X"}]}, + headers=auth_headers, + ) + list_id = create.json()["id"] + resp = await client.delete(f"/api/v1/target-lists/{list_id}", headers=auth_headers) + assert resp.status_code == 204 + + get = await client.get(f"/api/v1/target-lists/{list_id}", headers=auth_headers) + assert get.status_code == 404 + +@pytest.mark.asyncio +async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: dict, test_db): + """User from team B cannot access team A's list.""" + import uuid + from app.models.team import Team + from app.models.user import User + from app.core.security import get_password_hash + + # Create team A list using existing auth_headers + create = await client.post( + "/api/v1/target-lists/", + json={"name": "Team A List", "targets": [{"label": "SRV-A"}]}, + headers=auth_headers, + ) + assert create.status_code == 201 + list_id = create.json()["id"] + + # Create a separate team B with its own user + team_b = Team(name=f"Team B {uuid.uuid4()}") + test_db.add(team_b) + await test_db.flush() + + user_b = User( + email=f"userb_{uuid.uuid4()}@test.com", + password_hash=get_password_hash("password123"), + name="User B", + is_active=True, + team_id=team_b.id, + role="engineer", + ) + test_db.add(user_b) + await test_db.flush() + + # Get auth token for user B + login = await client.post( + "/api/v1/auth/login/json", + json={"email": user_b.email, "password": "password123"}, + ) + assert login.status_code == 200 + token_b = login.json()["access_token"] + headers_b = {"Authorization": f"Bearer {token_b}"} + + # Team B cannot access Team A's list + resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b) + assert resp.status_code == 404 diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 04bec2a5..3f6dd659 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -13,3 +13,5 @@ export { default as adminApi } from './admin' export { treeMarkdownApi } from './treeMarkdown' export { default as pinnedFlowsApi } from './pinnedFlows' export { default as analyticsApi } from './analytics' +export { targetListsApi } from './targetLists' +export { maintenanceSchedulesApi, batchLaunchApi } from './maintenanceSchedules' diff --git a/frontend/src/api/maintenanceSchedules.ts b/frontend/src/api/maintenanceSchedules.ts new file mode 100644 index 00000000..e63ad693 --- /dev/null +++ b/frontend/src/api/maintenanceSchedules.ts @@ -0,0 +1,24 @@ +import { apiClient } from './client' +import type { + MaintenanceSchedule, + MaintenanceScheduleCreate, + MaintenanceScheduleUpdate, + BatchLaunchRequest, + BatchLaunchResponse, +} from '@/types' + +export const maintenanceSchedulesApi = { + getForTree: (treeId: string): Promise => + apiClient.get(`/maintenance-schedules/tree/${treeId}`).then(r => r.data), + + create: (data: MaintenanceScheduleCreate): Promise => + apiClient.post('/maintenance-schedules', data).then(r => r.data), + + update: (id: string, data: MaintenanceScheduleUpdate): Promise => + apiClient.patch(`/maintenance-schedules/${id}`, data).then(r => r.data), +} + +export const batchLaunchApi = { + launch: (data: BatchLaunchRequest): Promise => + apiClient.post('/sessions/batch', data).then(r => r.data), +} diff --git a/frontend/src/api/targetLists.ts b/frontend/src/api/targetLists.ts new file mode 100644 index 00000000..28508b3a --- /dev/null +++ b/frontend/src/api/targetLists.ts @@ -0,0 +1,19 @@ +import { apiClient } from './client' +import type { TargetList, TargetListCreate } from '@/types' + +export const targetListsApi = { + list: (): Promise => + apiClient.get('/target-lists/').then(r => r.data), + + get: (id: string): Promise => + apiClient.get(`/target-lists/${id}`).then(r => r.data), + + create: (data: TargetListCreate): Promise => + apiClient.post('/target-lists/', data).then(r => r.data), + + update: (id: string, data: Partial): Promise => + apiClient.put(`/target-lists/${id}`, data).then(r => r.data), + + delete: (id: string): Promise => + apiClient.delete(`/target-lists/${id}`).then(() => undefined), +} diff --git a/frontend/src/components/layout/Sidebar.tsx b/frontend/src/components/layout/Sidebar.tsx index cd00f5b2..c092601f 100644 --- a/frontend/src/components/layout/Sidebar.tsx +++ b/frontend/src/components/layout/Sidebar.tsx @@ -29,7 +29,7 @@ export function Sidebar() { const [activeTags, setActiveTags] = useState([]) const [activeSessionCount, setActiveSessionCount] = useState(0) const [pinnedFlows, setPinnedFlows] = useState([]) - const [treeCounts, setTreeCounts] = useState({ total: 0, troubleshooting: 0, procedural: 0 }) + const [treeCounts, setTreeCounts] = useState({ total: 0, troubleshooting: 0, procedural: 0, maintenance: 0 }) // Fetch sidebar data on mount useEffect(() => { @@ -55,7 +55,8 @@ export function Sidebar() { const total = allTrees.length const troubleshooting = allTrees.filter(t => t.tree_type === 'troubleshooting').length const procedural = allTrees.filter(t => t.tree_type === 'procedural').length - setTreeCounts({ total, troubleshooting, procedural }) + const maintenance = allTrees.filter(t => t.tree_type === 'maintenance').length + setTreeCounts({ total, troubleshooting, procedural, maintenance }) } catch { // Silently handle errors } @@ -145,6 +146,7 @@ export function Sidebar() { children={[ { href: '/trees?type=troubleshooting', label: 'Troubleshooting', count: treeCounts.troubleshooting || undefined }, { href: '/trees?type=procedural', label: 'Projects', count: treeCounts.procedural || undefined }, + { href: '/trees?type=maintenance', label: 'Maintenance', count: treeCounts.maintenance || undefined }, ]} /> diff --git a/frontend/src/components/library/TreeGridView.tsx b/frontend/src/components/library/TreeGridView.tsx index a6d8cb8b..a633a862 100644 --- a/frontend/src/components/library/TreeGridView.tsx +++ b/frontend/src/components/library/TreeGridView.tsx @@ -1,5 +1,5 @@ import { Link } from 'react-router-dom' -import { Pencil, Globe, Lock, Trash2, GitBranch, FileText } from 'lucide-react' +import { Pencil, Globe, Lock, Trash2, GitBranch, FileText, Wrench } from 'lucide-react' import type { TreeListItem } from '@/types' import { TagBadges } from '@/components/common/TagBadges' import { AddToFolderMenu } from './AddToFolderMenu' @@ -41,6 +41,12 @@ export function TreeGridView({ Draft )} + {tree.tree_type === 'maintenance' && ( + + + Maintenance + + )}
{tree.is_public ? ( diff --git a/frontend/src/components/library/TreeListView.tsx b/frontend/src/components/library/TreeListView.tsx index 82227b39..7485d164 100644 --- a/frontend/src/components/library/TreeListView.tsx +++ b/frontend/src/components/library/TreeListView.tsx @@ -1,5 +1,5 @@ import { Link } from 'react-router-dom' -import { Pencil, Globe, Lock, GitBranch, FileText, Trash2 } from 'lucide-react' +import { Pencil, Globe, Lock, GitBranch, FileText, Trash2, Wrench } from 'lucide-react' import type { TreeListItem } from '@/types' import { TagBadges } from '@/components/common/TagBadges' import { AddToFolderMenu } from './AddToFolderMenu' @@ -42,6 +42,12 @@ export function TreeListView({ Draft )} + {tree.tree_type === 'maintenance' && ( + + + Maintenance + + )} {tree.is_public ? ( diff --git a/frontend/src/components/library/TreeTableView.tsx b/frontend/src/components/library/TreeTableView.tsx index bd4ea7ae..c09101ed 100644 --- a/frontend/src/components/library/TreeTableView.tsx +++ b/frontend/src/components/library/TreeTableView.tsx @@ -1,6 +1,6 @@ import { useState } from 'react' import { Link } from 'react-router-dom' -import { Pencil, Globe, Lock, ChevronUp, ChevronDown, GitBranch, FileText, Trash2 } from 'lucide-react' +import { Pencil, Globe, Lock, ChevronUp, ChevronDown, GitBranch, FileText, Trash2, Wrench } from 'lucide-react' import type { TreeListItem } from '@/types' import { TagBadges } from '@/components/common/TagBadges' import { AddToFolderMenu } from './AddToFolderMenu' @@ -144,6 +144,12 @@ export function TreeTableView({ Draft )} + {tree.tree_type === 'maintenance' && ( + + + Maintenance + + )} {tree.is_public ? ( diff --git a/frontend/src/components/maintenance/BatchLaunchModal.tsx b/frontend/src/components/maintenance/BatchLaunchModal.tsx new file mode 100644 index 00000000..46c091c9 --- /dev/null +++ b/frontend/src/components/maintenance/BatchLaunchModal.tsx @@ -0,0 +1,202 @@ +import { useState, useEffect } from 'react' +import { X, List, Clock, PenLine, ExternalLink } from 'lucide-react' +import { cn } from '@/lib/utils' +import { toast } from '@/lib/toast' +import { targetListsApi, batchLaunchApi } from '@/api' +import type { TargetList, TargetEntry } from '@/types' + +interface BatchLaunchModalProps { + treeId: string + treeName: string + onClose: () => void + onLaunched: (batchId: string, count: number) => void +} + +type TabId = 'manual' | 'saved' | 'previous' | 'psa' + +export function BatchLaunchModal({ treeId, treeName, onClose, onLaunched }: BatchLaunchModalProps) { + const [activeTab, setActiveTab] = useState('manual') + const [savedLists, setSavedLists] = useState(null) + const [selectedListId, setSelectedListId] = useState(null) + const [manualInput, setManualInput] = useState('') + const [isLaunching, setIsLaunching] = useState(false) + + useEffect(() => { + if (activeTab === 'saved' && savedLists === null) { + targetListsApi.list() + .then(setSavedLists) + .catch(() => toast.error('Failed to load saved lists')) + } + }, [activeTab, savedLists]) + + const getTargets = (): TargetEntry[] => { + if (activeTab === 'saved' && selectedListId && savedLists) { + const list = savedLists.find(l => l.id === selectedListId) + return list?.targets ?? [] + } + if (activeTab === 'manual') { + return manualInput + .split('\n') + .map(l => l.trim()) + .filter(Boolean) + .map(label => ({ label })) + } + return [] + } + + const targets = getTargets() + + const handleLaunch = async () => { + if (targets.length === 0) { + toast.error('Add at least one target before launching') + return + } + if (targets.length > 100) { + toast.error('Maximum 100 targets per batch') + return + } + setIsLaunching(true) + try { + const result = await batchLaunchApi.launch({ tree_id: treeId, targets }) + toast.success(`${result.count} sessions created`) + onLaunched(result.batch_id, result.count) + } catch { + toast.error('Failed to launch batch') + } finally { + setIsLaunching(false) + } + } + + const tabs: { id: TabId; label: string; icon: React.ReactNode }[] = [ + { id: 'manual', label: 'Manual Entry', icon: }, + { id: 'saved', label: 'Saved List', icon: }, + { id: 'previous', label: 'Previous Run', icon: }, + { id: 'psa', label: 'PSA / RMM', icon: }, + ] + + return ( +
+
+ {/* Header */} +
+
+

Batch Launch

+

{treeName}

+
+ +
+ + {/* Tabs */} +
+ {tabs.map(tab => ( + + ))} +
+ + {/* Content */} +
+ {activeTab === 'manual' && ( +
+ +