From d70ab123a429e99c0623b31474c781f9c53f011c Mon Sep 17 00:00:00 2001 From: Abdullatif Alrashdan Date: Thu, 28 Aug 2025 19:11:56 +0000 Subject: [PATCH 1/3] checkpoint to tests --- src/strands/models/bedrock.py | 27 ++++++++++++++++++++------- tests/strands/models/test_bedrock.py | 2 +- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ace35640a..449cd9959 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -114,13 +114,6 @@ def __init__( if region_name and boto_session: raise ValueError("Cannot specify both `region_name` and `boto_session`.") - self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID) - self.update_config(**model_config) - - logger.debug("config=<%s> | initializing", self.config) - - session = boto_session or boto3.Session() - # Add strands-agents to the request user agent if boto_client_config: existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) @@ -135,7 +128,13 @@ def __init__( else: client_config = BotocoreConfig(user_agent_extra="strands-agents") + session = boto_session or boto3.Session() resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION + self.config = BedrockModel.BedrockConfig(model_id=self._get_default_model_for_region(resolved_region)) + + self.update_config(**model_config) + + logger.debug("config=<%s> | initializing", self.config) self.client = session.client( service_name="bedrock-runtime", @@ -349,6 +348,20 @@ def _generate_redaction_events(self) -> list[StreamEvent]: return events + def _get_default_model_for_region(self, region: str) -> str: + priorities = [ + "sonnet-4", + "3-7-sonnet", # Claude 3.7 sonnet as a fallback + ] + client = boto3.client("bedrock", region_name=region) + response = client.list_inference_profiles() + inferenceProfileSummary = response["inferenceProfileSummaries"] + for priority in priorities: + for profile in inferenceProfileSummary: + if priority in profile["inferenceProfileId"]: + return profile["inferenceProfileId"] + return None + @override async def stream( self, diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 09e508845..105d40b4e 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -117,7 +117,7 @@ def test__init__default_model_id(bedrock_client): """Test that BedrockModel uses DEFAULT_MODEL_ID when no model_id is provided.""" _ = bedrock_client model = BedrockModel() - + tru_model_id = model.get_config().get("model_id") exp_model_id = DEFAULT_BEDROCK_MODEL_ID From 7dfd1626ec5dc21df4e847d9445aa5c439243bc3 Mon Sep 17 00:00:00 2001 From: Abdullatif Alrashdan Date: Thu, 28 Aug 2025 20:34:03 +0000 Subject: [PATCH 2/3] checkpoint for test suite --- src/strands/models/bedrock.py | 24 +- tests/strands/models/test_bedrock.py | 1975 +++++++++++++------------- 2 files changed, 1006 insertions(+), 993 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 234e844e6..d61e563f7 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -28,8 +28,8 @@ logger = logging.getLogger(__name__) -DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" DEFAULT_BEDROCK_REGION = "us-west-2" +DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ "Input is too long for requested model", @@ -133,8 +133,13 @@ def __init__( session = boto_session or boto3.Session() resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION - self.config = BedrockModel.BedrockConfig(model_id=self._get_default_model_for_region(resolved_region)) + # get default model id based on resolved region + resolved_model_id = self._get_default_model_for_region(resolved_region) + if resolved_model_id == "": + raise ValueError("default model {} is not available in {} region. Specify another model".format(DEFAULT_BEDROCK_MODEL_ID, resolved_region)) + + self.config = BedrockModel.BedrockConfig(model_id=resolved_model_id) self.update_config(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -352,18 +357,15 @@ def _generate_redaction_events(self) -> list[StreamEvent]: return events def _get_default_model_for_region(self, region: str) -> str: - priorities = [ - "sonnet-4", - "3-7-sonnet", # Claude 3.7 sonnet as a fallback - ] client = boto3.client("bedrock", region_name=region) response = client.list_inference_profiles() inferenceProfileSummary = response["inferenceProfileSummaries"] - for priority in priorities: - for profile in inferenceProfileSummary: - if priority in profile["inferenceProfileId"]: - return profile["inferenceProfileId"] - return None + + for profile in inferenceProfileSummary: + if DEFAULT_BEDROCK_MODEL_ID in profile["inferenceProfileId"]: + return profile["inferenceProfileId"] + + return "" @override async def stream( diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 105d40b4e..ccd6986a5 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -23,6 +23,17 @@ def session_cls(): mock_session_cls.return_value.region_name = None yield mock_session_cls +@pytest.fixture +def mock_bedrock_inference_profiles(): + with unittest.mock.patch.object(strands.models.bedrock.boto3, "client") as mock_boto_client: + mock_bedrock = unittest.mock.MagicMock() + mock_bedrock.list_inference_profiles.return_value = { + "inferenceProfileSummaries": [ + {"inferenceProfileId": "us.anthropic.claude-sonnet-4-20250514-v1:0"} + ] + } + mock_boto_client.return_value = mock_bedrock + yield mock_boto_client @pytest.fixture def mock_client_method(session_cls): @@ -113,18 +124,18 @@ class TestOutputModel(pydantic.BaseModel): return TestOutputModel -def test__init__default_model_id(bedrock_client): +def test__init__default_model_id(bedrock_client, mock_bedrock_inference_profiles): """Test that BedrockModel uses DEFAULT_MODEL_ID when no model_id is provided.""" _ = bedrock_client model = BedrockModel() tru_model_id = model.get_config().get("model_id") - exp_model_id = DEFAULT_BEDROCK_MODEL_ID + exp_model_id = "us."+DEFAULT_BEDROCK_MODEL_ID assert tru_model_id == exp_model_id -def test__init__with_default_region(session_cls, mock_client_method): +def test__init__with_default_region(session_cls, mock_client_method, mock_bedrock_inference_profiles): """Test that BedrockModel uses the provided region.""" with unittest.mock.patch.object(os, "environ", {}): BedrockModel() @@ -133,7 +144,7 @@ def test__init__with_default_region(session_cls, mock_client_method): ) -def test__init__with_session_region(session_cls, mock_client_method): +def test__init__with_session_region(session_cls, mock_client_method, mock_bedrock_inference_profiles): """Test that BedrockModel uses the provided region.""" session_cls.return_value.region_name = "eu-blah-1" @@ -142,14 +153,14 @@ def test__init__with_session_region(session_cls, mock_client_method): mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY) -def test__init__with_custom_region(mock_client_method): +def test__init__with_custom_region(mock_client_method, mock_bedrock_inference_profiles): """Test that BedrockModel uses the provided region.""" custom_region = "us-east-1" BedrockModel(region_name=custom_region) mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY) -def test__init__with_default_environment_variable_region(mock_client_method): +def test__init__with_default_environment_variable_region(mock_client_method, mock_bedrock_inference_profiles): """Test that BedrockModel uses the AWS_REGION since we code that in.""" with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "eu-west-2"}): BedrockModel() @@ -157,7 +168,7 @@ def test__init__with_default_environment_variable_region(mock_client_method): mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY) -def test__init__region_precedence(mock_client_method, session_cls): +def test__init__region_precedence(mock_client_method, session_cls, mock_bedrock_inference_profiles): """Test that BedrockModel uses the correct ordering of precedence when determining region.""" with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "us-environment-1"}) as mock_os_environ: session_cls.return_value.region_name = "us-session-1" @@ -187,7 +198,7 @@ def test__init__with_region_and_session_raises_value_error(): _ = BedrockModel(region_name="us-east-1", boto_session=boto3.Session(region_name="us-east-1")) -def test__init__default_user_agent(bedrock_client): +def test__init__default_user_agent(bedrock_client, mock_bedrock_inference_profiles): """Set user agent when no boto_client_config is provided.""" with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: mock_session = mock_session_cls.return_value @@ -201,7 +212,7 @@ def test__init__default_user_agent(bedrock_client): assert kwargs["config"].user_agent_extra == "strands-agents" -def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): +def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client, mock_bedrock_inference_profiles): """Set user agent when boto_client_config is provided without user_agent_extra.""" custom_config = BotocoreConfig(read_timeout=900) @@ -218,7 +229,7 @@ def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): assert kwargs["config"].read_timeout == 900 -def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client): +def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client, mock_bedrock_inference_profiles): """Append to existing user agent when boto_client_config is provided with user_agent_extra.""" custom_config = BotocoreConfig(user_agent_extra="existing-agent", read_timeout=900) @@ -235,7 +246,7 @@ def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client): assert kwargs["config"].read_timeout == 900 -def test__init__model_config(bedrock_client): +def test__init__model_config(bedrock_client, mock_bedrock_inference_profiles): _ = bedrock_client model = BedrockModel(max_tokens=1) @@ -255,985 +266,985 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id -def test_format_request_default(model, messages, model_id): - tru_request = model.format_request(messages) - exp_request = { - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - } - - assert tru_request == exp_request - - -def test_format_request_additional_request_fields(model, messages, model_id, additional_request_fields): - model.update_config(additional_request_fields=additional_request_fields) - tru_request = model.format_request(messages) - exp_request = { - "additionalModelRequestFields": additional_request_fields, - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - } - - assert tru_request == exp_request - - -def test_format_request_additional_response_field_paths(model, messages, model_id, additional_response_field_paths): - model.update_config(additional_response_field_paths=additional_response_field_paths) - tru_request = model.format_request(messages) - exp_request = { - "additionalModelResponseFieldPaths": additional_response_field_paths, - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - } - - assert tru_request == exp_request - - -def test_format_request_guardrail_config(model, messages, model_id, guardrail_config): - model.update_config(**guardrail_config) - tru_request = model.format_request(messages) - exp_request = { - "guardrailConfig": { - "guardrailIdentifier": guardrail_config["guardrail_id"], - "guardrailVersion": guardrail_config["guardrail_version"], - "trace": guardrail_config["guardrail_trace"], - "streamProcessingMode": guardrail_config["guardrail_stream_processing_mode"], - }, - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - } - - assert tru_request == exp_request - - -def test_format_request_guardrail_config_without_trace_or_stream_processing_mode(model, messages, model_id): - model.update_config( - **{ - "guardrail_id": "g1", - "guardrail_version": "v1", - } - ) - tru_request = model.format_request(messages) - exp_request = { - "guardrailConfig": { - "guardrailIdentifier": "g1", - "guardrailVersion": "v1", - "trace": "enabled", - }, - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - } - - assert tru_request == exp_request - - -def test_format_request_inference_config(model, messages, model_id, inference_config): - model.update_config(**inference_config) +def test_format_request_default(model, messages, model_id, mock_bedrock_inference_profiles): tru_request = model.format_request(messages) - exp_request = { - "inferenceConfig": { - "maxTokens": inference_config["max_tokens"], - "stopSequences": inference_config["stop_sequences"], - "temperature": inference_config["temperature"], - "topP": inference_config["top_p"], - }, - "modelId": model_id, - "messages": messages, - "system": [], - } - - assert tru_request == exp_request - - -def test_format_request_system_prompt(model, messages, model_id, system_prompt): - tru_request = model.format_request(messages, system_prompt=system_prompt) - exp_request = { - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [{"text": system_prompt}], - } - - assert tru_request == exp_request - - -def test_format_request_tool_specs(model, messages, model_id, tool_spec): - tru_request = model.format_request(messages, [tool_spec]) exp_request = { "inferenceConfig": {}, "modelId": model_id, "messages": messages, "system": [], - "toolConfig": { - "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, - }, - } - - assert tru_request == exp_request - - -def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): - model.update_config(cache_prompt=cache_type, cache_tools=cache_type) - tru_request = model.format_request(messages, [tool_spec]) - exp_request = { - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [{"cachePoint": {"type": cache_type}}], - "toolConfig": { - "tools": [ - {"toolSpec": tool_spec}, - {"cachePoint": {"type": cache_type}}, - ], - "toolChoice": {"auto": {}}, - }, } assert tru_request == exp_request -@pytest.mark.asyncio -async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist): - error_message = "Rate exceeded" - bedrock_client.converse_stream.side_effect = EventStreamError( - {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "ConverseStream" - ) - - with pytest.raises(ModelThrottledException) as excinfo: - await alist(model.stream(messages)) - - assert error_message in str(excinfo.value) - bedrock_client.converse_stream.assert_called_once_with( - modelId="m1", messages=messages, system=[], inferenceConfig={} - ) - - -@pytest.mark.asyncio -async def test_stream_with_invalid_content_throws(bedrock_client, model, alist): - # We used to hang on None, so ensure we don't regress: https://github.com/strands-agents/sdk-python/issues/642 - messages = [{"role": "user", "content": None}] - - with pytest.raises(TypeError): - await alist(model.stream(messages)) - - -@pytest.mark.asyncio -async def test_stream_throttling_exception_from_general_exception(bedrock_client, model, messages, alist): - error_message = "ThrottlingException: Rate exceeded for ConverseStream" - bedrock_client.converse_stream.side_effect = ClientError( - {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "Any" - ) - - with pytest.raises(ModelThrottledException) as excinfo: - await alist(model.stream(messages)) - - assert error_message in str(excinfo.value) - bedrock_client.converse_stream.assert_called_once_with( - modelId="m1", messages=messages, system=[], inferenceConfig={} - ) - - -@pytest.mark.asyncio -async def test_general_exception_is_raised(bedrock_client, model, messages, alist): - error_message = "Should be raised up" - bedrock_client.converse_stream.side_effect = ValueError(error_message) - - with pytest.raises(ValueError) as excinfo: - await alist(model.stream(messages)) - - assert error_message in str(excinfo.value) - bedrock_client.converse_stream.assert_called_once_with( - modelId="m1", messages=messages, system=[], inferenceConfig={} - ) - - -@pytest.mark.asyncio -async def test_stream(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist): - bedrock_client.converse_stream.return_value = {"stream": ["e1", "e2"]} - - request = { - "additionalModelRequestFields": additional_request_fields, - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - "toolConfig": { - "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, - }, - } - - model.update_config(additional_request_fields=additional_request_fields) - response = model.stream(messages, [tool_spec]) - - tru_chunks = await alist(response) - exp_chunks = ["e1", "e2"] - - assert tru_chunks == exp_chunks - bedrock_client.converse_stream.assert_called_once_with(**request) - - -@pytest.mark.asyncio -async def test_stream_stream_input_guardrails( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist -): - metadata_event = { - "metadata": { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - "metrics": {"latencyMs": 245}, - "trace": { - "guardrail": { - "inputAssessment": { - "3e59qlue4hag": { - "wordPolicy": { - "customWords": [ - { - "match": "CACTUS", - "action": "BLOCKED", - "detected": True, - } - ] - } - } - } - } - }, - } - } - bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} - - request = { - "additionalModelRequestFields": additional_request_fields, - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - "toolConfig": { - "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, - }, - } - - model.update_config(additional_request_fields=additional_request_fields) - response = model.stream(messages, [tool_spec]) - - tru_chunks = await alist(response) - exp_chunks = [ - {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, - metadata_event, - ] - - assert tru_chunks == exp_chunks - bedrock_client.converse_stream.assert_called_once_with(**request) - - -@pytest.mark.asyncio -async def test_stream_stream_output_guardrails( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist -): - model.update_config(guardrail_redact_input=False, guardrail_redact_output=True) - metadata_event = { - "metadata": { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - "metrics": {"latencyMs": 245}, - "trace": { - "guardrail": { - "outputAssessments": { - "3e59qlue4hag": [ - { - "wordPolicy": { - "customWords": [ - { - "match": "CACTUS", - "action": "BLOCKED", - "detected": True, - } - ] - }, - } - ] - }, - } - }, - } - } - bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} - - request = { - "additionalModelRequestFields": additional_request_fields, - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - "toolConfig": { - "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, - }, - } - - model.update_config(additional_request_fields=additional_request_fields) - response = model.stream(messages, [tool_spec]) - - tru_chunks = await alist(response) - exp_chunks = [ - {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, - metadata_event, - ] - - assert tru_chunks == exp_chunks - bedrock_client.converse_stream.assert_called_once_with(**request) - - -@pytest.mark.asyncio -async def test_stream_output_guardrails_redacts_input_and_output( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist -): - model.update_config(guardrail_redact_output=True) - metadata_event = { - "metadata": { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - "metrics": {"latencyMs": 245}, - "trace": { - "guardrail": { - "outputAssessments": { - "3e59qlue4hag": [ - { - "wordPolicy": { - "customWords": [ - { - "match": "CACTUS", - "action": "BLOCKED", - "detected": True, - } - ] - }, - } - ] - }, - } - }, - } - } - bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} - - request = { - "additionalModelRequestFields": additional_request_fields, - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - "toolConfig": { - "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, - }, - } - - model.update_config(additional_request_fields=additional_request_fields) - response = model.stream(messages, [tool_spec]) - - tru_chunks = await alist(response) - exp_chunks = [ - {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, - {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, - metadata_event, - ] - - assert tru_chunks == exp_chunks - bedrock_client.converse_stream.assert_called_once_with(**request) - - -@pytest.mark.asyncio -async def test_stream_output_no_blocked_guardrails_doesnt_redact( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist -): - metadata_event = { - "metadata": { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - "metrics": {"latencyMs": 245}, - "trace": { - "guardrail": { - "outputAssessments": { - "3e59qlue4hag": [ - { - "wordPolicy": { - "customWords": [ - { - "match": "CACTUS", - "action": "NONE", - "detected": True, - } - ] - }, - } - ] - }, - } - }, - } - } - bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} - - request = { - "additionalModelRequestFields": additional_request_fields, - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - "toolConfig": { - "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, - }, - } - - model.update_config(additional_request_fields=additional_request_fields) - response = model.stream(messages, [tool_spec]) - - tru_chunks = await alist(response) - exp_chunks = [metadata_event] - - assert tru_chunks == exp_chunks - bedrock_client.converse_stream.assert_called_once_with(**request) - - -@pytest.mark.asyncio -async def test_stream_output_no_guardrail_redact( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist -): - metadata_event = { - "metadata": { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - "metrics": {"latencyMs": 245}, - "trace": { - "guardrail": { - "outputAssessments": { - "3e59qlue4hag": [ - { - "wordPolicy": { - "customWords": [ - { - "match": "CACTUS", - "action": "BLOCKED", - "detected": True, - } - ] - }, - } - ] - }, - } - }, - } - } - bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} - - request = { - "additionalModelRequestFields": additional_request_fields, - "inferenceConfig": {}, - "modelId": model_id, - "messages": messages, - "system": [], - "toolConfig": { - "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, - }, - } - - model.update_config( - additional_request_fields=additional_request_fields, - guardrail_redact_output=False, - guardrail_redact_input=False, - ) - response = model.stream(messages, [tool_spec]) - - tru_chunks = await alist(response) - exp_chunks = [metadata_event] - - assert tru_chunks == exp_chunks - bedrock_client.converse_stream.assert_called_once_with(**request) - - -@pytest.mark.asyncio -async def test_stream_with_streaming_false(bedrock_client, alist, messages): - """Test stream method with streaming=False.""" - bedrock_client.converse.return_value = { - "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, - "stopReason": "end_turn", - } - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - response = model.stream(messages) - - tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockDelta": {"delta": {"text": "test"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, - ] - assert tru_events == exp_events - - bedrock_client.converse.assert_called_once() - bedrock_client.converse_stream.assert_not_called() - - -@pytest.mark.asyncio -async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, messages): - """Test stream method with streaming=False.""" - bedrock_client.converse.return_value = { - "output": { - "message": { - "role": "assistant", - "content": [{"toolUse": {"toolUseId": "123", "name": "dummyTool", "input": {"hello": "world!"}}}], - } - }, - "stopReason": "tool_use", - } - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - response = model.stream(messages) - - tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "dummyTool"}}}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"hello": "world!"}'}}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, - ] - assert tru_events == exp_events - - bedrock_client.converse.assert_called_once() - bedrock_client.converse_stream.assert_not_called() - - -@pytest.mark.asyncio -async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, messages): - """Test stream method with streaming=False.""" - bedrock_client.converse.return_value = { - "output": { - "message": { - "role": "assistant", - "content": [ - { - "reasoningContent": { - "reasoningText": {"text": "Thinking really hard....", "signature": "123"}, - } - } - ], - } - }, - "stopReason": "tool_use", - } - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - response = model.stream(messages) - - tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "123"}}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, - ] - assert tru_events == exp_events - - # Verify converse was called - bedrock_client.converse.assert_called_once() - bedrock_client.converse_stream.assert_not_called() - - -@pytest.mark.asyncio -async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages): - """Test stream method with streaming=False.""" - bedrock_client.converse.return_value = { - "output": { - "message": { - "role": "assistant", - "content": [ - { - "reasoningContent": { - "reasoningText": {"text": "Thinking really hard...."}, - } - } - ], - } - }, - "stopReason": "tool_use", - } - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - response = model.stream(messages) - - tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, - ] - assert tru_events == exp_events - - bedrock_client.converse.assert_called_once() - bedrock_client.converse_stream.assert_not_called() - - -@pytest.mark.asyncio -async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist, messages): - """Test stream method with streaming=False.""" - bedrock_client.converse.return_value = { - "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, - "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, - "metrics": {"latencyMs": 1234}, - "stopReason": "tool_use", - } - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - response = model.stream(messages) - - tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockDelta": {"delta": {"text": "test"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, - { - "metadata": { - "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, - "metrics": {"latencyMs": 1234}, - } - }, - ] - assert tru_events == exp_events - - # Verify converse was called - bedrock_client.converse.assert_called_once() - bedrock_client.converse_stream.assert_not_called() - - -@pytest.mark.asyncio -async def test_stream_input_guardrails(bedrock_client, alist, messages): - """Test stream method with streaming=False.""" - bedrock_client.converse.return_value = { - "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, - "trace": { - "guardrail": { - "inputAssessment": { - "3e59qlue4hag": { - "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]} - } - } - } - }, - "stopReason": "end_turn", - } - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - response = model.stream(messages) - - tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockDelta": {"delta": {"text": "test"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, - { - "metadata": { - "trace": { - "guardrail": { - "inputAssessment": { - "3e59qlue4hag": { - "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] - } - } - } - } - } - } - }, - {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, - ] - assert tru_events == exp_events - - bedrock_client.converse.assert_called_once() - bedrock_client.converse_stream.assert_not_called() - - -@pytest.mark.asyncio -async def test_stream_output_guardrails(bedrock_client, alist, messages): - """Test stream method with streaming=False.""" - bedrock_client.converse.return_value = { - "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, - "trace": { - "guardrail": { - "outputAssessments": { - "3e59qlue4hag": [ - { - "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, - } - ] - }, - } - }, - "stopReason": "end_turn", - } - - model = BedrockModel(model_id="test-model", streaming=False) - response = model.stream(messages) - - tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockDelta": {"delta": {"text": "test"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, - { - "metadata": { - "trace": { - "guardrail": { - "outputAssessments": { - "3e59qlue4hag": [ - { - "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] - } - } - ] - } - } - } - } - }, - {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, - ] - assert tru_events == exp_events - - bedrock_client.converse.assert_called_once() - bedrock_client.converse_stream.assert_not_called() - - -@pytest.mark.asyncio -async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages): - """Test stream method with streaming=False.""" - bedrock_client.converse.return_value = { - "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, - "trace": { - "guardrail": { - "outputAssessments": { - "3e59qlue4hag": [ - { - "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, - } - ] - }, - } - }, - "stopReason": "end_turn", - } - - model = BedrockModel(model_id="test-model", streaming=False) - response = model.stream(messages) - - tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockDelta": {"delta": {"text": "test"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, - { - "metadata": { - "trace": { - "guardrail": { - "outputAssessments": { - "3e59qlue4hag": [ - { - "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] - } - } - ] - } - } - } - } - }, - {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, - ] - assert tru_events == exp_events - - bedrock_client.converse.assert_called_once() - bedrock_client.converse_stream.assert_not_called() - - -@pytest.mark.asyncio -async def test_structured_output(bedrock_client, model, test_output_model_cls, alist): - messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] - - bedrock_client.converse_stream.return_value = { - "stream": [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "TestOutputModel"}}}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "John", "age": 30}'}}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use"}}, - ] - } - - stream = model.structured_output(test_output_model_cls, messages) - events = await alist(stream) - - tru_output = events[-1] - exp_output = {"output": test_output_model_cls(name="John", age=30)} - assert tru_output == exp_output - - -@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") -@pytest.mark.asyncio -async def test_add_note_on_client_error(bedrock_client, model, alist, messages): - """Test that add_note is called on ClientError with region and model ID information.""" - # Mock the client error response - error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} - bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") - - # Call the stream method which should catch and add notes to the exception - with pytest.raises(ClientError) as err: - await alist(model.stream(messages)) - - assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"] - - -@pytest.mark.asyncio -async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages): - """Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception).""" - # Mock the client error response - error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} - bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") - - # Call the stream method which should catch and add notes to the exception - with pytest.raises(ClientError): - await alist(model.stream(messages)) - - -@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") -@pytest.mark.asyncio -async def test_add_note_on_access_denied_exception(bedrock_client, model, alist, messages): - """Test that add_note adds documentation link for AccessDeniedException.""" - # Mock the client error response for access denied - error_response = { - "Error": { - "Code": "AccessDeniedException", - "Message": "An error occurred (AccessDeniedException) when calling the ConverseStream operation: " - "You don't have access to the model with the specified model ID.", - } - } - bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") - - # Call the stream method which should catch and add notes to the exception - with pytest.raises(ClientError) as err: - await alist(model.stream(messages)) - - assert err.value.__notes__ == [ - "└ Bedrock region: us-west-2", - "└ Model id: m1", - "└ For more information see " - "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue", - ] - - -@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") -@pytest.mark.asyncio -async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist, messages): - """Test that add_note adds documentation link for ValidationException about on-demand throughput.""" - # Mock the client error response for validation exception - error_response = { - "Error": { - "Code": "ValidationException", - "Message": "An error occurred (ValidationException) when calling the ConverseStream operation: " - "Invocation of model ID anthropic.claude-3-7-sonnet-20250219-v1:0 with on-demand throughput " - "isn’t supported. Retry your request with the ID or ARN of an inference profile that contains " - "this model.", - } - } - bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") - - # Call the stream method which should catch and add notes to the exception - with pytest.raises(ClientError) as err: - await alist(model.stream(messages)) - - assert err.value.__notes__ == [ - "└ Bedrock region: us-west-2", - "└ Model id: m1", - "└ For more information see " - "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported", - ] - - -@pytest.mark.asyncio -async def test_stream_logging(bedrock_client, model, messages, caplog, alist): - """Test that stream method logs debug messages at the expected stages.""" - import logging - - # Set the logger to debug level to capture debug messages - caplog.set_level(logging.DEBUG, logger="strands.models.bedrock") - - # Mock the response - bedrock_client.converse_stream.return_value = {"stream": ["e1", "e2"]} - - # Execute the stream method - response = model.stream(messages) - await alist(response) - - # Check that the expected log messages are present - log_text = caplog.text - assert "formatting request" in log_text - assert "request=<" in log_text - assert "invoking model" in log_text - assert "got response from model" in log_text - assert "finished streaming response from model" in log_text - - -def test_format_request_cleans_tool_result_content_blocks(model, model_id): - """Test that format_request cleans toolResult blocks by removing extra fields.""" - messages = [ - { - "role": "user", - "content": [ - { - "toolResult": { - "content": [{"text": "Tool output"}], - "toolUseId": "tool123", - "status": "success", - "extraField": "should be removed", - "mcpMetadata": {"server": "test"}, - } - }, - ], - } - ] - - formatted_request = model.format_request(messages) - - # Verify toolResult only contains allowed fields in the formatted request - tool_result = formatted_request["messages"][0]["content"][0]["toolResult"] - expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"} - assert tool_result == expected - assert "extraField" not in tool_result - assert "mcpMetadata" not in tool_result +# def test_format_request_additional_request_fields(model, messages, model_id, additional_request_fields): +# model.update_config(additional_request_fields=additional_request_fields) +# tru_request = model.format_request(messages) +# exp_request = { +# "additionalModelRequestFields": additional_request_fields, +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# } + +# assert tru_request == exp_request + + +# def test_format_request_additional_response_field_paths(model, messages, model_id, additional_response_field_paths): +# model.update_config(additional_response_field_paths=additional_response_field_paths) +# tru_request = model.format_request(messages) +# exp_request = { +# "additionalModelResponseFieldPaths": additional_response_field_paths, +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# } + +# assert tru_request == exp_request + + +# def test_format_request_guardrail_config(model, messages, model_id, guardrail_config): +# model.update_config(**guardrail_config) +# tru_request = model.format_request(messages) +# exp_request = { +# "guardrailConfig": { +# "guardrailIdentifier": guardrail_config["guardrail_id"], +# "guardrailVersion": guardrail_config["guardrail_version"], +# "trace": guardrail_config["guardrail_trace"], +# "streamProcessingMode": guardrail_config["guardrail_stream_processing_mode"], +# }, +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# } + +# assert tru_request == exp_request + + +# def test_format_request_guardrail_config_without_trace_or_stream_processing_mode(model, messages, model_id): +# model.update_config( +# **{ +# "guardrail_id": "g1", +# "guardrail_version": "v1", +# } +# ) +# tru_request = model.format_request(messages) +# exp_request = { +# "guardrailConfig": { +# "guardrailIdentifier": "g1", +# "guardrailVersion": "v1", +# "trace": "enabled", +# }, +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# } + +# assert tru_request == exp_request + + +# def test_format_request_inference_config(model, messages, model_id, inference_config): +# model.update_config(**inference_config) +# tru_request = model.format_request(messages) +# exp_request = { +# "inferenceConfig": { +# "maxTokens": inference_config["max_tokens"], +# "stopSequences": inference_config["stop_sequences"], +# "temperature": inference_config["temperature"], +# "topP": inference_config["top_p"], +# }, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# } + +# assert tru_request == exp_request + + +# def test_format_request_system_prompt(model, messages, model_id, system_prompt): +# tru_request = model.format_request(messages, system_prompt=system_prompt) +# exp_request = { +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [{"text": system_prompt}], +# } + +# assert tru_request == exp_request + + +# def test_format_request_tool_specs(model, messages, model_id, tool_spec): +# tru_request = model.format_request(messages, [tool_spec]) +# exp_request = { +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# "toolConfig": { +# "tools": [{"toolSpec": tool_spec}], +# "toolChoice": {"auto": {}}, +# }, +# } + +# assert tru_request == exp_request + + +# def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): +# model.update_config(cache_prompt=cache_type, cache_tools=cache_type) +# tru_request = model.format_request(messages, [tool_spec]) +# exp_request = { +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [{"cachePoint": {"type": cache_type}}], +# "toolConfig": { +# "tools": [ +# {"toolSpec": tool_spec}, +# {"cachePoint": {"type": cache_type}}, +# ], +# "toolChoice": {"auto": {}}, +# }, +# } + +# assert tru_request == exp_request + + +# @pytest.mark.asyncio +# async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist): +# error_message = "Rate exceeded" +# bedrock_client.converse_stream.side_effect = EventStreamError( +# {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "ConverseStream" +# ) + +# with pytest.raises(ModelThrottledException) as excinfo: +# await alist(model.stream(messages)) + +# assert error_message in str(excinfo.value) +# bedrock_client.converse_stream.assert_called_once_with( +# modelId="m1", messages=messages, system=[], inferenceConfig={} +# ) + + +# @pytest.mark.asyncio +# async def test_stream_with_invalid_content_throws(bedrock_client, model, alist): +# # We used to hang on None, so ensure we don't regress: https://github.com/strands-agents/sdk-python/issues/642 +# messages = [{"role": "user", "content": None}] + +# with pytest.raises(TypeError): +# await alist(model.stream(messages)) + + +# @pytest.mark.asyncio +# async def test_stream_throttling_exception_from_general_exception(bedrock_client, model, messages, alist): +# error_message = "ThrottlingException: Rate exceeded for ConverseStream" +# bedrock_client.converse_stream.side_effect = ClientError( +# {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "Any" +# ) + +# with pytest.raises(ModelThrottledException) as excinfo: +# await alist(model.stream(messages)) + +# assert error_message in str(excinfo.value) +# bedrock_client.converse_stream.assert_called_once_with( +# modelId="m1", messages=messages, system=[], inferenceConfig={} +# ) + + +# @pytest.mark.asyncio +# async def test_general_exception_is_raised(bedrock_client, model, messages, alist): +# error_message = "Should be raised up" +# bedrock_client.converse_stream.side_effect = ValueError(error_message) + +# with pytest.raises(ValueError) as excinfo: +# await alist(model.stream(messages)) + +# assert error_message in str(excinfo.value) +# bedrock_client.converse_stream.assert_called_once_with( +# modelId="m1", messages=messages, system=[], inferenceConfig={} +# ) + + +# @pytest.mark.asyncio +# async def test_stream(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist): +# bedrock_client.converse_stream.return_value = {"stream": ["e1", "e2"]} + +# request = { +# "additionalModelRequestFields": additional_request_fields, +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# "toolConfig": { +# "tools": [{"toolSpec": tool_spec}], +# "toolChoice": {"auto": {}}, +# }, +# } + +# model.update_config(additional_request_fields=additional_request_fields) +# response = model.stream(messages, [tool_spec]) + +# tru_chunks = await alist(response) +# exp_chunks = ["e1", "e2"] + +# assert tru_chunks == exp_chunks +# bedrock_client.converse_stream.assert_called_once_with(**request) + + +# @pytest.mark.asyncio +# async def test_stream_stream_input_guardrails( +# bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist +# ): +# metadata_event = { +# "metadata": { +# "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, +# "metrics": {"latencyMs": 245}, +# "trace": { +# "guardrail": { +# "inputAssessment": { +# "3e59qlue4hag": { +# "wordPolicy": { +# "customWords": [ +# { +# "match": "CACTUS", +# "action": "BLOCKED", +# "detected": True, +# } +# ] +# } +# } +# } +# } +# }, +# } +# } +# bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} + +# request = { +# "additionalModelRequestFields": additional_request_fields, +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# "toolConfig": { +# "tools": [{"toolSpec": tool_spec}], +# "toolChoice": {"auto": {}}, +# }, +# } + +# model.update_config(additional_request_fields=additional_request_fields) +# response = model.stream(messages, [tool_spec]) + +# tru_chunks = await alist(response) +# exp_chunks = [ +# {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, +# metadata_event, +# ] + +# assert tru_chunks == exp_chunks +# bedrock_client.converse_stream.assert_called_once_with(**request) + + +# @pytest.mark.asyncio +# async def test_stream_stream_output_guardrails( +# bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist +# ): +# model.update_config(guardrail_redact_input=False, guardrail_redact_output=True) +# metadata_event = { +# "metadata": { +# "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, +# "metrics": {"latencyMs": 245}, +# "trace": { +# "guardrail": { +# "outputAssessments": { +# "3e59qlue4hag": [ +# { +# "wordPolicy": { +# "customWords": [ +# { +# "match": "CACTUS", +# "action": "BLOCKED", +# "detected": True, +# } +# ] +# }, +# } +# ] +# }, +# } +# }, +# } +# } +# bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} + +# request = { +# "additionalModelRequestFields": additional_request_fields, +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# "toolConfig": { +# "tools": [{"toolSpec": tool_spec}], +# "toolChoice": {"auto": {}}, +# }, +# } + +# model.update_config(additional_request_fields=additional_request_fields) +# response = model.stream(messages, [tool_spec]) + +# tru_chunks = await alist(response) +# exp_chunks = [ +# {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, +# metadata_event, +# ] + +# assert tru_chunks == exp_chunks +# bedrock_client.converse_stream.assert_called_once_with(**request) + + +# @pytest.mark.asyncio +# async def test_stream_output_guardrails_redacts_input_and_output( +# bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist +# ): +# model.update_config(guardrail_redact_output=True) +# metadata_event = { +# "metadata": { +# "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, +# "metrics": {"latencyMs": 245}, +# "trace": { +# "guardrail": { +# "outputAssessments": { +# "3e59qlue4hag": [ +# { +# "wordPolicy": { +# "customWords": [ +# { +# "match": "CACTUS", +# "action": "BLOCKED", +# "detected": True, +# } +# ] +# }, +# } +# ] +# }, +# } +# }, +# } +# } +# bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} + +# request = { +# "additionalModelRequestFields": additional_request_fields, +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# "toolConfig": { +# "tools": [{"toolSpec": tool_spec}], +# "toolChoice": {"auto": {}}, +# }, +# } + +# model.update_config(additional_request_fields=additional_request_fields) +# response = model.stream(messages, [tool_spec]) + +# tru_chunks = await alist(response) +# exp_chunks = [ +# {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, +# {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, +# metadata_event, +# ] + +# assert tru_chunks == exp_chunks +# bedrock_client.converse_stream.assert_called_once_with(**request) + + +# @pytest.mark.asyncio +# async def test_stream_output_no_blocked_guardrails_doesnt_redact( +# bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist +# ): +# metadata_event = { +# "metadata": { +# "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, +# "metrics": {"latencyMs": 245}, +# "trace": { +# "guardrail": { +# "outputAssessments": { +# "3e59qlue4hag": [ +# { +# "wordPolicy": { +# "customWords": [ +# { +# "match": "CACTUS", +# "action": "NONE", +# "detected": True, +# } +# ] +# }, +# } +# ] +# }, +# } +# }, +# } +# } +# bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} + +# request = { +# "additionalModelRequestFields": additional_request_fields, +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# "toolConfig": { +# "tools": [{"toolSpec": tool_spec}], +# "toolChoice": {"auto": {}}, +# }, +# } + +# model.update_config(additional_request_fields=additional_request_fields) +# response = model.stream(messages, [tool_spec]) + +# tru_chunks = await alist(response) +# exp_chunks = [metadata_event] + +# assert tru_chunks == exp_chunks +# bedrock_client.converse_stream.assert_called_once_with(**request) + + +# @pytest.mark.asyncio +# async def test_stream_output_no_guardrail_redact( +# bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist +# ): +# metadata_event = { +# "metadata": { +# "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, +# "metrics": {"latencyMs": 245}, +# "trace": { +# "guardrail": { +# "outputAssessments": { +# "3e59qlue4hag": [ +# { +# "wordPolicy": { +# "customWords": [ +# { +# "match": "CACTUS", +# "action": "BLOCKED", +# "detected": True, +# } +# ] +# }, +# } +# ] +# }, +# } +# }, +# } +# } +# bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} + +# request = { +# "additionalModelRequestFields": additional_request_fields, +# "inferenceConfig": {}, +# "modelId": model_id, +# "messages": messages, +# "system": [], +# "toolConfig": { +# "tools": [{"toolSpec": tool_spec}], +# "toolChoice": {"auto": {}}, +# }, +# } + +# model.update_config( +# additional_request_fields=additional_request_fields, +# guardrail_redact_output=False, +# guardrail_redact_input=False, +# ) +# response = model.stream(messages, [tool_spec]) + +# tru_chunks = await alist(response) +# exp_chunks = [metadata_event] + +# assert tru_chunks == exp_chunks +# bedrock_client.converse_stream.assert_called_once_with(**request) + + +# @pytest.mark.asyncio +# async def test_stream_with_streaming_false(bedrock_client, alist, messages): +# """Test stream method with streaming=False.""" +# bedrock_client.converse.return_value = { +# "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, +# "stopReason": "end_turn", +# } + +# # Create model and call stream +# model = BedrockModel(model_id="test-model", streaming=False) +# response = model.stream(messages) + +# tru_events = await alist(response) +# exp_events = [ +# {"messageStart": {"role": "assistant"}}, +# {"contentBlockDelta": {"delta": {"text": "test"}}}, +# {"contentBlockStop": {}}, +# {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, +# ] +# assert tru_events == exp_events + +# bedrock_client.converse.assert_called_once() +# bedrock_client.converse_stream.assert_not_called() + + +# @pytest.mark.asyncio +# async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, messages): +# """Test stream method with streaming=False.""" +# bedrock_client.converse.return_value = { +# "output": { +# "message": { +# "role": "assistant", +# "content": [{"toolUse": {"toolUseId": "123", "name": "dummyTool", "input": {"hello": "world!"}}}], +# } +# }, +# "stopReason": "tool_use", +# } + +# # Create model and call stream +# model = BedrockModel(model_id="test-model", streaming=False) +# response = model.stream(messages) + +# tru_events = await alist(response) +# exp_events = [ +# {"messageStart": {"role": "assistant"}}, +# {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "dummyTool"}}}}, +# {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"hello": "world!"}'}}}}, +# {"contentBlockStop": {}}, +# {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, +# ] +# assert tru_events == exp_events + +# bedrock_client.converse.assert_called_once() +# bedrock_client.converse_stream.assert_not_called() + + +# @pytest.mark.asyncio +# async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, messages): +# """Test stream method with streaming=False.""" +# bedrock_client.converse.return_value = { +# "output": { +# "message": { +# "role": "assistant", +# "content": [ +# { +# "reasoningContent": { +# "reasoningText": {"text": "Thinking really hard....", "signature": "123"}, +# } +# } +# ], +# } +# }, +# "stopReason": "tool_use", +# } + +# # Create model and call stream +# model = BedrockModel(model_id="test-model", streaming=False) +# response = model.stream(messages) + +# tru_events = await alist(response) +# exp_events = [ +# {"messageStart": {"role": "assistant"}}, +# {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, +# {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "123"}}}}, +# {"contentBlockStop": {}}, +# {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, +# ] +# assert tru_events == exp_events + +# # Verify converse was called +# bedrock_client.converse.assert_called_once() +# bedrock_client.converse_stream.assert_not_called() + + +# @pytest.mark.asyncio +# async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages): +# """Test stream method with streaming=False.""" +# bedrock_client.converse.return_value = { +# "output": { +# "message": { +# "role": "assistant", +# "content": [ +# { +# "reasoningContent": { +# "reasoningText": {"text": "Thinking really hard...."}, +# } +# } +# ], +# } +# }, +# "stopReason": "tool_use", +# } + +# # Create model and call stream +# model = BedrockModel(model_id="test-model", streaming=False) +# response = model.stream(messages) + +# tru_events = await alist(response) +# exp_events = [ +# {"messageStart": {"role": "assistant"}}, +# {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, +# {"contentBlockStop": {}}, +# {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, +# ] +# assert tru_events == exp_events + +# bedrock_client.converse.assert_called_once() +# bedrock_client.converse_stream.assert_not_called() + + +# @pytest.mark.asyncio +# async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist, messages): +# """Test stream method with streaming=False.""" +# bedrock_client.converse.return_value = { +# "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, +# "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, +# "metrics": {"latencyMs": 1234}, +# "stopReason": "tool_use", +# } + +# # Create model and call stream +# model = BedrockModel(model_id="test-model", streaming=False) +# response = model.stream(messages) + +# tru_events = await alist(response) +# exp_events = [ +# {"messageStart": {"role": "assistant"}}, +# {"contentBlockDelta": {"delta": {"text": "test"}}}, +# {"contentBlockStop": {}}, +# {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, +# { +# "metadata": { +# "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, +# "metrics": {"latencyMs": 1234}, +# } +# }, +# ] +# assert tru_events == exp_events + +# # Verify converse was called +# bedrock_client.converse.assert_called_once() +# bedrock_client.converse_stream.assert_not_called() + + +# @pytest.mark.asyncio +# async def test_stream_input_guardrails(bedrock_client, alist, messages): +# """Test stream method with streaming=False.""" +# bedrock_client.converse.return_value = { +# "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, +# "trace": { +# "guardrail": { +# "inputAssessment": { +# "3e59qlue4hag": { +# "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]} +# } +# } +# } +# }, +# "stopReason": "end_turn", +# } + +# # Create model and call stream +# model = BedrockModel(model_id="test-model", streaming=False) +# response = model.stream(messages) + +# tru_events = await alist(response) +# exp_events = [ +# {"messageStart": {"role": "assistant"}}, +# {"contentBlockDelta": {"delta": {"text": "test"}}}, +# {"contentBlockStop": {}}, +# {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, +# { +# "metadata": { +# "trace": { +# "guardrail": { +# "inputAssessment": { +# "3e59qlue4hag": { +# "wordPolicy": { +# "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] +# } +# } +# } +# } +# } +# } +# }, +# {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, +# ] +# assert tru_events == exp_events + +# bedrock_client.converse.assert_called_once() +# bedrock_client.converse_stream.assert_not_called() + + +# @pytest.mark.asyncio +# async def test_stream_output_guardrails(bedrock_client, alist, messages): +# """Test stream method with streaming=False.""" +# bedrock_client.converse.return_value = { +# "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, +# "trace": { +# "guardrail": { +# "outputAssessments": { +# "3e59qlue4hag": [ +# { +# "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, +# } +# ] +# }, +# } +# }, +# "stopReason": "end_turn", +# } + +# model = BedrockModel(model_id="test-model", streaming=False) +# response = model.stream(messages) + +# tru_events = await alist(response) +# exp_events = [ +# {"messageStart": {"role": "assistant"}}, +# {"contentBlockDelta": {"delta": {"text": "test"}}}, +# {"contentBlockStop": {}}, +# {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, +# { +# "metadata": { +# "trace": { +# "guardrail": { +# "outputAssessments": { +# "3e59qlue4hag": [ +# { +# "wordPolicy": { +# "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] +# } +# } +# ] +# } +# } +# } +# } +# }, +# {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, +# ] +# assert tru_events == exp_events + +# bedrock_client.converse.assert_called_once() +# bedrock_client.converse_stream.assert_not_called() + + +# @pytest.mark.asyncio +# async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages): +# """Test stream method with streaming=False.""" +# bedrock_client.converse.return_value = { +# "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, +# "trace": { +# "guardrail": { +# "outputAssessments": { +# "3e59qlue4hag": [ +# { +# "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, +# } +# ] +# }, +# } +# }, +# "stopReason": "end_turn", +# } + +# model = BedrockModel(model_id="test-model", streaming=False) +# response = model.stream(messages) + +# tru_events = await alist(response) +# exp_events = [ +# {"messageStart": {"role": "assistant"}}, +# {"contentBlockDelta": {"delta": {"text": "test"}}}, +# {"contentBlockStop": {}}, +# {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, +# { +# "metadata": { +# "trace": { +# "guardrail": { +# "outputAssessments": { +# "3e59qlue4hag": [ +# { +# "wordPolicy": { +# "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] +# } +# } +# ] +# } +# } +# } +# } +# }, +# {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, +# ] +# assert tru_events == exp_events + +# bedrock_client.converse.assert_called_once() +# bedrock_client.converse_stream.assert_not_called() + + +# @pytest.mark.asyncio +# async def test_structured_output(bedrock_client, model, test_output_model_cls, alist): +# messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + +# bedrock_client.converse_stream.return_value = { +# "stream": [ +# {"messageStart": {"role": "assistant"}}, +# {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "TestOutputModel"}}}}, +# {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "John", "age": 30}'}}}}, +# {"contentBlockStop": {}}, +# {"messageStop": {"stopReason": "tool_use"}}, +# ] +# } + +# stream = model.structured_output(test_output_model_cls, messages) +# events = await alist(stream) + +# tru_output = events[-1] +# exp_output = {"output": test_output_model_cls(name="John", age=30)} +# assert tru_output == exp_output + + +# @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") +# @pytest.mark.asyncio +# async def test_add_note_on_client_error(bedrock_client, model, alist, messages): +# """Test that add_note is called on ClientError with region and model ID information.""" +# # Mock the client error response +# error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} +# bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") + +# # Call the stream method which should catch and add notes to the exception +# with pytest.raises(ClientError) as err: +# await alist(model.stream(messages)) + +# assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"] + + +# @pytest.mark.asyncio +# async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages): +# """Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception).""" +# # Mock the client error response +# error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} +# bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") + +# # Call the stream method which should catch and add notes to the exception +# with pytest.raises(ClientError): +# await alist(model.stream(messages)) + + +# @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") +# @pytest.mark.asyncio +# async def test_add_note_on_access_denied_exception(bedrock_client, model, alist, messages): +# """Test that add_note adds documentation link for AccessDeniedException.""" +# # Mock the client error response for access denied +# error_response = { +# "Error": { +# "Code": "AccessDeniedException", +# "Message": "An error occurred (AccessDeniedException) when calling the ConverseStream operation: " +# "You don't have access to the model with the specified model ID.", +# } +# } +# bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") + +# # Call the stream method which should catch and add notes to the exception +# with pytest.raises(ClientError) as err: +# await alist(model.stream(messages)) + +# assert err.value.__notes__ == [ +# "└ Bedrock region: us-west-2", +# "└ Model id: m1", +# "└ For more information see " +# "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue", +# ] + + +# @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") +# @pytest.mark.asyncio +# async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist, messages): +# """Test that add_note adds documentation link for ValidationException about on-demand throughput.""" +# # Mock the client error response for validation exception +# error_response = { +# "Error": { +# "Code": "ValidationException", +# "Message": "An error occurred (ValidationException) when calling the ConverseStream operation: " +# "Invocation of model ID anthropic.claude-3-7-sonnet-20250219-v1:0 with on-demand throughput " +# "isn’t supported. Retry your request with the ID or ARN of an inference profile that contains " +# "this model.", +# } +# } +# bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") + +# # Call the stream method which should catch and add notes to the exception +# with pytest.raises(ClientError) as err: +# await alist(model.stream(messages)) + +# assert err.value.__notes__ == [ +# "└ Bedrock region: us-west-2", +# "└ Model id: m1", +# "└ For more information see " +# "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported", +# ] + + +# @pytest.mark.asyncio +# async def test_stream_logging(bedrock_client, model, messages, caplog, alist): +# """Test that stream method logs debug messages at the expected stages.""" +# import logging + +# # Set the logger to debug level to capture debug messages +# caplog.set_level(logging.DEBUG, logger="strands.models.bedrock") + +# # Mock the response +# bedrock_client.converse_stream.return_value = {"stream": ["e1", "e2"]} + +# # Execute the stream method +# response = model.stream(messages) +# await alist(response) + +# # Check that the expected log messages are present +# log_text = caplog.text +# assert "formatting request" in log_text +# assert "request=<" in log_text +# assert "invoking model" in log_text +# assert "got response from model" in log_text +# assert "finished streaming response from model" in log_text + + +# def test_format_request_cleans_tool_result_content_blocks(model, model_id): +# """Test that format_request cleans toolResult blocks by removing extra fields.""" +# messages = [ +# { +# "role": "user", +# "content": [ +# { +# "toolResult": { +# "content": [{"text": "Tool output"}], +# "toolUseId": "tool123", +# "status": "success", +# "extraField": "should be removed", +# "mcpMetadata": {"server": "test"}, +# } +# }, +# ], +# } +# ] + +# formatted_request = model.format_request(messages) + +# # Verify toolResult only contains allowed fields in the formatted request +# tool_result = formatted_request["messages"][0]["content"][0]["toolResult"] +# expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"} +# assert tool_result == expected +# assert "extraField" not in tool_result +# assert "mcpMetadata" not in tool_result From 4db15c86514dfc09988fe103ab033cd800b47bf2 Mon Sep 17 00:00:00 2001 From: Abdullatif Alrashdan Date: Fri, 29 Aug 2025 13:05:52 +0000 Subject: [PATCH 3/3] Fix: resolve default region-specific model initialization in default agent settings --- src/strands/models/bedrock.py | 29 +- tests/strands/agent/test_agent.py | 6 +- tests/strands/models/test_bedrock.py | 1982 +++++++++++++------------- 3 files changed, 1024 insertions(+), 993 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index d61e563f7..9808d6571 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) DEFAULT_BEDROCK_REGION = "us-west-2" -DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" +DEFAULT_BEDROCK_MODEL_ID = "anthropic.claude-sonnet-4-20250514-v1:0" BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ "Input is too long for requested model", @@ -137,7 +137,11 @@ def __init__( # get default model id based on resolved region resolved_model_id = self._get_default_model_for_region(resolved_region) if resolved_model_id == "": - raise ValueError("default model {} is not available in {} region. Specify another model".format(DEFAULT_BEDROCK_MODEL_ID, resolved_region)) + raise ValueError( + "default model {} is not available in {} region. Specify another model".format( + DEFAULT_BEDROCK_MODEL_ID, resolved_region + ) + ) self.config = BedrockModel.BedrockConfig(model_id=resolved_model_id) self.update_config(**model_config) @@ -357,15 +361,18 @@ def _generate_redaction_events(self) -> list[StreamEvent]: return events def _get_default_model_for_region(self, region: str) -> str: - client = boto3.client("bedrock", region_name=region) - response = client.list_inference_profiles() - inferenceProfileSummary = response["inferenceProfileSummaries"] - - for profile in inferenceProfileSummary: - if DEFAULT_BEDROCK_MODEL_ID in profile["inferenceProfileId"]: - return profile["inferenceProfileId"] - - return "" + try: + client = boto3.client("bedrock", region_name=region) + response = client.list_inference_profiles() + inference_profile_summary = response["inferenceProfileSummaries"] + + for profile in inference_profile_summary: + if DEFAULT_BEDROCK_MODEL_ID in profile["inferenceProfileId"]: + return str(profile["inferenceProfileId"]) + + return "" + except ClientError as e: + raise e @override async def stream( diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index a8561abe4..fdfa9b83b 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -211,7 +211,7 @@ def test_agent__init__with_default_model(): agent = Agent() assert isinstance(agent.model, BedrockModel) - assert agent.model.config["model_id"] == DEFAULT_BEDROCK_MODEL_ID + assert agent.model.config["model_id"] and DEFAULT_BEDROCK_MODEL_ID in agent.model.config["model_id"] def test_agent__init__with_explicit_model(mock_model): @@ -891,7 +891,9 @@ def test_agent__del__(agent): def test_agent_init_with_no_model_or_model_id(): agent = Agent() assert agent.model is not None - assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID + assert agent.model.get_config().get("model_id") and DEFAULT_BEDROCK_MODEL_ID in agent.model.get_config().get( + "model_id" + ) def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index ccd6986a5..5586e479e 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -23,18 +23,18 @@ def session_cls(): mock_session_cls.return_value.region_name = None yield mock_session_cls + @pytest.fixture def mock_bedrock_inference_profiles(): with unittest.mock.patch.object(strands.models.bedrock.boto3, "client") as mock_boto_client: mock_bedrock = unittest.mock.MagicMock() mock_bedrock.list_inference_profiles.return_value = { - "inferenceProfileSummaries": [ - {"inferenceProfileId": "us.anthropic.claude-sonnet-4-20250514-v1:0"} - ] + "inferenceProfileSummaries": [{"inferenceProfileId": "us.anthropic.claude-sonnet-4-20250514-v1:0"}] } mock_boto_client.return_value = mock_bedrock yield mock_boto_client + @pytest.fixture def mock_client_method(session_cls): # the boto3.Session().client(...) method @@ -42,10 +42,10 @@ def mock_client_method(session_cls): @pytest.fixture -def bedrock_client(session_cls): +def bedrock_client(session_cls, region="us-west-2"): mock_client = session_cls.return_value.client.return_value mock_client.meta = unittest.mock.MagicMock() - mock_client.meta.region_name = "us-west-2" + mock_client.meta.region_name = region yield mock_client @@ -55,7 +55,7 @@ def model_id(): @pytest.fixture -def model(bedrock_client, model_id): +def model(bedrock_client, mock_bedrock_inference_profiles, model_id): _ = bedrock_client return BedrockModel(model_id=model_id) @@ -128,11 +128,11 @@ def test__init__default_model_id(bedrock_client, mock_bedrock_inference_profiles """Test that BedrockModel uses DEFAULT_MODEL_ID when no model_id is provided.""" _ = bedrock_client model = BedrockModel() - + tru_model_id = model.get_config().get("model_id") - exp_model_id = "us."+DEFAULT_BEDROCK_MODEL_ID + exp_model_id = DEFAULT_BEDROCK_MODEL_ID - assert tru_model_id == exp_model_id + assert tru_model_id and exp_model_id in tru_model_id def test__init__with_default_region(session_cls, mock_client_method, mock_bedrock_inference_profiles): @@ -266,9 +266,79 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id -def test_format_request_default(model, messages, model_id, mock_bedrock_inference_profiles): +def test_format_request_default(model, messages, model_id): + tru_request = model.format_request(messages) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + } + + assert tru_request == exp_request + + +def test_format_request_additional_request_fields(model, messages, model_id, additional_request_fields): + model.update_config(additional_request_fields=additional_request_fields) + tru_request = model.format_request(messages) + exp_request = { + "additionalModelRequestFields": additional_request_fields, + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + } + + assert tru_request == exp_request + + +def test_format_request_additional_response_field_paths(model, messages, model_id, additional_response_field_paths): + model.update_config(additional_response_field_paths=additional_response_field_paths) + tru_request = model.format_request(messages) + exp_request = { + "additionalModelResponseFieldPaths": additional_response_field_paths, + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + } + + assert tru_request == exp_request + + +def test_format_request_guardrail_config(model, messages, model_id, guardrail_config): + model.update_config(**guardrail_config) + tru_request = model.format_request(messages) + exp_request = { + "guardrailConfig": { + "guardrailIdentifier": guardrail_config["guardrail_id"], + "guardrailVersion": guardrail_config["guardrail_version"], + "trace": guardrail_config["guardrail_trace"], + "streamProcessingMode": guardrail_config["guardrail_stream_processing_mode"], + }, + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + } + + assert tru_request == exp_request + + +def test_format_request_guardrail_config_without_trace_or_stream_processing_mode(model, messages, model_id): + model.update_config( + **{ + "guardrail_id": "g1", + "guardrail_version": "v1", + } + ) tru_request = model.format_request(messages) exp_request = { + "guardrailConfig": { + "guardrailIdentifier": "g1", + "guardrailVersion": "v1", + "trace": "enabled", + }, "inferenceConfig": {}, "modelId": model_id, "messages": messages, @@ -278,973 +348,925 @@ def test_format_request_default(model, messages, model_id, mock_bedrock_inferenc assert tru_request == exp_request -# def test_format_request_additional_request_fields(model, messages, model_id, additional_request_fields): -# model.update_config(additional_request_fields=additional_request_fields) -# tru_request = model.format_request(messages) -# exp_request = { -# "additionalModelRequestFields": additional_request_fields, -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# } - -# assert tru_request == exp_request - - -# def test_format_request_additional_response_field_paths(model, messages, model_id, additional_response_field_paths): -# model.update_config(additional_response_field_paths=additional_response_field_paths) -# tru_request = model.format_request(messages) -# exp_request = { -# "additionalModelResponseFieldPaths": additional_response_field_paths, -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# } - -# assert tru_request == exp_request - - -# def test_format_request_guardrail_config(model, messages, model_id, guardrail_config): -# model.update_config(**guardrail_config) -# tru_request = model.format_request(messages) -# exp_request = { -# "guardrailConfig": { -# "guardrailIdentifier": guardrail_config["guardrail_id"], -# "guardrailVersion": guardrail_config["guardrail_version"], -# "trace": guardrail_config["guardrail_trace"], -# "streamProcessingMode": guardrail_config["guardrail_stream_processing_mode"], -# }, -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# } - -# assert tru_request == exp_request - - -# def test_format_request_guardrail_config_without_trace_or_stream_processing_mode(model, messages, model_id): -# model.update_config( -# **{ -# "guardrail_id": "g1", -# "guardrail_version": "v1", -# } -# ) -# tru_request = model.format_request(messages) -# exp_request = { -# "guardrailConfig": { -# "guardrailIdentifier": "g1", -# "guardrailVersion": "v1", -# "trace": "enabled", -# }, -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# } - -# assert tru_request == exp_request - - -# def test_format_request_inference_config(model, messages, model_id, inference_config): -# model.update_config(**inference_config) -# tru_request = model.format_request(messages) -# exp_request = { -# "inferenceConfig": { -# "maxTokens": inference_config["max_tokens"], -# "stopSequences": inference_config["stop_sequences"], -# "temperature": inference_config["temperature"], -# "topP": inference_config["top_p"], -# }, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# } - -# assert tru_request == exp_request - - -# def test_format_request_system_prompt(model, messages, model_id, system_prompt): -# tru_request = model.format_request(messages, system_prompt=system_prompt) -# exp_request = { -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [{"text": system_prompt}], -# } - -# assert tru_request == exp_request - - -# def test_format_request_tool_specs(model, messages, model_id, tool_spec): -# tru_request = model.format_request(messages, [tool_spec]) -# exp_request = { -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# "toolConfig": { -# "tools": [{"toolSpec": tool_spec}], -# "toolChoice": {"auto": {}}, -# }, -# } - -# assert tru_request == exp_request - - -# def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): -# model.update_config(cache_prompt=cache_type, cache_tools=cache_type) -# tru_request = model.format_request(messages, [tool_spec]) -# exp_request = { -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [{"cachePoint": {"type": cache_type}}], -# "toolConfig": { -# "tools": [ -# {"toolSpec": tool_spec}, -# {"cachePoint": {"type": cache_type}}, -# ], -# "toolChoice": {"auto": {}}, -# }, -# } - -# assert tru_request == exp_request - - -# @pytest.mark.asyncio -# async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist): -# error_message = "Rate exceeded" -# bedrock_client.converse_stream.side_effect = EventStreamError( -# {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "ConverseStream" -# ) - -# with pytest.raises(ModelThrottledException) as excinfo: -# await alist(model.stream(messages)) - -# assert error_message in str(excinfo.value) -# bedrock_client.converse_stream.assert_called_once_with( -# modelId="m1", messages=messages, system=[], inferenceConfig={} -# ) - - -# @pytest.mark.asyncio -# async def test_stream_with_invalid_content_throws(bedrock_client, model, alist): -# # We used to hang on None, so ensure we don't regress: https://github.com/strands-agents/sdk-python/issues/642 -# messages = [{"role": "user", "content": None}] - -# with pytest.raises(TypeError): -# await alist(model.stream(messages)) - - -# @pytest.mark.asyncio -# async def test_stream_throttling_exception_from_general_exception(bedrock_client, model, messages, alist): -# error_message = "ThrottlingException: Rate exceeded for ConverseStream" -# bedrock_client.converse_stream.side_effect = ClientError( -# {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "Any" -# ) - -# with pytest.raises(ModelThrottledException) as excinfo: -# await alist(model.stream(messages)) - -# assert error_message in str(excinfo.value) -# bedrock_client.converse_stream.assert_called_once_with( -# modelId="m1", messages=messages, system=[], inferenceConfig={} -# ) - - -# @pytest.mark.asyncio -# async def test_general_exception_is_raised(bedrock_client, model, messages, alist): -# error_message = "Should be raised up" -# bedrock_client.converse_stream.side_effect = ValueError(error_message) - -# with pytest.raises(ValueError) as excinfo: -# await alist(model.stream(messages)) - -# assert error_message in str(excinfo.value) -# bedrock_client.converse_stream.assert_called_once_with( -# modelId="m1", messages=messages, system=[], inferenceConfig={} -# ) - - -# @pytest.mark.asyncio -# async def test_stream(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist): -# bedrock_client.converse_stream.return_value = {"stream": ["e1", "e2"]} - -# request = { -# "additionalModelRequestFields": additional_request_fields, -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# "toolConfig": { -# "tools": [{"toolSpec": tool_spec}], -# "toolChoice": {"auto": {}}, -# }, -# } - -# model.update_config(additional_request_fields=additional_request_fields) -# response = model.stream(messages, [tool_spec]) - -# tru_chunks = await alist(response) -# exp_chunks = ["e1", "e2"] - -# assert tru_chunks == exp_chunks -# bedrock_client.converse_stream.assert_called_once_with(**request) - - -# @pytest.mark.asyncio -# async def test_stream_stream_input_guardrails( -# bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist -# ): -# metadata_event = { -# "metadata": { -# "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, -# "metrics": {"latencyMs": 245}, -# "trace": { -# "guardrail": { -# "inputAssessment": { -# "3e59qlue4hag": { -# "wordPolicy": { -# "customWords": [ -# { -# "match": "CACTUS", -# "action": "BLOCKED", -# "detected": True, -# } -# ] -# } -# } -# } -# } -# }, -# } -# } -# bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} - -# request = { -# "additionalModelRequestFields": additional_request_fields, -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# "toolConfig": { -# "tools": [{"toolSpec": tool_spec}], -# "toolChoice": {"auto": {}}, -# }, -# } - -# model.update_config(additional_request_fields=additional_request_fields) -# response = model.stream(messages, [tool_spec]) - -# tru_chunks = await alist(response) -# exp_chunks = [ -# {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, -# metadata_event, -# ] - -# assert tru_chunks == exp_chunks -# bedrock_client.converse_stream.assert_called_once_with(**request) - - -# @pytest.mark.asyncio -# async def test_stream_stream_output_guardrails( -# bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist -# ): -# model.update_config(guardrail_redact_input=False, guardrail_redact_output=True) -# metadata_event = { -# "metadata": { -# "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, -# "metrics": {"latencyMs": 245}, -# "trace": { -# "guardrail": { -# "outputAssessments": { -# "3e59qlue4hag": [ -# { -# "wordPolicy": { -# "customWords": [ -# { -# "match": "CACTUS", -# "action": "BLOCKED", -# "detected": True, -# } -# ] -# }, -# } -# ] -# }, -# } -# }, -# } -# } -# bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} - -# request = { -# "additionalModelRequestFields": additional_request_fields, -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# "toolConfig": { -# "tools": [{"toolSpec": tool_spec}], -# "toolChoice": {"auto": {}}, -# }, -# } - -# model.update_config(additional_request_fields=additional_request_fields) -# response = model.stream(messages, [tool_spec]) - -# tru_chunks = await alist(response) -# exp_chunks = [ -# {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, -# metadata_event, -# ] - -# assert tru_chunks == exp_chunks -# bedrock_client.converse_stream.assert_called_once_with(**request) - - -# @pytest.mark.asyncio -# async def test_stream_output_guardrails_redacts_input_and_output( -# bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist -# ): -# model.update_config(guardrail_redact_output=True) -# metadata_event = { -# "metadata": { -# "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, -# "metrics": {"latencyMs": 245}, -# "trace": { -# "guardrail": { -# "outputAssessments": { -# "3e59qlue4hag": [ -# { -# "wordPolicy": { -# "customWords": [ -# { -# "match": "CACTUS", -# "action": "BLOCKED", -# "detected": True, -# } -# ] -# }, -# } -# ] -# }, -# } -# }, -# } -# } -# bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} - -# request = { -# "additionalModelRequestFields": additional_request_fields, -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# "toolConfig": { -# "tools": [{"toolSpec": tool_spec}], -# "toolChoice": {"auto": {}}, -# }, -# } - -# model.update_config(additional_request_fields=additional_request_fields) -# response = model.stream(messages, [tool_spec]) - -# tru_chunks = await alist(response) -# exp_chunks = [ -# {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, -# {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, -# metadata_event, -# ] - -# assert tru_chunks == exp_chunks -# bedrock_client.converse_stream.assert_called_once_with(**request) - - -# @pytest.mark.asyncio -# async def test_stream_output_no_blocked_guardrails_doesnt_redact( -# bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist -# ): -# metadata_event = { -# "metadata": { -# "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, -# "metrics": {"latencyMs": 245}, -# "trace": { -# "guardrail": { -# "outputAssessments": { -# "3e59qlue4hag": [ -# { -# "wordPolicy": { -# "customWords": [ -# { -# "match": "CACTUS", -# "action": "NONE", -# "detected": True, -# } -# ] -# }, -# } -# ] -# }, -# } -# }, -# } -# } -# bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} - -# request = { -# "additionalModelRequestFields": additional_request_fields, -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# "toolConfig": { -# "tools": [{"toolSpec": tool_spec}], -# "toolChoice": {"auto": {}}, -# }, -# } - -# model.update_config(additional_request_fields=additional_request_fields) -# response = model.stream(messages, [tool_spec]) - -# tru_chunks = await alist(response) -# exp_chunks = [metadata_event] - -# assert tru_chunks == exp_chunks -# bedrock_client.converse_stream.assert_called_once_with(**request) - - -# @pytest.mark.asyncio -# async def test_stream_output_no_guardrail_redact( -# bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist -# ): -# metadata_event = { -# "metadata": { -# "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, -# "metrics": {"latencyMs": 245}, -# "trace": { -# "guardrail": { -# "outputAssessments": { -# "3e59qlue4hag": [ -# { -# "wordPolicy": { -# "customWords": [ -# { -# "match": "CACTUS", -# "action": "BLOCKED", -# "detected": True, -# } -# ] -# }, -# } -# ] -# }, -# } -# }, -# } -# } -# bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} - -# request = { -# "additionalModelRequestFields": additional_request_fields, -# "inferenceConfig": {}, -# "modelId": model_id, -# "messages": messages, -# "system": [], -# "toolConfig": { -# "tools": [{"toolSpec": tool_spec}], -# "toolChoice": {"auto": {}}, -# }, -# } - -# model.update_config( -# additional_request_fields=additional_request_fields, -# guardrail_redact_output=False, -# guardrail_redact_input=False, -# ) -# response = model.stream(messages, [tool_spec]) - -# tru_chunks = await alist(response) -# exp_chunks = [metadata_event] - -# assert tru_chunks == exp_chunks -# bedrock_client.converse_stream.assert_called_once_with(**request) - - -# @pytest.mark.asyncio -# async def test_stream_with_streaming_false(bedrock_client, alist, messages): -# """Test stream method with streaming=False.""" -# bedrock_client.converse.return_value = { -# "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, -# "stopReason": "end_turn", -# } - -# # Create model and call stream -# model = BedrockModel(model_id="test-model", streaming=False) -# response = model.stream(messages) - -# tru_events = await alist(response) -# exp_events = [ -# {"messageStart": {"role": "assistant"}}, -# {"contentBlockDelta": {"delta": {"text": "test"}}}, -# {"contentBlockStop": {}}, -# {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, -# ] -# assert tru_events == exp_events - -# bedrock_client.converse.assert_called_once() -# bedrock_client.converse_stream.assert_not_called() - - -# @pytest.mark.asyncio -# async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, messages): -# """Test stream method with streaming=False.""" -# bedrock_client.converse.return_value = { -# "output": { -# "message": { -# "role": "assistant", -# "content": [{"toolUse": {"toolUseId": "123", "name": "dummyTool", "input": {"hello": "world!"}}}], -# } -# }, -# "stopReason": "tool_use", -# } - -# # Create model and call stream -# model = BedrockModel(model_id="test-model", streaming=False) -# response = model.stream(messages) - -# tru_events = await alist(response) -# exp_events = [ -# {"messageStart": {"role": "assistant"}}, -# {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "dummyTool"}}}}, -# {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"hello": "world!"}'}}}}, -# {"contentBlockStop": {}}, -# {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, -# ] -# assert tru_events == exp_events - -# bedrock_client.converse.assert_called_once() -# bedrock_client.converse_stream.assert_not_called() - - -# @pytest.mark.asyncio -# async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, messages): -# """Test stream method with streaming=False.""" -# bedrock_client.converse.return_value = { -# "output": { -# "message": { -# "role": "assistant", -# "content": [ -# { -# "reasoningContent": { -# "reasoningText": {"text": "Thinking really hard....", "signature": "123"}, -# } -# } -# ], -# } -# }, -# "stopReason": "tool_use", -# } - -# # Create model and call stream -# model = BedrockModel(model_id="test-model", streaming=False) -# response = model.stream(messages) - -# tru_events = await alist(response) -# exp_events = [ -# {"messageStart": {"role": "assistant"}}, -# {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, -# {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "123"}}}}, -# {"contentBlockStop": {}}, -# {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, -# ] -# assert tru_events == exp_events - -# # Verify converse was called -# bedrock_client.converse.assert_called_once() -# bedrock_client.converse_stream.assert_not_called() - - -# @pytest.mark.asyncio -# async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages): -# """Test stream method with streaming=False.""" -# bedrock_client.converse.return_value = { -# "output": { -# "message": { -# "role": "assistant", -# "content": [ -# { -# "reasoningContent": { -# "reasoningText": {"text": "Thinking really hard...."}, -# } -# } -# ], -# } -# }, -# "stopReason": "tool_use", -# } - -# # Create model and call stream -# model = BedrockModel(model_id="test-model", streaming=False) -# response = model.stream(messages) - -# tru_events = await alist(response) -# exp_events = [ -# {"messageStart": {"role": "assistant"}}, -# {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, -# {"contentBlockStop": {}}, -# {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, -# ] -# assert tru_events == exp_events - -# bedrock_client.converse.assert_called_once() -# bedrock_client.converse_stream.assert_not_called() - - -# @pytest.mark.asyncio -# async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist, messages): -# """Test stream method with streaming=False.""" -# bedrock_client.converse.return_value = { -# "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, -# "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, -# "metrics": {"latencyMs": 1234}, -# "stopReason": "tool_use", -# } - -# # Create model and call stream -# model = BedrockModel(model_id="test-model", streaming=False) -# response = model.stream(messages) - -# tru_events = await alist(response) -# exp_events = [ -# {"messageStart": {"role": "assistant"}}, -# {"contentBlockDelta": {"delta": {"text": "test"}}}, -# {"contentBlockStop": {}}, -# {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, -# { -# "metadata": { -# "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, -# "metrics": {"latencyMs": 1234}, -# } -# }, -# ] -# assert tru_events == exp_events - -# # Verify converse was called -# bedrock_client.converse.assert_called_once() -# bedrock_client.converse_stream.assert_not_called() - - -# @pytest.mark.asyncio -# async def test_stream_input_guardrails(bedrock_client, alist, messages): -# """Test stream method with streaming=False.""" -# bedrock_client.converse.return_value = { -# "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, -# "trace": { -# "guardrail": { -# "inputAssessment": { -# "3e59qlue4hag": { -# "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]} -# } -# } -# } -# }, -# "stopReason": "end_turn", -# } - -# # Create model and call stream -# model = BedrockModel(model_id="test-model", streaming=False) -# response = model.stream(messages) - -# tru_events = await alist(response) -# exp_events = [ -# {"messageStart": {"role": "assistant"}}, -# {"contentBlockDelta": {"delta": {"text": "test"}}}, -# {"contentBlockStop": {}}, -# {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, -# { -# "metadata": { -# "trace": { -# "guardrail": { -# "inputAssessment": { -# "3e59qlue4hag": { -# "wordPolicy": { -# "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] -# } -# } -# } -# } -# } -# } -# }, -# {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, -# ] -# assert tru_events == exp_events - -# bedrock_client.converse.assert_called_once() -# bedrock_client.converse_stream.assert_not_called() - - -# @pytest.mark.asyncio -# async def test_stream_output_guardrails(bedrock_client, alist, messages): -# """Test stream method with streaming=False.""" -# bedrock_client.converse.return_value = { -# "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, -# "trace": { -# "guardrail": { -# "outputAssessments": { -# "3e59qlue4hag": [ -# { -# "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, -# } -# ] -# }, -# } -# }, -# "stopReason": "end_turn", -# } - -# model = BedrockModel(model_id="test-model", streaming=False) -# response = model.stream(messages) - -# tru_events = await alist(response) -# exp_events = [ -# {"messageStart": {"role": "assistant"}}, -# {"contentBlockDelta": {"delta": {"text": "test"}}}, -# {"contentBlockStop": {}}, -# {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, -# { -# "metadata": { -# "trace": { -# "guardrail": { -# "outputAssessments": { -# "3e59qlue4hag": [ -# { -# "wordPolicy": { -# "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] -# } -# } -# ] -# } -# } -# } -# } -# }, -# {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, -# ] -# assert tru_events == exp_events - -# bedrock_client.converse.assert_called_once() -# bedrock_client.converse_stream.assert_not_called() - - -# @pytest.mark.asyncio -# async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages): -# """Test stream method with streaming=False.""" -# bedrock_client.converse.return_value = { -# "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, -# "trace": { -# "guardrail": { -# "outputAssessments": { -# "3e59qlue4hag": [ -# { -# "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, -# } -# ] -# }, -# } -# }, -# "stopReason": "end_turn", -# } - -# model = BedrockModel(model_id="test-model", streaming=False) -# response = model.stream(messages) - -# tru_events = await alist(response) -# exp_events = [ -# {"messageStart": {"role": "assistant"}}, -# {"contentBlockDelta": {"delta": {"text": "test"}}}, -# {"contentBlockStop": {}}, -# {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, -# { -# "metadata": { -# "trace": { -# "guardrail": { -# "outputAssessments": { -# "3e59qlue4hag": [ -# { -# "wordPolicy": { -# "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] -# } -# } -# ] -# } -# } -# } -# } -# }, -# {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, -# ] -# assert tru_events == exp_events - -# bedrock_client.converse.assert_called_once() -# bedrock_client.converse_stream.assert_not_called() - - -# @pytest.mark.asyncio -# async def test_structured_output(bedrock_client, model, test_output_model_cls, alist): -# messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] - -# bedrock_client.converse_stream.return_value = { -# "stream": [ -# {"messageStart": {"role": "assistant"}}, -# {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "TestOutputModel"}}}}, -# {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "John", "age": 30}'}}}}, -# {"contentBlockStop": {}}, -# {"messageStop": {"stopReason": "tool_use"}}, -# ] -# } - -# stream = model.structured_output(test_output_model_cls, messages) -# events = await alist(stream) - -# tru_output = events[-1] -# exp_output = {"output": test_output_model_cls(name="John", age=30)} -# assert tru_output == exp_output - - -# @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") -# @pytest.mark.asyncio -# async def test_add_note_on_client_error(bedrock_client, model, alist, messages): -# """Test that add_note is called on ClientError with region and model ID information.""" -# # Mock the client error response -# error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} -# bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") - -# # Call the stream method which should catch and add notes to the exception -# with pytest.raises(ClientError) as err: -# await alist(model.stream(messages)) - -# assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"] - - -# @pytest.mark.asyncio -# async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages): -# """Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception).""" -# # Mock the client error response -# error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} -# bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") - -# # Call the stream method which should catch and add notes to the exception -# with pytest.raises(ClientError): -# await alist(model.stream(messages)) - - -# @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") -# @pytest.mark.asyncio -# async def test_add_note_on_access_denied_exception(bedrock_client, model, alist, messages): -# """Test that add_note adds documentation link for AccessDeniedException.""" -# # Mock the client error response for access denied -# error_response = { -# "Error": { -# "Code": "AccessDeniedException", -# "Message": "An error occurred (AccessDeniedException) when calling the ConverseStream operation: " -# "You don't have access to the model with the specified model ID.", -# } -# } -# bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") - -# # Call the stream method which should catch and add notes to the exception -# with pytest.raises(ClientError) as err: -# await alist(model.stream(messages)) - -# assert err.value.__notes__ == [ -# "└ Bedrock region: us-west-2", -# "└ Model id: m1", -# "└ For more information see " -# "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue", -# ] - - -# @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") -# @pytest.mark.asyncio -# async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist, messages): -# """Test that add_note adds documentation link for ValidationException about on-demand throughput.""" -# # Mock the client error response for validation exception -# error_response = { -# "Error": { -# "Code": "ValidationException", -# "Message": "An error occurred (ValidationException) when calling the ConverseStream operation: " -# "Invocation of model ID anthropic.claude-3-7-sonnet-20250219-v1:0 with on-demand throughput " -# "isn’t supported. Retry your request with the ID or ARN of an inference profile that contains " -# "this model.", -# } -# } -# bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") - -# # Call the stream method which should catch and add notes to the exception -# with pytest.raises(ClientError) as err: -# await alist(model.stream(messages)) - -# assert err.value.__notes__ == [ -# "└ Bedrock region: us-west-2", -# "└ Model id: m1", -# "└ For more information see " -# "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported", -# ] - - -# @pytest.mark.asyncio -# async def test_stream_logging(bedrock_client, model, messages, caplog, alist): -# """Test that stream method logs debug messages at the expected stages.""" -# import logging - -# # Set the logger to debug level to capture debug messages -# caplog.set_level(logging.DEBUG, logger="strands.models.bedrock") - -# # Mock the response -# bedrock_client.converse_stream.return_value = {"stream": ["e1", "e2"]} - -# # Execute the stream method -# response = model.stream(messages) -# await alist(response) - -# # Check that the expected log messages are present -# log_text = caplog.text -# assert "formatting request" in log_text -# assert "request=<" in log_text -# assert "invoking model" in log_text -# assert "got response from model" in log_text -# assert "finished streaming response from model" in log_text - - -# def test_format_request_cleans_tool_result_content_blocks(model, model_id): -# """Test that format_request cleans toolResult blocks by removing extra fields.""" -# messages = [ -# { -# "role": "user", -# "content": [ -# { -# "toolResult": { -# "content": [{"text": "Tool output"}], -# "toolUseId": "tool123", -# "status": "success", -# "extraField": "should be removed", -# "mcpMetadata": {"server": "test"}, -# } -# }, -# ], -# } -# ] - -# formatted_request = model.format_request(messages) - -# # Verify toolResult only contains allowed fields in the formatted request -# tool_result = formatted_request["messages"][0]["content"][0]["toolResult"] -# expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"} -# assert tool_result == expected -# assert "extraField" not in tool_result -# assert "mcpMetadata" not in tool_result +def test_format_request_inference_config(model, messages, model_id, inference_config): + model.update_config(**inference_config) + tru_request = model.format_request(messages) + exp_request = { + "inferenceConfig": { + "maxTokens": inference_config["max_tokens"], + "stopSequences": inference_config["stop_sequences"], + "temperature": inference_config["temperature"], + "topP": inference_config["top_p"], + }, + "modelId": model_id, + "messages": messages, + "system": [], + } + + assert tru_request == exp_request + + +def test_format_request_system_prompt(model, messages, model_id, system_prompt): + tru_request = model.format_request(messages, system_prompt=system_prompt) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [{"text": system_prompt}], + } + + assert tru_request == exp_request + + +def test_format_request_tool_specs(model, messages, model_id, tool_spec): + tru_request = model.format_request(messages, [tool_spec]) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": {"auto": {}}, + }, + } + + assert tru_request == exp_request + + +def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): + model.update_config(cache_prompt=cache_type, cache_tools=cache_type) + tru_request = model.format_request(messages, [tool_spec]) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [{"cachePoint": {"type": cache_type}}], + "toolConfig": { + "tools": [ + {"toolSpec": tool_spec}, + {"cachePoint": {"type": cache_type}}, + ], + "toolChoice": {"auto": {}}, + }, + } + + assert tru_request == exp_request + + +@pytest.mark.asyncio +async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist): + error_message = "Rate exceeded" + bedrock_client.converse_stream.side_effect = EventStreamError( + {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "ConverseStream" + ) + + with pytest.raises(ModelThrottledException) as excinfo: + await alist(model.stream(messages)) + + assert error_message in str(excinfo.value) + bedrock_client.converse_stream.assert_called_once_with( + modelId="m1", messages=messages, system=[], inferenceConfig={} + ) + + +@pytest.mark.asyncio +async def test_stream_with_invalid_content_throws(bedrock_client, model, alist): + # We used to hang on None, so ensure we don't regress: https://github.com/strands-agents/sdk-python/issues/642 + messages = [{"role": "user", "content": None}] + + with pytest.raises(TypeError): + await alist(model.stream(messages)) + + +@pytest.mark.asyncio +async def test_stream_throttling_exception_from_general_exception(bedrock_client, model, messages, alist): + error_message = "ThrottlingException: Rate exceeded for ConverseStream" + bedrock_client.converse_stream.side_effect = ClientError( + {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "Any" + ) + + with pytest.raises(ModelThrottledException) as excinfo: + await alist(model.stream(messages)) + + assert error_message in str(excinfo.value) + bedrock_client.converse_stream.assert_called_once_with( + modelId="m1", messages=messages, system=[], inferenceConfig={} + ) + + +@pytest.mark.asyncio +async def test_general_exception_is_raised(bedrock_client, model, messages, alist): + error_message = "Should be raised up" + bedrock_client.converse_stream.side_effect = ValueError(error_message) + + with pytest.raises(ValueError) as excinfo: + await alist(model.stream(messages)) + + assert error_message in str(excinfo.value) + bedrock_client.converse_stream.assert_called_once_with( + modelId="m1", messages=messages, system=[], inferenceConfig={} + ) + + +@pytest.mark.asyncio +async def test_stream(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist): + bedrock_client.converse_stream.return_value = {"stream": ["e1", "e2"]} + + request = { + "additionalModelRequestFields": additional_request_fields, + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": {"auto": {}}, + }, + } + + model.update_config(additional_request_fields=additional_request_fields) + response = model.stream(messages, [tool_spec]) + + tru_chunks = await alist(response) + exp_chunks = ["e1", "e2"] + + assert tru_chunks == exp_chunks + bedrock_client.converse_stream.assert_called_once_with(**request) + + +@pytest.mark.asyncio +async def test_stream_stream_input_guardrails( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist +): + metadata_event = { + "metadata": { + "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "metrics": {"latencyMs": 245}, + "trace": { + "guardrail": { + "inputAssessment": { + "3e59qlue4hag": { + "wordPolicy": { + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] + } + } + } + } + }, + } + } + bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} + + request = { + "additionalModelRequestFields": additional_request_fields, + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": {"auto": {}}, + }, + } + + model.update_config(additional_request_fields=additional_request_fields) + response = model.stream(messages, [tool_spec]) + + tru_chunks = await alist(response) + exp_chunks = [ + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + metadata_event, + ] + + assert tru_chunks == exp_chunks + bedrock_client.converse_stream.assert_called_once_with(**request) + + +@pytest.mark.asyncio +async def test_stream_stream_output_guardrails( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist +): + model.update_config(guardrail_redact_input=False, guardrail_redact_output=True) + metadata_event = { + "metadata": { + "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "metrics": {"latencyMs": 245}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] + }, + } + ] + }, + } + }, + } + } + bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} + + request = { + "additionalModelRequestFields": additional_request_fields, + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": {"auto": {}}, + }, + } + + model.update_config(additional_request_fields=additional_request_fields) + response = model.stream(messages, [tool_spec]) + + tru_chunks = await alist(response) + exp_chunks = [ + {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, + metadata_event, + ] + + assert tru_chunks == exp_chunks + bedrock_client.converse_stream.assert_called_once_with(**request) + + +@pytest.mark.asyncio +async def test_stream_output_guardrails_redacts_input_and_output( + bedrock_client, + mock_bedrock_inference_profiles, + model, + messages, + tool_spec, + model_id, + additional_request_fields, + alist, +): + model.update_config(guardrail_redact_output=True) + metadata_event = { + "metadata": { + "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "metrics": {"latencyMs": 245}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] + }, + } + ] + }, + } + }, + } + } + bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} + + request = { + "additionalModelRequestFields": additional_request_fields, + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": {"auto": {}}, + }, + } + + model.update_config(additional_request_fields=additional_request_fields) + response = model.stream(messages, [tool_spec]) + + tru_chunks = await alist(response) + exp_chunks = [ + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, + metadata_event, + ] + + assert tru_chunks == exp_chunks + bedrock_client.converse_stream.assert_called_once_with(**request) + + +@pytest.mark.asyncio +async def test_stream_output_no_blocked_guardrails_doesnt_redact( + bedrock_client, + mock_bedrock_inference_profiles, + model, + messages, + tool_spec, + model_id, + additional_request_fields, + alist, +): + metadata_event = { + "metadata": { + "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "metrics": {"latencyMs": 245}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [ + { + "match": "CACTUS", + "action": "NONE", + "detected": True, + } + ] + }, + } + ] + }, + } + }, + } + } + bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} + + request = { + "additionalModelRequestFields": additional_request_fields, + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": {"auto": {}}, + }, + } + + model.update_config(additional_request_fields=additional_request_fields) + response = model.stream(messages, [tool_spec]) + + tru_chunks = await alist(response) + exp_chunks = [metadata_event] + + assert tru_chunks == exp_chunks + bedrock_client.converse_stream.assert_called_once_with(**request) + + +@pytest.mark.asyncio +async def test_stream_output_no_guardrail_redact( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist +): + metadata_event = { + "metadata": { + "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "metrics": {"latencyMs": 245}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] + }, + } + ] + }, + } + }, + } + } + bedrock_client.converse_stream.return_value = {"stream": [metadata_event]} + + request = { + "additionalModelRequestFields": additional_request_fields, + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": {"auto": {}}, + }, + } + + model.update_config( + additional_request_fields=additional_request_fields, + guardrail_redact_output=False, + guardrail_redact_input=False, + ) + response = model.stream(messages, [tool_spec]) + + tru_chunks = await alist(response) + exp_chunks = [metadata_event] + + assert tru_chunks == exp_chunks + bedrock_client.converse_stream.assert_called_once_with(**request) + + +@pytest.mark.asyncio +async def test_stream_with_streaming_false(bedrock_client, mock_bedrock_inference_profiles, alist, messages): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "stopReason": "end_turn", + } + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + ] + assert tru_events == exp_events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +@pytest.mark.asyncio +async def test_stream_with_streaming_false_and_tool_use( + bedrock_client, mock_bedrock_inference_profiles, alist, messages +): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "dummyTool", "input": {"hello": "world!"}}}], + } + }, + "stopReason": "tool_use", + } + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "dummyTool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"hello": "world!"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + assert tru_events == exp_events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +@pytest.mark.asyncio +async def test_stream_with_streaming_false_and_reasoning( + bedrock_client, mock_bedrock_inference_profiles, alist, messages +): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": {"text": "Thinking really hard....", "signature": "123"}, + } + } + ], + } + }, + "stopReason": "tool_use", + } + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "123"}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + assert tru_events == exp_events + + # Verify converse was called + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +@pytest.mark.asyncio +async def test_stream_and_reasoning_no_signature(bedrock_client, mock_bedrock_inference_profiles, alist, messages): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": {"text": "Thinking really hard...."}, + } + } + ], + } + }, + "stopReason": "tool_use", + } + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + assert tru_events == exp_events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +@pytest.mark.asyncio +async def test_stream_with_streaming_false_with_metrics_and_usage( + bedrock_client, mock_bedrock_inference_profiles, alist, messages +): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, + "metrics": {"latencyMs": 1234}, + "stopReason": "tool_use", + } + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + { + "metadata": { + "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, + "metrics": {"latencyMs": 1234}, + } + }, + ] + assert tru_events == exp_events + + # Verify converse was called + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +@pytest.mark.asyncio +async def test_stream_input_guardrails(bedrock_client, mock_bedrock_inference_profiles, alist, messages): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "inputAssessment": { + "3e59qlue4hag": { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]} + } + } + } + }, + "stopReason": "end_turn", + } + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "inputAssessment": { + "3e59qlue4hag": { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + assert tru_events == exp_events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +@pytest.mark.asyncio +async def test_stream_output_guardrails(bedrock_client, mock_bedrock_inference_profiles, alist, messages): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, + } + ] + }, + } + }, + "stopReason": "end_turn", + } + + model = BedrockModel(model_id="test-model", streaming=False) + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + ] + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + assert tru_events == exp_events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +@pytest.mark.asyncio +async def test_stream_output_guardrails_redacts_output( + bedrock_client, mock_bedrock_inference_profiles, alist, messages +): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, + } + ] + }, + } + }, + "stopReason": "end_turn", + } + + model = BedrockModel(model_id="test-model", streaming=False) + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + ] + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + assert tru_events == exp_events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +@pytest.mark.asyncio +async def test_structured_output(bedrock_client, model, test_output_model_cls, alist): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + bedrock_client.converse_stream.return_value = { + "stream": [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "TestOutputModel"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "John", "age": 30}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + } + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + tru_output = events[-1] + exp_output = {"output": test_output_model_cls(name="John", age=30)} + assert tru_output == exp_output + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") +@pytest.mark.asyncio +async def test_add_note_on_client_error(bedrock_client, model, alist, messages): + """Test that add_note is called on ClientError with region and model ID information.""" + # Mock the client error response + error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} + bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") + + # Call the stream method which should catch and add notes to the exception + with pytest.raises(ClientError) as err: + await alist(model.stream(messages)) + + assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"] + + +@pytest.mark.asyncio +async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages): + """Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception).""" + # Mock the client error response + error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} + bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") + + # Call the stream method which should catch and add notes to the exception + with pytest.raises(ClientError): + await alist(model.stream(messages)) + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") +@pytest.mark.asyncio +async def test_add_note_on_access_denied_exception(bedrock_client, model, alist, messages): + """Test that add_note adds documentation link for AccessDeniedException.""" + # Mock the client error response for access denied + error_response = { + "Error": { + "Code": "AccessDeniedException", + "Message": "An error occurred (AccessDeniedException) when calling the ConverseStream operation: " + "You don't have access to the model with the specified model ID.", + } + } + bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") + + # Call the stream method which should catch and add notes to the exception + with pytest.raises(ClientError) as err: + await alist(model.stream(messages)) + + assert err.value.__notes__ == [ + "└ Bedrock region: us-west-2", + "└ Model id: m1", + "└ For more information see " + "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue", + ] + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") +@pytest.mark.asyncio +async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist, messages): + """Test that add_note adds documentation link for ValidationException about on-demand throughput.""" + # Mock the client error response for validation exception + error_response = { + "Error": { + "Code": "ValidationException", + "Message": "An error occurred (ValidationException) when calling the ConverseStream operation: " + "Invocation of model ID anthropic.claude-3-7-sonnet-20250219-v1:0 with on-demand throughput " + "isn’t supported. Retry your request with the ID or ARN of an inference profile that contains " + "this model.", + } + } + bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") + + # Call the stream method which should catch and add notes to the exception + with pytest.raises(ClientError) as err: + await alist(model.stream(messages)) + + assert err.value.__notes__ == [ + "└ Bedrock region: us-west-2", + "└ Model id: m1", + "└ For more information see " + "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported", + ] + + +@pytest.mark.asyncio +async def test_stream_logging(bedrock_client, model, messages, caplog, alist): + """Test that stream method logs debug messages at the expected stages.""" + import logging + + # Set the logger to debug level to capture debug messages + caplog.set_level(logging.DEBUG, logger="strands.models.bedrock") + + # Mock the response + bedrock_client.converse_stream.return_value = {"stream": ["e1", "e2"]} + + # Execute the stream method + response = model.stream(messages) + await alist(response) + + # Check that the expected log messages are present + log_text = caplog.text + assert "formatting request" in log_text + assert "request=<" in log_text + assert "invoking model" in log_text + assert "got response from model" in log_text + assert "finished streaming response from model" in log_text + + +def test_format_request_cleans_tool_result_content_blocks(model, model_id): + """Test that format_request cleans toolResult blocks by removing extra fields.""" + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "content": [{"text": "Tool output"}], + "toolUseId": "tool123", + "status": "success", + "extraField": "should be removed", + "mcpMetadata": {"server": "test"}, + } + }, + ], + } + ] + + formatted_request = model.format_request(messages) + + # Verify toolResult only contains allowed fields in the formatted request + tool_result = formatted_request["messages"][0]["content"][0]["toolResult"] + expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"} + assert tool_result == expected + assert "extraField" not in tool_result + assert "mcpMetadata" not in tool_result