diff --git a/backend/app/api/endpoints/session_handoffs.py b/backend/app/api/endpoints/session_handoffs.py index 5e444bd2..5b62a3c5 100644 --- a/backend/app/api/endpoints/session_handoffs.py +++ b/backend/app/api/endpoints/session_handoffs.py @@ -1,19 +1,24 @@ """Handoff endpoints — unified park/escalate. - POST /ai-sessions/{id}/handoff — Create handoff + POST /ai-sessions/{id}/handoff — Create handoff GET /ai-sessions/{id}/handoffs — Handoff history POST /ai-sessions/{id}/handoffs/{hid}/claim — Claim session - GET /ai-sessions/queue — Team queue + GET /ai-sessions/queue — Team queue + GET /ai-sessions/escalations/stream — SSE: live escalation arrivals """ +import asyncio +import json import logging -from typing import Annotated +from typing import Annotated, AsyncGenerator from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin +from app.core.escalation_bus import bus as escalation_bus from app.models.user import User from app.models.ai_session import AISession from app.models.session_handoff import SessionHandoff @@ -127,3 +132,80 @@ async def get_queue( team_id=current_user.team_id, account_id=current_user.account_id, ) + + +# ─── Live escalation arrivals (SSE) ────────────────────────────────────────── +# +# Streams `handoff_created` events to subscribers in the same account_id as +# the new handoff. Connected EscalationQueue instances prepend the new card +# with the locked 200ms slide-in. Account-scoped: cross-tenant leakage is +# prevented at the bus.publish boundary (only handoff.account_id subscribers +# are notified) and re-enforced here by binding the subscription to +# current_user.account_id. +# +# Heartbeat: a `: keepalive\n\n` SSE comment every 25s keeps the connection +# alive through Railway / nginx default 60s idle timeouts. Reconnect policy +# is on the client (browser EventSource auto-reconnects; our fetch-based +# reader retries with backoff). + + +_HEARTBEAT_INTERVAL_S = 25 +_QUEUE_GET_TIMEOUT_S = 25 # < heartbeat so heartbeat fires reliably + + +@queue_router.get("/escalations/stream") +async def stream_escalations( + request: Request, + current_user: Annotated[User, Depends(require_engineer_or_admin)], +): + """SSE stream of new escalation arrivals for the current user's account. + + Role-gated to engineer/admin/owner so viewers can't subscribe (matches + the queue + claim role surface). One open connection per browser tab is + expected; the bus handles fan-out. + """ + if not current_user.account_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="No account" + ) + + account_id = current_user.account_id + + async def event_generator() -> AsyncGenerator[str, None]: + queue = await escalation_bus.subscribe(account_id) + try: + # Initial hello so the client knows the stream is live. + yield ( + "event: ready\n" + f"data: {json.dumps({'account_id': str(account_id)})}\n\n" + ) + + while True: + if await request.is_disconnected(): + break + try: + event = await asyncio.wait_for( + queue.get(), timeout=_QUEUE_GET_TIMEOUT_S + ) + except asyncio.TimeoutError: + # Heartbeat keeps the connection alive through proxies. + yield ": keepalive\n\n" + continue + + event_type = event.get("type", "message") + yield ( + f"event: {event_type}\n" + f"data: {json.dumps(event)}\n\n" + ) + finally: + await escalation_bus.unsubscribe(account_id, queue) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/backend/app/core/escalation_bus.py b/backend/app/core/escalation_bus.py new file mode 100644 index 00000000..bf623950 --- /dev/null +++ b/backend/app/core/escalation_bus.py @@ -0,0 +1,97 @@ +"""In-memory pub/sub bus for live escalation events. + +Single-process, non-durable. When a handoff fires, every connected SSE +subscriber for the same `account_id` receives the event. Subscribers come +and go as senior techs open and close the EscalationQueue page. + +Pre-PMF scale (3 pilots × 5-20 techs/MSP = ~15-60 concurrent subscribers +total, single Railway replica) makes in-memory the right call. When the +deployment scales horizontally, swap this for Redis pub/sub or similar — +the public surface (`publish` / `subscribe`) is intentionally narrow so +the swap is local. + +Events are JSON-serializable dicts. `publish()` is non-blocking (drops the +event if a subscriber's queue is full rather than back-pressuring the +caller). `subscribe()` MUST be paired with `unsubscribe()` in a finally +block, or you leak queues. +""" +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from uuid import UUID + +logger = logging.getLogger(__name__) + + +# Bound how many unconsumed events can sit in a subscriber's queue before +# we start dropping. 64 is generous for the queue-page use case; if a +# subscriber is that far behind, they're probably gone or stuck. +_QUEUE_MAXSIZE = 64 + + +class EscalationBus: + """Account-scoped pub/sub for escalation arrival events.""" + + def __init__(self) -> None: + self._subscribers: dict[UUID, set[asyncio.Queue[dict[str, Any]]]] = {} + self._lock = asyncio.Lock() + + async def subscribe(self, account_id: UUID) -> asyncio.Queue[dict[str, Any]]: + """Register a new subscriber queue for an account. + + Caller must invoke `unsubscribe(account_id, queue)` when the + consumer disconnects. + """ + queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue( + maxsize=_QUEUE_MAXSIZE + ) + async with self._lock: + self._subscribers.setdefault(account_id, set()).add(queue) + return queue + + async def unsubscribe( + self, account_id: UUID, queue: asyncio.Queue[dict[str, Any]] + ) -> None: + async with self._lock: + subs = self._subscribers.get(account_id) + if subs is None: + return + subs.discard(queue) + if not subs: + self._subscribers.pop(account_id, None) + + async def publish(self, account_id: UUID, event: dict[str, Any]) -> int: + """Fan event out to every subscriber for `account_id`. + + Returns the number of subscribers that successfully received the + event. Drops the event for any subscriber whose queue is full + (logs at warning level). + """ + async with self._lock: + subs = list(self._subscribers.get(account_id, ())) + if not subs: + return 0 + delivered = 0 + for queue in subs: + try: + queue.put_nowait(event) + delivered += 1 + except asyncio.QueueFull: + logger.warning( + "EscalationBus: dropped event for full subscriber queue " + "(account_id=%s, event=%s)", + account_id, + event.get("type", "?"), + ) + return delivered + + def subscriber_count(self, account_id: UUID) -> int: + """Diagnostic — number of active subscribers for an account.""" + return len(self._subscribers.get(account_id, ())) + + +# Module-level singleton. FastAPI imports this; `subscribe()` and `publish()` +# are coroutine-safe via the internal Lock. +bus = EscalationBus() diff --git a/backend/app/services/handoff_manager.py b/backend/app/services/handoff_manager.py index fedc8a74..bc3717f9 100644 --- a/backend/app/services/handoff_manager.py +++ b/backend/app/services/handoff_manager.py @@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.email import EmailService +from app.core.escalation_bus import bus as escalation_bus from app.models.ai_session import AISession from app.models.session_branch import SessionBranch from app.models.session_handoff import SessionHandoff @@ -114,6 +115,29 @@ class HandoffManager: if handoff.intent != "escalate": return 0 + # Publish to the in-memory bus first so connected senior-tech inboxes + # see the new card slide in within ~1s of escalate. This path is + # fire-and-forget (no IO, just memory) so it can sit ahead of the + # email fan-out. + try: + await escalation_bus.publish( + handoff.account_id, + { + "type": "handoff_created", + "handoff_id": str(handoff.id), + "session_id": str(handoff.session_id), + "priority": handoff.priority, + "engineer_notes": handoff.engineer_notes or "", + "created_at": handoff.created_at.isoformat() + if handoff.created_at + else None, + }, + ) + except Exception: + logger.exception( + "EscalationBus publish failed for handoff %s", handoff.id + ) + try: recipients = ( await self.db.execute( diff --git a/backend/tests/test_escalation_bus.py b/backend/tests/test_escalation_bus.py new file mode 100644 index 00000000..50d10f3c --- /dev/null +++ b/backend/tests/test_escalation_bus.py @@ -0,0 +1,106 @@ +"""Unit tests for the in-memory escalation pub/sub bus.""" +import asyncio +from uuid import uuid4 + +import pytest + +from app.core.escalation_bus import EscalationBus + + +@pytest.mark.asyncio +async def test_publish_with_no_subscribers_returns_zero(): + bus = EscalationBus() + delivered = await bus.publish(uuid4(), {"type": "handoff_created"}) + assert delivered == 0 + + +@pytest.mark.asyncio +async def test_subscribe_then_publish_delivers_event(): + bus = EscalationBus() + account = uuid4() + queue = await bus.subscribe(account) + try: + delivered = await bus.publish(account, {"type": "handoff_created", "id": "x"}) + assert delivered == 1 + event = await asyncio.wait_for(queue.get(), timeout=1.0) + assert event == {"type": "handoff_created", "id": "x"} + finally: + await bus.unsubscribe(account, queue) + + +@pytest.mark.asyncio +async def test_two_subscribers_same_account_both_receive(): + bus = EscalationBus() + account = uuid4() + q1 = await bus.subscribe(account) + q2 = await bus.subscribe(account) + try: + delivered = await bus.publish(account, {"type": "x"}) + assert delivered == 2 + e1 = await asyncio.wait_for(q1.get(), timeout=1.0) + e2 = await asyncio.wait_for(q2.get(), timeout=1.0) + assert e1 == e2 == {"type": "x"} + finally: + await bus.unsubscribe(account, q1) + await bus.unsubscribe(account, q2) + + +@pytest.mark.asyncio +async def test_subscriber_in_other_account_does_not_receive(): + """Cross-tenant isolation is the whole point — sanity check it directly.""" + bus = EscalationBus() + account_a = uuid4() + account_b = uuid4() + q_a = await bus.subscribe(account_a) + q_b = await bus.subscribe(account_b) + try: + delivered = await bus.publish(account_a, {"type": "x"}) + assert delivered == 1 + + e_a = await asyncio.wait_for(q_a.get(), timeout=1.0) + assert e_a == {"type": "x"} + + # B's queue must remain empty. + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(q_b.get(), timeout=0.1) + finally: + await bus.unsubscribe(account_a, q_a) + await bus.unsubscribe(account_b, q_b) + + +@pytest.mark.asyncio +async def test_unsubscribe_drops_subscriber_count_to_zero(): + bus = EscalationBus() + account = uuid4() + q = await bus.subscribe(account) + assert bus.subscriber_count(account) == 1 + await bus.unsubscribe(account, q) + assert bus.subscriber_count(account) == 0 + + +@pytest.mark.asyncio +async def test_publish_drops_events_when_subscriber_queue_is_full(): + """A stuck subscriber must not back-pressure publishers.""" + bus = EscalationBus() + account = uuid4() + queue = await bus.subscribe(account) + try: + # Stuff the queue past capacity (maxsize is 64) without consuming. + for _ in range(65): + await bus.publish(account, {"type": "x"}) + # Sanity: queue holds at most maxsize. + assert queue.qsize() <= 64 + # Publishes after capacity didn't raise — they were dropped silently. + finally: + await bus.unsubscribe(account, queue) + + +@pytest.mark.asyncio +async def test_unsubscribe_unknown_queue_is_noop(): + """Defensive: unsubscribe on an account/queue that isn't registered + should not raise — finally blocks rely on this.""" + bus = EscalationBus() + account = uuid4() + fake_queue: asyncio.Queue = asyncio.Queue() + # Should not raise. + await bus.unsubscribe(account, fake_queue) diff --git a/backend/tests/test_handoff_manager.py b/backend/tests/test_handoff_manager.py index fc4644be..3a2836a5 100644 --- a/backend/tests/test_handoff_manager.py +++ b/backend/tests/test_handoff_manager.py @@ -278,6 +278,58 @@ async def test_dispatch_graceful_degradation_when_email_raises( assert sent == 0 +@pytest.mark.asyncio +async def test_dispatch_publishes_to_escalation_bus( + client: AsyncClient, test_user, auth_headers, test_db +): + """dispatch_escalation_notifications puts an event on the in-memory bus + so connected SSE subscribers see live arrivals.""" + from app.core.escalation_bus import bus as escalation_bus + + session = AISession( + user_id=test_user["user_data"]["id"], + account_id=test_user["user_data"]["account_id"], + session_type="guided", + intake_type="free_text", + intake_content={"text": "x"}, + problem_summary="VPN down", + status="active", + confidence_tier="discovery", + conversation_messages=[], + ) + test_db.add(session) + await test_db.commit() + + manager = HandoffManager(test_db) + handoff = await manager.create_handoff( + session_id=session.id, + intent="escalate", + engineer_notes="please help", + user_id=test_user["user_data"]["id"], + ) + await test_db.commit() + + from uuid import UUID as PyUUID + account_id = PyUUID(test_user["user_data"]["account_id"]) + + queue = await escalation_bus.subscribe(account_id) + try: + with patch( + "app.services.handoff_manager.EmailService.send_notification_email", + new=AsyncMock(return_value=True), + ): + await manager.dispatch_escalation_notifications(handoff) + + import asyncio + event = await asyncio.wait_for(queue.get(), timeout=1.0) + assert event["type"] == "handoff_created" + assert event["handoff_id"] == str(handoff.id) + assert event["session_id"] == str(session.id) + assert event["priority"] == "normal" + finally: + await escalation_bus.unsubscribe(account_id, queue) + + @pytest.mark.asyncio async def test_create_handoff_endpoint_dispatches_on_escalate( client: AsyncClient, test_user, auth_headers, test_db diff --git a/backend/tests/test_session_handoffs_api.py b/backend/tests/test_session_handoffs_api.py index 6edaac1e..6ddc307c 100644 --- a/backend/tests/test_session_handoffs_api.py +++ b/backend/tests/test_session_handoffs_api.py @@ -113,6 +113,49 @@ async def test_claim_blocked_for_viewer_role( assert "engineer" in claim_resp.json()["detail"].lower() +@pytest.mark.asyncio +async def test_escalations_stream_blocked_for_viewer( + client: AsyncClient, test_user, auth_headers, test_db +): + """SSE stream is role-gated to engineer-or-admin (matches queue/claim).""" + user_id = PyUUID(test_user["user_data"]["id"]) + user = ( + await test_db.execute(select(User).where(User.id == user_id)) + ).scalar_one() + user.account_role = "viewer" + await test_db.commit() + + resp = await client.get( + "/api/v1/ai-sessions/escalations/stream", headers=auth_headers + ) + assert resp.status_code == 403 + + +@pytest.mark.asyncio +async def test_escalations_stream_returns_sse_content_type( + client: AsyncClient, test_user, auth_headers, test_db +): + """Engineer/owner can open the SSE stream and gets text/event-stream + plus an initial `ready` event. Read just enough bytes to confirm the + handshake — the full pub/sub flow is covered by the bus + dispatcher + tests separately.""" + async with client.stream( + "GET", + "/api/v1/ai-sessions/escalations/stream", + headers=auth_headers, + ) as resp: + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + # First chunk must contain the ready event. + first = b"" + async for chunk in resp.aiter_bytes(): + first += chunk + if b"event: ready" in first and b"\n\n" in first: + break + assert b"event: ready" in first + assert b'"account_id"' in first + + @pytest.mark.asyncio async def test_claim_allowed_for_engineer_role( client: AsyncClient, test_user, auth_headers, test_db