feat: add generate_text_stream to AnthropicProvider for SSE support

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-03-28 23:02:35 +00:00
parent ca686c0901
commit 2f3781bfc2

View File

@@ -7,6 +7,7 @@ backends for JSON generation used by the AI Flow Builder.
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from app.core.config import settings from app.core.config import settings
@@ -54,6 +55,26 @@ class AIProvider(ABC):
""" """
... ...
async def generate_text_stream(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> "AsyncIterator[str]":
"""Stream a text response token by token.
Args:
system_prompt: System-level instruction for the model.
messages: List of message dicts with "role" and "content" keys.
max_tokens: Maximum output tokens.
Yields:
Text chunks as they are generated.
"""
raise NotImplementedError("Streaming not supported for this provider")
# Make this an async generator to satisfy type checker
yield "" # pragma: no cover
class GeminiProvider(AIProvider): class GeminiProvider(AIProvider):
"""Google Gemini provider using the google-genai SDK.""" """Google Gemini provider using the google-genai SDK."""
@@ -221,6 +242,23 @@ class AnthropicProvider(AIProvider):
# Anthropic doesn't differentiate between JSON and text mode # Anthropic doesn't differentiate between JSON and text mode
return await self.generate_json(system_prompt, messages, max_tokens) return await self.generate_json(system_prompt, messages, max_tokens)
async def generate_text_stream(
self,
system_prompt: str,
messages: list[dict[str, str]],
max_tokens: int = 4096,
) -> AsyncIterator[str]:
client = _get_anthropic_client(self._api_key, self._timeout)
async with client.messages.stream(
model=self._model,
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
) as stream:
async for text in stream.text_stream:
yield text
def get_ai_provider(model: str | None = None) -> AIProvider: def get_ai_provider(model: str | None = None) -> AIProvider:
"""Factory that returns the configured AI provider. """Factory that returns the configured AI provider.