From b0c887797f063996ecec03e38dbb62cd061af14c Mon Sep 17 00:00:00 2001 From: Dan Ferguson Date: Mon, 11 Aug 2025 11:21:09 -0400 Subject: [PATCH 1/3] Introduced a method of selecting a prompt formatting option at prompt time for SageMakerAIModel objects --- src/strands/models/sagemaker.py | 383 ++++++++++++++++++++++--- tests/strands/models/test_sagemaker.py | 268 +++++++++++++++++ 2 files changed, 606 insertions(+), 45 deletions(-) diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 9cfe27d9e..bb67ff52e 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -1,10 +1,13 @@ """Amazon SageMaker model provider.""" +import base64 import json import logging +import mimetypes import os from dataclasses import dataclass -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast +from enum import Enum +from typing import Any, AsyncGenerator, Callable, Literal, Optional, Type, TypedDict, TypeVar, Union, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -14,14 +17,24 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec -from .openai import OpenAIModel +from ..types.tools import ToolResult, ToolSpec, ToolUse +from .model import Model T = TypeVar("T", bound=BaseModel) logger = logging.getLogger(__name__) +class ModelProvider(Enum): + """Supported model providers for prompt formatting.""" + + OPENAI = "openai" + MISTRAL = "mistral" + LLAMA = "llama" + ANTHROPIC = "anthropic" + CUSTOM = "custom" + + @dataclass class UsageMetadata: """Usage metadata for the model. @@ -86,7 +99,7 @@ def __init__(self, **kwargs: dict): self.function = FunctionCall(**kwargs.get("function", {"name": "", "arguments": ""})) -class SageMakerAIModel(OpenAIModel): +class SageMakerAIModel(Model): """Amazon SageMaker model provider implementation.""" client: SageMakerRuntimeClient # type: ignore[assignment] @@ -120,7 +133,8 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): Attributes: endpoint_name: The name of the SageMaker endpoint to invoke inference_component_name: The name of the inference component to use - + model_provider: The model provider format to use for prompt formatting (defaults to openai) + custom_formatter: Custom formatter function when model_provider is CUSTOM additional_args: Other request parameters, as supported by https://bit.ly/sagemaker-invoke-endpoint-params """ @@ -129,6 +143,8 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): inference_component_name: Union[str, None] target_model: Union[Optional[str], None] target_variant: Union[Optional[str], None] + model_provider: Optional[Union[ModelProvider, str]] + custom_formatter: Optional[Callable[[Messages, Optional[str]], list[dict[str, Any]]]] additional_args: Optional[dict[str, Any]] def __init__( @@ -150,6 +166,18 @@ def __init__( payload_config.setdefault("tool_results_as_user_messages", False) self.endpoint_config = dict(endpoint_config) self.payload_config = dict(payload_config) + + # Set default model provider if not specified + if "model_provider" not in self.endpoint_config or self.endpoint_config["model_provider"] is None: + self.endpoint_config["model_provider"] = ModelProvider.OPENAI + elif isinstance(self.endpoint_config["model_provider"], str): + self.endpoint_config["model_provider"] = ModelProvider(self.endpoint_config["model_provider"]) + + # Validate custom formatter if using CUSTOM provider + if self.endpoint_config["model_provider"] == ModelProvider.CUSTOM and not self.endpoint_config.get( + "custom_formatter" + ): + raise ValueError("custom_formatter is required when model_provider is CUSTOM") logger.debug( "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config ) @@ -191,7 +219,7 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i """ return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config) - @override + # @override def format_request( self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None ) -> dict[str, Any]: @@ -205,7 +233,7 @@ def format_request( Returns: An Amazon SageMaker chat streaming request. """ - formatted_messages = self.format_request_messages(messages, system_prompt) + formatted_messages = self._format_messages_for_provider(messages, system_prompt) payload = { "messages": formatted_messages, @@ -276,6 +304,309 @@ def format_request( return request + def _format_messages_for_provider( + self, messages: Messages, system_prompt: Optional[str] = None + ) -> list[dict[str, Any]]: + """Format messages based on the selected model provider. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + Formatted messages array for the specified provider. + """ + provider = self.endpoint_config["model_provider"] + + if provider == ModelProvider.OPENAI: + return self._format_openai_messages(messages, system_prompt) + elif provider == ModelProvider.MISTRAL: + return self._format_mistral_messages(messages, system_prompt) + elif provider == ModelProvider.LLAMA: + return self._format_llama_messages(messages, system_prompt) + elif provider == ModelProvider.ANTHROPIC: + return self._format_anthropic_messages(messages, system_prompt) + elif provider == ModelProvider.CUSTOM: + custom_formatter = self.endpoint_config["custom_formatter"] + return custom_formatter(messages, system_prompt) + else: + # Default to OpenAI format + return self._format_openai_messages(messages, system_prompt) + + def _format_openai_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages for OpenAI-compatible models.""" + return self.format_request_messages(messages, system_prompt) + + def _format_mistral_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages for Mistral models.""" + # Mistral uses similar format to OpenAI but with some differences + formatted_messages = self.format_request_messages(messages, system_prompt) + # Add Mistral-specific formatting here if needed + return formatted_messages + + def _format_llama_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages for Llama models.""" + # Llama often uses a different conversation format + formatted_messages = [] + + # Add system prompt if provided + if system_prompt: + formatted_messages.append({"role": "system", "content": system_prompt}) + + # Process messages with Llama-specific formatting + for message in messages: + formatted_message = {"role": message["role"], "content": self._format_content_for_llama(message["content"])} + formatted_messages.append(formatted_message) + + return formatted_messages + + def _format_anthropic_messages( + self, messages: Messages, system_prompt: Optional[str] = None + ) -> list[dict[str, Any]]: + """Format messages for Anthropic Claude models.""" + # Anthropic has specific requirements for message formatting + formatted_messages = [] + + # System prompt is handled separately in Anthropic + for message in messages: + formatted_message = { + "role": message["role"], + "content": self._format_content_for_anthropic(message["content"]), + } + formatted_messages.append(formatted_message) + + return formatted_messages + + def _format_content_for_llama(self, content: list[ContentBlock]) -> str: + """Format content blocks for Llama models (typically expects string content).""" + text_parts = [] + for block in content: + if "text" in block: + text_parts.append(block["text"]) + elif "toolUse" in block: + # Handle tool use for Llama + tool_use = block["toolUse"] + text_parts.append(f"[TOOL_CALL: {tool_use['name']}({json.dumps(tool_use['input'])})]") + elif "toolResult" in block: + # Handle tool results for Llama + tool_result = block["toolResult"] + result_text = " ".join([c.get("text", str(c)) for c in tool_result["content"]]) + text_parts.append(f"[TOOL_RESULT: {result_text}]") + return " ".join(text_parts) + + def _format_content_for_anthropic(self, content: list[ContentBlock]) -> list[dict[str, Any]]: + """Format content blocks for Anthropic models.""" + formatted_content = [] + for block in content: + if "text" in block: + formatted_content.append({"type": "text", "text": block["text"]}) + elif "image" in block: + # Anthropic image format + image_data = base64.b64encode(block["image"]["source"]["bytes"]).decode("utf-8") + formatted_content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": f"image/{block['image']['format']}", + "data": image_data, + }, + } + ) + # Add other content types as needed + return formatted_content + + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + """Format a content block. + + Args: + content: Message content. + + Returns: + Formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a SageMaker-compatible format. + """ + if "reasoningContent" in content and content["reasoningContent"]: + return { + "signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""), + "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""), + "type": "thinking", + } + elif not content.get("reasoningContent", None): + content.pop("reasoningContent", None) + + if "video" in content: + return { + "type": "video_url", + "video_url": { + "detail": "auto", + "url": content["video"]["source"]["bytes"], + }, + } + + if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "file": { + "file_data": f"data:{mime_type};base64,{file_data}", + "filename": content["document"]["name"], + }, + "type": "file", + } + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @classmethod + def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + """Format a tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + Formatted tool call. + """ + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + Formatted messages array. + """ + formatted_messages: list[dict[str, Any]] + formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + + for message in messages: + contents = message["content"] + + formatted_contents = [ + cls.format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content + ] + formatted_tool_messages = [ + cls.format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format a response event into a standardized message chunk. + + Args: + event: A response event from the model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + @override async def stream( self, @@ -474,7 +805,6 @@ async def stream( logger.debug("finished streaming response from model") - @override @classmethod def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: """Format a SageMaker compatible tool message. @@ -504,43 +834,6 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: "content": content_string, # String instead of list } - @override - @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: - """Format a content block. - - Args: - content: Message content. - - Returns: - Formatted content block. - - Raises: - TypeError: If the content block type cannot be converted to a SageMaker-compatible format. - """ - # if "text" in content and not isinstance(content["text"], str): - # return {"type": "text", "text": str(content["text"])} - - if "reasoningContent" in content and content["reasoningContent"]: - return { - "signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""), - "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""), - "type": "thinking", - } - elif not content.get("reasoningContent", None): - content.pop("reasoningContent", None) - - if "video" in content: - return { - "type": "video_url", - "video_url": { - "detail": "auto", - "url": content["video"]["source"]["bytes"], - }, - } - - return super().format_request_message_content(content) - @override async def structured_output( self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index ba395b2d6..fe5e05734 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -10,6 +10,7 @@ from strands.models.sagemaker import ( FunctionCall, + ModelProvider, SageMakerAIModel, ToolCall, UsageMetadata, @@ -522,6 +523,248 @@ async def test_stream_non_streaming_with_tool_calls(self, sagemaker_client, mode assert usage_data["totalTokens"] == 30 +class TestModelProvider: + """Test suite for model provider functionality.""" + + def test_default_model_provider(self, boto_session): + """Test that default model provider is set to OpenAI.""" + endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} + payload_config = {"max_tokens": 1024} + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + assert model.endpoint_config["model_provider"] == ModelProvider.OPENAI + + def test_model_provider_enum(self, boto_session): + """Test setting model provider using enum.""" + endpoint_config = { + "endpoint_name": "test-endpoint", + "region_name": "us-east-1", + "model_provider": ModelProvider.LLAMA, + } + payload_config = {"max_tokens": 1024} + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + assert model.endpoint_config["model_provider"] == ModelProvider.LLAMA + + def test_model_provider_string(self, boto_session): + """Test setting model provider using string.""" + endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1", "model_provider": "mistral"} + payload_config = {"max_tokens": 1024} + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + assert model.endpoint_config["model_provider"] == ModelProvider.MISTRAL + + def test_custom_formatter_required(self, boto_session): + """Test that custom formatter is required when using CUSTOM provider.""" + endpoint_config = { + "endpoint_name": "test-endpoint", + "region_name": "us-east-1", + "model_provider": ModelProvider.CUSTOM, + } + payload_config = {"max_tokens": 1024} + + with pytest.raises(ValueError, match="custom_formatter is required when model_provider is CUSTOM"): + SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session) + + def test_custom_formatter_with_function(self, boto_session): + """Test using custom formatter with a function.""" + + def custom_formatter(messages, system_prompt=None): + formatted = [] + if system_prompt: + formatted.append({"role": "system", "text": system_prompt}) + for msg in messages: + formatted.append({"role": msg["role"], "text": msg["content"][0]["text"]}) + return formatted + + endpoint_config = { + "endpoint_name": "test-endpoint", + "region_name": "us-east-1", + "model_provider": ModelProvider.CUSTOM, + "custom_formatter": custom_formatter, + } + payload_config = {"max_tokens": 1024} + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + assert model.endpoint_config["model_provider"] == ModelProvider.CUSTOM + assert model.endpoint_config["custom_formatter"] == custom_formatter + + def test_format_messages_openai_provider(self, boto_session, messages, system_prompt): + """Test message formatting with OpenAI provider.""" + endpoint_config = { + "endpoint_name": "test-endpoint", + "region_name": "us-east-1", + "model_provider": ModelProvider.OPENAI, + } + payload_config = {"max_tokens": 1024} + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + formatted = model._format_messages_for_provider(messages, system_prompt) + + assert len(formatted) == 2 # system + user message + assert formatted[0]["role"] == "system" + assert formatted[0]["content"] == system_prompt + assert formatted[1]["role"] == "user" + assert len(formatted[1]["content"]) == 1 + assert formatted[1]["content"][0]["text"] == "What is the capital of France?" + + def test_format_messages_llama_provider(self, boto_session, messages, system_prompt): + """Test message formatting with Llama provider.""" + endpoint_config = { + "endpoint_name": "test-endpoint", + "region_name": "us-east-1", + "model_provider": ModelProvider.LLAMA, + } + payload_config = {"max_tokens": 1024} + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + formatted = model._format_messages_for_provider(messages, system_prompt) + + assert len(formatted) == 2 # system + user message + assert formatted[0]["role"] == "system" + assert formatted[0]["content"] == system_prompt + assert formatted[1]["role"] == "user" + assert isinstance(formatted[1]["content"], str) # Llama uses string content + assert "What is the capital of France?" in formatted[1]["content"] + + def test_format_messages_custom_provider(self, boto_session, messages, system_prompt): + """Test message formatting with custom provider.""" + + def custom_formatter(msgs, sys_prompt=None): + result = [{"custom_format": True}] + if sys_prompt: + result.append({"system": sys_prompt}) + for msg in msgs: + result.append({"speaker": msg["role"], "text": msg["content"][0]["text"]}) + return result + + endpoint_config = { + "endpoint_name": "test-endpoint", + "region_name": "us-east-1", + "model_provider": ModelProvider.CUSTOM, + "custom_formatter": custom_formatter, + } + payload_config = {"max_tokens": 1024} + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + formatted = model._format_messages_for_provider(messages, system_prompt) + + assert len(formatted) == 3 # custom marker + system + user + assert formatted[0]["custom_format"] is True + assert formatted[1]["system"] == system_prompt + assert formatted[2]["speaker"] == "user" + assert formatted[2]["text"] == "What is the capital of France?" + + def test_format_messages_with_tool_use_llama(self, boto_session, system_prompt): + """Test Llama formatting with tool use messages.""" + messages_with_tools = [ + {"role": "user", "content": [{"text": "What's the weather?"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "tool123", "name": "get_weather", "input": {"location": "Paris"}}} + ], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "tool123", "content": [{"text": "Sunny, 25°C"}]}}], + }, + ] + + endpoint_config = { + "endpoint_name": "test-endpoint", + "region_name": "us-east-1", + "model_provider": ModelProvider.LLAMA, + } + payload_config = {"max_tokens": 1024} + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + formatted = model._format_messages_for_provider(messages_with_tools, system_prompt) + + assert len(formatted) == 4 # system + 3 messages + # Check that tool calls are formatted as strings for Llama + assistant_msg = formatted[2] + assert isinstance(assistant_msg["content"], str) + assert "TOOL_CALL" in assistant_msg["content"] + assert "get_weather" in assistant_msg["content"] + + # Check that tool results are formatted as strings + tool_result_msg = formatted[3] + assert isinstance(tool_result_msg["content"], str) + assert "TOOL_RESULT" in tool_result_msg["content"] + assert "Sunny, 25°C" in tool_result_msg["content"] + + def test_format_messages_anthropic_with_images(self, boto_session): + """Test Anthropic formatting with image content.""" + import base64 + + # Create mock image data + image_bytes = b"fake_image_data" + messages_with_image = [ + { + "role": "user", + "content": [ + {"text": "What's in this image?"}, + {"image": {"format": "png", "source": {"bytes": image_bytes}}}, + ], + } + ] + + endpoint_config = { + "endpoint_name": "test-endpoint", + "region_name": "us-east-1", + "model_provider": ModelProvider.ANTHROPIC, + } + payload_config = {"max_tokens": 1024} + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + formatted = model._format_messages_for_provider(messages_with_image) + + assert len(formatted) == 1 + user_msg = formatted[0] + assert user_msg["role"] == "user" + assert len(user_msg["content"]) == 2 + + # Check text content + text_content = user_msg["content"][0] + assert text_content["type"] == "text" + assert text_content["text"] == "What's in this image?" + + # Check image content + image_content = user_msg["content"][1] + assert image_content["type"] == "image" + assert image_content["source"]["type"] == "base64" + assert image_content["source"]["media_type"] == "image/png" + assert image_content["source"]["data"] == base64.b64encode(image_bytes).decode("utf-8") + + class TestDataClasses: """Test suite for data classes.""" @@ -572,3 +815,28 @@ def test_tool_call(self): assert tool2.type == "function" assert tool2.function.name == "get_time" assert tool2.function.arguments == '{"timezone": "UTC"}' + + +class TestModelProviderEnum: + """Test suite for ModelProvider enum.""" + + def test_model_provider_values(self): + """Test ModelProvider enum values.""" + assert ModelProvider.OPENAI.value == "openai" + assert ModelProvider.MISTRAL.value == "mistral" + assert ModelProvider.LLAMA.value == "llama" + assert ModelProvider.ANTHROPIC.value == "anthropic" + assert ModelProvider.CUSTOM.value == "custom" + + def test_model_provider_from_string(self): + """Test creating ModelProvider from string.""" + assert ModelProvider("openai") == ModelProvider.OPENAI + assert ModelProvider("mistral") == ModelProvider.MISTRAL + assert ModelProvider("llama") == ModelProvider.LLAMA + assert ModelProvider("anthropic") == ModelProvider.ANTHROPIC + assert ModelProvider("custom") == ModelProvider.CUSTOM + + def test_invalid_model_provider(self): + """Test invalid model provider string raises ValueError.""" + with pytest.raises(ValueError): + ModelProvider("invalid_provider") From d5480d96cf0f5c70c2baaf31ad4bf7dca04b3818 Mon Sep 17 00:00:00 2001 From: Dan Ferguson Date: Fri, 15 Aug 2025 13:38:51 -0400 Subject: [PATCH 2/3] Updated the code needed to use flags for different models accessed by SageMakerAIModels --- src/strands/models/llamaapi.py | 22 +- src/strands/models/mistral.py | 20 +- src/strands/models/sagemaker.py | 103 ++------ tests/strands/models/test_sagemaker.py | 313 +++++++++++-------------- tests/strands/tools/test_decorator.py | 9 +- 5 files changed, 177 insertions(+), 290 deletions(-) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 421b06e52..667346232 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -86,7 +86,8 @@ def get_config(self) -> LlamaConfig: """ return self.config - def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + @staticmethod + def _format_request_message_content(content: ContentBlock) -> dict[str, Any]: """Format a LlamaAPI content block. - NOTE: "reasoningContent" and "video" are not supported currently. @@ -116,7 +117,8 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: + @staticmethod + def _format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: """Format a Llama API tool call. Args: @@ -133,7 +135,8 @@ def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any] "id": tool_use["toolUseId"], } - def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: + @staticmethod + def _format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: """Format a Llama API tool message. Args: @@ -153,10 +156,11 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any return { "role": "tool", "tool_call_id": tool_result["toolUseId"], - "content": [self._format_request_message_content(content) for content in contents], + "content": [LlamaAPIModel._format_request_message_content(content) for content in contents], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: """Format a LlamaAPI compatible messages array. Args: @@ -174,17 +178,17 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = "" formatted_contents = [ - self._format_request_message_content(content) + cls._format_request_message_content(content=content) for content in contents if not any(block_type in content for block_type in ["toolResult", "toolUse"]) ] formatted_tool_calls = [ - self._format_request_message_tool_call(content["toolUse"]) + cls._format_request_message_tool_call(tool_use=content["toolUse"]) for content in contents if "toolUse" in content ] formatted_tool_messages = [ - self._format_request_tool_message(content["toolResult"]) + cls._format_request_tool_message(tool_result=content["toolResult"]) for content in contents if "toolResult" in content ] @@ -220,7 +224,7 @@ def format_request( format. """ request = { - "messages": self._format_request_messages(messages, system_prompt), + "messages": self.format_request_messages(messages, system_prompt), "model": self.config["model_id"], "stream": True, "tools": [ diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 8855b6d64..3c55f1027 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -112,7 +112,8 @@ def get_config(self) -> MistralConfig: """ return self.config - def _format_request_message_content(self, content: ContentBlock) -> Union[str, dict[str, Any]]: + @staticmethod + def _format_request_message_content(content: ContentBlock) -> Union[str, dict[str, Any]]: """Format a Mistral content block. Args: @@ -141,7 +142,8 @@ def _format_request_message_content(self, content: ContentBlock) -> Union[str, d raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: + @staticmethod + def _format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: """Format a Mistral tool call. Args: @@ -159,7 +161,8 @@ def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any] "type": "function", } - def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: + @staticmethod + def _format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: """Format a Mistral tool message. Args: @@ -184,7 +187,8 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "tool_call_id": tool_result["toolUseId"], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: """Format a Mistral compatible messages array. Args: @@ -209,13 +213,13 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s for content in contents: if "text" in content: - formatted_content = self._format_request_message_content(content) + formatted_content = cls._format_request_message_content(content) if isinstance(formatted_content, str): text_contents.append(formatted_content) elif "toolUse" in content: - tool_calls.append(self._format_request_message_tool_call(content["toolUse"])) + tool_calls.append(cls._format_request_message_tool_call(content["toolUse"])) elif "toolResult" in content: - tool_messages.append(self._format_request_tool_message(content["toolResult"])) + tool_messages.append(cls._format_request_tool_message(content["toolResult"])) if text_contents or tool_calls: formatted_message: dict[str, Any] = { @@ -251,7 +255,7 @@ def format_request( """ request: dict[str, Any] = { "model": self.config["model_id"], - "messages": self._format_request_messages(messages, system_prompt), + "messages": self.format_request_messages(messages, system_prompt), } if "max_tokens" in self.config: diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index bb67ff52e..bea80d812 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -31,7 +31,6 @@ class ModelProvider(Enum): OPENAI = "openai" MISTRAL = "mistral" LLAMA = "llama" - ANTHROPIC = "anthropic" CUSTOM = "custom" @@ -219,7 +218,6 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i """ return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config) - # @override def format_request( self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None ) -> dict[str, Any]: @@ -233,7 +231,8 @@ def format_request( Returns: An Amazon SageMaker chat streaming request. """ - formatted_messages = self._format_messages_for_provider(messages, system_prompt) + # formatted_messages = self._format_messages_for_provider(messages, tool_specs, system_prompt) + formatted_messages = self._format_request_messages(messages, system_prompt) payload = { "messages": formatted_messages, @@ -304,9 +303,7 @@ def format_request( return request - def _format_messages_for_provider( - self, messages: Messages, system_prompt: Optional[str] = None - ) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: """Format messages based on the selected model provider. Args: @@ -319,13 +316,11 @@ def _format_messages_for_provider( provider = self.endpoint_config["model_provider"] if provider == ModelProvider.OPENAI: - return self._format_openai_messages(messages, system_prompt) + return SageMakerAIModel._format_openai_messages(messages, system_prompt) elif provider == ModelProvider.MISTRAL: return self._format_mistral_messages(messages, system_prompt) elif provider == ModelProvider.LLAMA: return self._format_llama_messages(messages, system_prompt) - elif provider == ModelProvider.ANTHROPIC: - return self._format_anthropic_messages(messages, system_prompt) elif provider == ModelProvider.CUSTOM: custom_formatter = self.endpoint_config["custom_formatter"] return custom_formatter(messages, system_prompt) @@ -333,88 +328,26 @@ def _format_messages_for_provider( # Default to OpenAI format return self._format_openai_messages(messages, system_prompt) - def _format_openai_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + @staticmethod + def _format_openai_messages(messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: """Format messages for OpenAI-compatible models.""" - return self.format_request_messages(messages, system_prompt) - - def _format_mistral_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format messages for Mistral models.""" - # Mistral uses similar format to OpenAI but with some differences - formatted_messages = self.format_request_messages(messages, system_prompt) - # Add Mistral-specific formatting here if needed - return formatted_messages - - def _format_llama_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format messages for Llama models.""" - # Llama often uses a different conversation format - formatted_messages = [] + from strands.models.openai import OpenAIModel - # Add system prompt if provided - if system_prompt: - formatted_messages.append({"role": "system", "content": system_prompt}) + return OpenAIModel.format_request_messages(messages, system_prompt) - # Process messages with Llama-specific formatting - for message in messages: - formatted_message = {"role": message["role"], "content": self._format_content_for_llama(message["content"])} - formatted_messages.append(formatted_message) - - return formatted_messages + @staticmethod + def _format_mistral_messages(messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages for Mistral models.""" + from strands.models.mistral import MistralModel - def _format_anthropic_messages( - self, messages: Messages, system_prompt: Optional[str] = None - ) -> list[dict[str, Any]]: - """Format messages for Anthropic Claude models.""" - # Anthropic has specific requirements for message formatting - formatted_messages = [] + return MistralModel.format_request_messages(messages, system_prompt) - # System prompt is handled separately in Anthropic - for message in messages: - formatted_message = { - "role": message["role"], - "content": self._format_content_for_anthropic(message["content"]), - } - formatted_messages.append(formatted_message) + @staticmethod + def _format_llama_messages(messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages for Llama models.""" + from strands.models.llamaapi import LlamaAPIModel - return formatted_messages - - def _format_content_for_llama(self, content: list[ContentBlock]) -> str: - """Format content blocks for Llama models (typically expects string content).""" - text_parts = [] - for block in content: - if "text" in block: - text_parts.append(block["text"]) - elif "toolUse" in block: - # Handle tool use for Llama - tool_use = block["toolUse"] - text_parts.append(f"[TOOL_CALL: {tool_use['name']}({json.dumps(tool_use['input'])})]") - elif "toolResult" in block: - # Handle tool results for Llama - tool_result = block["toolResult"] - result_text = " ".join([c.get("text", str(c)) for c in tool_result["content"]]) - text_parts.append(f"[TOOL_RESULT: {result_text}]") - return " ".join(text_parts) - - def _format_content_for_anthropic(self, content: list[ContentBlock]) -> list[dict[str, Any]]: - """Format content blocks for Anthropic models.""" - formatted_content = [] - for block in content: - if "text" in block: - formatted_content.append({"type": "text", "text": block["text"]}) - elif "image" in block: - # Anthropic image format - image_data = base64.b64encode(block["image"]["source"]["bytes"]).decode("utf-8") - formatted_content.append( - { - "type": "image", - "source": { - "type": "base64", - "media_type": f"image/{block['image']['format']}", - "data": image_data, - }, - } - ) - # Add other content types as needed - return formatted_content + return LlamaAPIModel.format_request_messages(messages, system_prompt) @classmethod def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index fe5e05734..e7f2c20c7 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -53,7 +53,11 @@ def payload_config() -> Dict[str, Any]: @pytest.fixture -def model(boto_session, endpoint_config, payload_config): +def model( + boto_session, + endpoint_config: SageMakerAIModel.SageMakerAIEndpointConfig, + payload_config: SageMakerAIModel.SageMakerAIPayloadSchema, +): """SageMaker model instance with mocked boto session.""" return SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session) @@ -93,8 +97,10 @@ class TestSageMakerAIModel: def test_init_default(self, boto_session): """Test initialization with default parameters.""" - endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} - payload_config = {"max_tokens": 1024} + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1" + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) model = SageMakerAIModel( endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session ) @@ -109,16 +115,10 @@ def test_init_default(self, boto_session): def test_init_with_all_params(self, boto_session): """Test initialization with all parameters.""" - endpoint_config = { - "endpoint_name": "test-endpoint", - "inference_component_name": "test-component", - "region_name": "us-west-2", - } - payload_config = { - "stream": False, - "max_tokens": 1024, - "temperature": 0.7, - } + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", inference_component_name="test-component", region_name="us-west-2" + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024, stream=False, temperature=0.7) client_config = BotocoreConfig(user_agent_extra="test-agent") model = SageMakerAIModel( @@ -141,8 +141,10 @@ def test_init_with_all_params(self, boto_session): def test_init_with_client_config(self, boto_session): """Test initialization with client configuration.""" - endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} - payload_config = {"max_tokens": 1024} + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1" + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) client_config = BotocoreConfig(user_agent_extra="test-agent") SageMakerAIModel( @@ -528,8 +530,10 @@ class TestModelProvider: def test_default_model_provider(self, boto_session): """Test that default model provider is set to OpenAI.""" - endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} - payload_config = {"max_tokens": 1024} + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1" + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) model = SageMakerAIModel( endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session @@ -539,12 +543,10 @@ def test_default_model_provider(self, boto_session): def test_model_provider_enum(self, boto_session): """Test setting model provider using enum.""" - endpoint_config = { - "endpoint_name": "test-endpoint", - "region_name": "us-east-1", - "model_provider": ModelProvider.LLAMA, - } - payload_config = {"max_tokens": 1024} + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.LLAMA + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) model = SageMakerAIModel( endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session @@ -554,8 +556,10 @@ def test_model_provider_enum(self, boto_session): def test_model_provider_string(self, boto_session): """Test setting model provider using string.""" - endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1", "model_provider": "mistral"} - payload_config = {"max_tokens": 1024} + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.MISTRAL + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) model = SageMakerAIModel( endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session @@ -563,206 +567,153 @@ def test_model_provider_string(self, boto_session): assert model.endpoint_config["model_provider"] == ModelProvider.MISTRAL - def test_custom_formatter_required(self, boto_session): - """Test that custom formatter is required when using CUSTOM provider.""" - endpoint_config = { - "endpoint_name": "test-endpoint", - "region_name": "us-east-1", - "model_provider": ModelProvider.CUSTOM, - } - payload_config = {"max_tokens": 1024} - - with pytest.raises(ValueError, match="custom_formatter is required when model_provider is CUSTOM"): - SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session) - - def test_custom_formatter_with_function(self, boto_session): - """Test using custom formatter with a function.""" - - def custom_formatter(messages, system_prompt=None): - formatted = [] - if system_prompt: - formatted.append({"role": "system", "text": system_prompt}) - for msg in messages: - formatted.append({"role": msg["role"], "text": msg["content"][0]["text"]}) - return formatted - - endpoint_config = { - "endpoint_name": "test-endpoint", - "region_name": "us-east-1", - "model_provider": ModelProvider.CUSTOM, - "custom_formatter": custom_formatter, - } - payload_config = {"max_tokens": 1024} + def test_format_messages_openai_provider(self, boto_session, messages): + """Test message formatting with OpenAI provider.""" + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.OPENAI + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) model = SageMakerAIModel( endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session ) - assert model.endpoint_config["model_provider"] == ModelProvider.CUSTOM - assert model.endpoint_config["custom_formatter"] == custom_formatter + tru_request = model.format_request(messages) + exp_request = { + "EndpointName": "test-endpoint", + "Body": '{"messages": [{"role": "user", "content": [{' + '"text": "What is the capital of France?", ' + '"type": "text"}]}], "max_tokens": 1024, "stream": true}', + "ContentType": "application/json", + "Accept": "application/json", + } + assert tru_request == exp_request - def test_format_messages_openai_provider(self, boto_session, messages, system_prompt): + def test_format_messages_mistral_provider(self, boto_session, messages): """Test message formatting with OpenAI provider.""" - endpoint_config = { - "endpoint_name": "test-endpoint", - "region_name": "us-east-1", - "model_provider": ModelProvider.OPENAI, - } - payload_config = {"max_tokens": 1024} + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.MISTRAL + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) model = SageMakerAIModel( endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session ) - formatted = model._format_messages_for_provider(messages, system_prompt) - - assert len(formatted) == 2 # system + user message - assert formatted[0]["role"] == "system" - assert formatted[0]["content"] == system_prompt - assert formatted[1]["role"] == "user" - assert len(formatted[1]["content"]) == 1 - assert formatted[1]["content"][0]["text"] == "What is the capital of France?" - - def test_format_messages_llama_provider(self, boto_session, messages, system_prompt): - """Test message formatting with Llama provider.""" - endpoint_config = { - "endpoint_name": "test-endpoint", - "region_name": "us-east-1", - "model_provider": ModelProvider.LLAMA, + tru_request = model.format_request(messages) + exp_request = { + "EndpointName": "test-endpoint", + "Body": '{"messages": [{"role": "user", "content": "What is the capital of France?"}], ' + '"max_tokens": 1024, "stream": true}', + "ContentType": "application/json", + "Accept": "application/json", } - payload_config = {"max_tokens": 1024} + assert tru_request == exp_request - model = SageMakerAIModel( - endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + def test_format_messages_llama_provider(self, boto_session, messages): + """Test message formatting with Llama provider.""" + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.LLAMA ) - - formatted = model._format_messages_for_provider(messages, system_prompt) - - assert len(formatted) == 2 # system + user message - assert formatted[0]["role"] == "system" - assert formatted[0]["content"] == system_prompt - assert formatted[1]["role"] == "user" - assert isinstance(formatted[1]["content"], str) # Llama uses string content - assert "What is the capital of France?" in formatted[1]["content"] - - def test_format_messages_custom_provider(self, boto_session, messages, system_prompt): - """Test message formatting with custom provider.""" - - def custom_formatter(msgs, sys_prompt=None): - result = [{"custom_format": True}] - if sys_prompt: - result.append({"system": sys_prompt}) - for msg in msgs: - result.append({"speaker": msg["role"], "text": msg["content"][0]["text"]}) - return result - - endpoint_config = { - "endpoint_name": "test-endpoint", - "region_name": "us-east-1", - "model_provider": ModelProvider.CUSTOM, - "custom_formatter": custom_formatter, - } - payload_config = {"max_tokens": 1024} + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) model = SageMakerAIModel( endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session ) - formatted = model._format_messages_for_provider(messages, system_prompt) + tru_request = model.format_request(messages) - assert len(formatted) == 3 # custom marker + system + user - assert formatted[0]["custom_format"] is True - assert formatted[1]["system"] == system_prompt - assert formatted[2]["speaker"] == "user" - assert formatted[2]["text"] == "What is the capital of France?" + exp_request = { + "EndpointName": "test-endpoint", + "Body": '{"messages": [{"role": "user", "content": [{' + '"text": "What is the capital of France?", ' + '"type": "text"}]}], "max_tokens": 1024, "stream": true}', + "ContentType": "application/json", + "Accept": "application/json", + } + assert tru_request == exp_request - def test_format_messages_with_tool_use_llama(self, boto_session, system_prompt): + def test_format_messages_with_tool_use_llama(self, boto_session): """Test Llama formatting with tool use messages.""" - messages_with_tools = [ - {"role": "user", "content": [{"text": "What's the weather?"}]}, + + messages = [ { "role": "assistant", "content": [ - {"toolUse": {"toolUseId": "tool123", "name": "get_weather", "input": {"location": "Paris"}}} + { + "toolUse": { + "toolUseId": "c1", + "name": "calculator", + "input": {"expression": "2+2"}, + }, + }, ], }, - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "tool123", "content": [{"text": "Sunny, 25°C"}]}}], - }, ] - endpoint_config = { - "endpoint_name": "test-endpoint", - "region_name": "us-east-1", - "model_provider": ModelProvider.LLAMA, - } - payload_config = {"max_tokens": 1024} + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.LLAMA + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) model = SageMakerAIModel( endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session ) - formatted = model._format_messages_for_provider(messages_with_tools, system_prompt) - - assert len(formatted) == 4 # system + 3 messages - # Check that tool calls are formatted as strings for Llama - assistant_msg = formatted[2] - assert isinstance(assistant_msg["content"], str) - assert "TOOL_CALL" in assistant_msg["content"] - assert "get_weather" in assistant_msg["content"] - - # Check that tool results are formatted as strings - tool_result_msg = formatted[3] - assert isinstance(tool_result_msg["content"], str) - assert "TOOL_RESULT" in tool_result_msg["content"] - assert "Sunny, 25°C" in tool_result_msg["content"] + tru_request = model.format_request(messages) + exp_request = { + "EndpointName": "test-endpoint", + "Body": '{"messages": [{"role": "assistant", ' + '"tool_calls": [{"function": {"arguments": "{\\"expression\\": \\"2+2\\"}", ' + '"name": "calculator"}, "id": "c1"}]}], "max_tokens": 1024, "stream": true}', + "ContentType": "application/json", + "Accept": "application/json", + } + assert tru_request == exp_request - def test_format_messages_anthropic_with_images(self, boto_session): - """Test Anthropic formatting with image content.""" - import base64 + def test_format_messages_custom_provider(self, boto_session, messages, system_prompt): + """Test message formatting with custom provider.""" - # Create mock image data - image_bytes = b"fake_image_data" - messages_with_image = [ - { - "role": "user", - "content": [ - {"text": "What's in this image?"}, - {"image": {"format": "png", "source": {"bytes": image_bytes}}}, - ], - } - ] + def custom_formatter(msgs, sys_prompt=None): + result = [{"custom_format": True}] + if sys_prompt: + result.append({"system": sys_prompt}) + for msg in msgs: + result.append({"speaker": msg["role"], "text": msg["content"][0]["text"]}) + return result - endpoint_config = { - "endpoint_name": "test-endpoint", - "region_name": "us-east-1", - "model_provider": ModelProvider.ANTHROPIC, - } - payload_config = {"max_tokens": 1024} + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", + region_name="us-east-1", + model_provider=ModelProvider.CUSTOM, + custom_formatter=custom_formatter, + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) model = SageMakerAIModel( endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session ) - formatted = model._format_messages_for_provider(messages_with_image) - - assert len(formatted) == 1 - user_msg = formatted[0] - assert user_msg["role"] == "user" - assert len(user_msg["content"]) == 2 + tru_request = model.format_request(messages) + exp_request = { + "EndpointName": "test-endpoint", + "Body": '{"messages": [{"custom_format": true}, ' + '{"speaker": "user", "text": "What is the capital of France?"}], ' + '"max_tokens": 1024, "stream": true}', + "ContentType": "application/json", + "Accept": "application/json", + } + assert tru_request == exp_request - # Check text content - text_content = user_msg["content"][0] - assert text_content["type"] == "text" - assert text_content["text"] == "What's in this image?" + def test_custom_formatter_required(self, boto_session): + """Test that custom formatter is required when using CUSTOM provider.""" + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.CUSTOM + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) - # Check image content - image_content = user_msg["content"][1] - assert image_content["type"] == "image" - assert image_content["source"]["type"] == "base64" - assert image_content["source"]["media_type"] == "image/png" - assert image_content["source"]["data"] == base64.b64encode(image_bytes).decode("utf-8") + with pytest.raises(ValueError, match="custom_formatter is required when model_provider is CUSTOM"): + SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session) class TestDataClasses: @@ -825,7 +776,6 @@ def test_model_provider_values(self): assert ModelProvider.OPENAI.value == "openai" assert ModelProvider.MISTRAL.value == "mistral" assert ModelProvider.LLAMA.value == "llama" - assert ModelProvider.ANTHROPIC.value == "anthropic" assert ModelProvider.CUSTOM.value == "custom" def test_model_provider_from_string(self): @@ -833,7 +783,6 @@ def test_model_provider_from_string(self): assert ModelProvider("openai") == ModelProvider.OPENAI assert ModelProvider("mistral") == ModelProvider.MISTRAL assert ModelProvider("llama") == ModelProvider.LLAMA - assert ModelProvider("anthropic") == ModelProvider.ANTHROPIC assert ModelProvider("custom") == ModelProvider.CUSTOM def test_invalid_model_provider(self): diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 246879da7..e490c7bb0 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1064,7 +1064,7 @@ async def _run_context_injection_test(context_tool: AgentTool): "content": [ {"text": "Tool 'context_tool' (ID: test-id)"}, {"text": "injected agent 'test_agent' processed: some_message"}, - {"text": "context agent 'test_agent'"} + {"text": "context agent 'test_agent'"}, ], "toolUseId": "test-id", } @@ -1151,7 +1151,7 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> dict: assert len(tool_results) == 1 tool_result = tool_results[0] - + # Should get a validation error because tool_context is required but not provided assert tool_result["status"] == "error" assert "tool_context" in tool_result["content"][0]["text"].lower() @@ -1173,10 +1173,7 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str: tool_use={ "toolUseId": "test-id-2", "name": "context_tool", - "input": { - "message": "some_message", - "tool_context": "my_custom_context_string" - }, + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, }, invocation_state={ "agent": Agent(name="test_agent"), From 328c6899423ca4ee4522bb87e7acc83571d74187 Mon Sep 17 00:00:00 2001 From: Dan Ferguson Date: Fri, 15 Aug 2025 14:05:25 -0400 Subject: [PATCH 3/3] Removed custom provider for now. --- src/strands/models/sagemaker.py | 34 +++++----- tests/strands/models/test_sagemaker.py | 93 +++++++++++++------------- 2 files changed, 66 insertions(+), 61 deletions(-) diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index bea80d812..9bc2fb83d 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -31,7 +31,7 @@ class ModelProvider(Enum): OPENAI = "openai" MISTRAL = "mistral" LLAMA = "llama" - CUSTOM = "custom" + # CUSTOM = "custom" @dataclass @@ -101,7 +101,7 @@ def __init__(self, **kwargs: dict): class SageMakerAIModel(Model): """Amazon SageMaker model provider implementation.""" - client: SageMakerRuntimeClient # type: ignore[assignment] + client: SageMakerRuntimeClient class SageMakerAIPayloadSchema(TypedDict, total=False): """Payload schema for the Amazon SageMaker AI model. @@ -173,13 +173,13 @@ def __init__( self.endpoint_config["model_provider"] = ModelProvider(self.endpoint_config["model_provider"]) # Validate custom formatter if using CUSTOM provider - if self.endpoint_config["model_provider"] == ModelProvider.CUSTOM and not self.endpoint_config.get( - "custom_formatter" - ): - raise ValueError("custom_formatter is required when model_provider is CUSTOM") - logger.debug( - "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config - ) + # if self.endpoint_config["model_provider"] == ModelProvider.CUSTOM and not self.endpoint_config.get( + # "custom_formatter" + # ): + # raise ValueError("custom_formatter is required when model_provider is CUSTOM") + # logger.debug( + # "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config + # ) region = self.endpoint_config.get("region_name") or os.getenv("AWS_REGION") or "us-west-2" session = boto_session or boto3.Session(region_name=str(region)) @@ -210,7 +210,7 @@ def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) -> self.endpoint_config.update(endpoint_config) @override - def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: ignore[override] + def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": """Get the Amazon SageMaker model configuration. Returns: @@ -318,15 +318,17 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s if provider == ModelProvider.OPENAI: return SageMakerAIModel._format_openai_messages(messages, system_prompt) elif provider == ModelProvider.MISTRAL: - return self._format_mistral_messages(messages, system_prompt) + return SageMakerAIModel._format_mistral_messages(messages, system_prompt) elif provider == ModelProvider.LLAMA: - return self._format_llama_messages(messages, system_prompt) - elif provider == ModelProvider.CUSTOM: - custom_formatter = self.endpoint_config["custom_formatter"] - return custom_formatter(messages, system_prompt) + return SageMakerAIModel._format_llama_messages(messages, system_prompt) + # elif provider == ModelProvider.CUSTOM: + # custom_formatter = self.endpoint_config["custom_formatter"] + # if custom_formatter is None: + # raise ValueError("custom_formatter is required when model_provider is CUSTOM") + # return custom_formatter(messages, system_prompt) else: # Default to OpenAI format - return self._format_openai_messages(messages, system_prompt) + return SageMakerAIModel._format_openai_messages(messages, system_prompt) @staticmethod def _format_openai_messages(messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index e7f2c20c7..49e9d81a3 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -671,49 +671,52 @@ def test_format_messages_with_tool_use_llama(self, boto_session): } assert tru_request == exp_request - def test_format_messages_custom_provider(self, boto_session, messages, system_prompt): - """Test message formatting with custom provider.""" - - def custom_formatter(msgs, sys_prompt=None): - result = [{"custom_format": True}] - if sys_prompt: - result.append({"system": sys_prompt}) - for msg in msgs: - result.append({"speaker": msg["role"], "text": msg["content"][0]["text"]}) - return result - - endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( - endpoint_name="test-endpoint", - region_name="us-east-1", - model_provider=ModelProvider.CUSTOM, - custom_formatter=custom_formatter, - ) - payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) - - model = SageMakerAIModel( - endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session - ) - - tru_request = model.format_request(messages) - exp_request = { - "EndpointName": "test-endpoint", - "Body": '{"messages": [{"custom_format": true}, ' - '{"speaker": "user", "text": "What is the capital of France?"}], ' - '"max_tokens": 1024, "stream": true}', - "ContentType": "application/json", - "Accept": "application/json", - } - assert tru_request == exp_request - - def test_custom_formatter_required(self, boto_session): - """Test that custom formatter is required when using CUSTOM provider.""" - endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( - endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.CUSTOM - ) - payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) - - with pytest.raises(ValueError, match="custom_formatter is required when model_provider is CUSTOM"): - SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session) + # def test_format_messages_custom_provider(self, boto_session, messages, system_prompt): + # """Test message formatting with custom provider.""" + # + # def custom_formatter(msgs, sys_prompt=None): + # result = [{"custom_format": True}] + # if sys_prompt: + # result.append({"system": sys_prompt}) + # for msg in msgs: + # result.append({"speaker": msg["role"], "text": msg["content"][0]["text"]}) + # return result + # + # endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + # endpoint_name="test-endpoint", + # region_name="us-east-1", + # model_provider=ModelProvider.CUSTOM, + # custom_formatter=custom_formatter, + # ) + # payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + # + # model = SageMakerAIModel( + # endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + # ) + # + # tru_request = model.format_request(messages) + # exp_request = { + # "EndpointName": "test-endpoint", + # "Body": '{"messages": [{"custom_format": true}, ' + # '{"speaker": "user", "text": "What is the capital of France?"}], ' + # '"max_tokens": 1024, "stream": true}', + # "ContentType": "application/json", + # "Accept": "application/json", + # } + # assert tru_request == exp_request + # + # def test_custom_formatter_required(self, boto_session): + # """Test that custom formatter is required when using CUSTOM provider.""" + # endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + # endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.CUSTOM + # ) + # payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + # + # with pytest.raises(ValueError, match="custom_formatter is required when model_provider is CUSTOM"): + # SageMakerAIModel( + # endpoint_config=endpoint_config, + # payload_config=payload_config, + # boto_session=boto_session) class TestDataClasses: @@ -776,14 +779,14 @@ def test_model_provider_values(self): assert ModelProvider.OPENAI.value == "openai" assert ModelProvider.MISTRAL.value == "mistral" assert ModelProvider.LLAMA.value == "llama" - assert ModelProvider.CUSTOM.value == "custom" + # assert ModelProvider.CUSTOM.value == "custom" def test_model_provider_from_string(self): """Test creating ModelProvider from string.""" assert ModelProvider("openai") == ModelProvider.OPENAI assert ModelProvider("mistral") == ModelProvider.MISTRAL assert ModelProvider("llama") == ModelProvider.LLAMA - assert ModelProvider("custom") == ModelProvider.CUSTOM + # assert ModelProvider("custom") == ModelProvider.CUSTOM def test_invalid_model_provider(self): """Test invalid model provider string raises ValueError."""