Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

logger = logging.getLogger(__name__)

DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0"
DEFAULT_BEDROCK_REGION = "us-west-2"
DEFAULT_BEDROCK_MODEL_ID = "anthropic.claude-sonnet-4-20250514-v1:0"

BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
"Input is too long for requested model",
Expand Down Expand Up @@ -119,13 +119,6 @@ def __init__(
if region_name and boto_session:
raise ValueError("Cannot specify both `region_name` and `boto_session`.")

self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID)
self.update_config(**model_config)

logger.debug("config=<%s> | initializing", self.config)

session = boto_session or boto3.Session()

# Add strands-agents to the request user agent
if boto_client_config:
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
Expand All @@ -140,8 +133,23 @@ def __init__(
else:
client_config = BotocoreConfig(user_agent_extra="strands-agents")

session = boto_session or boto3.Session()
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION

# get default model id based on resolved region
resolved_model_id = self._get_default_model_for_region(resolved_region)
if resolved_model_id == "":
raise ValueError(
"default model {} is not available in {} region. Specify another model".format(
DEFAULT_BEDROCK_MODEL_ID, resolved_region
)
)

self.config = BedrockModel.BedrockConfig(model_id=resolved_model_id)
self.update_config(**model_config)

logger.debug("config=<%s> | initializing", self.config)

self.client = session.client(
service_name="bedrock-runtime",
config=client_config,
Expand Down Expand Up @@ -355,6 +363,20 @@ def _generate_redaction_events(self) -> list[StreamEvent]:

return events

def _get_default_model_for_region(self, region: str) -> str:
try:
client = boto3.client("bedrock", region_name=region)
response = client.list_inference_profiles()
inference_profile_summary = response["inferenceProfileSummaries"]

for profile in inference_profile_summary:
if DEFAULT_BEDROCK_MODEL_ID in profile["inferenceProfileId"]:
return str(profile["inferenceProfileId"])

return ""
except ClientError as e:
raise e

@override
async def stream(
self,
Expand Down
6 changes: 4 additions & 2 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_agent__init__with_default_model():
agent = Agent()

assert isinstance(agent.model, BedrockModel)
assert agent.model.config["model_id"] == DEFAULT_BEDROCK_MODEL_ID
assert agent.model.config["model_id"] and DEFAULT_BEDROCK_MODEL_ID in agent.model.config["model_id"]


def test_agent__init__with_explicit_model(mock_model):
Expand Down Expand Up @@ -891,7 +891,9 @@ def test_agent__del__(agent):
def test_agent_init_with_no_model_or_model_id():
agent = Agent()
assert agent.model is not None
assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID
assert agent.model.get_config().get("model_id") and DEFAULT_BEDROCK_MODEL_ID in agent.model.get_config().get(
"model_id"
)


def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator):
Expand Down
81 changes: 57 additions & 24 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,28 @@ def session_cls():
yield mock_session_cls


@pytest.fixture
def mock_bedrock_inference_profiles():
with unittest.mock.patch.object(strands.models.bedrock.boto3, "client") as mock_boto_client:
mock_bedrock = unittest.mock.MagicMock()
mock_bedrock.list_inference_profiles.return_value = {
"inferenceProfileSummaries": [{"inferenceProfileId": "us.anthropic.claude-sonnet-4-20250514-v1:0"}]
}
mock_boto_client.return_value = mock_bedrock
yield mock_boto_client


@pytest.fixture
def mock_client_method(session_cls):
# the boto3.Session().client(...) method
return session_cls.return_value.client


@pytest.fixture
def bedrock_client(session_cls):
def bedrock_client(session_cls, region="us-west-2"):
mock_client = session_cls.return_value.client.return_value
mock_client.meta = unittest.mock.MagicMock()
mock_client.meta.region_name = "us-west-2"
mock_client.meta.region_name = region
yield mock_client


Expand All @@ -44,7 +55,7 @@ def model_id():


@pytest.fixture
def model(bedrock_client, model_id):
def model(bedrock_client, mock_bedrock_inference_profiles, model_id):
_ = bedrock_client

return BedrockModel(model_id=model_id)
Expand Down Expand Up @@ -113,18 +124,18 @@ class TestOutputModel(pydantic.BaseModel):
return TestOutputModel


def test__init__default_model_id(bedrock_client):
def test__init__default_model_id(bedrock_client, mock_bedrock_inference_profiles):
"""Test that BedrockModel uses DEFAULT_MODEL_ID when no model_id is provided."""
_ = bedrock_client
model = BedrockModel()

tru_model_id = model.get_config().get("model_id")
exp_model_id = DEFAULT_BEDROCK_MODEL_ID

assert tru_model_id == exp_model_id
assert tru_model_id and exp_model_id in tru_model_id


def test__init__with_default_region(session_cls, mock_client_method):
def test__init__with_default_region(session_cls, mock_client_method, mock_bedrock_inference_profiles):
"""Test that BedrockModel uses the provided region."""
with unittest.mock.patch.object(os, "environ", {}):
BedrockModel()
Expand All @@ -133,7 +144,7 @@ def test__init__with_default_region(session_cls, mock_client_method):
)


def test__init__with_session_region(session_cls, mock_client_method):
def test__init__with_session_region(session_cls, mock_client_method, mock_bedrock_inference_profiles):
"""Test that BedrockModel uses the provided region."""
session_cls.return_value.region_name = "eu-blah-1"

Expand All @@ -142,22 +153,22 @@ def test__init__with_session_region(session_cls, mock_client_method):
mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY, endpoint_url=None)


