feat: add generate_text_stream to AnthropicProvider for SSE support
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,7 @@ backends for JSON generation used by the AI Flow Builder.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
@@ -54,6 +55,26 @@ class AIProvider(ABC):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def generate_text_stream(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> "AsyncIterator[str]":
|
||||||
|
"""Stream a text response token by token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: System-level instruction for the model.
|
||||||
|
messages: List of message dicts with "role" and "content" keys.
|
||||||
|
max_tokens: Maximum output tokens.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Text chunks as they are generated.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Streaming not supported for this provider")
|
||||||
|
# Make this an async generator to satisfy type checker
|
||||||
|
yield "" # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
class GeminiProvider(AIProvider):
|
class GeminiProvider(AIProvider):
|
||||||
"""Google Gemini provider using the google-genai SDK."""
|
"""Google Gemini provider using the google-genai SDK."""
|
||||||
@@ -221,6 +242,23 @@ class AnthropicProvider(AIProvider):
|
|||||||
# Anthropic doesn't differentiate between JSON and text mode
|
# Anthropic doesn't differentiate between JSON and text mode
|
||||||
return await self.generate_json(system_prompt, messages, max_tokens)
|
return await self.generate_json(system_prompt, messages, max_tokens)
|
||||||
|
|
||||||
|
async def generate_text_stream(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
client = _get_anthropic_client(self._api_key, self._timeout)
|
||||||
|
|
||||||
|
async with client.messages.stream(
|
||||||
|
model=self._model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=messages,
|
||||||
|
) as stream:
|
||||||
|
async for text in stream.text_stream:
|
||||||
|
yield text
|
||||||
|
|
||||||
|
|
||||||
def get_ai_provider(model: str | None = None) -> AIProvider:
|
def get_ai_provider(model: str | None = None) -> AIProvider:
|
||||||
"""Factory that returns the configured AI provider.
|
"""Factory that returns the configured AI provider.
|
||||||
|
|||||||
Reference in New Issue
Block a user