fix(tests): stabilize escalation SSE backend tests
Co-Authored-By: Codex <noreply@openai.com>
This commit is contained in:
@@ -38,39 +38,46 @@ class EscalationBus:
|
||||
self._subscribers: dict[UUID, set[asyncio.Queue[dict[str, Any]]]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def subscribe(self, account_id: UUID) -> asyncio.Queue[dict[str, Any]]:
|
||||
@staticmethod
|
||||
def _normalize_account_id(account_id: UUID | str) -> UUID:
|
||||
return account_id if isinstance(account_id, UUID) else UUID(str(account_id))
|
||||
|
||||
async def subscribe(self, account_id: UUID | str) -> asyncio.Queue[dict[str, Any]]:
|
||||
"""Register a new subscriber queue for an account.
|
||||
|
||||
Caller must invoke `unsubscribe(account_id, queue)` when the
|
||||
consumer disconnects.
|
||||
"""
|
||||
normalized_account_id = self._normalize_account_id(account_id)
|
||||
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(
|
||||
maxsize=_QUEUE_MAXSIZE
|
||||
)
|
||||
async with self._lock:
|
||||
self._subscribers.setdefault(account_id, set()).add(queue)
|
||||
self._subscribers.setdefault(normalized_account_id, set()).add(queue)
|
||||
return queue
|
||||
|
||||
async def unsubscribe(
|
||||
self, account_id: UUID, queue: asyncio.Queue[dict[str, Any]]
|
||||
self, account_id: UUID | str, queue: asyncio.Queue[dict[str, Any]]
|
||||
) -> None:
|
||||
normalized_account_id = self._normalize_account_id(account_id)
|
||||
async with self._lock:
|
||||
subs = self._subscribers.get(account_id)
|
||||
subs = self._subscribers.get(normalized_account_id)
|
||||
if subs is None:
|
||||
return
|
||||
subs.discard(queue)
|
||||
if not subs:
|
||||
self._subscribers.pop(account_id, None)
|
||||
self._subscribers.pop(normalized_account_id, None)
|
||||
|
||||
async def publish(self, account_id: UUID, event: dict[str, Any]) -> int:
|
||||
async def publish(self, account_id: UUID | str, event: dict[str, Any]) -> int:
|
||||
"""Fan event out to every subscriber for `account_id`.
|
||||
|
||||
Returns the number of subscribers that successfully received the
|
||||
event. Drops the event for any subscriber whose queue is full
|
||||
(logs at warning level).
|
||||
"""
|
||||
normalized_account_id = self._normalize_account_id(account_id)
|
||||
async with self._lock:
|
||||
subs = list(self._subscribers.get(account_id, ()))
|
||||
subs = list(self._subscribers.get(normalized_account_id, ()))
|
||||
if not subs:
|
||||
return 0
|
||||
delivered = 0
|
||||
@@ -82,14 +89,15 @@ class EscalationBus:
|
||||
logger.warning(
|
||||
"EscalationBus: dropped event for full subscriber queue "
|
||||
"(account_id=%s, event=%s)",
|
||||
account_id,
|
||||
normalized_account_id,
|
||||
event.get("type", "?"),
|
||||
)
|
||||
return delivered
|
||||
|
||||
def subscriber_count(self, account_id: UUID) -> int:
|
||||
def subscriber_count(self, account_id: UUID | str) -> int:
|
||||
"""Diagnostic — number of active subscribers for an account."""
|
||||
return len(self._subscribers.get(account_id, ()))
|
||||
normalized_account_id = self._normalize_account_id(account_id)
|
||||
return len(self._subscribers.get(normalized_account_id, ()))
|
||||
|
||||
|
||||
# Module-level singleton. FastAPI imports this; `subscribe()` and `publish()`
|
||||
|
||||
Reference in New Issue
Block a user