Skip to content

Commit 7a5caad

Browse files
JackYPCOnlineJack Yuan
andauthored
fix: fix stop reason for bedrock model when stop_reason (#767)
* fix: fix stop reason for bedrock model when stop_reason is end_turn in tool use response. * change logger info to warning, optimize if condition * fix: add unit tests --------- Co-authored-by: Jack Yuan <[email protected]>
1 parent ae9d5ad commit 7a5caad

File tree

2 files changed

+76
-2
lines changed

2 files changed

+76
-2
lines changed

src/strands/models/bedrock.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,8 @@ def _stream(
435435
logger.debug("got response from model")
436436
if streaming:
437437
response = self.client.converse_stream(**request)
438+
# Track tool use events to fix stopReason for streaming responses
439+
has_tool_use = False
438440
for chunk in response["stream"]:
439441
if (
440442
"metadata" in chunk
@@ -446,7 +448,24 @@ def _stream(
446448
for event in self._generate_redaction_events():
447449
callback(event)
448450

449-
callback(chunk)
451+
# Track if we see tool use events
452+
if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"):
453+
has_tool_use = True
454+
455+
# Fix stopReason for streaming responses that contain tool use
456+
if (
457+
has_tool_use
458+
and "messageStop" in chunk
459+
and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn"
460+
):
461+
# Create corrected chunk with tool_use stopReason
462+
modified_chunk = chunk.copy()
463+
modified_chunk["messageStop"] = message_stop.copy()
464+
modified_chunk["messageStop"]["stopReason"] = "tool_use"
465+
logger.warning("Override stop reason from end_turn to tool_use")
466+
callback(modified_chunk)
467+
else:
468+
callback(chunk)
450469

451470
else:
452471
response = self.client.converse(**request)
@@ -582,9 +601,17 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
582601
yield {"contentBlockStop": {}}
583602

584603
# Yield messageStop event
604+
# Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side
605+
current_stop_reason = response["stopReason"]
606+
if current_stop_reason == "end_turn":
607+
message_content = response["output"]["message"]["content"]
608+
if any("toolUse" in content for content in message_content):
609+
current_stop_reason = "tool_use"
610+
logger.warning("Override stop reason from end_turn to tool_use")
611+
585612
yield {
586613
"messageStop": {
587-
"stopReason": response["stopReason"],
614+
"stopReason": current_stop_reason,
588615
"additionalModelResponseFields": response.get("additionalModelResponseFields"),
589616
}
590617
}

tests/strands/models/test_bedrock.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,53 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist):
12271227
assert "finished streaming response from model" in log_text
12281228

12291229

1230+
@pytest.mark.asyncio
1231+
async def test_stream_stop_reason_override_streaming(bedrock_client, model, messages, alist):
1232+
"""Test that stopReason is overridden from end_turn to tool_use in streaming mode when tool use is detected."""
1233+
bedrock_client.converse_stream.return_value = {
1234+
"stream": [
1235+
{"messageStart": {"role": "assistant"}},
1236+
{"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test_tool"}}}},
1237+
{"contentBlockDelta": {"delta": {"test": {"input": '{"param": "value"}'}}}},
1238+
{"contentBlockStop": {}},
1239+
{"messageStop": {"stopReason": "end_turn"}},
1240+
]
1241+
}
1242+
1243+
response = model.stream(messages)
1244+
events = await alist(response)
1245+
1246+
# Find the messageStop event
1247+
message_stop_event = next(event for event in events if "messageStop" in event)
1248+
1249+
# Verify stopReason was overridden to tool_use
1250+
assert message_stop_event["messageStop"]["stopReason"] == "tool_use"
1251+
1252+
1253+
@pytest.mark.asyncio
1254+
async def test_stream_stop_reason_override_non_streaming(bedrock_client, alist, messages):
1255+
"""Test that stopReason is overridden from end_turn to tool_use in non-streaming mode when tool use is detected."""
1256+
bedrock_client.converse.return_value = {
1257+
"output": {
1258+
"message": {
1259+
"role": "assistant",
1260+
"content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"param": "value"}}}],
1261+
}
1262+
},
1263+
"stopReason": "end_turn",
1264+
}
1265+
1266+
model = BedrockModel(model_id="test-model", streaming=False)
1267+
response = model.stream(messages)
1268+
events = await alist(response)
1269+
1270+
# Find the messageStop event
1271+
message_stop_event = next(event for event in events if "messageStop" in event)
1272+
1273+
# Verify stopReason was overridden to tool_use
1274+
assert message_stop_event["messageStop"]["stopReason"] == "tool_use"
1275+
1276+
12301277
def test_format_request_cleans_tool_result_content_blocks(model, model_id):
12311278
"""Test that format_request cleans toolResult blocks by removing extra fields."""
12321279
messages = [

0 commit comments

Comments
 (0)