fix(tests): stabilize escalation SSE backend tests

Co-Authored-By: Codex <noreply@openai.com>
This commit is contained in:
2026-04-27 19:47:43 -04:00
parent ba46fc5644
commit bc15952857
7 changed files with 153 additions and 73 deletions

View File

@@ -156,7 +156,10 @@ _QUEUE_GET_TIMEOUT_S = 25 # < heartbeat so heartbeat fires reliably
@queue_router.get("/escalations/stream")
async def stream_escalations(
request: Request,
current_user: Annotated[User, Depends(require_engineer_or_admin)],
current_user: Annotated[
User,
Depends(require_engineer_or_admin, scope="function"),
],
):
"""SSE stream of new escalation arrivals for the current user's account.

View File

@@ -38,39 +38,46 @@ class EscalationBus:
self._subscribers: dict[UUID, set[asyncio.Queue[dict[str, Any]]]] = {}
self._lock = asyncio.Lock()
async def subscribe(self, account_id: UUID) -> asyncio.Queue[dict[str, Any]]:
@staticmethod
def _normalize_account_id(account_id: UUID | str) -> UUID:
return account_id if isinstance(account_id, UUID) else UUID(str(account_id))
async def subscribe(self, account_id: UUID | str) -> asyncio.Queue[dict[str, Any]]:
"""Register a new subscriber queue for an account.
Caller must invoke `unsubscribe(account_id, queue)` when the
consumer disconnects.
"""
normalized_account_id = self._normalize_account_id(account_id)
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(
maxsize=_QUEUE_MAXSIZE
)
async with self._lock:
self._subscribers.setdefault(account_id, set()).add(queue)
self._subscribers.setdefault(normalized_account_id, set()).add(queue)
return queue
async def unsubscribe(
self, account_id: UUID, queue: asyncio.Queue[dict[str, Any]]
self, account_id: UUID | str, queue: asyncio.Queue[dict[str, Any]]
) -> None:
normalized_account_id = self._normalize_account_id(account_id)
async with self._lock:
subs = self._subscribers.get(account_id)
subs = self._subscribers.get(normalized_account_id)
if subs is None:
return
subs.discard(queue)
if not subs:
self._subscribers.pop(account_id, None)
self._subscribers.pop(normalized_account_id, None)
async def publish(self, account_id: UUID, event: dict[str, Any]) -> int:
async def publish(self, account_id: UUID | str, event: dict[str, Any]) -> int:
"""Fan event out to every subscriber for `account_id`.
Returns the number of subscribers that successfully received the
event. Drops the event for any subscriber whose queue is full
(logs at warning level).
"""
normalized_account_id = self._normalize_account_id(account_id)
async with self._lock:
subs = list(self._subscribers.get(account_id, ()))
subs = list(self._subscribers.get(normalized_account_id, ()))
if not subs:
return 0
delivered = 0
@@ -82,14 +89,15 @@ class EscalationBus:
logger.warning(
"EscalationBus: dropped event for full subscriber queue "
"(account_id=%s, event=%s)",
account_id,
normalized_account_id,
event.get("type", "?"),
)
return delivered
def subscriber_count(self, account_id: UUID) -> int:
def subscriber_count(self, account_id: UUID | str) -> int:
"""Diagnostic — number of active subscribers for an account."""
return len(self._subscribers.get(account_id, ()))
normalized_account_id = self._normalize_account_id(account_id)
return len(self._subscribers.get(normalized_account_id, ()))
# Module-level singleton. FastAPI imports this; `subscribe()` and `publish()`

View File

@@ -68,6 +68,21 @@ async def test_subscriber_in_other_account_does_not_receive():
await bus.unsubscribe(account_b, q_b)
@pytest.mark.asyncio
async def test_publish_normalizes_string_uuid_account_id():
"""ORM-created objects can briefly carry string UUIDs in-memory."""
bus = EscalationBus()
account = uuid4()
queue = await bus.subscribe(account)
try:
delivered = await bus.publish(str(account), {"type": "x"})
assert delivered == 1
event = await asyncio.wait_for(queue.get(), timeout=1.0)
assert event == {"type": "x"}
finally:
await bus.unsubscribe(str(account), queue)
@pytest.mark.asyncio
async def test_unsubscribe_drops_subscriber_count_to_zero():
bus = EscalationBus()

View File

@@ -9,6 +9,26 @@ from app.models.user import User
from app.services.handoff_manager import HandoffManager
@pytest.fixture(autouse=True)
def stub_ai_assessment():
"""Keep handoff tests focused on handoff behavior, not external AI calls."""
with patch.object(
HandoffManager,
"_generate_ai_assessment",
new=AsyncMock(
return_value=(
"Stub escalation assessment",
{
"likely_cause": "Stub",
"suggested_steps": [],
"confidence": "medium",
},
)
),
):
yield
@pytest.mark.asyncio
async def test_create_park_handoff(client: AsyncClient, test_user, auth_headers, test_db):
"""Parking a session creates a handoff with snapshot."""

View File

@@ -1,12 +1,41 @@
"""API endpoint tests for session handoffs."""
from unittest.mock import AsyncMock, patch
from uuid import UUID as PyUUID
import pytest
from httpx import AsyncClient
from sqlalchemy import select
from app.api.endpoints.session_handoffs import stream_escalations
from app.core.escalation_bus import bus as escalation_bus
from app.models.ai_session import AISession
from app.models.user import User
from app.services.handoff_manager import HandoffManager
class _ConnectedRequest:
async def is_disconnected(self) -> bool:
return False
@pytest.fixture(autouse=True)
def stub_ai_assessment():
"""Endpoint tests should not wait on the external AI assessment path."""
with patch.object(
HandoffManager,
"_generate_ai_assessment",
new=AsyncMock(
return_value=(
"Stub escalation assessment",
{
"likely_cause": "Stub",
"suggested_steps": [],
"confidence": "medium",
},
)
),
):
yield
@pytest.mark.asyncio
@@ -137,23 +166,30 @@ async def test_escalations_stream_returns_sse_content_type(
):
"""Engineer/owner can open the SSE stream and gets text/event-stream
plus an initial `ready` event. Read just enough bytes to confirm the
handshake — the full pub/sub flow is covered by the bus + dispatcher
tests separately."""
async with client.stream(
"GET",
"/api/v1/ai-sessions/escalations/stream",
headers=auth_headers,
) as resp:
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
# First chunk must contain the ready event.
first = b""
async for chunk in resp.aiter_bytes():
first += chunk
if b"event: ready" in first and b"\n\n" in first:
break
assert b"event: ready" in first
assert b'"account_id"' in first
handshake — the full pub/sub flow is covered by the bus + dispatcher tests
separately.
Do not use `client.stream()` here: HTTPX's ASGITransport buffers the whole
response body before returning, which hangs forever for an infinite SSE
stream.
"""
user_id = PyUUID(test_user["user_data"]["id"])
user = (
await test_db.execute(select(User).where(User.id == user_id))
).scalar_one()
resp = await stream_escalations(_ConnectedRequest(), current_user=user)
assert resp.media_type == "text/event-stream"
body_iterator = resp.body_iterator
try:
first = await anext(body_iterator)
finally:
await body_iterator.aclose()
assert "event: ready" in first
assert '"account_id"' in first
assert escalation_bus.subscriber_count(user.account_id) == 0
@pytest.mark.asyncio