Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ def _stream(
logger.debug("got response from model")
if streaming:
response = self.client.converse_stream(**request)
# Track tool use events to fix stopReason for streaming responses
has_tool_use = False
for chunk in response["stream"]:
if (
"metadata" in chunk
Expand All @@ -446,7 +448,24 @@ def _stream(
for event in self._generate_redaction_events():
callback(event)

callback(chunk)
# Track if we see tool use events
if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"):
has_tool_use = True

# Fix stopReason for streaming responses that contain tool use
if (
has_tool_use
and "messageStop" in chunk
and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn"
):
# Create corrected chunk with tool_use stopReason
modified_chunk = chunk.copy()
modified_chunk["messageStop"] = message_stop.copy()
modified_chunk["messageStop"]["stopReason"] = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")
callback(modified_chunk)
else:
callback(chunk)

else:
response = self.client.converse(**request)
Expand Down Expand Up @@ -582,9 +601,17 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
yield {"contentBlockStop": {}}

# Yield messageStop event
# Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side
current_stop_reason = response["stopReason"]
if current_stop_reason == "end_turn":
message_content = response["output"]["message"]["content"]
if any("toolUse" in content for content in message_content):
current_stop_reason = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")

yield {
"messageStop": {
"stopReason": response["stopReason"],
"stopReason": current_stop_reason,
"additionalModelResponseFields": response.get("additionalModelResponseFields"),
}
}
Expand Down
47 changes: 47 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,53 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist):
assert "finished streaming response from model" in log_text


@pytest.mark.asyncio
async def test_stream_stop_reason_override_streaming(bedrock_client, model, messages, alist):
"""Test that stopReason is overridden from end_turn to tool_use in streaming mode when tool use is detected."""
bedrock_client.converse_stream.return_value = {
"stream": [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test_tool"}}}},
{"contentBlockDelta": {"delta": {"test": {"input": '{"param": "value"}'}}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
]
}

response = model.stream(messages)
events = await alist(response)

# Find the messageStop event
message_stop_event = next(event for event in events if "messageStop" in event)

# Verify stopReason was overridden to tool_use
assert message_stop_event["messageStop"]["stopReason"] == "tool_use"


@pytest.mark.asyncio
async def test_stream_stop_reason_override_non_streaming(bedrock_client, alist, messages):
"""Test that stopReason is overridden from end_turn to tool_use in non-streaming mode when tool use is detected."""
bedrock_client.converse.return_value = {
"output": {
"message": {
"role": "assistant",
"content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"param": "value"}}}],
}
},
"stopReason": "end_turn",
}

model = BedrockModel(model_id="test-model", streaming=False)
response = model.stream(messages)
events = await alist(response)

# Find the messageStop event
message_stop_event = next(event for event in events if "messageStop" in event)

# Verify stopReason was overridden to tool_use
assert message_stop_event["messageStop"]["stopReason"] == "tool_use"


def test_format_request_cleans_tool_result_content_blocks(model, model_id):
"""Test that format_request cleans toolResult blocks by removing extra fields."""
messages = [
Expand Down
Loading