feat: add HandoffManager service with dual-write and integration tests

Unified park/escalate handoff management with snapshot generation,
AI diagnostic assessment for escalations (via _call_ai), claim workflow
that reactivates sessions, PSA push via existing psa_documentation_service,
and team queue query. Dual-writes to ai_sessions.escalation_package for
backward compatibility.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-03-24 08:42:21 +00:00
parent 7b4060a4d1
commit f84b868d13
2 changed files with 404 additions and 0 deletions

View File

@@ -0,0 +1,289 @@
"""Handoff management — unified park/escalate with dual-write backward compat.
Creates handoff snapshots, AI assessments (for escalations), claim workflow,
and queue queries. Dual-writes to ai_sessions.escalation_package for
backward compatibility with the existing escalation queue.
"""
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ai_session import AISession
from app.models.session_branch import SessionBranch
from app.models.session_handoff import SessionHandoff
logger = logging.getLogger(__name__)
class HandoffManager:
"""Unified park/escalate handoff management."""
def __init__(self, db: AsyncSession):
self.db = db
async def create_handoff(
self,
session_id: UUID,
intent: str,
engineer_notes: str | None,
user_id: UUID,
priority: str = "normal",
) -> SessionHandoff:
"""Create a handoff (park or escalate).
Generates snapshot, updates session status, dual-writes to
escalation_package for backward compat.
"""
result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
session = result.scalar_one_or_none()
if not session:
raise ValueError(f"Session {session_id} not found")
# Generate snapshot
snapshot = await self._generate_snapshot(session)
# Generate AI assessment for escalations
ai_assessment = None
ai_assessment_data = None
if intent == "escalate":
ai_assessment, ai_assessment_data = await self._generate_ai_assessment(session)
handoff = SessionHandoff(
session_id=session_id,
handed_off_by=user_id,
intent=intent,
source_branch_id=session.active_branch_id,
snapshot=snapshot,
ai_assessment=ai_assessment,
ai_assessment_data=ai_assessment_data,
engineer_notes=engineer_notes,
priority=priority,
)
self.db.add(handoff)
# Update session status
if intent == "park":
session.status = "paused"
elif intent == "escalate":
session.status = "escalated"
session.handoff_count = (session.handoff_count or 0) + 1
# Dual-write for backward compat
session.escalation_package = {
"snapshot": snapshot,
"intent": intent,
"engineer_notes": engineer_notes,
"handoff_id": str(handoff.id),
}
await self.db.flush()
return handoff
async def _generate_snapshot(self, session: AISession) -> dict[str, Any]:
"""Generate a snapshot of the session state at handoff time."""
snapshot: dict[str, Any] = {
"problem_summary": session.problem_summary,
"problem_domain": session.problem_domain,
"status": session.status,
"step_count": session.step_count,
"confidence_tier": session.confidence_tier,
}
# Add branch map if branching is active
if session.is_branching:
branches_result = await self.db.execute(
select(SessionBranch)
.where(SessionBranch.session_id == session.id)
.order_by(SessionBranch.branch_order)
)
branches = list(branches_result.scalars().all())
branch_map = []
for b in branches:
branch_map.append({
"id": str(b.id),
"label": b.label,
"status": b.status,
"status_reason": b.status_reason,
"parent_branch_id": str(b.parent_branch_id) if b.parent_branch_id else None,
})
snapshot["branch_map"] = branch_map
snapshot["active_branch_id"] = str(session.active_branch_id) if session.active_branch_id else None
return snapshot
async def claim_session(
self,
handoff_id: UUID,
claiming_user_id: UUID,
) -> SessionHandoff:
"""Claim a handed-off session."""
result = await self.db.execute(
select(SessionHandoff).where(SessionHandoff.id == handoff_id)
)
handoff = result.scalar_one_or_none()
if not handoff:
raise ValueError(f"Handoff {handoff_id} not found")
handoff.claimed_by = claiming_user_id
handoff.claimed_at = datetime.now(timezone.utc)
# Reactivate session
session_result = await self.db.execute(
select(AISession).where(AISession.id == handoff.session_id)
)
session = session_result.scalar_one()
session.status = "active"
# Dual-write
session.escalated_to_id = claiming_user_id
await self.db.flush()
return handoff
async def _generate_ai_assessment(
self, session: AISession
) -> tuple[str | None, dict[str, Any] | None]:
"""Generate AI diagnostic assessment for escalation handoffs."""
try:
from app.services.assistant_chat_service import _call_ai
context = f"Problem: {session.problem_summary or 'Unknown'}\nDomain: {session.problem_domain or 'Unknown'}"
msgs = session.conversation_messages or []
# Include last 10 messages for context
recent = "\n".join(
f"[{m.get('role', '?')}]: {m.get('content', '')[:200]}"
for m in msgs[-10:]
)
assessment_text, _, _ = await _call_ai(
system_base="You are a diagnostic assessment generator for MSP escalations.",
rag_context="",
history=[],
new_message=(
f"Generate a brief diagnostic assessment for this escalation.\n"
f"{context}\n\nRecent conversation:\n{recent}\n\n"
f"Return: 1) Most likely cause, 2) Suggested next steps, 3) Confidence (low/medium/high)"
),
max_tokens=500,
)
assessment_data = {
"likely_cause": "See assessment text",
"suggested_steps": [],
"confidence": "medium",
}
return assessment_text, assessment_data
except Exception:
logger.exception("Failed to generate AI assessment")
return None, None
async def generate_briefing(
self, handoff_id: UUID, claiming_user_id: UUID
) -> str:
"""Generate a natural-language briefing for the engineer claiming the session."""
result = await self.db.execute(
select(SessionHandoff).where(SessionHandoff.id == handoff_id)
)
handoff = result.scalar_one_or_none()
if not handoff:
raise ValueError(f"Handoff {handoff_id} not found")
session_result = await self.db.execute(
select(AISession).where(AISession.id == handoff.session_id)
)
session = session_result.scalar_one()
from app.services.assistant_chat_service import _call_ai
snapshot_text = str(handoff.snapshot)[:2000]
briefing, _, _ = await _call_ai(
system_base="You are a handoff briefing generator for MSP teams.",
rag_context="",
history=[],
new_message=(
f"Generate a concise briefing for an engineer picking up this session.\n"
f"Problem: {session.problem_summary}\n"
f"Intent: {handoff.intent}\n"
f"Engineer notes: {handoff.engineer_notes or 'None'}\n"
f"Snapshot: {snapshot_text}\n"
f"AI Assessment: {handoff.ai_assessment or 'None'}"
),
max_tokens=500,
)
return briefing
async def push_to_psa(self, handoff_id: UUID) -> SessionHandoff:
"""Push handoff notes to PSA via existing psa_documentation_service."""
result = await self.db.execute(
select(SessionHandoff).where(SessionHandoff.id == handoff_id)
)
handoff = result.scalar_one_or_none()
if not handoff:
raise ValueError(f"Handoff {handoff_id} not found")
try:
from app.services.psa_documentation_service import push_session_notes
session_result = await self.db.execute(
select(AISession).where(AISession.id == handoff.session_id)
)
session = session_result.scalar_one()
if session.psa_ticket_id and session.psa_connection_id:
note_id = await push_session_notes(
session=session,
notes_content=handoff.ai_assessment or str(handoff.snapshot),
db=self.db,
)
handoff.psa_note_pushed = True
handoff.psa_note_id = note_id
except Exception:
logger.exception(f"Failed to push handoff {handoff_id} to PSA")
await self.db.flush()
return handoff
async def get_queue(
self,
team_id: UUID | None = None,
account_id: UUID | None = None,
) -> list[dict[str, Any]]:
"""Get team queue of parked + escalated sessions."""
query = (
select(SessionHandoff, AISession)
.join(AISession, SessionHandoff.session_id == AISession.id)
.where(SessionHandoff.claimed_by.is_(None))
.order_by(SessionHandoff.created_at.desc())
)
if team_id:
query = query.where(AISession.team_id == team_id)
elif account_id:
query = query.where(AISession.account_id == account_id)
result = await self.db.execute(query)
rows = result.all()
queue_items = []
for handoff, session in rows:
queue_items.append({
"handoff_id": handoff.id,
"session_id": session.id,
"intent": handoff.intent,
"problem_summary": session.problem_summary,
"problem_domain": session.problem_domain,
"priority": handoff.priority,
"engineer_notes": handoff.engineer_notes,
"created_at": handoff.created_at,
"claimed_by": handoff.claimed_by,
"claimed_at": handoff.claimed_at,
})
return queue_items

