diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 421b06e52..667346232 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -86,7 +86,8 @@ def get_config(self) -> LlamaConfig: """ return self.config - def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + @staticmethod + def _format_request_message_content(content: ContentBlock) -> dict[str, Any]: """Format a LlamaAPI content block. - NOTE: "reasoningContent" and "video" are not supported currently. @@ -116,7 +117,8 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: + @staticmethod + def _format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: """Format a Llama API tool call. Args: @@ -133,7 +135,8 @@ def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any] "id": tool_use["toolUseId"], } - def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: + @staticmethod + def _format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: """Format a Llama API tool message. Args: @@ -153,10 +156,11 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any return { "role": "tool", "tool_call_id": tool_result["toolUseId"], - "content": [self._format_request_message_content(content) for content in contents], + "content": [LlamaAPIModel._format_request_message_content(content) for content in contents], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: """Format a LlamaAPI compatible messages array. Args: @@ -174,17 +178,17 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = "" formatted_contents = [ - self._format_request_message_content(content) + cls._format_request_message_content(content=content) for content in contents if not any(block_type in content for block_type in ["toolResult", "toolUse"]) ] formatted_tool_calls = [ - self._format_request_message_tool_call(content["toolUse"]) + cls._format_request_message_tool_call(tool_use=content["toolUse"]) for content in contents if "toolUse" in content ] formatted_tool_messages = [ - self._format_request_tool_message(content["toolResult"]) + cls._format_request_tool_message(tool_result=content["toolResult"]) for content in contents if "toolResult" in content ] @@ -220,7 +224,7 @@ def format_request( format. """ request = { - "messages": self._format_request_messages(messages, system_prompt), + "messages": self.format_request_messages(messages, system_prompt), "model": self.config["model_id"], "stream": True, "tools": [ diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 8855b6d64..3c55f1027 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -112,7 +112,8 @@ def get_config(self) -> MistralConfig: """ return self.config - def _format_request_message_content(self, content: ContentBlock) -> Union[str, dict[str, Any]]: + @staticmethod + def _format_request_message_content(content: ContentBlock) -> Union[str, dict[str, Any]]: """Format a Mistral content block. Args: @@ -141,7 +142,8 @@ def _format_request_message_content(self, content: ContentBlock) -> Union[str, d raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: + @staticmethod + def _format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: """Format a Mistral tool call. Args: @@ -159,7 +161,8 @@ def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any] "type": "function", } - def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: + @staticmethod + def _format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: """Format a Mistral tool message. Args: @@ -184,7 +187,8 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "tool_call_id": tool_result["toolUseId"], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: """Format a Mistral compatible messages array. Args: @@ -209,13 +213,13 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s for content in contents: if "text" in content: - formatted_content = self._format_request_message_content(content) + formatted_content = cls._format_request_message_content(content) if isinstance(formatted_content, str): text_contents.append(formatted_content) elif "toolUse" in content: - tool_calls.append(self._format_request_message_tool_call(content["toolUse"])) + tool_calls.append(cls._format_request_message_tool_call(content["toolUse"])) elif "toolResult" in content: - tool_messages.append(self._format_request_tool_message(content["toolResult"])) + tool_messages.append(cls._format_request_tool_message(content["toolResult"])) if text_contents or tool_calls: formatted_message: dict[str, Any] = { @@ -251,7 +255,7 @@ def format_request( """ request: dict[str, Any] = { "model": self.config["model_id"], - "messages": self._format_request_messages(messages, system_prompt), + "messages": self.format_request_messages(messages, system_prompt), } if "max_tokens" in self.config: diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 9cfe27d9e..9bc2fb83d 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -1,10 +1,13 @@ """Amazon SageMaker model provider.""" +import base64 import json import logging +import mimetypes import os from dataclasses import dataclass -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast +from enum import Enum +from typing import Any, AsyncGenerator, Callable, Literal, Optional, Type, TypedDict, TypeVar, Union, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -14,14 +17,23 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec -from .openai import OpenAIModel +from ..types.tools import ToolResult, ToolSpec, ToolUse +from .model import Model T = TypeVar("T", bound=BaseModel) logger = logging.getLogger(__name__) +class ModelProvider(Enum): + """Supported model providers for prompt formatting.""" + + OPENAI = "openai" + MISTRAL = "mistral" + LLAMA = "llama" + # CUSTOM = "custom" + + @dataclass class UsageMetadata: """Usage metadata for the model. @@ -86,10 +98,10 @@ def __init__(self, **kwargs: dict): self.function = FunctionCall(**kwargs.get("function", {"name": "", "arguments": ""})) -class SageMakerAIModel(OpenAIModel): +class SageMakerAIModel(Model): """Amazon SageMaker model provider implementation.""" - client: SageMakerRuntimeClient # type: ignore[assignment] + client: SageMakerRuntimeClient class SageMakerAIPayloadSchema(TypedDict, total=False): """Payload schema for the Amazon SageMaker AI model. @@ -120,7 +132,8 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): Attributes: endpoint_name: The name of the SageMaker endpoint to invoke inference_component_name: The name of the inference component to use - + model_provider: The model provider format to use for prompt formatting (defaults to openai) + custom_formatter: Custom formatter function when model_provider is CUSTOM additional_args: Other request parameters, as supported by https://bit.ly/sagemaker-invoke-endpoint-params """ @@ -129,6 +142,8 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): inference_component_name: Union[str, None] target_model: Union[Optional[str], None] target_variant: Union[Optional[str], None] + model_provider: Optional[Union[ModelProvider, str]] + custom_formatter: Optional[Callable[[Messages, Optional[str]], list[dict[str, Any]]]] additional_args: Optional[dict[str, Any]] def __init__( @@ -150,9 +165,21 @@ def __init__( payload_config.setdefault("tool_results_as_user_messages", False) self.endpoint_config = dict(endpoint_config) self.payload_config = dict(payload_config) - logger.debug( - "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config - ) + + # Set default model provider if not specified + if "model_provider" not in self.endpoint_config or self.endpoint_config["model_provider"] is None: + self.endpoint_config["model_provider"] = ModelProvider.OPENAI + elif isinstance(self.endpoint_config["model_provider"], str): + self.endpoint_config["model_provider"] = ModelProvider(self.endpoint_config["model_provider"]) + + # Validate custom formatter if using CUSTOM provider + # if self.endpoint_config["model_provider"] == ModelProvider.CUSTOM and not self.endpoint_config.get( + # "custom_formatter" + # ): + # raise ValueError("custom_formatter is required when model_provider is CUSTOM") + # logger.debug( + # "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config + # ) region = self.endpoint_config.get("region_name") or os.getenv("AWS_REGION") or "us-west-2" session = boto_session or boto3.Session(region_name=str(region)) @@ -183,7 +210,7 @@ def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) -> self.endpoint_config.update(endpoint_config) @override - def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: ignore[override] + def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": """Get the Amazon SageMaker model configuration. Returns: @@ -191,7 +218,6 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i """ return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config) - @override def format_request( self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None ) -> dict[str, Any]: @@ -205,7 +231,8 @@ def format_request( Returns: An Amazon SageMaker chat streaming request. """ - formatted_messages = self.format_request_messages(messages, system_prompt) + # formatted_messages = self._format_messages_for_provider(messages, tool_specs, system_prompt) + formatted_messages = self._format_request_messages(messages, system_prompt) payload = { "messages": formatted_messages, @@ -276,6 +303,245 @@ def format_request( return request + def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages based on the selected model provider. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + Formatted messages array for the specified provider. + """ + provider = self.endpoint_config["model_provider"] + + if provider == ModelProvider.OPENAI: + return SageMakerAIModel._format_openai_messages(messages, system_prompt) + elif provider == ModelProvider.MISTRAL: + return SageMakerAIModel._format_mistral_messages(messages, system_prompt) + elif provider == ModelProvider.LLAMA: + return SageMakerAIModel._format_llama_messages(messages, system_prompt) + # elif provider == ModelProvider.CUSTOM: + # custom_formatter = self.endpoint_config["custom_formatter"] + # if custom_formatter is None: + # raise ValueError("custom_formatter is required when model_provider is CUSTOM") + # return custom_formatter(messages, system_prompt) + else: + # Default to OpenAI format + return SageMakerAIModel._format_openai_messages(messages, system_prompt) + + @staticmethod + def _format_openai_messages(messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages for OpenAI-compatible models.""" + from strands.models.openai import OpenAIModel + + return OpenAIModel.format_request_messages(messages, system_prompt) + + @staticmethod + def _format_mistral_messages(messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages for Mistral models.""" + from strands.models.mistral import MistralModel + + return MistralModel.format_request_messages(messages, system_prompt) + + @staticmethod + def _format_llama_messages(messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages for Llama models.""" + from strands.models.llamaapi import LlamaAPIModel + + return LlamaAPIModel.format_request_messages(messages, system_prompt) + + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + """Format a content block. + + Args: + content: Message content. + + Returns: + Formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a SageMaker-compatible format. + """ + if "reasoningContent" in content and content["reasoningContent"]: + return { + "signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""), + "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""), + "type": "thinking", + } + elif not content.get("reasoningContent", None): + content.pop("reasoningContent", None) + + if "video" in content: + return { + "type": "video_url", + "video_url": { + "detail": "auto", + "url": content["video"]["source"]["bytes"], + }, + } + + if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "file": { + "file_data": f"data:{mime_type};base64,{file_data}", + "filename": content["document"]["name"], + }, + "type": "file", + } + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @classmethod + def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + """Format a tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + Formatted tool call. + """ + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + Formatted messages array. + """ + formatted_messages: list[dict[str, Any]] + formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + + for message in messages: + contents = message["content"] + + formatted_contents = [ + cls.format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content + ] + formatted_tool_messages = [ + cls.format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format a response event into a standardized message chunk. + + Args: + event: A response event from the model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + @override async def stream( self, @@ -474,7 +740,6 @@ async def stream( logger.debug("finished streaming response from model") - @override @classmethod def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: """Format a SageMaker compatible tool message. @@ -504,43 +769,6 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: "content": content_string, # String instead of list } - @override - @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: - """Format a content block. - - Args: - content: Message content. - - Returns: - Formatted content block. - - Raises: - TypeError: If the content block type cannot be converted to a SageMaker-compatible format. - """ - # if "text" in content and not isinstance(content["text"], str): - # return {"type": "text", "text": str(content["text"])} - - if "reasoningContent" in content and content["reasoningContent"]: - return { - "signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""), - "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""), - "type": "thinking", - } - elif not content.get("reasoningContent", None): - content.pop("reasoningContent", None) - - if "video" in content: - return { - "type": "video_url", - "video_url": { - "detail": "auto", - "url": content["video"]["source"]["bytes"], - }, - } - - return super().format_request_message_content(content) - @override async def structured_output( self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index ba395b2d6..49e9d81a3 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -10,6 +10,7 @@ from strands.models.sagemaker import ( FunctionCall, + ModelProvider, SageMakerAIModel, ToolCall, UsageMetadata, @@ -52,7 +53,11 @@ def payload_config() -> Dict[str, Any]: @pytest.fixture -def model(boto_session, endpoint_config, payload_config): +def model( + boto_session, + endpoint_config: SageMakerAIModel.SageMakerAIEndpointConfig, + payload_config: SageMakerAIModel.SageMakerAIPayloadSchema, +): """SageMaker model instance with mocked boto session.""" return SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session) @@ -92,8 +97,10 @@ class TestSageMakerAIModel: def test_init_default(self, boto_session): """Test initialization with default parameters.""" - endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} - payload_config = {"max_tokens": 1024} + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1" + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) model = SageMakerAIModel( endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session ) @@ -108,16 +115,10 @@ def test_init_default(self, boto_session): def test_init_with_all_params(self, boto_session): """Test initialization with all parameters.""" - endpoint_config = { - "endpoint_name": "test-endpoint", - "inference_component_name": "test-component", - "region_name": "us-west-2", - } - payload_config = { - "stream": False, - "max_tokens": 1024, - "temperature": 0.7, - } + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", inference_component_name="test-component", region_name="us-west-2" + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024, stream=False, temperature=0.7) client_config = BotocoreConfig(user_agent_extra="test-agent") model = SageMakerAIModel( @@ -140,8 +141,10 @@ def test_init_with_all_params(self, boto_session): def test_init_with_client_config(self, boto_session): """Test initialization with client configuration.""" - endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} - payload_config = {"max_tokens": 1024} + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1" + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) client_config = BotocoreConfig(user_agent_extra="test-agent") SageMakerAIModel( @@ -522,6 +525,200 @@ async def test_stream_non_streaming_with_tool_calls(self, sagemaker_client, mode assert usage_data["totalTokens"] == 30 +class TestModelProvider: + """Test suite for model provider functionality.""" + + def test_default_model_provider(self, boto_session): + """Test that default model provider is set to OpenAI.""" + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1" + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + assert model.endpoint_config["model_provider"] == ModelProvider.OPENAI + + def test_model_provider_enum(self, boto_session): + """Test setting model provider using enum.""" + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.LLAMA + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + assert model.endpoint_config["model_provider"] == ModelProvider.LLAMA + + def test_model_provider_string(self, boto_session): + """Test setting model provider using string.""" + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.MISTRAL + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + assert model.endpoint_config["model_provider"] == ModelProvider.MISTRAL + + def test_format_messages_openai_provider(self, boto_session, messages): + """Test message formatting with OpenAI provider.""" + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.OPENAI + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + tru_request = model.format_request(messages) + exp_request = { + "EndpointName": "test-endpoint", + "Body": '{"messages": [{"role": "user", "content": [{' + '"text": "What is the capital of France?", ' + '"type": "text"}]}], "max_tokens": 1024, "stream": true}', + "ContentType": "application/json", + "Accept": "application/json", + } + assert tru_request == exp_request + + def test_format_messages_mistral_provider(self, boto_session, messages): + """Test message formatting with OpenAI provider.""" + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.MISTRAL + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + tru_request = model.format_request(messages) + exp_request = { + "EndpointName": "test-endpoint", + "Body": '{"messages": [{"role": "user", "content": "What is the capital of France?"}], ' + '"max_tokens": 1024, "stream": true}', + "ContentType": "application/json", + "Accept": "application/json", + } + assert tru_request == exp_request + + def test_format_messages_llama_provider(self, boto_session, messages): + """Test message formatting with Llama provider.""" + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.LLAMA + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + tru_request = model.format_request(messages) + + exp_request = { + "EndpointName": "test-endpoint", + "Body": '{"messages": [{"role": "user", "content": [{' + '"text": "What is the capital of France?", ' + '"type": "text"}]}], "max_tokens": 1024, "stream": true}', + "ContentType": "application/json", + "Accept": "application/json", + } + assert tru_request == exp_request + + def test_format_messages_with_tool_use_llama(self, boto_session): + """Test Llama formatting with tool use messages.""" + + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "c1", + "name": "calculator", + "input": {"expression": "2+2"}, + }, + }, + ], + }, + ] + + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.LLAMA + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + tru_request = model.format_request(messages) + exp_request = { + "EndpointName": "test-endpoint", + "Body": '{"messages": [{"role": "assistant", ' + '"tool_calls": [{"function": {"arguments": "{\\"expression\\": \\"2+2\\"}", ' + '"name": "calculator"}, "id": "c1"}]}], "max_tokens": 1024, "stream": true}', + "ContentType": "application/json", + "Accept": "application/json", + } + assert tru_request == exp_request + + # def test_format_messages_custom_provider(self, boto_session, messages, system_prompt): + # """Test message formatting with custom provider.""" + # + # def custom_formatter(msgs, sys_prompt=None): + # result = [{"custom_format": True}] + # if sys_prompt: + # result.append({"system": sys_prompt}) + # for msg in msgs: + # result.append({"speaker": msg["role"], "text": msg["content"][0]["text"]}) + # return result + # + # endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + # endpoint_name="test-endpoint", + # region_name="us-east-1", + # model_provider=ModelProvider.CUSTOM, + # custom_formatter=custom_formatter, + # ) + # payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + # + # model = SageMakerAIModel( + # endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + # ) + # + # tru_request = model.format_request(messages) + # exp_request = { + # "EndpointName": "test-endpoint", + # "Body": '{"messages": [{"custom_format": true}, ' + # '{"speaker": "user", "text": "What is the capital of France?"}], ' + # '"max_tokens": 1024, "stream": true}', + # "ContentType": "application/json", + # "Accept": "application/json", + # } + # assert tru_request == exp_request + # + # def test_custom_formatter_required(self, boto_session): + # """Test that custom formatter is required when using CUSTOM provider.""" + # endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + # endpoint_name="test-endpoint", region_name="us-east-1", model_provider=ModelProvider.CUSTOM + # ) + # payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024) + # + # with pytest.raises(ValueError, match="custom_formatter is required when model_provider is CUSTOM"): + # SageMakerAIModel( + # endpoint_config=endpoint_config, + # payload_config=payload_config, + # boto_session=boto_session) + + class TestDataClasses: """Test suite for data classes.""" @@ -572,3 +769,26 @@ def test_tool_call(self): assert tool2.type == "function" assert tool2.function.name == "get_time" assert tool2.function.arguments == '{"timezone": "UTC"}' + + +class TestModelProviderEnum: + """Test suite for ModelProvider enum.""" + + def test_model_provider_values(self): + """Test ModelProvider enum values.""" + assert ModelProvider.OPENAI.value == "openai" + assert ModelProvider.MISTRAL.value == "mistral" + assert ModelProvider.LLAMA.value == "llama" + # assert ModelProvider.CUSTOM.value == "custom" + + def test_model_provider_from_string(self): + """Test creating ModelProvider from string.""" + assert ModelProvider("openai") == ModelProvider.OPENAI + assert ModelProvider("mistral") == ModelProvider.MISTRAL + assert ModelProvider("llama") == ModelProvider.LLAMA + # assert ModelProvider("custom") == ModelProvider.CUSTOM + + def test_invalid_model_provider(self): + """Test invalid model provider string raises ValueError.""" + with pytest.raises(ValueError): + ModelProvider("invalid_provider") diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 246879da7..e490c7bb0 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1064,7 +1064,7 @@ async def _run_context_injection_test(context_tool: AgentTool): "content": [ {"text": "Tool 'context_tool' (ID: test-id)"}, {"text": "injected agent 'test_agent' processed: some_message"}, - {"text": "context agent 'test_agent'"} + {"text": "context agent 'test_agent'"}, ], "toolUseId": "test-id", } @@ -1151,7 +1151,7 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> dict: assert len(tool_results) == 1 tool_result = tool_results[0] - + # Should get a validation error because tool_context is required but not provided assert tool_result["status"] == "error" assert "tool_context" in tool_result["content"][0]["text"].lower() @@ -1173,10 +1173,7 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str: tool_use={ "toolUseId": "test-id-2", "name": "context_tool", - "input": { - "message": "some_message", - "tool_context": "my_custom_context_string" - }, + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, }, invocation_state={ "agent": Agent(name="test_agent"),