|
4 | 4 | import logging
|
5 | 5 | import os
|
6 | 6 | from dataclasses import dataclass
|
7 |
| -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast |
| 7 | +from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union |
8 | 8 |
|
9 | 9 | import boto3
|
10 | 10 | from botocore.config import Config as BotocoreConfig
|
@@ -151,8 +151,8 @@ def __init__(
|
151 | 151 | validate_config_keys(payload_config, self.SageMakerAIPayloadSchema)
|
152 | 152 | payload_config.setdefault("stream", True)
|
153 | 153 | payload_config.setdefault("tool_results_as_user_messages", False)
|
154 |
| - self.endpoint_config = dict(endpoint_config) |
155 |
| - self.payload_config = dict(payload_config) |
| 154 | + self.endpoint_config = self.SageMakerAIEndpointConfig(**endpoint_config) |
| 155 | + self.payload_config = self.SageMakerAIPayloadSchema(**payload_config) |
156 | 156 | logger.debug(
|
157 | 157 | "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config
|
158 | 158 | )
|
@@ -193,7 +193,7 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i
|
193 | 193 | Returns:
|
194 | 194 | The Amazon SageMaker model configuration.
|
195 | 195 | """
|
196 |
| - return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config) |
| 196 | + return self.endpoint_config |
197 | 197 |
|
198 | 198 | @override
|
199 | 199 | def format_request(
|
@@ -238,6 +238,10 @@ def format_request(
|
238 | 238 | },
|
239 | 239 | }
|
240 | 240 |
|
| 241 | + payload_additional_args = self.payload_config.get("additional_args") |
| 242 | + if payload_additional_args: |
| 243 | + payload.update(payload_additional_args) |
| 244 | + |
241 | 245 | # Remove tools and tool_choice if tools = []
|
242 | 246 | if not payload["tools"]:
|
243 | 247 | payload.pop("tools")
|
@@ -273,16 +277,20 @@ def format_request(
|
273 | 277 | }
|
274 | 278 |
|
275 | 279 | # Add optional SageMaker parameters if provided
|
276 |
| - if self.endpoint_config.get("inference_component_name"): |
277 |
| - request["InferenceComponentName"] = self.endpoint_config["inference_component_name"] |
278 |
| - if self.endpoint_config.get("target_model"): |
279 |
| - request["TargetModel"] = self.endpoint_config["target_model"] |
280 |
| - if self.endpoint_config.get("target_variant"): |
281 |
| - request["TargetVariant"] = self.endpoint_config["target_variant"] |
282 |
| - |
283 |
| - # Add additional args if provided |
284 |
| - if self.endpoint_config.get("additional_args"): |
285 |
| - request.update(self.endpoint_config["additional_args"].__dict__) |
| 280 | + inf_component_name = self.endpoint_config.get("inference_component_name") |
| 281 | + if inf_component_name: |
| 282 | + request["InferenceComponentName"] = inf_component_name |
| 283 | + target_model = self.endpoint_config.get("target_model") |
| 284 | + if target_model: |
| 285 | + request["TargetModel"] = target_model |
| 286 | + target_variant = self.endpoint_config.get("target_variant") |
| 287 | + if target_variant: |
| 288 | + request["TargetVariant"] = target_variant |
| 289 | + |
| 290 | + # Add additional request args if provided |
| 291 | + additional_args = self.endpoint_config.get("additional_args") |
| 292 | + if additional_args: |
| 293 | + request.update(additional_args) |
286 | 294 |
|
287 | 295 | return request
|
288 | 296 |
|
|
0 commit comments