"""In-memory pub/sub bus for live escalation events. Single-process, non-durable. When a handoff fires, every connected SSE subscriber for the same `account_id` receives the event. Subscribers come and go as senior techs open and close the EscalationQueue page. Pre-PMF scale (3 pilots × 5-20 techs/MSP = ~15-60 concurrent subscribers total, single Railway replica) makes in-memory the right call. When the deployment scales horizontally, swap this for Redis pub/sub or similar — the public surface (`publish` / `subscribe`) is intentionally narrow so the swap is local. Events are JSON-serializable dicts. `publish()` is non-blocking (drops the event if a subscriber's queue is full rather than back-pressuring the caller). `subscribe()` MUST be paired with `unsubscribe()` in a finally block, or you leak queues. """ from __future__ import annotations import asyncio import logging from typing import Any from uuid import UUID logger = logging.getLogger(__name__) # Bound how many unconsumed events can sit in a subscriber's queue before # we start dropping. 64 is generous for the queue-page use case; if a # subscriber is that far behind, they're probably gone or stuck. _QUEUE_MAXSIZE = 64 class EscalationBus: """Account-scoped pub/sub for escalation arrival events.""" def __init__(self) -> None: self._subscribers: dict[UUID, set[asyncio.Queue[dict[str, Any]]]] = {} self._lock = asyncio.Lock() @staticmethod def _normalize_account_id(account_id: UUID | str) -> UUID: return account_id if isinstance(account_id, UUID) else UUID(str(account_id)) async def subscribe(self, account_id: UUID | str) -> asyncio.Queue[dict[str, Any]]: """Register a new subscriber queue for an account. Caller must invoke `unsubscribe(account_id, queue)` when the consumer disconnects. """ normalized_account_id = self._normalize_account_id(account_id) queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue( maxsize=_QUEUE_MAXSIZE ) async with self._lock: self._subscribers.setdefault(normalized_account_id, set()).add(queue) return queue async def unsubscribe( self, account_id: UUID | str, queue: asyncio.Queue[dict[str, Any]] ) -> None: normalized_account_id = self._normalize_account_id(account_id) async with self._lock: subs = self._subscribers.get(normalized_account_id) if subs is None: return subs.discard(queue) if not subs: self._subscribers.pop(normalized_account_id, None) async def publish(self, account_id: UUID | str, event: dict[str, Any]) -> int: """Fan event out to every subscriber for `account_id`. Returns the number of subscribers that successfully received the event. Drops the event for any subscriber whose queue is full (logs at warning level). """ normalized_account_id = self._normalize_account_id(account_id) async with self._lock: subs = list(self._subscribers.get(normalized_account_id, ())) if not subs: return 0 delivered = 0 for queue in subs: try: queue.put_nowait(event) delivered += 1 except asyncio.QueueFull: logger.warning( "EscalationBus: dropped event for full subscriber queue " "(account_id=%s, event=%s)", normalized_account_id, event.get("type", "?"), ) return delivered def subscriber_count(self, account_id: UUID | str) -> int: """Diagnostic — number of active subscribers for an account.""" normalized_account_id = self._normalize_account_id(account_id) return len(self._subscribers.get(normalized_account_id, ())) # Module-level singleton. FastAPI imports this; `subscribe()` and `publish()` # are coroutine-safe via the internal Lock. bus = EscalationBus()