diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 005eed3df..4af3fd20e 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -35,10 +35,13 @@ class LiteLLMConfig(TypedDict, total=False): params: Model parameters (e.g., max_tokens). For a complete list of supported parameters, see https://docs.litellm.ai/docs/completion/input#input-params-1. + streaming: Optional flag to indicate whether provider streaming should be used. + If omitted, defaults to True (preserves existing behaviour). """ model_id: str params: Optional[dict[str, Any]] + streaming: Optional[bool] def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None: """Initialize provider instance. diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index fc2e9c778..063a40a99 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -50,10 +50,13 @@ class OpenAIConfig(TypedDict, total=False): params: Model parameters (e.g., max_tokens). For a complete list of supported parameters, see https://platform.openai.com/docs/api-reference/chat/create. + streaming: Optional flag to indicate whether provider streaming should be used. + If omitted, defaults to True (preserves existing behaviour). """ model_id: str params: Optional[dict[str, Any]] + streaming: Optional[bool] def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: """Initialize provider instance. @@ -263,7 +266,8 @@ def format_request( return { "messages": self.format_request_messages(messages, system_prompt), "model": self.config["model_id"], - "stream": True, + # Use configured streaming flag; default True to preserve previous behavior. + "stream": bool(self.get_config().get("streaming", True)), "stream_options": {"include_usage": True}, "tools": [ { @@ -352,6 +356,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + def _convert_non_streaming_to_streaming(self, response: Any) -> list[StreamEvent]: + """Convert a provider non-streaming response into streaming-style events. + + This helper intentionally *does not* emit the initial message_start/content_start events, + because the caller (stream) already yields them to preserve parity with streaming flow. + """ + events: list[StreamEvent] = [] + + # Extract main text content from first choice if available + if getattr(response, "choices", None): + choice = response.choices[0] + content = None + if hasattr(choice, "message") and hasattr(choice.message, "content"): + content = choice.message.content + + # handle str content + if isinstance(content, str): + events.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": content})) + # handle list content (list of blocks/dicts) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict): + # reasoning content + if "reasoningContent" in block and isinstance(block["reasoningContent"], dict): + try: + text = block["reasoningContent"]["reasoningText"]["text"] + events.append( + self.format_chunk( + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": text} + ) + ) + except Exception: + # fall back to keeping the block as text if malformed + pass + # text block + elif "text" in block: + events.append( + self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": block["text"]} + ) + ) + # ignore other block types for now + elif isinstance(block, str): + events.append( + self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": block}) + ) + + # content stop + events.append(self.format_chunk({"chunk_type": "content_stop"})) + + # message stop — convert finish reason if available + stop_reason = None + if getattr(response, "choices", None): + stop_reason = getattr(response.choices[0], "finish_reason", None) + events.append(self.format_chunk({"chunk_type": "message_stop", "data": stop_reason or "stop"})) + + # metadata (usage) if present + if getattr(response, "usage", None): + events.append(self.format_chunk({"chunk_type": "metadata", "data": response.usage})) + + return events + @override async def stream( self, @@ -409,50 +475,63 @@ async def stream( tool_calls: dict[int, list[Any]] = {} - async for event in response: - # Defensive: skip events with empty or missing choices - if not getattr(event, "choices", None): - continue - choice = event.choices[0] - - if choice.delta.content: - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - ) - - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": "reasoning_content", - "data": choice.delta.reasoning_content, - } - ) + streaming = bool(self.get_config().get("streaming", True)) + + if streaming: + # response is an async iterator when streaming=True + async for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + if choice.delta.content: + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + ) + + if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": choice.delta.reasoning_content, + } + ) - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) - if choice.finish_reason: - break + if choice.finish_reason: + break - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) - for tool_deltas in tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + for tool_deltas in tool_calls.values(): + yield self.format_chunk( + {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} + ) - for tool_delta in tool_deltas: - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + for tool_delta in tool_deltas: + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} + ) - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) - yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) + yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) - # Skip remaining events as we don't have use for anything except the final usage payload - async for event in response: - _ = event + # Skip remaining events as we don't have use for anything except the final usage payload + async for event in response: + _ = event - if event.usage: - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + if event.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + else: + # Non-streaming provider response — convert to streaming-style events (excluding the initial + # message_start/content_start because we already emitted them above). + for ev in self._convert_non_streaming_to_streaming(response): + yield ev logger.debug("finished streaming response from model") diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index f8c8568fe..2161a67ae 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -612,6 +612,52 @@ async def test_stream(openai_client, model_id, model, agenerator, alist): openai_client.chat.completions.create.assert_called_once_with(**expected_request) +@pytest.mark.asyncio +async def test_stream_respects_streaming_flag(openai_client, model_id, alist): + # Model configured to NOT stream + model = OpenAIModel(client_args={}, model_id=model_id, params={"max_tokens": 1}, streaming=False) + + # Mock a non-streaming response object + mock_choice = unittest.mock.Mock() + mock_choice.finish_reason = "stop" + mock_choice.message = unittest.mock.Mock() + mock_choice.message.content = "non-stream result" + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + mock_response.usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response) + + # Consume the generator and verify the events + response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}]) + tru_events = await alist(response_gen) + + expected_request = { + "max_tokens": 1, + "model": model_id, + "messages": [{"role": "user", "content": [{"text": "hi", "type": "text"}]}], + "stream": False, + "stream_options": {"include_usage": True}, + "tools": [], + } + openai_client.chat.completions.create.assert_called_once_with(**expected_request) + + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "non-stream result"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + "metrics": {"latencyMs": 0}, + } + }, + ] + assert tru_events == exp_events + + @pytest.mark.asyncio async def test_stream_empty(openai_client, model_id, model, agenerator, alist): mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None)