diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 21defd4a8..73b9723d0 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -77,11 +77,7 @@ def set_open_ai_agent_temporal_overrides( Returns: A context manager that yields the configured TemporalTraceProvider. - """ - if model_params is None: - model_params = ModelActivityParameters() - previous_runner = get_default_agent_runner() previous_trace_provider = get_trace_provider() provider = TemporalTraceProvider( diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index ed5e1ffa4..9df7b43af 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -49,6 +49,7 @@ HandoffOutputItem, ToolCallItem, ToolCallOutputItem, + TResponseOutputItem, TResponseStreamEvent, ) from openai import APIStatusError, AsyncOpenAI, BaseModel @@ -106,26 +107,53 @@ def __init__( super().__init__(lambda: next(self._responses)) -class TestHelloModel(StaticTestModel): - responses = [ - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="test", annotations=[], type="output_text" - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], +class ResponseBuilders: + @staticmethod + def model_response(output: TResponseOutputItem) -> ModelResponse: + return ModelResponse( + output=[output], usage=Usage(), response_id=None, ) - ] + + @staticmethod + def response_output_message(text: str) -> ResponseOutputMessage: + return ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text=text, + annotations=[], + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ) + + @staticmethod + def tool_call(arguments: str, name: str) -> ModelResponse: + return ResponseBuilders.model_response( + ResponseFunctionToolCall( + arguments=arguments, + call_id="call", + name=name, + type="function_call", + id="id", + status="completed", + ) + ) + + @staticmethod + def output_message(text: str) -> ModelResponse: + return ResponseBuilders.model_response( + ResponseBuilders.response_output_message(text) + ) + + +class TestHelloModel(StaticTestModel): + responses = [ResponseBuilders.output_message("test")] @workflow.defn @@ -244,133 +272,23 @@ async def get_weather_nexus_operation( class TestWeatherModel(StaticTestModel): responses = [ - ModelResponse( - output=[ - ResponseFunctionToolCall( - arguments='{"city":"Tokyo"}', - call_id="call", - name="get_weather", - type="function_call", - id="id", - status="completed", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseFunctionToolCall( - arguments='{"input":{"city":"Tokyo"}}', - call_id="call", - name="get_weather_object", - type="function_call", - id="id", - status="completed", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseFunctionToolCall( - arguments='{"city":"Tokyo","country":"Japan"}', - call_id="call", - name="get_weather_country", - type="function_call", - id="id", - status="completed", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseFunctionToolCall( - arguments='{"city":"Tokyo"}', - call_id="call", - name="get_weather_context", - type="function_call", - id="id", - status="completed", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseFunctionToolCall( - arguments='{"city":"Tokyo"}', - call_id="call", - name="get_weather_method", - type="function_call", - id="id", - status="completed", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Test weather result", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather"), + ResponseBuilders.tool_call('{"input":{"city":"Tokyo"}}', "get_weather_object"), + ResponseBuilders.tool_call( + '{"city":"Tokyo","country":"Japan"}', "get_weather_country" ), + ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_context"), + ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_method"), + ResponseBuilders.output_message("Test weather result"), ] class TestNexusWeatherModel(StaticTestModel): responses = [ - ModelResponse( - output=[ - ResponseFunctionToolCall( - arguments='{"input":{"city":"Tokyo"}}', - call_id="call", - name="get_weather_nexus_operation", - type="function_call", - id="id", - status="completed", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Test nexus weather result", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.tool_call( + '{"input":{"city":"Tokyo"}}', "get_weather_nexus_operation" ), + ResponseBuilders.output_message("Test nexus weather result"), ] @@ -615,24 +533,8 @@ async def test_nexus_tool_workflow( @no_type_check class TestResearchModel(StaticTestModel): responses = [ - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text='{"searches":[{"query":"best Caribbean surfing spots April","reason":"Identify locations with optimal surfing conditions in the Caribbean during April."},{"query":"top Caribbean islands for hiking April","reason":"Find Caribbean islands with excellent hiking opportunities that are ideal in April."},{"query":"Caribbean water sports destinations April","reason":"Locate Caribbean destinations offering a variety of water sports activities in April."},{"query":"surfing conditions Caribbean April","reason":"Understand the surfing conditions and which islands are suitable for surfing in April."},{"query":"Caribbean adventure travel hiking surfing","reason":"Explore adventure travel options that combine hiking and surfing in the Caribbean."},{"query":"best beaches for surfing Caribbean April","reason":"Identify which Caribbean beaches are renowned for surfing in April."},{"query":"Caribbean islands with national parks hiking","reason":"Find islands with national parks or reserves that offer hiking trails."},{"query":"Caribbean weather April surfing conditions","reason":"Research the weather conditions in April affecting surfing in the Caribbean."},{"query":"Caribbean water sports rentals April","reason":"Look for places where water sports equipment can be rented in the Caribbean during April."},{"query":"Caribbean multi-activity vacation packages","reason":"Look for vacation packages that offer a combination of surfing, hiking, and water sports."}]}', - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.output_message( + '{"searches":[{"query":"best Caribbean surfing spots April","reason":"Identify locations with optimal surfing conditions in the Caribbean during April."},{"query":"top Caribbean islands for hiking April","reason":"Find Caribbean islands with excellent hiking opportunities that are ideal in April."},{"query":"Caribbean water sports destinations April","reason":"Locate Caribbean destinations offering a variety of water sports activities in April."},{"query":"surfing conditions Caribbean April","reason":"Understand the surfing conditions and which islands are suitable for surfing in April."},{"query":"Caribbean adventure travel hiking surfing","reason":"Explore adventure travel options that combine hiking and surfing in the Caribbean."},{"query":"best beaches for surfing Caribbean April","reason":"Identify which Caribbean beaches are renowned for surfing in April."},{"query":"Caribbean islands with national parks hiking","reason":"Find islands with national parks or reserves that offer hiking trails."},{"query":"Caribbean weather April surfing conditions","reason":"Research the weather conditions in April affecting surfing in the Caribbean."},{"query":"Caribbean water sports rentals April","reason":"Look for places where water sports equipment can be rented in the Caribbean during April."},{"query":"Caribbean multi-activity vacation packages","reason":"Look for vacation packages that offer a combination of surfing, hiking, and water sports."}]}' ) ] for i in range(10): @@ -645,43 +547,15 @@ class TestResearchModel(StaticTestModel): type="web_search_call", action=ActionSearch(query="", type="search"), ), - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Granada", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ), + ResponseBuilders.response_output_message("Granada"), ], usage=Usage(), response_id=None, ) ) responses.append( - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text='{"follow_up_questions":[], "markdown_report":"report", "short_summary":"rep"}', - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.output_message( + '{"follow_up_questions":[], "markdown_report":"report", "short_summary":"rep"}' ) ) @@ -835,76 +709,13 @@ async def run(self, msg: str) -> str: class AgentAsToolsModel(StaticTestModel): responses = [ - ModelResponse( - output=[ - ResponseFunctionToolCall( - arguments='{"input":"I am full"}', - call_id="call", - name="translate_to_spanish", - type="function_call", - id="id", - status="completed", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Estoy lleno.", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.tool_call('{"input":"I am full"}', "translate_to_spanish"), + ResponseBuilders.output_message("Estoy lleno."), + ResponseBuilders.output_message( + 'The translation to Spanish is: "Estoy lleno."' ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text='The translation to Spanish is: "Estoy lleno."', - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text='The translation to Spanish is: "Estoy lleno."', - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.output_message( + 'The translation to Spanish is: "Estoy lleno."' ), ] @@ -1087,109 +898,19 @@ class ProcessUserMessageInput(BaseModel): class CustomerServiceModel(StaticTestModel): responses = [ - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Hi there! How can I assist you today?", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.output_message("Hi there! How can I assist you today?"), + ResponseBuilders.tool_call("{}", "transfer_to_seat_booking_agent"), + ResponseBuilders.output_message( + "Could you please provide your confirmation number?" ), - ModelResponse( - output=[ - ResponseFunctionToolCall( - arguments="{}", - call_id="call", - name="transfer_to_seat_booking_agent", - type="function_call", - id="id", - status="completed", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.output_message( + "Thanks! What seat number would you like to change to?" ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Could you please provide your confirmation number?", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.tool_call( + '{"confirmation_number":"11111","new_seat":"window seat"}', "update_seat" ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Thanks! What seat number would you like to change to?", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseFunctionToolCall( - arguments='{"confirmation_number":"11111","new_seat":"window seat"}', - call_id="call", - name="update_seat", - type="function_call", - id="id", - status="completed", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.output_message( + "Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!" ), ] @@ -1359,83 +1080,15 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): class InputGuardrailModel(OpenAIResponsesModel): __test__ = False responses: list[ModelResponse] = [ - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="The capital of California is Sacramento.", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="x=3", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, - ), + ResponseBuilders.output_message("The capital of California is Sacramento."), + ResponseBuilders.output_message("x=3"), ] guardrail_responses = [ - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text='{"is_math_homework":false,"reasoning":"The question asked is about the capital of California, which is a geography-related query, not math."}', - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.output_message( + '{"is_math_homework":false,"reasoning":"The question asked is about the capital of California, which is a geography-related query, not math."}' ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text='{"is_math_homework":true,"reasoning":"The question involves solving an equation for a variable, which is a typical math homework problem."}', - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.output_message( + '{"is_math_homework":true,"reasoning":"The question involves solving an equation for a variable, which is a typical math homework problem."}' ), ] @@ -1583,24 +1236,8 @@ async def test_input_guardrail(client: Client, use_local_model: bool): class OutputGuardrailModel(StaticTestModel): responses = [ - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text='{"reasoning":"The phone number\'s area code (650) is associated with a region. However, the exact location is not definitive, but it\'s commonly linked to the San Francisco Peninsula in California, including cities like San Mateo, Palo Alto, and parts of Silicon Valley. It\'s important to note that area codes don\'t always guarantee a specific location due to mobile number portability.","response":"The area code 650 is typically associated with California, particularly the San Francisco Peninsula, including cities like Palo Alto and San Mateo.","user_name":null}', - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, + ResponseBuilders.output_message( + '{"reasoning":"The phone number\'s area code (650) is associated with a region. However, the exact location is not definitive, but it\'s commonly linked to the San Francisco Peninsula in California, including cities like San Mateo, Palo Alto, and parts of Silicon Valley. It\'s important to note that area codes don\'t always guarantee a specific location due to mobile number portability.","response":"The area code 650 is typically associated with California, particularly the San Francisco Peninsula, including cities like Palo Alto and San Mateo.","user_name":null}' ) ] @@ -1690,39 +1327,8 @@ async def test_output_guardrail(client: Client, use_local_model: bool): class WorkflowToolModel(StaticTestModel): responses = [ - ModelResponse( - output=[ - ResponseFunctionToolCall( - arguments="{}", - call_id="call", - name="run_tool", - type="function_call", - id="id", - status="completed", - ) - ], - usage=Usage(), - response_id=None, - ), - ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Workflow tool was used", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, - ), + ResponseBuilders.tool_call("{}", "run_tool"), + ResponseBuilders.output_message("Workflow tool was used"), ] @@ -1900,23 +1506,7 @@ async def get_response( activity.logger.info("Waiting") await asyncio.sleep(1.0) activity.logger.info("Returning") - return ModelResponse( - output=[ - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="test", annotations=[], type="output_text" - ) - ], - role="assistant", - status="completed", - type="message", - ) - ], - usage=Usage(), - response_id=None, - ) + return ResponseBuilders.output_message("test") def stream_response( self, @@ -2130,19 +1720,7 @@ class FileSearchToolModel(StaticTestModel): Result(text="Other scene"), ], ), - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Patroclus", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ), + ResponseBuilders.response_output_message("Patroclus"), ], usage=Usage(), response_id=None, @@ -2219,19 +1797,7 @@ class ImageGenerationModel(StaticTestModel): id="id", status="completed", ), - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Patroclus", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ), + ResponseBuilders.response_output_message("Patroclus"), ], usage=Usage(), response_id=None, @@ -2307,19 +1873,7 @@ class CodeInterpreterModel(StaticTestModel): id="id", status="completed", ), - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Over 9000", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ), + ResponseBuilders.response_output_message("Over 9000"), ], usage=Usage(), response_id=None, @@ -2412,19 +1966,7 @@ class HostedMCPModel(StaticTestModel): id="id", output="Mcp output", ), - ResponseOutputMessage( - id="", - content=[ - ResponseOutputText( - text="Some language", - annotations=[], - type="output_text", - ) - ], - role="assistant", - status="completed", - type="message", - ), + ResponseBuilders.response_output_message("Some language"), ], usage=Usage(), response_id=None,