def test__init__with_custom_region(mock_client_method):
def test__init__with_custom_region(mock_client_method, mock_bedrock_inference_profiles):
"""Test that BedrockModel uses the provided region."""
custom_region = "us-east-1"
BedrockModel(region_name=custom_region)
mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY, endpoint_url=None)


def test__init__with_default_environment_variable_region(mock_client_method):
def test__init__with_default_environment_variable_region(mock_client_method, mock_bedrock_inference_profiles):
"""Test that BedrockModel uses the AWS_REGION since we code that in."""
with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "eu-west-2"}):
BedrockModel()

mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY, endpoint_url=None)


def test__init__region_precedence(mock_client_method, session_cls):
def test__init__region_precedence(mock_client_method, session_cls, mock_bedrock_inference_profiles):
"""Test that BedrockModel uses the correct ordering of precedence when determining region."""
with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "us-environment-1"}) as mock_os_environ:
session_cls.return_value.region_name = "us-session-1"
Expand Down Expand Up @@ -204,7 +215,7 @@ def test__init__with_region_and_session_raises_value_error():
_ = BedrockModel(region_name="us-east-1", boto_session=boto3.Session(region_name="us-east-1"))


def test__init__default_user_agent(bedrock_client):
def test__init__default_user_agent(bedrock_client, mock_bedrock_inference_profiles):
"""Set user agent when no boto_client_config is provided."""
with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls:
mock_session = mock_session_cls.return_value
Expand All @@ -218,7 +229,7 @@ def test__init__default_user_agent(bedrock_client):
assert kwargs["config"].user_agent_extra == "strands-agents"


def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client):
def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client, mock_bedrock_inference_profiles):
"""Set user agent when boto_client_config is provided without user_agent_extra."""
custom_config = BotocoreConfig(read_timeout=900)

Expand All @@ -235,7 +246,7 @@ def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client):
assert kwargs["config"].read_timeout == 900


def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client):
def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client, mock_bedrock_inference_profiles):
"""Append to existing user agent when boto_client_config is provided with user_agent_extra."""
custom_config = BotocoreConfig(user_agent_extra="existing-agent", read_timeout=900)

Expand All @@ -252,7 +263,7 @@ def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client):
assert kwargs["config"].read_timeout == 900


def test__init__model_config(bedrock_client):
def test__init__model_config(bedrock_client, mock_bedrock_inference_profiles):
_ = bedrock_client

model = BedrockModel(max_tokens=1)
Expand Down Expand Up @@ -614,7 +625,14 @@ async def test_stream_stream_output_guardrails(

@pytest.mark.asyncio
async def test_stream_output_guardrails_redacts_input_and_output(
bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist
bedrock_client,
mock_bedrock_inference_profiles,
model,
messages,
tool_spec,
model_id,
additional_request_fields,
alist,
):
model.update_config(guardrail_redact_output=True)
metadata_event = {
Expand Down Expand Up @@ -672,7 +690,14 @@ async def test_stream_output_guardrails_redacts_input_and_output(

@pytest.mark.asyncio
async def test_stream_output_no_blocked_guardrails_doesnt_redact(
bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist
bedrock_client,
mock_bedrock_inference_profiles,
model,
messages,
tool_spec,
model_id,
additional_request_fields,
alist,
):
metadata_event = {
"metadata": {
Expand Down Expand Up @@ -781,7 +806,7 @@ async def test_stream_output_no_guardrail_redact(


@pytest.mark.asyncio
async def test_stream_with_streaming_false(bedrock_client, alist, messages):
async def test_stream_with_streaming_false(bedrock_client, mock_bedrock_inference_profiles, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
Expand All @@ -806,7 +831,9 @@ async def test_stream_with_streaming_false(bedrock_client, alist, messages):


@pytest.mark.asyncio
async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, messages):
async def test_stream_with_streaming_false_and_tool_use(
bedrock_client, mock_bedrock_inference_profiles, alist, messages
):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {
Expand Down Expand Up @@ -837,7 +864,9 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, m


@pytest.mark.asyncio
async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, messages):
async def test_stream_with_streaming_false_and_reasoning(
bedrock_client, mock_bedrock_inference_profiles, alist, messages
):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {
Expand Down Expand Up @@ -875,7 +904,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist,


@pytest.mark.asyncio
async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages):
async def test_stream_and_reasoning_no_signature(bedrock_client, mock_bedrock_inference_profiles, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {
Expand Down Expand Up @@ -911,7 +940,9 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages


@pytest.mark.asyncio
async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist, messages):
async def test_stream_with_streaming_false_with_metrics_and_usage(
bedrock_client, mock_bedrock_inference_profiles, alist, messages
):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
Expand Down Expand Up @@ -945,7 +976,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client


@pytest.mark.asyncio
async def test_stream_input_guardrails(bedrock_client, alist, messages):
async def test_stream_input_guardrails(bedrock_client, mock_bedrock_inference_profiles, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
Expand Down Expand Up @@ -995,7 +1026,7 @@ async def test_stream_input_guardrails(bedrock_client, alist, messages):


@pytest.mark.asyncio
async def test_stream_output_guardrails(bedrock_client, alist, messages):
async def test_stream_output_guardrails(bedrock_client, mock_bedrock_inference_profiles, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
Expand Down Expand Up @@ -1048,7 +1079,9 @@ async def test_stream_output_guardrails(bedrock_client, alist, messages):


@pytest.mark.asyncio
async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages):
async def test_stream_output_guardrails_redacts_output(
bedrock_client, mock_bedrock_inference_profiles, alist, messages
):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
Expand Down