View File

@@ -0,0 +1,115 @@
"""Integration tests for HandoffManager service."""
import pytest
from httpx import AsyncClient
from app.models.ai_session import AISession
@pytest.mark.asyncio
async def test_create_park_handoff(client: AsyncClient, test_user, auth_headers, test_db):
"""Parking a session creates a handoff with snapshot."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[{"role": "user", "content": "help me"}],
)
test_db.add(session)
await test_db.flush()
from app.services.handoff_manager import HandoffManager
manager = HandoffManager(test_db)
handoff = await manager.create_handoff(
session_id=session.id,
intent="park",
engineer_notes="Waiting for client to provide logs",
user_id=test_user["user_data"]["id"],
)
assert handoff is not None
assert handoff.intent == "park"
assert handoff.engineer_notes == "Waiting for client to provide logs"
assert handoff.snapshot is not None
# Session should be paused
await test_db.refresh(session)
assert session.status == "paused"
assert session.handoff_count == 1
@pytest.mark.asyncio
async def test_create_escalate_handoff(client: AsyncClient, test_user, auth_headers, test_db):
"""Escalating creates handoff + dual-writes to escalation_package."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[],
)
test_db.add(session)
await test_db.flush()
from app.services.handoff_manager import HandoffManager
manager = HandoffManager(test_db)
handoff = await manager.create_handoff(
session_id=session.id,
intent="escalate",
engineer_notes="Need senior help",
user_id=test_user["user_data"]["id"],
)
assert handoff.intent == "escalate"
# Dual-write check
await test_db.refresh(session)
assert session.status == "escalated"
assert session.escalation_package is not None
assert "branch_map" in session.escalation_package or "snapshot" in session.escalation_package
@pytest.mark.asyncio
async def test_claim_session(client: AsyncClient, test_user, test_admin, auth_headers, test_db):
"""Claiming a handoff sets claimed_by and reactivates session."""
session = AISession(
user_id=test_user["user_data"]["id"],
account_id=test_user["user_data"]["account_id"],
session_type="guided",
intake_type="free_text",
intake_content={"text": "test"},
status="active",
confidence_tier="discovery",
conversation_messages=[],
)
test_db.add(session)
await test_db.flush()
from app.services.handoff_manager import HandoffManager
manager = HandoffManager(test_db)
handoff = await manager.create_handoff(
session_id=session.id,
intent="escalate",
engineer_notes="Need help",
user_id=test_user["user_data"]["id"],
)
claimed = await manager.claim_session(
handoff_id=handoff.id,
claiming_user_id=test_admin["user_data"]["id"],
)
assert claimed.claimed_by == test_admin["user_data"]["id"]
assert claimed.claimed_at is not None
await test_db.refresh(session)
assert session.status == "active"