diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index 4074d1ebd..b3d4395e5 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -21,15 +21,13 @@ from temporalio.contrib.openai_agents._temporal_openai_agents import ( OpenAIAgentsPlugin, OpenAIPayloadConverter, - TestModel, - TestModelProvider, ) from temporalio.contrib.openai_agents._trace_interceptor import ( OpenAIAgentsTracingInterceptor, ) from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError -from . import workflow +from . import testing, workflow __all__ = [ "AgentsWorkflowError", @@ -38,7 +36,6 @@ "OpenAIPayloadConverter", "StatelessMCPServerProvider", "StatefulMCPServerProvider", - "TestModel", - "TestModelProvider", + "testing", "workflow", ] diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 9df481d00..b7bbcc541 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -6,19 +6,7 @@ from datetime import timedelta from typing import AsyncIterator, Callable, Optional, Sequence, Union -from agents import ( - AgentOutputSchemaBase, - Handoff, - Model, - ModelProvider, - ModelResponse, - ModelSettings, - ModelTracing, - Tool, - TResponseInputItem, - set_trace_provider, -) -from agents.items import TResponseStreamEvent +from agents import ModelProvider, set_trace_provider from agents.run import get_default_agent_runner, set_default_agent_runner from agents.tracing import get_trace_provider from agents.tracing.provider import DefaultTraceProvider @@ -103,58 +91,6 @@ def set_open_ai_agent_temporal_overrides( set_trace_provider(previous_trace_provider or DefaultTraceProvider()) -class TestModelProvider(ModelProvider): - """Test model provider which simply returns the given module.""" - - __test__ = False - - def __init__(self, model: Model): - """Initialize a test model provider with a model.""" - self._model = model - - def get_model(self, model_name: Union[str, None]) -> Model: - """Get a model from the model provider.""" - return self._model - - -class TestModel(Model): - """Test model for use mocking model responses.""" - - __test__ = False - - def __init__(self, fn: Callable[[], ModelResponse]) -> None: - """Initialize a test model with a callable.""" - self.fn = fn - - async def get_response( - self, - system_instructions: Union[str, None], - input: Union[str, list[TResponseInputItem]], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: Union[AgentOutputSchemaBase, None], - handoffs: list[Handoff], - tracing: ModelTracing, - **kwargs, - ) -> ModelResponse: - """Get a response from the model.""" - return self.fn() - - def stream_response( - self, - system_instructions: Optional[str], - input: Union[str, list[TResponseInputItem]], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: Optional[AgentOutputSchemaBase], - handoffs: list[Handoff], - tracing: ModelTracing, - **kwargs, - ) -> AsyncIterator[TResponseStreamEvent]: - """Get a streamed response from the model. Unimplemented.""" - raise NotImplementedError() - - class OpenAIPayloadConverter(PydanticPayloadConverter): """PayloadConverter for OpenAI agents.""" diff --git a/temporalio/contrib/openai_agents/testing.py b/temporalio/contrib/openai_agents/testing.py new file mode 100644 index 000000000..5788a9df2 --- /dev/null +++ b/temporalio/contrib/openai_agents/testing.py @@ -0,0 +1,175 @@ +"""Testing utilities for OpenAI agents.""" + +from typing import AsyncIterator, Callable, Optional, Union + +from agents import ( + AgentOutputSchemaBase, + Handoff, + Model, + ModelProvider, + ModelResponse, + ModelSettings, + ModelTracing, + Tool, + TResponseInputItem, + Usage, +) +from agents.items import TResponseOutputItem, TResponseStreamEvent +from openai.types.responses import ( + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseOutputText, +) + + +class ResponseBuilders: + """Builders for creating model responses for testing. + + .. warning:: + This API is experimental and may change in the future. + """ + + @staticmethod + def model_response(output: TResponseOutputItem) -> ModelResponse: + """Create a ModelResponse with the given output. + + .. warning:: + This API is experimental and may change in the future. + """ + return ModelResponse( + output=[output], + usage=Usage(), + response_id=None, + ) + + @staticmethod + def response_output_message(text: str) -> ResponseOutputMessage: + """Create a ResponseOutputMessage with text content. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text=text, + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + + @staticmethod + def tool_call(arguments: str, name: str) -> ModelResponse: + """Create a ModelResponse with a function tool call. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseBuilders.model_response( + ResponseFunctionToolCall( + arguments=arguments, + call_id="call", + name=name, + type="function_call", + id="id", + status="completed", + ) + ) + + @staticmethod + def output_message(text: str) -> ModelResponse: + """Create a ModelResponse with an output message. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseBuilders.model_response( + ResponseBuilders.response_output_message(text) + ) + + +class TestModelProvider(ModelProvider): + """Test model provider which simply returns the given module. + + .. warning:: + This API is experimental and may change in the future. + """ + + __test__ = False + + def __init__(self, model: Model): + """Initialize a test model provider with a model. + + .. warning:: + This API is experimental and may change in the future. + """ + self._model = model + + def get_model(self, model_name: Union[str, None]) -> Model: + """Get a model from the model provider. + + .. warning:: + This API is experimental and may change in the future. + """ + return self._model + + +class TestModel(Model): + """Test model for use mocking model responses. + + .. warning:: + This API is experimental and may change in the future. + """ + + __test__ = False + + def __init__(self, fn: Callable[[], ModelResponse]) -> None: + """Initialize a test model with a callable. + + .. warning:: + This API is experimental and may change in the future. + """ + self.fn = fn + + async def get_response( + self, + system_instructions: Union[str, None], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Union[AgentOutputSchemaBase, None], + handoffs: list[Handoff], + tracing: ModelTracing, + **kwargs, + ) -> ModelResponse: + """Get a response from the mocked model, by calling the callable passed to the constructor.""" + return self.fn() + + def stream_response( + self, + system_instructions: Optional[str], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Optional[AgentOutputSchemaBase], + handoffs: list[Handoff], + tracing: ModelTracing, + **kwargs, + ) -> AsyncIterator[TResponseStreamEvent]: + """Get a streamed response from the model. Unimplemented.""" + raise NotImplementedError() + + @staticmethod + def returning_responses(responses: list[ModelResponse]) -> "TestModel": + """Create a mock model which sequentially returns responses from a list. + + .. warning:: + This API is experimental and may change in the future. + """ + i = iter(responses) + return TestModel(lambda: next(i)) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 812731be2..f65f26903 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -13,7 +13,6 @@ Sequence, Union, cast, - no_type_check, ) import nexusrpc @@ -92,8 +91,6 @@ from temporalio.contrib import openai_agents from temporalio.contrib.openai_agents import ( ModelActivityParameters, - TestModel, - TestModelProvider, ) from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider from temporalio.contrib.openai_agents._openai_runner import _convert_agent @@ -101,6 +98,11 @@ _extract_summary, _TemporalModelStub, ) +from temporalio.contrib.openai_agents.testing import ( + ResponseBuilders, + TestModel, + TestModelProvider, +) from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import ApplicationError, CancelledError, TemporalError from temporalio.testing import WorkflowEnvironment @@ -112,64 +114,8 @@ from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name -class StaticTestModel(TestModel): - __test__ = False - responses: list[ModelResponse] = [] - - def __init__( - self, - ) -> None: - self._responses = iter(self.responses) - super().__init__(lambda: next(self._responses)) - - -class ResponseBuilders: - @staticmethod - def model_response(output: TResponseOutputItem) -> ModelResponse: - return ModelResponse( - output=[output], - usage=Usage(), - response_id=None, - ) - - @staticmethod - def response_output_message(text: str) -> ResponseOutputMessage: - return ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text=text, - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - - @staticmethod - def tool_call(arguments: str, name: str) -> ModelResponse: - return ResponseBuilders.model_response( - ResponseFunctionToolCall( - arguments=arguments, - call_id="call", - name=name, - type="function_call", - id="id", - status="completed", - ) - ) - - @staticmethod - def output_message(text: str) -> ModelResponse: - return ResponseBuilders.model_response( - ResponseBuilders.response_output_message(text) - ) - - -class TestHelloModel(StaticTestModel): - responses = [ResponseBuilders.output_message("test")] +def hello_mock_model(): + return TestModel.returning_responses([ResponseBuilders.output_message("test")]) @workflow.defn @@ -194,7 +140,7 @@ async def test_hello_world_agent(client: Client, use_local_model: bool): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(TestHelloModel()) + model_provider=TestModelProvider(hello_mock_model()) if use_local_model else None, ) @@ -286,26 +232,32 @@ async def get_weather_nexus_operation( ) -class TestWeatherModel(StaticTestModel): - responses = [ - ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather"), - ResponseBuilders.tool_call('{"input":{"city":"Tokyo"}}', "get_weather_object"), - ResponseBuilders.tool_call( - '{"city":"Tokyo","country":"Japan"}', "get_weather_country" - ), - ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_context"), - ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_method"), - ResponseBuilders.output_message("Test weather result"), - ] +def weather_mock_model(): + return TestModel.returning_responses( + [ + ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather"), + ResponseBuilders.tool_call( + '{"input":{"city":"Tokyo"}}', "get_weather_object" + ), + ResponseBuilders.tool_call( + '{"city":"Tokyo","country":"Japan"}', "get_weather_country" + ), + ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_context"), + ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_method"), + ResponseBuilders.output_message("Test weather result"), + ] + ) -class TestNexusWeatherModel(StaticTestModel): - responses = [ - ResponseBuilders.tool_call( - '{"input":{"city":"Tokyo"}}', "get_weather_nexus_operation" - ), - ResponseBuilders.output_message("Test nexus weather result"), - ] +def nexus_weather_mock_model(): + return TestModel.returning_responses( + [ + ResponseBuilders.tool_call( + '{"input":{"city":"Tokyo"}}', "get_weather_nexus_operation" + ), + ResponseBuilders.output_message("Test nexus weather result"), + ] + ) @workflow.defn @@ -376,7 +328,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(TestWeatherModel()) + model_provider=TestModelProvider(weather_mock_model()) if use_local_model else None, ) @@ -488,10 +440,12 @@ async def get_weather_failure(city: str) -> Weather: raise ApplicationError("No weather", non_retryable=True) -class TestWeatherFailureModel(StaticTestModel): - responses = [ - ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_failure"), - ] +def weather_failure_mock_model(): + return TestModel.returning_responses( + [ + ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_failure"), + ] + ) async def test_tool_failure_workflow(client: Client): @@ -501,7 +455,7 @@ async def test_tool_failure_workflow(client: Client): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(TestWeatherFailureModel()), + model_provider=TestModelProvider(weather_failure_mock_model()), ) ] client = Client(**new_config) @@ -543,7 +497,7 @@ async def test_nexus_tool_workflow( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(TestNexusWeatherModel()) + model_provider=TestModelProvider(nexus_weather_mock_model()) if use_local_model else None, ) @@ -597,14 +551,13 @@ async def test_nexus_tool_workflow( ) -@no_type_check -class TestResearchModel(StaticTestModel): +def research_mock_model(): responses = [ ResponseBuilders.output_message( '{"searches":[{"query":"best Caribbean surfing spots April","reason":"Identify locations with optimal surfing conditions in the Caribbean during April."},{"query":"top Caribbean islands for hiking April","reason":"Find Caribbean islands with excellent hiking opportunities that are ideal in April."},{"query":"Caribbean water sports destinations April","reason":"Locate Caribbean destinations offering a variety of water sports activities in April."},{"query":"surfing conditions Caribbean April","reason":"Understand the surfing conditions and which islands are suitable for surfing in April."},{"query":"Caribbean adventure travel hiking surfing","reason":"Explore adventure travel options that combine hiking and surfing in the Caribbean."},{"query":"best beaches for surfing Caribbean April","reason":"Identify which Caribbean beaches are renowned for surfing in April."},{"query":"Caribbean islands with national parks hiking","reason":"Find islands with national parks or reserves that offer hiking trails."},{"query":"Caribbean weather April surfing conditions","reason":"Research the weather conditions in April affecting surfing in the Caribbean."},{"query":"Caribbean water sports rentals April","reason":"Look for places where water sports equipment can be rented in the Caribbean during April."},{"query":"Caribbean multi-activity vacation packages","reason":"Look for vacation packages that offer a combination of surfing, hiking, and water sports."}]}' ) ] - for i in range(10): + for _ in range(10): responses.append( ModelResponse( output=[ @@ -625,6 +578,7 @@ class TestResearchModel(StaticTestModel): '{"follow_up_questions":[], "markdown_report":"report", "short_summary":"rep"}' ) ) + return TestModel.returning_responses(responses) @workflow.defn @@ -646,7 +600,7 @@ async def test_research_workflow(client: Client, use_local_model: bool): start_to_close_timeout=timedelta(seconds=120), schedule_to_close_timeout=timedelta(seconds=120), ), - model_provider=TestModelProvider(TestResearchModel()) + model_provider=TestModelProvider(research_mock_model()) if use_local_model else None, ) @@ -774,17 +728,19 @@ async def run(self, msg: str) -> str: return synthesizer_result.final_output -class AgentAsToolsModel(StaticTestModel): - responses = [ - ResponseBuilders.tool_call('{"input":"I am full"}', "translate_to_spanish"), - ResponseBuilders.output_message("Estoy lleno."), - ResponseBuilders.output_message( - 'The translation to Spanish is: "Estoy lleno."' - ), - ResponseBuilders.output_message( - 'The translation to Spanish is: "Estoy lleno."' - ), - ] +def agent_as_tools_mock_model(): + return TestModel.returning_responses( + [ + ResponseBuilders.tool_call('{"input":"I am full"}', "translate_to_spanish"), + ResponseBuilders.output_message("Estoy lleno."), + ResponseBuilders.output_message( + 'The translation to Spanish is: "Estoy lleno."' + ), + ResponseBuilders.output_message( + 'The translation to Spanish is: "Estoy lleno."' + ), + ] + ) @pytest.mark.parametrize("use_local_model", [True, False]) @@ -797,7 +753,7 @@ async def test_agents_as_tools_workflow(client: Client, use_local_model: bool): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(AgentAsToolsModel()) + model_provider=TestModelProvider(agent_as_tools_mock_model()) if use_local_model else None, ) @@ -966,25 +922,28 @@ class ProcessUserMessageInput(BaseModel): chat_length: int -class CustomerServiceModel(StaticTestModel): - responses = [ - ResponseBuilders.output_message("Hi there! How can I assist you today?"), - ResponseBuilders.tool_call("{}", "transfer_to_seat_booking_agent"), - ResponseBuilders.output_message( - "Could you please provide your confirmation number?" - ), - ResponseBuilders.output_message( - "Thanks! What seat number would you like to change to?" - ), - ResponseBuilders.tool_call( - '{"confirmation_number":"11111","new_seat":"window seat"}', "update_seat" - ), - ResponseBuilders.output_message( - "Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!" - ), - ResponseBuilders.tool_call("{}", "transfer_to_triage_agent"), - ResponseBuilders.output_message("You're welcome!"), - ] +def customer_service_mock_model(): + return TestModel.returning_responses( + [ + ResponseBuilders.output_message("Hi there! How can I assist you today?"), + ResponseBuilders.tool_call("{}", "transfer_to_seat_booking_agent"), + ResponseBuilders.output_message( + "Could you please provide your confirmation number?" + ), + ResponseBuilders.output_message( + "Thanks! What seat number would you like to change to?" + ), + ResponseBuilders.tool_call( + '{"confirmation_number":"11111","new_seat":"window seat"}', + "update_seat", + ), + ResponseBuilders.output_message( + "Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!" + ), + ResponseBuilders.tool_call("{}", "transfer_to_triage_agent"), + ResponseBuilders.output_message("You're welcome!"), + ] + ) @workflow.defn @@ -1061,7 +1020,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(CustomerServiceModel()) + model_provider=TestModelProvider(customer_service_mock_model()) if use_local_model else None, ) @@ -1322,12 +1281,14 @@ async def test_input_guardrail(client: Client, use_local_model: bool): assert result[1] == "Sorry, I can't help you with your math homework." -class OutputGuardrailModel(StaticTestModel): - responses = [ - ResponseBuilders.output_message( - '{"reasoning":"The phone number\'s area code (650) is associated with a region. However, the exact location is not definitive, but it\'s commonly linked to the San Francisco Peninsula in California, including cities like San Mateo, Palo Alto, and parts of Silicon Valley. It\'s important to note that area codes don\'t always guarantee a specific location due to mobile number portability.","response":"The area code 650 is typically associated with California, particularly the San Francisco Peninsula, including cities like Palo Alto and San Mateo.","user_name":null}' - ) - ] +def output_guardrail_mock_model(): + return TestModel.returning_responses( + [ + ResponseBuilders.output_message( + '{"reasoning":"The phone number\'s area code (650) is associated with a region. However, the exact location is not definitive, but it\'s commonly linked to the San Francisco Peninsula in California, including cities like San Mateo, Palo Alto, and parts of Silicon Valley. It\'s important to note that area codes don\'t always guarantee a specific location due to mobile number portability.","response":"The area code 650 is typically associated with California, particularly the San Francisco Peninsula, including cities like Palo Alto and San Mateo.","user_name":null}' + ) + ] + ) # The agent's output type @@ -1390,7 +1351,7 @@ async def test_output_guardrail(client: Client, use_local_model: bool): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(OutputGuardrailModel()) + model_provider=TestModelProvider(output_guardrail_mock_model()) if use_local_model else None, ) @@ -1413,11 +1374,13 @@ async def test_output_guardrail(client: Client, use_local_model: bool): assert not result -class WorkflowToolModel(StaticTestModel): - responses = [ - ResponseBuilders.tool_call("{}", "run_tool"), - ResponseBuilders.output_message("Workflow tool was used"), - ] +def workflow_tool_mock_model(): + return TestModel.returning_responses( + [ + ResponseBuilders.tool_call("{}", "run_tool"), + ResponseBuilders.output_message("Workflow tool was used"), + ] + ) @workflow.defn @@ -1447,7 +1410,7 @@ async def test_workflow_method_tools(client: Client): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(WorkflowToolModel()), + model_provider=TestModelProvider(workflow_tool_mock_model()), ) ] client = Client(**new_config) @@ -1624,7 +1587,7 @@ async def run(self, prompt: str) -> str: class CheckModelNameProvider(ModelProvider): def get_model(self, model_name: Optional[str]) -> Model: assert model_name == "test_model" - return TestHelloModel() + return hello_mock_model() async def test_alternative_model(client: Client): @@ -1735,7 +1698,7 @@ async def test_session(client: Client): new_config = client.config() new_config["plugins"] = [ openai_agents.OpenAIAgentsPlugin( - model_provider=TestModelProvider(TestHelloModel()), + model_provider=TestModelProvider(hello_mock_model()), ) ] client = Client(**new_config) @@ -1799,26 +1762,28 @@ async def test_lite_llm(client: Client, env: WorkflowEnvironment): await workflow_handle.result() -class FileSearchToolModel(StaticTestModel): - responses = [ - ModelResponse( - output=[ - ResponseFileSearchToolCall( - queries=["side character in the Iliad"], - type="file_search_call", - id="id", - status="completed", - results=[ - Result(text="Some scene"), - Result(text="Other scene"), - ], - ), - ResponseBuilders.response_output_message("Patroclus"), - ], - usage=Usage(), - response_id=None, - ), - ] +def file_search_tool_mock_model(): + return TestModel.returning_responses( + [ + ModelResponse( + output=[ + ResponseFileSearchToolCall( + queries=["side character in the Iliad"], + type="file_search_call", + id="id", + status="completed", + results=[ + Result(text="Some scene"), + Result(text="Other scene"), + ], + ), + ResponseBuilders.response_output_message("Patroclus"), + ], + usage=Usage(), + response_id=None, + ), + ] + ) @workflow.defn @@ -1858,7 +1823,7 @@ async def test_file_search_tool(client: Client, use_local_model): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(FileSearchToolModel()) + model_provider=TestModelProvider(file_search_tool_mock_model()) if use_local_model else None, ) @@ -1881,21 +1846,23 @@ async def test_file_search_tool(client: Client, use_local_model): assert result == "Patroclus" -class ImageGenerationModel(StaticTestModel): - responses = [ - ModelResponse( - output=[ - ImageGenerationCall( - type="image_generation_call", - id="id", - status="completed", - ), - ResponseBuilders.response_output_message("Patroclus"), - ], - usage=Usage(), - response_id=None, - ), - ] +def image_generation_mock_model(): + return TestModel.returning_responses( + [ + ModelResponse( + output=[ + ImageGenerationCall( + type="image_generation_call", + id="id", + status="completed", + ), + ResponseBuilders.response_output_message("Patroclus"), + ], + usage=Usage(), + response_id=None, + ), + ] + ) @workflow.defn @@ -1934,7 +1901,7 @@ async def test_image_generation_tool(client: Client, use_local_model): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(ImageGenerationModel()) + model_provider=TestModelProvider(image_generation_mock_model()) if use_local_model else None, ) @@ -1955,23 +1922,25 @@ async def test_image_generation_tool(client: Client, use_local_model): result = await workflow_handle.result() -class CodeInterpreterModel(StaticTestModel): - responses = [ - ModelResponse( - output=[ - ResponseCodeInterpreterToolCall( - container_id="", - code="some code", - type="code_interpreter_call", - id="id", - status="completed", - ), - ResponseBuilders.response_output_message("Over 9000"), - ], - usage=Usage(), - response_id=None, - ), - ] +def code_interpreter_mock_model(): + return TestModel.returning_responses( + [ + ModelResponse( + output=[ + ResponseCodeInterpreterToolCall( + container_id="", + code="some code", + type="code_interpreter_call", + id="id", + status="completed", + ), + ResponseBuilders.response_output_message("Over 9000"), + ], + usage=Usage(), + response_id=None, + ), + ] + ) @workflow.defn @@ -2007,7 +1976,7 @@ async def test_code_interpreter_tool(client: Client): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=60) ), - model_provider=TestModelProvider(CodeInterpreterModel()), + model_provider=TestModelProvider(code_interpreter_mock_model()), ) ] client = Client(**new_config) @@ -2027,37 +1996,39 @@ async def test_code_interpreter_tool(client: Client): assert result == "Over 9000" -class HostedMCPModel(StaticTestModel): - responses = [ - ModelResponse( - output=[ - McpApprovalRequest( - arguments="", - name="", - server_label="gitmcp", - type="mcp_approval_request", - id="id", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - McpCall( - arguments="", - name="", - server_label="", - type="mcp_call", - id="id", - output="Mcp output", - ), - ResponseBuilders.response_output_message("Some language"), - ], - usage=Usage(), - response_id=None, - ), - ] +def hosted_mcp_mock_model(): + return TestModel.returning_responses( + [ + ModelResponse( + output=[ + McpApprovalRequest( + arguments="", + name="", + server_label="gitmcp", + type="mcp_approval_request", + id="id", + ) + ], + usage=Usage(), + response_id=None, + ), + ModelResponse( + output=[ + McpCall( + arguments="", + name="", + server_label="", + type="mcp_call", + id="id", + output="Mcp output", + ), + ResponseBuilders.response_output_message("Some language"), + ], + usage=Usage(), + response_id=None, + ), + ] + ) @workflow.defn @@ -2102,7 +2073,7 @@ async def test_hosted_mcp_tool(client: Client): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=120) ), - model_provider=TestModelProvider(HostedMCPModel()), + model_provider=TestModelProvider(hosted_mcp_mock_model()), ) ] client = Client(**new_config) @@ -2134,13 +2105,15 @@ def get_model(self, model_name: Union[str, None]) -> Model: return self._model -class MultipleModelsModel(StaticTestModel): - responses = [ - ResponseBuilders.tool_call("{}", "transfer_to_underling"), - ResponseBuilders.output_message( - "I'm here to help! Was there a specific task you needed assistance with regarding the storeroom?" - ), - ] +def multiple_models_mock_model(): + return TestModel.returning_responses( + [ + ResponseBuilders.tool_call("{}", "transfer_to_underling"), + ResponseBuilders.output_message( + "I'm here to help! Was there a specific task you needed assistance with regarding the storeroom?" + ), + ] + ) @workflow.defn @@ -2167,7 +2140,7 @@ async def run(self, use_run_config: bool): async def test_multiple_models(client: Client): - provider = AssertDifferentModelProvider(MultipleModelsModel()) + provider = AssertDifferentModelProvider(multiple_models_mock_model()) new_config = client.config() new_config["plugins"] = [ openai_agents.OpenAIAgentsPlugin( @@ -2195,7 +2168,7 @@ async def test_multiple_models(client: Client): async def test_run_config_models(client: Client): - provider = AssertDifferentModelProvider(MultipleModelsModel()) + provider = AssertDifferentModelProvider(multiple_models_mock_model()) new_config = client.config() new_config["plugins"] = [ openai_agents.OpenAIAgentsPlugin( @@ -2241,7 +2214,7 @@ def provide( start_to_close_timeout=timedelta(seconds=120), summary_override=SummaryProvider(), ), - model_provider=TestModelProvider(TestHelloModel()), + model_provider=TestModelProvider(hello_mock_model()), ) ] client = Client(**new_config) @@ -2284,12 +2257,14 @@ async def run(self) -> OutputType: return result.final_output -class OutputTypeModel(StaticTestModel): - responses = [ - ResponseBuilders.output_message( - '{"answer": "My answer"}', - ), - ] +def output_type_mock_model(): + return TestModel.returning_responses( + [ + ResponseBuilders.output_message( + '{"answer": "My answer"}', + ), + ] + ) async def test_output_type(client: Client): @@ -2299,7 +2274,7 @@ async def test_output_type(client: Client): model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=120), ), - model_provider=TestModelProvider(OutputTypeModel()), + model_provider=TestModelProvider(output_type_mock_model()), ) ] client = Client(**new_config) @@ -2362,18 +2337,20 @@ async def run(self, timeout: timedelta, factory_argument: Optional[Any]) -> str: return result.final_output -class TrackingMCPModel(StaticTestModel): - responses = [ - ResponseBuilders.tool_call( - arguments='{"name":"Tom"}', - name="Say-Hello", - ), - ResponseBuilders.tool_call( - arguments='{"name":"Tim"}', - name="Say-Hello", - ), - ResponseBuilders.output_message("Hi Tom and Tim!"), - ] +def tracking_mcp_mock_model(): + return TestModel.returning_responses( + [ + ResponseBuilders.tool_call( + arguments='{"name":"Tom"}', + name="Say-Hello", + ), + ResponseBuilders.tool_call( + arguments='{"name":"Tim"}', + name="Say-Hello", + ), + ResponseBuilders.output_message("Hi Tom and Tim!"), + ] + ) def get_tracking_server(name: str): @@ -2475,7 +2452,7 @@ async def test_mcp_server( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=120) ), - model_provider=TestModelProvider(TrackingMCPModel()) + model_provider=TestModelProvider(tracking_mcp_mock_model()) if use_local_model else None, mcp_server_providers=[server], @@ -2585,7 +2562,7 @@ def factory(args: Optional[Any]) -> MCPServer: model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=120) ), - model_provider=TestModelProvider(TrackingMCPModel()), + model_provider=TestModelProvider(tracking_mcp_mock_model()), mcp_server_providers=[server], ) ] @@ -2651,7 +2628,7 @@ def override_get_activities() -> Sequence[Callable]: model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=120) ), - model_provider=TestModelProvider(TrackingMCPModel()), + model_provider=TestModelProvider(tracking_mcp_mock_model()), mcp_server_providers=[server], ) ] diff --git a/tests/contrib/openai_agents/test_openai_tracing.py b/tests/contrib/openai_agents/test_openai_tracing.py index c8ad366e6..44ffc6a16 100644 --- a/tests/contrib/openai_agents/test_openai_tracing.py +++ b/tests/contrib/openai_agents/test_openai_tracing.py @@ -7,10 +7,13 @@ from temporalio.client import Client from temporalio.contrib import openai_agents -from temporalio.contrib.openai_agents import ( +from temporalio.contrib.openai_agents.testing import ( TestModelProvider, ) -from tests.contrib.openai_agents.test_openai import ResearchWorkflow, TestResearchModel +from tests.contrib.openai_agents.test_openai import ( + ResearchWorkflow, + research_mock_model, +) from tests.helpers import new_worker @@ -42,7 +45,7 @@ async def test_tracing(client: Client): new_config = client.config() new_config["plugins"] = [ openai_agents.OpenAIAgentsPlugin( - model_provider=TestModelProvider(TestResearchModel()) + model_provider=TestModelProvider(research_mock_model()) ) ] client = Client(**new_config)