diff --git a/README.md b/README.md index 3ff0ec2e4..b6e22910b 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,7 @@ Built-in providers: - [MistralAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/mistral/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) + - [OpenAI Responses API](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) - [SageMaker](https://strandsagents.com/latest/user-guide/concepts/model-providers/sagemaker/) - [Writer](https://strandsagents.com/latest/user-guide/concepts/model-providers/writer/) diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py new file mode 100644 index 000000000..fae95833c --- /dev/null +++ b/src/strands/models/openai_responses.py @@ -0,0 +1,529 @@ +"""OpenAI model provider using the Responses API. + +- Docs: https://platform.openai.com/docs/overview +""" + +import base64 +import json +import logging +import mimetypes +from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast + +import openai +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.streaming import StreamEvent +from ..types.tools import ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class Client(Protocol): + """Protocol defining the OpenAI Responses API interface for the underlying provider client.""" + + @property + # pragma: no cover + def responses(self) -> Any: + """Responses interface.""" + ... + + +class OpenAIResponsesModel(Model): + """OpenAI Responses API model provider implementation.""" + + client: Client + client_args: dict[str, Any] + + class OpenAIResponsesConfig(TypedDict, total=False): + """Configuration options for OpenAI Responses API models. + + Attributes: + model_id: Model ID (e.g., "gpt-4o"). + For a complete list of supported models, see https://platform.openai.com/docs/models. + params: Model parameters (e.g., max_output_tokens, temperature, etc.). + For a complete list of supported parameters, see + https://platform.openai.com/docs/api-reference/responses/create. + """ + + model_id: str + params: Optional[dict[str, Any]] + + def __init__( + self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIResponsesConfig] + ) -> None: + """Initialize provider instance. + + Args: + client_args: Arguments for the OpenAI client. + For a complete list of supported arguments, see https://pypi.org/project/openai/. + **model_config: Configuration options for the OpenAI Responses API model. + """ + validate_config_keys(model_config, self.OpenAIResponsesConfig) + self.config = dict(model_config) + self.client_args = client_args or {} + + logger.debug("config=<%s> | initializing", self.config) + + @override + def update_config(self, **model_config: Unpack[OpenAIResponsesConfig]) -> None: # type: ignore[override] + """Update the OpenAI Responses API model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + validate_config_keys(model_config, self.OpenAIResponsesConfig) + self.config.update(model_config) + + @override + def get_config(self) -> OpenAIResponsesConfig: + """Get the OpenAI Responses API model configuration. + + Returns: + The OpenAI Responses API model configuration. + """ + return cast(OpenAIResponsesModel.OpenAIResponsesConfig, self.config) + + @classmethod + def _format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + """Format an OpenAI compatible content block. + + Args: + content: Message content. + + Returns: + OpenAI compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to an OpenAI-compatible format. + """ + if "document" in content: + # only PDF type supported + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "type": "input_file", + "file_url": f"data:{mime_type};base64,{file_data}", + } + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + + return { + "type": "input_image", + "image_url": f"data:{mime_type};base64,{image_data}", + } + + if "text" in content: + return {"type": "input_text", "text": content["text"]} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @classmethod + def _format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + """Format an OpenAI compatible tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + OpenAI compatible tool call. + """ + return { + "type": "function_call", + "call_id": tool_use["toolUseId"], + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + } + + @classmethod + def _format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + """Format an OpenAI compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + OpenAI compatible tool message. + """ + output_parts = [] + + for content in tool_result["content"]: + if "json" in content: + output_parts.append(json.dumps(content["json"])) + elif "text" in content: + output_parts.append(content["text"]) + + return { + "type": "function_call_output", + "call_id": tool_result["toolUseId"], + "output": "\n".join(output_parts) if output_parts else "", + } + + @classmethod + def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages: list[dict[str, Any]] = [] + + for message in messages: + role = message["role"] + if role == "system": + continue # type: ignore[unreachable] + + contents = message["content"] + + formatted_contents = [ + cls._format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + + formatted_tool_calls = [ + cls._format_request_message_tool_call(content["toolUse"]) + for content in contents + if "toolUse" in content + ] + + formatted_tool_messages = [ + cls._format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + if formatted_contents: + formatted_messages.append( + { + "role": role, # "user" | "assistant" + "content": formatted_contents, + } + ) + + formatted_messages.extend(formatted_tool_calls) + formatted_messages.extend(formatted_tool_messages) + + return [ + message + for message in formatted_messages + if message.get("content") or message.get("type") in ["function_call", "function_call_output"] + ] + + def _format_request( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + ) -> dict[str, Any]: + """Format an OpenAI Responses API compatible response streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An OpenAI Responses API compatible response streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible + format. + """ + input_items = self._format_request_messages(messages) + request = { + "model": self.config["model_id"], + "input": input_items, + "stream": True, + **cast(dict[str, Any], self.config.get("params", {})), + } + + if system_prompt: + request["instructions"] = system_prompt + + # Add tools if provided + if tool_specs: + request["tools"] = [ + { + "type": "function", + "name": tool_spec["name"], + "description": tool_spec.get("description", ""), + "parameters": tool_spec["inputSchema"]["json"], + } + for tool_spec in tool_specs + ] + + return request + + def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an OpenAI response event into a standardized message chunk. + + Args: + event: A response event from the OpenAI compatible model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as chunk_type is controlled in the stream method. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the OpenAI Responses API model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by OpenAI (rate limits). + """ + logger.debug("formatting request for OpenAI Responses API") + request = self._format_request(messages, tool_specs, system_prompt) + logger.debug("formatted request=<%s>", request) + + logger.debug("invoking OpenAI Responses API model") + + async with openai.AsyncOpenAI(**self.client_args) as client: + try: + response = await client.responses.create(**request) + except openai.APIError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + logger.warning("OpenAI Responses API threw context window overflow error") + raise ContextWindowOverflowException(str(e)) from e + elif hasattr(e, "code") and e.code == "rate_limit_exceeded": + logger.warning("OpenAI Responses API threw rate limit error") + raise ModelThrottledException(str(e)) from e + else: + raise + + logger.debug("got response from OpenAI Responses API model") + + yield self._format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[str, dict[str, Any]] = {} + final_usage = None + has_text_content = False + + try: + async for event in response: + if hasattr(event, "type"): + if event.type == "response.output_text.delta": + # Text content streaming + if not has_text_content: + yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) + has_text_content = True + if hasattr(event, "delta") and isinstance(event.delta, str): + has_text_content = True + yield self._format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": event.delta} + ) + + elif event.type == "response.output_item.added": + # Tool call started + if ( + hasattr(event, "item") + and hasattr(event.item, "type") + and event.item.type == "function_call" + ): + call_id = getattr(event.item, "call_id", "unknown") + tool_calls[call_id] = { + "name": getattr(event.item, "name", ""), + "arguments": "", + "call_id": call_id, + "item_id": getattr(event.item, "id", ""), + } + + elif event.type == "response.function_call_arguments.delta": + # Tool arguments streaming - match by item_id + if hasattr(event, "delta") and hasattr(event, "item_id"): + for _call_id, call_info in tool_calls.items(): + if call_info["item_id"] == event.item_id: + call_info["arguments"] += event.delta + break + + elif event.type == "response.completed": + # Response complete + if hasattr(event, "response") and hasattr(event.response, "usage"): + final_usage = event.response.usage + break + except openai.APIError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + logger.warning("OpenAI Responses API threw context window overflow error") + raise ContextWindowOverflowException(str(e)) from e + elif hasattr(e, "code") and e.code == "rate_limit_exceeded": + logger.warning("OpenAI Responses API threw rate limit error") + raise ModelThrottledException(str(e)) from e + else: + raise + + # Close text content if we had any + if has_text_content: + yield self._format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + + # Yield tool calls if any + for call_info in tool_calls.values(): + mock_tool_call = type( + "MockToolCall", + (), + { + "function": type( + "MockFunction", (), {"name": call_info["name"], "arguments": call_info["arguments"]} + )(), + "id": call_info["call_id"], + }, + )() + + yield self._format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call}) + yield self._format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call}) + yield self._format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + finish_reason = "tool_calls" if tool_calls else "stop" + yield self._format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + if final_usage: + usage_data = type( + "Usage", + (), + { + "prompt_tokens": getattr(final_usage, "input_tokens", 0), + "completion_tokens": getattr(final_usage, "output_tokens", 0), + "total_tokens": getattr(final_usage, "total_tokens", 0), + }, + )() + yield self._format_chunk({"chunk_type": "metadata", "data": usage_data}) + + logger.debug("finished streaming response from OpenAI Responses API model") + + @override + async def structured_output( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Get structured output from the OpenAI Responses API model. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by OpenAI (rate limits). + """ + async with openai.AsyncOpenAI(**self.client_args) as client: + try: + response = await client.responses.parse( + model=self.get_config()["model_id"], + input=self._format_request(prompt, system_prompt=system_prompt)["input"], + text_format=output_model, + ) + except openai.BadRequestError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + logger.warning("OpenAI Responses API threw context window overflow error") + raise ContextWindowOverflowException(str(e)) from e + raise + except openai.RateLimitError as e: + logger.warning("OpenAI Responses API threw rate limit error") + raise ModelThrottledException(str(e)) from e + except openai.APIError as e: + # Handle streaming API errors that come as APIError + error_message = str(e).lower() + if "context window" in error_message or "exceeds the context" in error_message: + logger.warning("OpenAI Responses API threw context window overflow error") + raise ContextWindowOverflowException(str(e)) from e + elif "rate limit" in error_message or "tokens per min" in error_message: + logger.warning("OpenAI Responses API threw rate limit error") + raise ModelThrottledException(str(e)) from e + raise + + if response.output_parsed: + yield {"output": response.output_parsed} + else: + raise ValueError("No valid parsed output found in the OpenAI Responses API response.") diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py new file mode 100644 index 000000000..eb78217d8 --- /dev/null +++ b/tests/strands/models/test_openai_responses.py @@ -0,0 +1,538 @@ +import unittest.mock + +import openai +import pydantic +import pytest + +import strands +from strands.models.openai_responses import OpenAIResponsesModel +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException + + +@pytest.fixture +def openai_client(): + with unittest.mock.patch.object(strands.models.openai_responses.openai, "AsyncOpenAI") as mock_client_cls: + mock_client = unittest.mock.AsyncMock() + mock_client_cls.return_value.__aenter__.return_value = mock_client + yield mock_client + + +@pytest.fixture +def model_id(): + return "gpt-4o" + + +@pytest.fixture +def model(openai_client, model_id): + _ = openai_client + return OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + +def test__init__(model_id): + model = OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) + + tru_config = model.get_config() + exp_config = {"model_id": "gpt-4o", "params": {"max_output_tokens": 100}} + + assert tru_config == exp_config + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +@pytest.mark.parametrize( + "content, exp_result", + [ + # Document + ( + { + "document": { + "format": "pdf", + "name": "test doc", + "source": {"bytes": b"document"}, + }, + }, + { + "type": "input_file", + "file_url": "data:application/pdf;base64,ZG9jdW1lbnQ=", + }, + ), + # Image + ( + { + "image": { + "format": "jpg", + "source": {"bytes": b"image"}, + }, + }, + { + "type": "input_image", + "image_url": "data:image/jpeg;base64,aW1hZ2U=", + }, + ), + # Text + ( + {"text": "hello"}, + {"type": "input_text", "text": "hello"}, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = OpenAIResponsesModel._format_request_message_content(content) + assert tru_result == exp_result + + +def test_format_request_message_content_unsupported_type(): + content = {"unsupported": {}} + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + OpenAIResponsesModel._format_request_message_content(content) + + +def test_format_request_message_tool_call(): + tool_use = { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + } + + tru_result = OpenAIResponsesModel._format_request_message_tool_call(tool_use) + exp_result = { + "type": "function_call", + "call_id": "c1", + "name": "calculator", + "arguments": '{"expression": "2+2"}', + } + assert tru_result == exp_result + + +def test_format_request_tool_message(): + tool_result = { + "content": [{"text": "4"}, {"json": ["4"]}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIResponsesModel._format_request_tool_message(tool_result) + exp_result = { + "type": "function_call_output", + "call_id": "c1", + "output": '4\n["4"]', + } + assert tru_result == exp_result + + +def test_format_request_messages(system_prompt): + messages = [ + { + "content": [], + "role": "user", + }, + { + "content": [{"text": "hello"}], + "role": "user", + }, + { + "content": [ + {"text": "call tool"}, + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + "role": "assistant", + }, + { + "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"text": "4"}]}}], + "role": "user", + }, + ] + + tru_result = OpenAIResponsesModel._format_request_messages(messages) + exp_result = [ + { + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + }, + { + "role": "assistant", + "content": [{"type": "input_text", "text": "call tool"}], + }, + { + "type": "function_call", + "call_id": "c1", + "name": "calculator", + "arguments": '{"expression": "2+2"}', + }, + { + "type": "function_call_output", + "call_id": "c1", + "output": "4", + }, + ] + assert tru_result == exp_result + + +def test_format_request(model, messages, tool_specs, system_prompt): + tru_request = model._format_request(messages, tool_specs, system_prompt) + exp_request = { + "model": "gpt-4o", + "input": [ + { + "role": "user", + "content": [{"type": "input_text", "text": "test"}], + } + ], + "stream": True, + "instructions": system_prompt, + "tools": [ + { + "type": "function", + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + ], + "max_output_tokens": 100, + } + assert tru_request == exp_request + + +@pytest.mark.parametrize( + ("event", "exp_chunk"), + [ + # Message start + ( + {"chunk_type": "message_start"}, + {"messageStart": {"role": "assistant"}}, + ), + # Content Start - Tool Use + ( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": unittest.mock.Mock(**{"function.name": "calculator", "id": "c1"}), + }, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, + ), + # Content Start - Text + ( + {"chunk_type": "content_start", "data_type": "text"}, + {"contentBlockStart": {"start": {}}}, + ), + # Content Delta - Tool Use + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + ), + # Content Delta - Tool Use - None + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments=None)), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}}}, + ), + # Content Delta - Reasoning Text + ( + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "I'm thinking"}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "I'm thinking"}}}}, + ), + # Content Delta - Text + ( + {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, + {"contentBlockDelta": {"delta": {"text": "hello"}}}, + ), + # Content Stop + ( + {"chunk_type": "content_stop"}, + {"contentBlockStop": {}}, + ), + # Message Stop - Tool Use + ( + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"messageStop": {"stopReason": "tool_use"}}, + ), + # Message Stop - Max Tokens + ( + {"chunk_type": "message_stop", "data": "length"}, + {"messageStop": {"stopReason": "max_tokens"}}, + ), + # Message Stop - End Turn + ( + {"chunk_type": "message_stop", "data": "stop"}, + {"messageStop": {"stopReason": "end_turn"}}, + ), + # Metadata + ( + { + "chunk_type": "metadata", + "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), + }, + { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + }, + "metrics": { + "latencyMs": 0, + }, + }, + }, + ), + ], +) +def test_format_chunk(event, exp_chunk, model): + tru_chunk = model._format_chunk(event) + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown_type(model): + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model._format_chunk(event) + + +@pytest.mark.asyncio +async def test_stream(openai_client, model_id, model, agenerator, alist): + # Mock response events + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "Hello"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 0}, + } + }, + ] + + assert len(tru_events) == len(exp_events) + expected_request = { + "model": model_id, + "input": [{"role": "user", "content": [{"type": "input_text", "text": "test"}]}], + "stream": True, + "max_output_tokens": 100, + } + openai_client.responses.create.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_stream_with_tool_calls(openai_client, model, agenerator, alist): + # Mock tool call events + mock_tool_event = unittest.mock.Mock( + type="response.output_item.added", + item=unittest.mock.Mock(type="function_call", call_id="call_123", name="calculator", id="item_456"), + ) + mock_args_event = unittest.mock.Mock( + type="response.function_call_arguments.delta", delta='{"expression": "2+2"}', item_id="item_456" + ) + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_tool_event, mock_args_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "calculate 2+2"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Should include tool call events + assert any("toolUse" in str(event) for event in tru_events) + assert {"messageStop": {"stopReason": "tool_use"}} in tru_events + + +@pytest.mark.asyncio +async def test_structured_output(openai_client, model, test_output_model_cls, alist): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_parsed_instance = test_output_model_cls(name="John", age=30) + mock_response = unittest.mock.Mock(output_parsed=mock_parsed_instance) + + openai_client.responses.parse = unittest.mock.AsyncMock(return_value=mock_response) + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + tru_result = events[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result + + +@pytest.mark.asyncio +async def test_stream_context_overflow_exception(openai_client, model, messages): + """Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException.""" + mock_error = openai.BadRequestError( + message="This model's maximum context length is 4096 tokens.", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "maximum context length" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_stream_rate_limit_as_throttle(openai_client, model, messages): + """Test that rate limit errors are converted to ModelThrottledException.""" + mock_error = openai.RateLimitError( + message="Rate limit exceeded", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + mock_error.code = "rate_limit_exceeded" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Rate limit exceeded" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_structured_output_context_overflow_exception(openai_client, model, messages, test_output_model_cls): + """Test that structured output handles context overflow properly.""" + mock_error = openai.BadRequestError( + message="This model's maximum context length is 4096 tokens.", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.responses.parse.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + assert "maximum context length" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_structured_output_rate_limit_as_throttle(openai_client, model, messages, test_output_model_cls): + """Test that structured output handles rate limit errors properly.""" + mock_error = openai.RateLimitError( + message="Rate limit exceeded", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + mock_error.code = "rate_limit_exceeded" + + openai_client.responses.parse.side_effect = mock_error + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + assert "Rate limit exceeded" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +def test_config_validation_warns_on_unknown_keys(openai_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + OpenAIResponsesModel({"api_key": "test"}, model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index c1f442b2a..8ee147ab3 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -16,6 +16,7 @@ from strands.models.mistral import MistralModel from strands.models.ollama import OllamaModel from strands.models.openai import OpenAIModel +from strands.models.openai_responses import OpenAIResponsesModel from strands.models.writer import WriterModel @@ -118,6 +119,16 @@ def __init__(self): }, ), ) +openai_responses = ProviderInfo( + id="openai_responses", + environment_variable="OPENAI_API_KEY", + factory=lambda: OpenAIResponsesModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ), +) writer = ProviderInfo( id="writer", environment_variable="WRITER_API_KEY", @@ -149,5 +160,6 @@ def __init__(self): litellm, mistral, openai, + openai_responses, writer, ] diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 115a0819d..e448739b6 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -7,6 +7,7 @@ import strands from strands import Agent, tool from strands.models.openai import OpenAIModel +from strands.models.openai_responses import OpenAIResponsesModel from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from tests_integ.models import providers @@ -14,10 +15,16 @@ pytestmark = providers.openai.mark -@pytest.fixture -def model(): - return OpenAIModel( - model_id="gpt-4o", +@pytest.fixture( + params=[ + ("openai", OpenAIModel, "gpt-4o"), + ("openai_responses", OpenAIResponsesModel, "gpt-4o"), + ] +) +def model(request): + model_name, model_class, model_id = request.param + return model_class( + model_id=model_id, client_args={ "api_key": os.getenv("OPENAI_API_KEY"), }, @@ -73,7 +80,7 @@ def test_image_path(request): return request.config.rootpath / "tests_integ" / "test_image.png" -def test_agent_invoke(agent): +def test_agent_invoke(agent, model): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -81,7 +88,7 @@ def test_agent_invoke(agent): @pytest.mark.asyncio -async def test_agent_invoke_async(agent): +async def test_agent_invoke_async(agent, model): result = await agent.invoke_async("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -89,7 +96,7 @@ async def test_agent_invoke_async(agent): @pytest.mark.asyncio -async def test_agent_stream_async(agent): +async def test_agent_stream_async(agent, model): stream = agent.stream_async("What is the time and weather in New York?") async for event in stream: _ = event @@ -171,15 +178,22 @@ def tool_with_image_return(): agent("Run the the tool and analyze the image") -def test_context_window_overflow_integration(): +@pytest.mark.parametrize( + "model_class,model_id", + [ + (OpenAIModel, "gpt-4o-mini-2024-07-18"), + (OpenAIResponsesModel, "gpt-4o-mini-2024-07-18"), + ], +) +def test_context_window_overflow_integration(model_class, model_id): """Integration test for context window overflow with OpenAI. This test verifies that when a request exceeds the model's context window, the OpenAI model properly raises a ContextWindowOverflowException. """ # Use gpt-4o-mini which has a smaller context window to make this test more reliable - mini_model = OpenAIModel( - model_id="gpt-4o-mini-2024-07-18", + mini_model = model_class( + model_id=model_id, client_args={ "api_key": os.getenv("OPENAI_API_KEY"), }, @@ -199,7 +213,14 @@ def test_context_window_overflow_integration(): agent(long_text) -def test_rate_limit_throttling_integration_no_retries(model): +@pytest.mark.parametrize( + "model_class,model_id", + [ + (OpenAIModel, "gpt-4o"), + (OpenAIResponsesModel, "gpt-4o"), + ], +) +def test_rate_limit_throttling_integration_no_retries(model_class, model_id): """Integration test for rate limit handling with retries disabled. This test verifies that when a request exceeds OpenAI's rate limits, @@ -208,6 +229,12 @@ def test_rate_limit_throttling_integration_no_retries(model): """ # Patch the event loop constants to disable retries for this test with unittest.mock.patch("strands.event_loop.event_loop.MAX_ATTEMPTS", 1): + model = model_class( + model_id=model_id, + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ) agent = Agent(model=model) # Create a message that's very long to trigger token-per-minute rate limits