Skip to content

Commit 419de19

Browse files
authored
Fix additional_args passing in SageMakerAIModel (#983)
* fix(sagemaker): additional_args dict issue Fix error where passing an additional_args dict to SageMakerAIModel would raise an AttributeError because Python dicts have no '__dict__' attribute. Fixes #982 * fix(sagemaker): typing for endpoint_config Fix typing for SageMakerAIModel.endpoint_config which was previously being treated as an arbitrary dictionary due to init assignment. * fix(sagemaker): Typing for payload_config Fix typing for SageMakerAIModel.payload_config, which was previously being treated as a plain dict due to init assignment. * test(sagemaker): tests for ep additional_args Add a test to check for insertion of endpoint config additional_args * fix(sagemaker): include payload additional_args Copy SageMakerAIPayloadSchema's additional_args into request payloads where provided - previously these were being ignored. Includes unit tests.
1 parent 9632ed5 commit 419de19

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

src/strands/models/sagemaker.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import os
66
from dataclasses import dataclass
7-
from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast
7+
from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union
88

99
import boto3
1010
from botocore.config import Config as BotocoreConfig
@@ -151,8 +151,8 @@ def __init__(
151151
validate_config_keys(payload_config, self.SageMakerAIPayloadSchema)
152152
payload_config.setdefault("stream", True)
153153
payload_config.setdefault("tool_results_as_user_messages", False)
154-
self.endpoint_config = dict(endpoint_config)
155-
self.payload_config = dict(payload_config)
154+
self.endpoint_config = self.SageMakerAIEndpointConfig(**endpoint_config)
155+
self.payload_config = self.SageMakerAIPayloadSchema(**payload_config)
156156
logger.debug(
157157
"endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config
158158
)
@@ -193,7 +193,7 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i
193193
Returns:
194194
The Amazon SageMaker model configuration.
195195
"""
196-
return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config)
196+
return self.endpoint_config
197197

198198
@override
199199
def format_request(
@@ -238,6 +238,10 @@ def format_request(
238238
},
239239
}
240240

241+
payload_additional_args = self.payload_config.get("additional_args")
242+
if payload_additional_args:
243+
payload.update(payload_additional_args)
244+
241245
# Remove tools and tool_choice if tools = []
242246
if not payload["tools"]:
243247
payload.pop("tools")
@@ -273,16 +277,20 @@ def format_request(
273277
}
274278

275279
# Add optional SageMaker parameters if provided
276-
if self.endpoint_config.get("inference_component_name"):
277-
request["InferenceComponentName"] = self.endpoint_config["inference_component_name"]
278-
if self.endpoint_config.get("target_model"):
279-
request["TargetModel"] = self.endpoint_config["target_model"]
280-
if self.endpoint_config.get("target_variant"):
281-
request["TargetVariant"] = self.endpoint_config["target_variant"]
282-
283-
# Add additional args if provided
284-
if self.endpoint_config.get("additional_args"):
285-
request.update(self.endpoint_config["additional_args"].__dict__)
280+
inf_component_name = self.endpoint_config.get("inference_component_name")
281+
if inf_component_name:
282+
request["InferenceComponentName"] = inf_component_name
283+
target_model = self.endpoint_config.get("target_model")
284+
if target_model:
285+
request["TargetModel"] = target_model
286+
target_variant = self.endpoint_config.get("target_variant")
287+
if target_variant:
288+
request["TargetVariant"] = target_variant
289+
290+
# Add additional request args if provided
291+
additional_args = self.endpoint_config.get("additional_args")
292+
if additional_args:
293+
request.update(additional_args)
286294

287295
return request
288296

tests/strands/models/test_sagemaker.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,13 @@ def test_init_with_all_params(self, boto_session):
112112
"endpoint_name": "test-endpoint",
113113
"inference_component_name": "test-component",
114114
"region_name": "us-west-2",
115+
"additional_args": {"test_req_arg_name": "test_req_arg_value"},
115116
}
116117
payload_config = {
117118
"stream": False,
118119
"max_tokens": 1024,
119120
"temperature": 0.7,
121+
"additional_args": {"test_payload_arg_name": "test_payload_arg_value"},
120122
}
121123
client_config = BotocoreConfig(user_agent_extra="test-agent")
122124

@@ -129,9 +131,11 @@ def test_init_with_all_params(self, boto_session):
129131

130132
assert model.endpoint_config["endpoint_name"] == "test-endpoint"
131133
assert model.endpoint_config["inference_component_name"] == "test-component"
134+
assert model.endpoint_config["additional_args"]["test_req_arg_name"] == "test_req_arg_value"
132135
assert model.payload_config["stream"] is False
133136
assert model.payload_config["max_tokens"] == 1024
134137
assert model.payload_config["temperature"] == 0.7
138+
assert model.payload_config["additional_args"]["test_payload_arg_name"] == "test_payload_arg_value"
135139

136140
boto_session.client.assert_called_once_with(
137141
service_name="sagemaker-runtime",
@@ -239,6 +243,30 @@ def test_get_config(self, model, endpoint_config):
239243
# assert "tools" in payload
240244
# assert payload["tools"] == []
241245

246+
def test_format_request_with_additional_args(self, boto_session, endpoint_config, messages, payload_config):
247+
"""Test formatting a request's `additional_args` where provided"""
248+
endpoint_config_ext = {
249+
**endpoint_config,
250+
"additional_args": {
251+
"extra_request_key": "extra_request_value",
252+
},
253+
}
254+
payload_config_ext = {
255+
**payload_config,
256+
"additional_args": {
257+
"extra_payload_key": "extra_payload_value",
258+
},
259+
}
260+
model = SageMakerAIModel(
261+
boto_session=boto_session,
262+
endpoint_config=endpoint_config_ext,
263+
payload_config=payload_config_ext,
264+
)
265+
request = model.format_request(messages)
266+
assert request.get("extra_request_key") == "extra_request_value"
267+
payload = json.loads(request["Body"])
268+
assert payload.get("extra_payload_key") == "extra_payload_value"
269+
242270
@pytest.mark.asyncio
243271
async def test_stream_with_streaming_enabled(self, sagemaker_client, model, messages):
244272
"""Test streaming response with streaming enabled."""

0 commit comments

Comments
 (0)