106 lines
4.0 KiB
Python
106 lines
4.0 KiB
Python
"""In-memory pub/sub bus for live escalation events.
|
||
|
||
Single-process, non-durable. When a handoff fires, every connected SSE
|
||
subscriber for the same `account_id` receives the event. Subscribers come
|
||
and go as senior techs open and close the EscalationQueue page.
|
||
|
||
Pre-PMF scale (3 pilots × 5-20 techs/MSP = ~15-60 concurrent subscribers
|
||
total, single Railway replica) makes in-memory the right call. When the
|
||
deployment scales horizontally, swap this for Redis pub/sub or similar —
|
||
the public surface (`publish` / `subscribe`) is intentionally narrow so
|
||
the swap is local.
|
||
|
||
Events are JSON-serializable dicts. `publish()` is non-blocking (drops the
|
||
event if a subscriber's queue is full rather than back-pressuring the
|
||
caller). `subscribe()` MUST be paired with `unsubscribe()` in a finally
|
||
block, or you leak queues.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
from typing import Any
|
||
from uuid import UUID
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# Bound how many unconsumed events can sit in a subscriber's queue before
|
||
# we start dropping. 64 is generous for the queue-page use case; if a
|
||
# subscriber is that far behind, they're probably gone or stuck.
|
||
_QUEUE_MAXSIZE = 64
|
||
|
||
|
||
class EscalationBus:
|
||
"""Account-scoped pub/sub for escalation arrival events."""
|
||
|
||
def __init__(self) -> None:
|
||
self._subscribers: dict[UUID, set[asyncio.Queue[dict[str, Any]]]] = {}
|
||
self._lock = asyncio.Lock()
|
||
|
||
@staticmethod
|
||
def _normalize_account_id(account_id: UUID | str) -> UUID:
|
||
return account_id if isinstance(account_id, UUID) else UUID(str(account_id))
|
||
|
||
async def subscribe(self, account_id: UUID | str) -> asyncio.Queue[dict[str, Any]]:
|
||
"""Register a new subscriber queue for an account.
|
||
|
||
Caller must invoke `unsubscribe(account_id, queue)` when the
|
||
consumer disconnects.
|
||
"""
|
||
normalized_account_id = self._normalize_account_id(account_id)
|
||
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(
|
||
maxsize=_QUEUE_MAXSIZE
|
||
)
|
||
async with self._lock:
|
||
self._subscribers.setdefault(normalized_account_id, set()).add(queue)
|
||
return queue
|
||
|
||
async def unsubscribe(
|
||
self, account_id: UUID | str, queue: asyncio.Queue[dict[str, Any]]
|
||
) -> None:
|
||
normalized_account_id = self._normalize_account_id(account_id)
|
||
async with self._lock:
|
||
subs = self._subscribers.get(normalized_account_id)
|
||
if subs is None:
|
||
return
|
||
subs.discard(queue)
|
||
if not subs:
|
||
self._subscribers.pop(normalized_account_id, None)
|
||
|
||
async def publish(self, account_id: UUID | str, event: dict[str, Any]) -> int:
|
||
"""Fan event out to every subscriber for `account_id`.
|
||
|
||
Returns the number of subscribers that successfully received the
|
||
event. Drops the event for any subscriber whose queue is full
|
||
(logs at warning level).
|
||
"""
|
||
normalized_account_id = self._normalize_account_id(account_id)
|
||
async with self._lock:
|
||
subs = list(self._subscribers.get(normalized_account_id, ()))
|
||
if not subs:
|
||
return 0
|
||
delivered = 0
|
||
for queue in subs:
|
||
try:
|
||
queue.put_nowait(event)
|
||
delivered += 1
|
||
except asyncio.QueueFull:
|
||
logger.warning(
|
||
"EscalationBus: dropped event for full subscriber queue "
|
||
"(account_id=%s, event=%s)",
|
||
normalized_account_id,
|
||
event.get("type", "?"),
|
||
)
|
||
return delivered
|
||
|
||
def subscriber_count(self, account_id: UUID | str) -> int:
|
||
"""Diagnostic — number of active subscribers for an account."""
|
||
normalized_account_id = self._normalize_account_id(account_id)
|
||
return len(self._subscribers.get(normalized_account_id, ()))
|
||
|
||
|
||
# Module-level singleton. FastAPI imports this; `subscribe()` and `publish()`
|
||
# are coroutine-safe via the internal Lock.
|
||
bus = EscalationBus()
|