diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 29cb40d40..0292df6ec 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -18,7 +18,7 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolChoice, ToolSpec from .model import Model logger = logging.getLogger(__name__) @@ -192,7 +192,11 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: return formatted_messages def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, ) -> dict[str, Any]: """Format an Anthropic streaming request. @@ -200,6 +204,7 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Returns: An Anthropic streaming request. @@ -220,6 +225,7 @@ def format_request( } for tool_spec in tool_specs or [] ], + **({"tool_choice": tool_choice} if tool_choice else {}), **({"system": system_prompt} if system_prompt else {}), **(self.config.get("params") or {}), } @@ -347,6 +353,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Anthropic model. @@ -355,6 +362,7 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -365,7 +373,7 @@ async def stream( ModelThrottledException: If the request is throttled by Anthropic. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("request=<%s>", request) logger.debug("invoking model") @@ -407,7 +415,13 @@ async def structured_output( """ tool_spec = convert_pydantic_to_tool_spec(output_model) - response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs) + response = self.stream( + messages=prompt, + tool_specs=[tool_spec], + system_prompt=system_prompt, + tool_choice=cast(ToolChoice, {"any": {}}), + **kwargs, + ) async for event in process_stream(response): yield event diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ace35640a..8182156f4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -7,7 +7,7 @@ import json import logging import os -from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -20,7 +20,7 @@ from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec +from ..types.tools import ToolChoice, ToolResult, ToolSpec from .model import Model logger = logging.getLogger(__name__) @@ -168,6 +168,7 @@ def format_request( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, ) -> dict[str, Any]: """Format a Bedrock converse stream request. @@ -175,6 +176,7 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Returns: A Bedrock converse stream request. @@ -197,7 +199,7 @@ def format_request( else [] ), ], - "toolChoice": {"auto": {}}, + **({"toolChoice": tool_choice} if tool_choice else {}), } } if tool_specs @@ -355,6 +357,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Bedrock model. @@ -366,6 +369,7 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -384,7 +388,7 @@ def callback(event: Optional[StreamEvent] = None) -> None: loop = asyncio.get_event_loop() queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() - thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt) + thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt, tool_choice) task = asyncio.create_task(thread) while True: @@ -402,6 +406,7 @@ def _stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, ) -> None: """Stream conversation with the Bedrock model. @@ -413,6 +418,7 @@ def _stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Raises: ContextWindowOverflowException: If the input exceeds the model's context window. @@ -420,7 +426,7 @@ def _stream( """ try: logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("request=<%s>", request) logger.debug("invoking model") @@ -624,7 +630,13 @@ async def structured_output( """ tool_spec = convert_pydantic_to_tool_spec(output_model) - response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs) + response = self.stream( + messages=prompt, + tool_specs=[tool_spec], + system_prompt=system_prompt, + tool_choice=cast(ToolChoice, {"any": {}}), + **kwargs, + ) async for event in streaming.process_stream(response): yield event diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index c1e99f1a2..2cfd654d0 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -14,7 +14,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolChoice, ToolSpec from .openai import OpenAIModel logger = logging.getLogger(__name__) @@ -109,6 +109,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -117,6 +118,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 421b06e52..f93d7aa0a 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -18,7 +18,7 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent, Usage -from ..types.tools import ToolResult, ToolSpec, ToolUse +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from .model import Model logger = logging.getLogger(__name__) @@ -327,6 +327,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LlamaAPI model. @@ -335,6 +336,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 8855b6d64..d9a4ed033 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -15,7 +15,7 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from .model import Model logger = logging.getLogger(__name__) @@ -394,6 +394,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Mistral model. @@ -402,6 +403,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index cb24b704d..6fc46891c 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -8,7 +8,7 @@ from ..types.content import Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolChoice, ToolSpec logger = logging.getLogger(__name__) @@ -70,6 +70,7 @@ def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, **kwargs: Any, ) -> AsyncIterable[StreamEvent]: """Stream conversation with the model. @@ -84,6 +85,7 @@ def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 76cd87d72..5a7d71011 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -13,7 +13,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StopReason, StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolChoice, ToolSpec from .model import Model logger = logging.getLogger(__name__) @@ -284,6 +284,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Ollama model. @@ -292,6 +293,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 1076fbae4..3ace6bf0c 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -16,7 +16,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from .model import Model logger = logging.getLogger(__name__) @@ -171,6 +171,27 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: "content": [cls.format_request_message_content(content) for content in contents], } + @classmethod + def format_request_tool_choice(cls, tool_choice: ToolChoice) -> Union[str, dict[str, Any]]: + """Format a tool choice for OpenAI compatibility. + + Args: + tool_choice: Tool choice configuration in Bedrock format. + + Returns: + OpenAI compatible tool choice format. + """ + match tool_choice: + case {"auto": _}: + return "auto" # OpenAI SDK doesn't define constants for these values + case {"any": _}: + return "required" + case {"tool": {"name": tool_name}}: + return {"type": "function", "function": {"name": tool_name}} + case _: + # This should not happen with proper typing, but handle gracefully + return "auto" + @classmethod def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: """Format an OpenAI compatible messages array. @@ -213,7 +234,11 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, ) -> dict[str, Any]: """Format an OpenAI compatible chat streaming request. @@ -221,6 +246,7 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Returns: An OpenAI compatible chat streaming request. @@ -245,6 +271,7 @@ def format_request( } for tool_spec in tool_specs or [] ], + **({"tool_choice": self.format_request_tool_choice(tool_choice)} if tool_choice else {}), **cast(dict[str, Any], self.config.get("params", {})), } @@ -326,6 +353,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the OpenAI model. @@ -334,13 +362,14 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("formatted request=<%s>", request) logger.debug("invoking model") diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 9cfe27d9e..8f7cbfdf8 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -14,7 +14,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec +from ..types.tools import ToolChoice, ToolResult, ToolSpec from .openai import OpenAIModel T = TypeVar("T", bound=BaseModel) @@ -193,7 +193,11 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i @override def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, ) -> dict[str, Any]: """Format an Amazon SageMaker chat streaming request. @@ -201,6 +205,8 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** Returns: An Amazon SageMaker chat streaming request. @@ -282,6 +288,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the SageMaker model. @@ -290,6 +297,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index f6a3da3d8..ff9272ba2 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -16,7 +16,7 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from .model import Model logger = logging.getLogger(__name__) @@ -352,6 +352,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[ToolChoice] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Writer model. @@ -360,6 +361,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index bb7c874f6..4a261c7a3 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -142,10 +142,15 @@ class ToolContext: agent: "Agent" +# Individual ToolChoice type aliases +ToolChoiceAutoDict = dict[Literal["auto"], ToolChoiceAuto] +ToolChoiceAnyDict = dict[Literal["any"], ToolChoiceAny] +ToolChoiceToolDict = dict[Literal["tool"], ToolChoiceTool] + ToolChoice = Union[ - dict[Literal["auto"], ToolChoiceAuto], - dict[Literal["any"], ToolChoiceAny], - dict[Literal["tool"], ToolChoiceTool], + ToolChoiceAutoDict, + ToolChoiceAnyDict, + ToolChoiceToolDict, ] """ Configuration for how the model should choose tools. diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 5e8d69ea7..1e5fa01d5 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -417,6 +417,72 @@ def test_format_request_with_empty_content(model, model_id, max_tokens): assert tru_request == exp_request +def test_format_request_tool_choice_auto(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tool_choice = {"auto": {}} + + tru_request = model.format_request(messages, tool_specs, tool_choice=tool_choice) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [ + { + "name": "test_tool", + "description": "test tool", + "input_schema": {"key": "value"}, + } + ], + "tool_choice": tool_choice, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_any(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tool_choice = {"any": {}} + + tru_request = model.format_request(messages, tool_specs, tool_choice=tool_choice) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [ + { + "name": "test_tool", + "description": "test tool", + "input_schema": {"key": "value"}, + } + ], + "tool_choice": tool_choice, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_tool(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tool_choice = {"tool": {"name": "test_tool"}} + + tru_request = model.format_request(messages, tool_specs, tool_choice=tool_choice) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [ + { + "name": "test_tool", + "description": "test tool", + "input_schema": {"key": "value"}, + } + ], + "tool_choice": tool_choice, + } + + assert tru_request == exp_request + + def test_format_chunk_message_start(model): event = {"type": "message_start"} diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 09e508845..8fc67f7fb 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -376,7 +376,57 @@ def test_format_request_tool_specs(model, messages, model_id, tool_spec): "system": [], "toolConfig": { "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, + }, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_auto(model, messages, model_id, tool_spec): + tool_choice = {"auto": {}} + tru_request = model.format_request(messages, [tool_spec], tool_choice=tool_choice) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": tool_choice, + }, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_any(model, messages, model_id, tool_spec): + tool_choice = {"any": {}} + tru_request = model.format_request(messages, [tool_spec], tool_choice=tool_choice) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": tool_choice, + }, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_tool(model, messages, model_id, tool_spec): + tool_choice = {"tool": {"name": "test_tool"}} + tru_request = model.format_request(messages, [tool_spec], tool_choice=tool_choice) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": tool_choice, }, } @@ -396,7 +446,6 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): {"toolSpec": tool_spec}, {"cachePoint": {"type": cache_type}}, ], - "toolChoice": {"auto": {}}, }, } @@ -470,7 +519,6 @@ async def test_stream(bedrock_client, model, messages, tool_spec, model_id, addi "system": [], "toolConfig": { "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, }, } @@ -521,7 +569,6 @@ async def test_stream_stream_input_guardrails( "system": [], "toolConfig": { "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, }, } @@ -578,7 +625,6 @@ async def test_stream_stream_output_guardrails( "system": [], "toolConfig": { "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, }, } @@ -635,7 +681,6 @@ async def test_stream_output_guardrails_redacts_input_and_output( "system": [], "toolConfig": { "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, }, } @@ -692,7 +737,6 @@ async def test_stream_output_no_blocked_guardrails_doesnt_redact( "system": [], "toolConfig": { "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, }, } @@ -745,7 +789,6 @@ async def test_stream_output_no_guardrail_redact( "system": [], "toolConfig": { "tools": [{"toolSpec": tool_spec}], - "toolChoice": {"auto": {}}, }, } diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index a7c97701c..ff6dc121f 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -179,6 +179,30 @@ def test_format_request_tool_message(): assert tru_result == exp_result +def test_format_request_tool_choice_auto(): + tool_choice = {"auto": {}} + + tru_result = OpenAIModel.format_request_tool_choice(tool_choice) + exp_result = "auto" + assert tru_result == exp_result + + +def test_format_request_tool_choice_any(): + tool_choice = {"any": {}} + + tru_result = OpenAIModel.format_request_tool_choice(tool_choice) + exp_result = "required" + assert tru_result == exp_result + + +def test_format_request_tool_choice_tool(): + tool_choice = {"tool": {"name": "test_tool"}} + + tru_result = OpenAIModel.format_request_tool_choice(tool_choice) + exp_result = {"type": "function", "function": {"name": "test_tool"}} + assert tru_result == exp_result + + def test_format_request_messages(system_prompt): messages = [ { @@ -278,6 +302,123 @@ def test_format_request(model, messages, tool_specs, system_prompt): assert tru_request == exp_request +def test_format_request_with_tool_choice_auto(model, messages, tool_specs, system_prompt): + tool_choice = {"auto": {}} + tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "test", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "tool_choice": "auto", + "max_tokens": 1, + } + assert tru_request == exp_request + + +def test_format_request_with_tool_choice_any(model, messages, tool_specs, system_prompt): + tool_choice = {"any": {}} + tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "test", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "tool_choice": "required", + "max_tokens": 1, + } + assert tru_request == exp_request + + +def test_format_request_with_tool_choice_tool(model, messages, tool_specs, system_prompt): + tool_choice = {"tool": {"name": "test_tool"}} + tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "test", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "tool_choice": {"type": "function", "function": {"name": "test_tool"}}, + "max_tokens": 1, + } + assert tru_request == exp_request + + @pytest.mark.parametrize( ("event", "exp_chunk"), [