Files
resolutionflow/backend/app/core/escalation_bus.py

106 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""In-memory pub/sub bus for live escalation events.
Single-process, non-durable. When a handoff fires, every connected SSE
subscriber for the same `account_id` receives the event. Subscribers come
and go as senior techs open and close the EscalationQueue page.
Pre-PMF scale (3 pilots × 5-20 techs/MSP = ~15-60 concurrent subscribers
total, single Railway replica) makes in-memory the right call. When the
deployment scales horizontally, swap this for Redis pub/sub or similar —
the public surface (`publish` / `subscribe`) is intentionally narrow so
the swap is local.
Events are JSON-serializable dicts. `publish()` is non-blocking (drops the
event if a subscriber's queue is full rather than back-pressuring the
caller). `subscribe()` MUST be paired with `unsubscribe()` in a finally
block, or you leak queues.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from uuid import UUID
logger = logging.getLogger(__name__)
# Bound how many unconsumed events can sit in a subscriber's queue before
# we start dropping. 64 is generous for the queue-page use case; if a
# subscriber is that far behind, they're probably gone or stuck.
_QUEUE_MAXSIZE = 64
class EscalationBus:
"""Account-scoped pub/sub for escalation arrival events."""
def __init__(self) -> None:
self._subscribers: dict[UUID, set[asyncio.Queue[dict[str, Any]]]] = {}
self._lock = asyncio.Lock()
@staticmethod
def _normalize_account_id(account_id: UUID | str) -> UUID:
return account_id if isinstance(account_id, UUID) else UUID(str(account_id))
async def subscribe(self, account_id: UUID | str) -> asyncio.Queue[dict[str, Any]]:
"""Register a new subscriber queue for an account.
Caller must invoke `unsubscribe(account_id, queue)` when the
consumer disconnects.
"""
normalized_account_id = self._normalize_account_id(account_id)
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(
maxsize=_QUEUE_MAXSIZE
)
async with self._lock:
self._subscribers.setdefault(normalized_account_id, set()).add(queue)
return queue
async def unsubscribe(
self, account_id: UUID | str, queue: asyncio.Queue[dict[str, Any]]
) -> None:
normalized_account_id = self._normalize_account_id(account_id)
async with self._lock:
subs = self._subscribers.get(normalized_account_id)
if subs is None:
return
subs.discard(queue)
if not subs:
self._subscribers.pop(normalized_account_id, None)
async def publish(self, account_id: UUID | str, event: dict[str, Any]) -> int:
"""Fan event out to every subscriber for `account_id`.
Returns the number of subscribers that successfully received the
event. Drops the event for any subscriber whose queue is full
(logs at warning level).
"""
normalized_account_id = self._normalize_account_id(account_id)
async with self._lock:
subs = list(self._subscribers.get(normalized_account_id, ()))
if not subs:
return 0
delivered = 0
for queue in subs:
try:
queue.put_nowait(event)
delivered += 1
except asyncio.QueueFull:
logger.warning(
"EscalationBus: dropped event for full subscriber queue "
"(account_id=%s, event=%s)",
normalized_account_id,
event.get("type", "?"),
)
return delivered
def subscriber_count(self, account_id: UUID | str) -> int:
"""Diagnostic — number of active subscribers for an account."""
normalized_account_id = self._normalize_account_id(account_id)
return len(self._subscribers.get(normalized_account_id, ()))
# Module-level singleton. FastAPI imports this; `subscribe()` and `publish()`
# are coroutine-safe via the internal Lock.
bus = EscalationBus()