diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c44717041..ba4828c1a 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -435,6 +435,8 @@ def _stream( logger.debug("got response from model") if streaming: response = self.client.converse_stream(**request) + # Track tool use events to fix stopReason for streaming responses + has_tool_use = False for chunk in response["stream"]: if ( "metadata" in chunk @@ -446,7 +448,24 @@ def _stream( for event in self._generate_redaction_events(): callback(event) - callback(chunk) + # Track if we see tool use events + if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"): + has_tool_use = True + + # Fix stopReason for streaming responses that contain tool use + if ( + has_tool_use + and "messageStop" in chunk + and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn" + ): + # Create corrected chunk with tool_use stopReason + modified_chunk = chunk.copy() + modified_chunk["messageStop"] = message_stop.copy() + modified_chunk["messageStop"]["stopReason"] = "tool_use" + logger.warning("Override stop reason from end_turn to tool_use") + callback(modified_chunk) + else: + callback(chunk) else: response = self.client.converse(**request) @@ -582,9 +601,17 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera yield {"contentBlockStop": {}} # Yield messageStop event + # Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side + current_stop_reason = response["stopReason"] + if current_stop_reason == "end_turn": + message_content = response["output"]["message"]["content"] + if any("toolUse" in content for content in message_content): + current_stop_reason = "tool_use" + logger.warning("Override stop reason from end_turn to tool_use") + yield { "messageStop": { - "stopReason": response["stopReason"], + "stopReason": current_stop_reason, "additionalModelResponseFields": response.get("additionalModelResponseFields"), } } diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index f1a2250e4..2f44c2e65 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1227,6 +1227,53 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist): assert "finished streaming response from model" in log_text +@pytest.mark.asyncio +async def test_stream_stop_reason_override_streaming(bedrock_client, model, messages, alist): + """Test that stopReason is overridden from end_turn to tool_use in streaming mode when tool use is detected.""" + bedrock_client.converse_stream.return_value = { + "stream": [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test_tool"}}}}, + {"contentBlockDelta": {"delta": {"test": {"input": '{"param": "value"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + } + + response = model.stream(messages) + events = await alist(response) + + # Find the messageStop event + message_stop_event = next(event for event in events if "messageStop" in event) + + # Verify stopReason was overridden to tool_use + assert message_stop_event["messageStop"]["stopReason"] == "tool_use" + + +@pytest.mark.asyncio +async def test_stream_stop_reason_override_non_streaming(bedrock_client, alist, messages): + """Test that stopReason is overridden from end_turn to tool_use in non-streaming mode when tool use is detected.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"param": "value"}}}], + } + }, + "stopReason": "end_turn", + } + + model = BedrockModel(model_id="test-model", streaming=False) + response = model.stream(messages) + events = await alist(response) + + # Find the messageStop event + message_stop_event = next(event for event in events if "messageStop" in event) + + # Verify stopReason was overridden to tool_use + assert message_stop_event["messageStop"]["stopReason"] == "tool_use" + + def test_format_request_cleans_tool_result_content_blocks(model, model_id): """Test that format_request cleans toolResult blocks by removing extra fields.""" messages = [