Skip to content

Commit 8e26192

Browse files
committed
fix: add unit tests
1 parent 182fd38 commit 8e26192

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

tests/strands/models/test_bedrock.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,53 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist):
12271227
assert "finished streaming response from model" in log_text
12281228

12291229

1230+
@pytest.mark.asyncio
1231+
async def test_stream_stop_reason_override_streaming(bedrock_client, model, messages, alist):
1232+
"""Test that stopReason is overridden from end_turn to tool_use in streaming mode when tool use is detected."""
1233+
bedrock_client.converse_stream.return_value = {
1234+
"stream": [
1235+
{"messageStart": {"role": "assistant"}},
1236+
{"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test_tool"}}}},
1237+
{"contentBlockDelta": {"delta": {"test": {"input": '{"param": "value"}'}}}},
1238+
{"contentBlockStop": {}},
1239+
{"messageStop": {"stopReason": "end_turn"}},
1240+
]
1241+
}
1242+
1243+
response = model.stream(messages)
1244+
events = await alist(response)
1245+
1246+
# Find the messageStop event
1247+
message_stop_event = next(event for event in events if "messageStop" in event)
1248+
1249+
# Verify stopReason was overridden to tool_use
1250+
assert message_stop_event["messageStop"]["stopReason"] == "tool_use"
1251+
1252+
1253+
@pytest.mark.asyncio
1254+
async def test_stream_stop_reason_override_non_streaming(bedrock_client, alist, messages):
1255+
"""Test that stopReason is overridden from end_turn to tool_use in non-streaming mode when tool use is detected."""
1256+
bedrock_client.converse.return_value = {
1257+
"output": {
1258+
"message": {
1259+
"role": "assistant",
1260+
"content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"param": "value"}}}],
1261+
}
1262+
},
1263+
"stopReason": "end_turn",
1264+
}
1265+
1266+
model = BedrockModel(model_id="test-model", streaming=False)
1267+
response = model.stream(messages)
1268+
events = await alist(response)
1269+
1270+
# Find the messageStop event
1271+
message_stop_event = next(event for event in events if "messageStop" in event)
1272+
1273+
# Verify stopReason was overridden to tool_use
1274+
assert message_stop_event["messageStop"]["stopReason"] == "tool_use"
1275+
1276+
12301277
def test_format_request_cleans_tool_result_content_blocks(model, model_id):
12311278
"""Test that format_request cleans toolResult blocks by removing extra fields."""
12321279
messages = [

0 commit comments

Comments
 (0)