- Move mid-file pydantic/uuid imports to top of sessions.py - Add can_access_tree, is_active, and draft status guards to batch_launch_sessions - Remove notes field from _BatchTarget to keep API clean - Add max_length=100 cap to targets list in _BatchLaunchRequest - Hoist tree_snapshot computation above the session creation loop - Replace N db.refresh() calls with a single bulk select after flush - Add test_batch_launch_requires_auth and test_batch_launch_rejects_draft_tree tests - Fix trailing slash on /api/v1/trees/ URL in new test (caused 307 redirect) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
579 lines
20 KiB
Python
579 lines
20 KiB
Python
from datetime import datetime, timezone
|
|
from typing import Annotated, Optional
|
|
from uuid import UUID
|
|
import uuid
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|
from fastapi.responses import PlainTextResponse
|
|
from pydantic import BaseModel, Field as PydanticField
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
|
|
from app.core.database import get_db
|
|
from app.models.tree import Tree
|
|
from app.models.session import Session
|
|
from app.models.user import User
|
|
from app.schemas.session import (
|
|
SessionCreate,
|
|
SessionUpdate,
|
|
SessionResponse,
|
|
SessionExport,
|
|
ScratchpadUpdate,
|
|
SaveAsTreeRequest,
|
|
SaveAsTreeResponse,
|
|
SessionComplete,
|
|
)
|
|
from app.api.deps import get_current_active_user
|
|
from app.core.permissions import can_access_tree
|
|
from app.services.export_service import generate_markdown_export, generate_text_export, generate_html_export, generate_psa_export
|
|
|
|
router = APIRouter(prefix="/sessions", tags=["sessions"])
|
|
|
|
|
|
@router.get("", response_model=list[SessionResponse])
|
|
async def list_sessions(
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
completed: Optional[bool] = Query(None, description="Filter by completion status"),
|
|
ticket_number: Optional[str] = Query(None, description="Search by ticket number (partial match)"),
|
|
client_name: Optional[str] = Query(None, description="Search by client name (partial match)"),
|
|
tree_name: Optional[str] = Query(None, description="Filter by tree name from snapshot"),
|
|
started_after: Optional[datetime] = Query(None, description="Filter sessions started after this datetime"),
|
|
started_before: Optional[datetime] = Query(None, description="Filter sessions started before this datetime"),
|
|
completed_after: Optional[datetime] = Query(None, description="Filter sessions completed after this datetime"),
|
|
completed_before: Optional[datetime] = Query(None, description="Filter sessions completed before this datetime"),
|
|
skip: int = Query(0, ge=0),
|
|
limit: int = Query(50, ge=1, le=100)
|
|
):
|
|
"""List user's troubleshooting sessions with comprehensive filtering."""
|
|
query = select(Session).where(Session.user_id == current_user.id)
|
|
|
|
# Completion status filter
|
|
if completed is not None:
|
|
if completed:
|
|
query = query.where(Session.completed_at.isnot(None))
|
|
else:
|
|
query = query.where(Session.completed_at.is_(None))
|
|
|
|
# Ticket number search (case-insensitive partial match)
|
|
if ticket_number:
|
|
query = query.where(Session.ticket_number.ilike(f"%{ticket_number}%"))
|
|
|
|
# Client name search (case-insensitive partial match)
|
|
if client_name:
|
|
query = query.where(Session.client_name.ilike(f"%{client_name}%"))
|
|
|
|
# Tree name filter (JSONB path query)
|
|
if tree_name:
|
|
query = query.where(Session.tree_snapshot['name'].astext.ilike(f"%{tree_name}%"))
|
|
|
|
# Date range filters
|
|
if started_after:
|
|
query = query.where(Session.started_at >= started_after)
|
|
if started_before:
|
|
query = query.where(Session.started_at <= started_before)
|
|
if completed_after:
|
|
query = query.where(Session.completed_at >= completed_after)
|
|
if completed_before:
|
|
query = query.where(Session.completed_at <= completed_before)
|
|
|
|
query = query.order_by(Session.started_at.desc())
|
|
query = query.offset(skip).limit(limit)
|
|
|
|
result = await db.execute(query)
|
|
sessions = result.scalars().all()
|
|
return sessions
|
|
|
|
|
|
@router.get("/{session_id}", response_model=SessionResponse)
|
|
async def get_session(
|
|
session_id: UUID,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
"""Get a specific session."""
|
|
result = await db.execute(select(Session).where(Session.id == session_id))
|
|
session = result.scalar_one_or_none()
|
|
|
|
if not session:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Session not found"
|
|
)
|
|
|
|
if session.user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You don't have access to this session"
|
|
)
|
|
|
|
return session
|
|
|
|
|
|
@router.post("", response_model=SessionResponse, status_code=status.HTTP_201_CREATED)
|
|
async def start_session(
|
|
session_data: SessionCreate,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
"""Start a new troubleshooting session."""
|
|
# Get the tree
|
|
result = await db.execute(select(Tree).where(Tree.id == session_data.tree_id))
|
|
tree = result.scalar_one_or_none()
|
|
|
|
if not tree:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Tree not found"
|
|
)
|
|
|
|
if not tree.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Cannot start session with inactive tree"
|
|
)
|
|
|
|
if not can_access_tree(current_user, tree):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You don't have access to this tree"
|
|
)
|
|
|
|
# For procedural trees with intake forms, validate required fields
|
|
session_variables = session_data.session_variables or {}
|
|
if tree.tree_type == 'procedural' and tree.intake_form:
|
|
missing_fields = []
|
|
for field in tree.intake_form:
|
|
if field.get("required") and not session_variables.get(field["variable_name"]):
|
|
missing_fields.append(field["label"])
|
|
if missing_fields:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail=f"Missing required intake form fields: {', '.join(missing_fields)}"
|
|
)
|
|
|
|
# Create session with tree snapshot (includes tree metadata for filtering/export)
|
|
tree_snapshot = {
|
|
**tree.tree_structure,
|
|
"name": tree.name,
|
|
"description": tree.description,
|
|
"category": tree.category,
|
|
"version": tree.version,
|
|
"tree_type": tree.tree_type,
|
|
}
|
|
|
|
new_session = Session(
|
|
tree_id=tree.id,
|
|
user_id=current_user.id,
|
|
tree_snapshot=tree_snapshot,
|
|
path_taken=[],
|
|
decisions=[],
|
|
ticket_number=session_data.ticket_number,
|
|
client_name=session_data.client_name,
|
|
session_variables=session_variables,
|
|
)
|
|
|
|
# Increment tree usage count
|
|
tree.usage_count += 1
|
|
|
|
db.add(new_session)
|
|
await db.commit()
|
|
await db.refresh(new_session)
|
|
return new_session
|
|
|
|
|
|
@router.put("/{session_id}", response_model=SessionResponse)
|
|
async def update_session(
|
|
session_id: UUID,
|
|
session_data: SessionUpdate,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
"""Update session (add decisions, notes, etc.)."""
|
|
result = await db.execute(select(Session).where(Session.id == session_id))
|
|
session = result.scalar_one_or_none()
|
|
|
|
if not session:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Session not found"
|
|
)
|
|
|
|
if session.user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You don't have access to this session"
|
|
)
|
|
|
|
if session.completed_at:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Cannot update a completed session"
|
|
)
|
|
|
|
# Use mode='json' to ensure datetime fields are serialized as ISO strings for JSONB storage
|
|
update_data = session_data.model_dump(exclude_unset=True, mode='json')
|
|
|
|
for field, value in update_data.items():
|
|
setattr(session, field, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
return session
|
|
|
|
|
|
@router.post("/{session_id}/complete", response_model=SessionResponse)
|
|
async def complete_session(
|
|
session_id: UUID,
|
|
completion_data: SessionComplete,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
"""Mark session as complete."""
|
|
result = await db.execute(select(Session).where(Session.id == session_id))
|
|
session = result.scalar_one_or_none()
|
|
|
|
if not session:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Session not found"
|
|
)
|
|
|
|
if session.user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You don't have access to this session"
|
|
)
|
|
|
|
if session.completed_at:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Session already completed"
|
|
)
|
|
|
|
session.completed_at = datetime.now(timezone.utc)
|
|
session.outcome = completion_data.outcome
|
|
session.outcome_notes = completion_data.outcome_notes
|
|
session.next_steps = completion_data.next_steps
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
return session
|
|
|
|
|
|
@router.patch("/{session_id}/scratchpad", response_model=SessionResponse)
|
|
async def update_scratchpad(
|
|
session_id: UUID,
|
|
data: ScratchpadUpdate,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
"""Update session scratchpad. Accepts updates on both active and completed sessions."""
|
|
result = await db.execute(select(Session).where(Session.id == session_id))
|
|
session = result.scalar_one_or_none()
|
|
|
|
if not session:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Session not found"
|
|
)
|
|
|
|
if session.user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You don't have access to this session"
|
|
)
|
|
|
|
session.scratchpad = data.scratchpad
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
return session
|
|
|
|
|
|
@router.post("/{session_id}/export")
|
|
async def export_session(
|
|
session_id: UUID,
|
|
export_options: SessionExport,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
"""Export session to formatted notes."""
|
|
result = await db.execute(select(Session).where(Session.id == session_id))
|
|
session = result.scalar_one_or_none()
|
|
|
|
if not session:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Session not found"
|
|
)
|
|
|
|
if session.user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You don't have access to this session"
|
|
)
|
|
|
|
# Generate export based on format
|
|
if export_options.format == "markdown":
|
|
content = generate_markdown_export(session, export_options)
|
|
media_type = "text/markdown"
|
|
elif export_options.format == "html":
|
|
content = generate_html_export(session, export_options)
|
|
media_type = "text/html"
|
|
elif export_options.format == "psa":
|
|
content = generate_psa_export(session, export_options)
|
|
media_type = "text/plain"
|
|
else: # text
|
|
content = generate_text_export(session, export_options)
|
|
media_type = "text/plain"
|
|
|
|
# Resolve variables in export output
|
|
session_vars = getattr(session, 'session_variables', None) or {}
|
|
if session_vars:
|
|
from app.services.variable_service import resolve_variables
|
|
content = resolve_variables(content, session_vars)
|
|
|
|
# Phase C: Apply redaction AFTER generation and variable resolution
|
|
redaction_summary = None
|
|
if export_options.redaction_mode == "mask":
|
|
from app.services.redaction_service import apply_redaction_to_text, format_redaction_footer
|
|
try:
|
|
content, redaction_summary = apply_redaction_to_text(content)
|
|
footer = format_redaction_footer(redaction_summary)
|
|
if footer:
|
|
content += footer
|
|
except Exception:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Redaction processing failed"
|
|
)
|
|
|
|
# Only mark as exported if session is completed
|
|
if session.completed_at:
|
|
session.exported = True
|
|
await db.commit()
|
|
|
|
# Build response with redaction headers
|
|
import json
|
|
headers = {"X-Redaction-Mode": export_options.redaction_mode}
|
|
if redaction_summary is not None:
|
|
headers["X-Redaction-Summary"] = json.dumps(redaction_summary.to_dict())
|
|
|
|
return PlainTextResponse(content=content, media_type=media_type, headers=headers)
|
|
|
|
|
|
# --- Save Session as Tree ---
|
|
|
|
|
|
@router.post("/{session_id}/save-as-tree", response_model=SaveAsTreeResponse, status_code=status.HTTP_201_CREATED)
|
|
async def save_session_as_tree(
|
|
session_id: UUID,
|
|
request_data: SaveAsTreeRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
"""Save a session as a new tree.
|
|
|
|
Converts the session's path_taken and custom_steps into a linear tree structure.
|
|
The new tree is linked to the original tree via parent_tree_id (fork relationship).
|
|
|
|
Args:
|
|
session_id: ID of the session to save
|
|
request_data: Tree name, description, and status
|
|
db: Database session
|
|
current_user: Current authenticated user
|
|
|
|
Returns:
|
|
SaveAsTreeResponse with new tree ID and name
|
|
"""
|
|
from app.core.session_to_tree import convert_session_to_tree, generate_tree_name_from_session
|
|
from app.core.tree_validation import can_publish_tree
|
|
from app.core.subscriptions import check_tree_limit
|
|
|
|
# Load the session
|
|
result = await db.execute(
|
|
select(Session).where(
|
|
Session.id == session_id,
|
|
Session.user_id == current_user.id
|
|
)
|
|
)
|
|
session = result.scalar_one_or_none()
|
|
|
|
if not session:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Session not found"
|
|
)
|
|
|
|
# Load the original tree to get metadata
|
|
tree_result = await db.execute(
|
|
select(Tree).where(Tree.id == session.tree_id)
|
|
)
|
|
original_tree = tree_result.scalar_one_or_none()
|
|
|
|
if not original_tree:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Original tree not found"
|
|
)
|
|
|
|
# Convert session to tree structure
|
|
tree_structure = convert_session_to_tree(
|
|
session.path_taken,
|
|
session.tree_snapshot,
|
|
session.custom_steps,
|
|
session.decisions
|
|
)
|
|
|
|
# Generate tree name
|
|
if request_data.tree_name:
|
|
tree_name = request_data.tree_name
|
|
else:
|
|
tree_name = generate_tree_name_from_session(
|
|
original_tree.name,
|
|
session.ticket_number,
|
|
session.client_name
|
|
)
|
|
|
|
# Validate if status is published
|
|
if request_data.status == 'published':
|
|
can_publish, validation_errors = can_publish_tree(
|
|
tree_structure,
|
|
tree_name,
|
|
request_data.description,
|
|
tree_type=original_tree.tree_type,
|
|
intake_form=original_tree.intake_form,
|
|
)
|
|
if not can_publish:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail={
|
|
"message": "Cannot save as published tree with validation errors",
|
|
"errors": validation_errors
|
|
}
|
|
)
|
|
|
|
# Check subscription tree limit
|
|
if current_user.account_id:
|
|
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
|
|
if not can_create:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
|
|
)
|
|
|
|
# Create the new tree as a fork of the original
|
|
new_tree = Tree(
|
|
name=tree_name,
|
|
description=request_data.description or f"Saved from troubleshooting session on {session.started_at.strftime('%Y-%m-%d')}",
|
|
tree_structure=tree_structure,
|
|
author_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
status=request_data.status,
|
|
is_public=False,
|
|
is_default=False,
|
|
# Fork tracking - link to original tree
|
|
parent_tree_id=original_tree.id,
|
|
root_tree_id=original_tree.root_tree_id if original_tree.root_tree_id else original_tree.id,
|
|
fork_depth=original_tree.fork_depth + 1,
|
|
fork_reason=f"Saved from session: {session.ticket_number or 'No ticket'}" if session.ticket_number else "Saved from troubleshooting session",
|
|
parent_updated_at=original_tree.updated_at
|
|
)
|
|
|
|
db.add(new_tree)
|
|
await db.commit()
|
|
await db.refresh(new_tree)
|
|
|
|
return SaveAsTreeResponse(
|
|
tree_id=new_tree.id,
|
|
tree_name=new_tree.name,
|
|
message=f"Session saved as {'published' if request_data.status == 'published' else 'draft'} tree"
|
|
)
|
|
|
|
|
|
# ── Batch Launch (Maintenance Flows) ──────────────────────────────────────
|
|
|
|
|
|
class _BatchTarget(BaseModel):
|
|
label: str = PydanticField(..., min_length=1, max_length=255)
|
|
|
|
|
|
class _BatchLaunchRequest(BaseModel):
|
|
tree_id: UUID
|
|
targets: list[_BatchTarget] = PydanticField(..., min_length=1, max_length=100)
|
|
|
|
|
|
class _BatchLaunchResponse(BaseModel):
|
|
batch_id: str
|
|
count: int
|
|
sessions: list[dict]
|
|
|
|
|
|
@router.post("/batch", status_code=201, response_model=_BatchLaunchResponse)
|
|
async def batch_launch_sessions(
|
|
data: _BatchLaunchRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
):
|
|
"""Create one session per target for a maintenance flow batch run."""
|
|
tree_result = await db.execute(select(Tree).where(Tree.id == data.tree_id))
|
|
tree = tree_result.scalar_one_or_none()
|
|
if not tree:
|
|
raise HTTPException(status_code=404, detail="Tree not found")
|
|
|
|
if not can_access_tree(current_user, tree):
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
if not tree.is_active:
|
|
raise HTTPException(status_code=400, detail="Cannot batch-launch an inactive flow")
|
|
|
|
if tree.status == 'draft':
|
|
raise HTTPException(status_code=400, detail="Cannot batch-launch a draft flow")
|
|
|
|
if tree.tree_type != "maintenance":
|
|
raise HTTPException(status_code=400, detail="Batch launch is only for maintenance flows")
|
|
|
|
batch_id = uuid.uuid4()
|
|
created_sessions = []
|
|
|
|
# Hoist snapshot computation out of the loop — same tree for all targets
|
|
tree_snapshot = {
|
|
**tree.tree_structure,
|
|
"name": tree.name,
|
|
"description": tree.description,
|
|
"tree_type": tree.tree_type,
|
|
}
|
|
|
|
for target in data.targets:
|
|
session = Session(
|
|
tree_id=tree.id,
|
|
user_id=current_user.id,
|
|
tree_snapshot=tree_snapshot,
|
|
path_taken=[],
|
|
decisions=[],
|
|
custom_steps=[],
|
|
session_variables={},
|
|
batch_id=batch_id,
|
|
target_label=target.label,
|
|
)
|
|
db.add(session)
|
|
created_sessions.append(session)
|
|
|
|
await db.flush()
|
|
session_ids = [s.id for s in created_sessions]
|
|
result = await db.execute(select(Session).where(Session.id.in_(session_ids)))
|
|
created_sessions = result.scalars().all()
|
|
await db.commit()
|
|
|
|
return _BatchLaunchResponse(
|
|
batch_id=str(batch_id),
|
|
count=len(created_sessions),
|
|
sessions=[
|
|
{
|
|
"id": str(s.id),
|
|
"batch_id": str(s.batch_id),
|
|
"target_label": s.target_label,
|
|
"tree_id": str(s.tree_id),
|
|
}
|
|
for s in created_sessions
|
|
],
|
|
)
|