diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c44717041..fa0a0962b 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -28,8 +28,8 @@ logger = logging.getLogger(__name__) -DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" DEFAULT_BEDROCK_REGION = "us-west-2" +DEFAULT_BEDROCK_MODEL_ID = "anthropic.claude-sonnet-4-20250514-v1:0" BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ "Input is too long for requested model", @@ -119,13 +119,6 @@ def __init__( if region_name and boto_session: raise ValueError("Cannot specify both `region_name` and `boto_session`.") - self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID) - self.update_config(**model_config) - - logger.debug("config=<%s> | initializing", self.config) - - session = boto_session or boto3.Session() - # Add strands-agents to the request user agent if boto_client_config: existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) @@ -140,8 +133,23 @@ def __init__( else: client_config = BotocoreConfig(user_agent_extra="strands-agents") + session = boto_session or boto3.Session() resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION + # get default model id based on resolved region + resolved_model_id = self._get_default_model_for_region(resolved_region) + if resolved_model_id == "": + raise ValueError( + "default model {} is not available in {} region. Specify another model".format( + DEFAULT_BEDROCK_MODEL_ID, resolved_region + ) + ) + + self.config = BedrockModel.BedrockConfig(model_id=resolved_model_id) + self.update_config(**model_config) + + logger.debug("config=<%s> | initializing", self.config) + self.client = session.client( service_name="bedrock-runtime", config=client_config, @@ -355,6 +363,20 @@ def _generate_redaction_events(self) -> list[StreamEvent]: return events + def _get_default_model_for_region(self, region: str) -> str: + try: + client = boto3.client("bedrock", region_name=region) + response = client.list_inference_profiles() + inference_profile_summary = response["inferenceProfileSummaries"] + + for profile in inference_profile_summary: + if DEFAULT_BEDROCK_MODEL_ID in profile["inferenceProfileId"]: + return str(profile["inferenceProfileId"]) + + return "" + except ClientError as e: + raise e + @override async def stream( self, diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index a8561abe4..fdfa9b83b 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -211,7 +211,7 @@ def test_agent__init__with_default_model(): agent = Agent() assert isinstance(agent.model, BedrockModel) - assert agent.model.config["model_id"] == DEFAULT_BEDROCK_MODEL_ID + assert agent.model.config["model_id"] and DEFAULT_BEDROCK_MODEL_ID in agent.model.config["model_id"] def test_agent__init__with_explicit_model(mock_model): @@ -891,7 +891,9 @@ def test_agent__del__(agent): def test_agent_init_with_no_model_or_model_id(): agent = Agent() assert agent.model is not None - assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID + assert agent.model.get_config().get("model_id") and DEFAULT_BEDROCK_MODEL_ID in agent.model.get_config().get( + "model_id" + ) def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index f1a2250e4..094999339 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -24,6 +24,17 @@ def session_cls(): yield mock_session_cls +@pytest.fixture +def mock_bedrock_inference_profiles(): + with unittest.mock.patch.object(strands.models.bedrock.boto3, "client") as mock_boto_client: + mock_bedrock = unittest.mock.MagicMock() + mock_bedrock.list_inference_profiles.return_value = { + "inferenceProfileSummaries": [{"inferenceProfileId": "us.anthropic.claude-sonnet-4-20250514-v1:0"}] + } + mock_boto_client.return_value = mock_bedrock + yield mock_boto_client + + @pytest.fixture def mock_client_method(session_cls): # the boto3.Session().client(...) method @@ -31,10 +42,10 @@ def mock_client_method(session_cls): @pytest.fixture -def bedrock_client(session_cls): +def bedrock_client(session_cls, region="us-west-2"): mock_client = session_cls.return_value.client.return_value mock_client.meta = unittest.mock.MagicMock() - mock_client.meta.region_name = "us-west-2" + mock_client.meta.region_name = region yield mock_client @@ -44,7 +55,7 @@ def model_id(): @pytest.fixture -def model(bedrock_client, model_id): +def model(bedrock_client, mock_bedrock_inference_profiles, model_id): _ = bedrock_client return BedrockModel(model_id=model_id) @@ -113,7 +124,7 @@ class TestOutputModel(pydantic.BaseModel): return TestOutputModel -def test__init__default_model_id(bedrock_client): +def test__init__default_model_id(bedrock_client, mock_bedrock_inference_profiles): """Test that BedrockModel uses DEFAULT_MODEL_ID when no model_id is provided.""" _ = bedrock_client model = BedrockModel() @@ -121,10 +132,10 @@ def test__init__default_model_id(bedrock_client): tru_model_id = model.get_config().get("model_id") exp_model_id = DEFAULT_BEDROCK_MODEL_ID - assert tru_model_id == exp_model_id + assert tru_model_id and exp_model_id in tru_model_id -def test__init__with_default_region(session_cls, mock_client_method): +def test__init__with_default_region(session_cls, mock_client_method, mock_bedrock_inference_profiles): """Test that BedrockModel uses the provided region.""" with unittest.mock.patch.object(os, "environ", {}): BedrockModel() @@ -133,7 +144,7 @@ def test__init__with_default_region(session_cls, mock_client_method): ) -def test__init__with_session_region(session_cls, mock_client_method): +def test__init__with_session_region(session_cls, mock_client_method, mock_bedrock_inference_profiles): """Test that BedrockModel uses the provided region.""" session_cls.return_value.region_name = "eu-blah-1" @@ -142,14 +153,14 @@ def test__init__with_session_region(session_cls, mock_client_method): mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY, endpoint_url=None) -def test__init__with_custom_region(mock_client_method): +def test__init__with_custom_region(mock_client_method, mock_bedrock_inference_profiles): """Test that BedrockModel uses the provided region.""" custom_region = "us-east-1" BedrockModel(region_name=custom_region) mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY, endpoint_url=None) -def test__init__with_default_environment_variable_region(mock_client_method): +def test__init__with_default_environment_variable_region(mock_client_method, mock_bedrock_inference_profiles): """Test that BedrockModel uses the AWS_REGION since we code that in.""" with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "eu-west-2"}): BedrockModel() @@ -157,7 +168,7 @@ def test__init__with_default_environment_variable_region(mock_client_method): mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY, endpoint_url=None) -def test__init__region_precedence(mock_client_method, session_cls): +def test__init__region_precedence(mock_client_method, session_cls, mock_bedrock_inference_profiles): """Test that BedrockModel uses the correct ordering of precedence when determining region.""" with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "us-environment-1"}) as mock_os_environ: session_cls.return_value.region_name = "us-session-1" @@ -204,7 +215,7 @@ def test__init__with_region_and_session_raises_value_error(): _ = BedrockModel(region_name="us-east-1", boto_session=boto3.Session(region_name="us-east-1")) -def test__init__default_user_agent(bedrock_client): +def test__init__default_user_agent(bedrock_client, mock_bedrock_inference_profiles): """Set user agent when no boto_client_config is provided.""" with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: mock_session = mock_session_cls.return_value @@ -218,7 +229,7 @@ def test__init__default_user_agent(bedrock_client): assert kwargs["config"].user_agent_extra == "strands-agents" -def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): +def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client, mock_bedrock_inference_profiles): """Set user agent when boto_client_config is provided without user_agent_extra.""" custom_config = BotocoreConfig(read_timeout=900) @@ -235,7 +246,7 @@ def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): assert kwargs["config"].read_timeout == 900 -def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client): +def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client, mock_bedrock_inference_profiles): """Append to existing user agent when boto_client_config is provided with user_agent_extra.""" custom_config = BotocoreConfig(user_agent_extra="existing-agent", read_timeout=900) @@ -252,7 +263,7 @@ def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client): assert kwargs["config"].read_timeout == 900 -def test__init__model_config(bedrock_client): +def test__init__model_config(bedrock_client, mock_bedrock_inference_profiles): _ = bedrock_client model = BedrockModel(max_tokens=1) @@ -614,7 +625,14 @@ async def test_stream_stream_output_guardrails( @pytest.mark.asyncio async def test_stream_output_guardrails_redacts_input_and_output( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist + bedrock_client, + mock_bedrock_inference_profiles, + model, + messages, + tool_spec, + model_id, + additional_request_fields, + alist, ): model.update_config(guardrail_redact_output=True) metadata_event = { @@ -672,7 +690,14 @@ async def test_stream_output_guardrails_redacts_input_and_output( @pytest.mark.asyncio async def test_stream_output_no_blocked_guardrails_doesnt_redact( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist + bedrock_client, + mock_bedrock_inference_profiles, + model, + messages, + tool_spec, + model_id, + additional_request_fields, + alist, ): metadata_event = { "metadata": { @@ -781,7 +806,7 @@ async def test_stream_output_no_guardrail_redact( @pytest.mark.asyncio -async def test_stream_with_streaming_false(bedrock_client, alist, messages): +async def test_stream_with_streaming_false(bedrock_client, mock_bedrock_inference_profiles, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -806,7 +831,9 @@ async def test_stream_with_streaming_false(bedrock_client, alist, messages): @pytest.mark.asyncio -async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, messages): +async def test_stream_with_streaming_false_and_tool_use( + bedrock_client, mock_bedrock_inference_profiles, alist, messages +): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -837,7 +864,9 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, m @pytest.mark.asyncio -async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, messages): +async def test_stream_with_streaming_false_and_reasoning( + bedrock_client, mock_bedrock_inference_profiles, alist, messages +): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -875,7 +904,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, @pytest.mark.asyncio -async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages): +async def test_stream_and_reasoning_no_signature(bedrock_client, mock_bedrock_inference_profiles, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -911,7 +940,9 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages @pytest.mark.asyncio -async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist, messages): +async def test_stream_with_streaming_false_with_metrics_and_usage( + bedrock_client, mock_bedrock_inference_profiles, alist, messages +): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -945,7 +976,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client @pytest.mark.asyncio -async def test_stream_input_guardrails(bedrock_client, alist, messages): +async def test_stream_input_guardrails(bedrock_client, mock_bedrock_inference_profiles, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -995,7 +1026,7 @@ async def test_stream_input_guardrails(bedrock_client, alist, messages): @pytest.mark.asyncio -async def test_stream_output_guardrails(bedrock_client, alist, messages): +async def test_stream_output_guardrails(bedrock_client, mock_bedrock_inference_profiles, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -1048,7 +1079,9 @@ async def test_stream_output_guardrails(bedrock_client, alist, messages): @pytest.mark.asyncio -async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages): +async def test_stream_output_guardrails_redacts_output( + bedrock_client, mock_bedrock_inference_profiles, alist, messages +): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},