1+ import asyncio
12import json
23import os
34import uuid
45from dataclasses import dataclass
56from datetime import timedelta
6- from typing import Any , Optional , Union , no_type_check
7+ from typing import Any , AsyncIterator , Optional , Union , no_type_check
78
89import nexusrpc
910import pytest
3940 HandoffOutputItem ,
4041 ToolCallItem ,
4142 ToolCallOutputItem ,
43+ TResponseStreamEvent ,
4244)
4345from openai import APIStatusError , AsyncOpenAI , BaseModel
4446from openai .types .responses import (
47+ EasyInputMessageParam ,
4548 ResponseFunctionToolCall ,
49+ ResponseFunctionToolCallParam ,
4650 ResponseFunctionWebSearch ,
51+ ResponseInputTextParam ,
4752 ResponseOutputMessage ,
4853 ResponseOutputText ,
4954)
5055from openai .types .responses .response_function_web_search import ActionSearch
56+ from openai .types .responses .response_input_item_param import Message
5157from openai .types .responses .response_prompt_param import ResponsePromptParam
5258from pydantic import ConfigDict , Field , TypeAdapter
5359
6167 TestModel ,
6268 TestModelProvider ,
6369)
70+ from temporalio .contrib .openai_agents ._temporal_model_stub import _extract_summary
6471from temporalio .contrib .pydantic import pydantic_data_converter
6572from temporalio .exceptions import ApplicationError , CancelledError
6673from temporalio .testing import WorkflowEnvironment
7077from tests .helpers import new_worker
7178from tests .helpers .nexus import create_nexus_endpoint , make_nexus_endpoint_name
7279
73- response_index : int = 0
74-
7580
7681class StaticTestModel (TestModel ):
7782 __test__ = False
7883 responses : list [ModelResponse ] = []
7984
80- def response (self ):
81- global response_index
82- response = self .responses [response_index ]
83- response_index += 1
84- return response
85-
8685 def __init__ (
8786 self ,
8887 ) -> None :
89- global response_index
90- response_index = 0
91- super ().__init__ (self .response )
88+ self ._responses = iter (self .responses )
89+ super ().__init__ (lambda : next (self ._responses ))
9290
9391
9492class TestHelloModel (StaticTestModel ):
@@ -687,7 +685,8 @@ async def test_research_workflow(client: Client, use_local_model: bool):
687685 new_config ["plugins" ] = [
688686 openai_agents .OpenAIAgentsPlugin (
689687 model_params = ModelActivityParameters (
690- start_to_close_timeout = timedelta (seconds = 30 )
688+ start_to_close_timeout = timedelta (seconds = 120 ),
689+ schedule_to_close_timeout = timedelta (seconds = 120 ),
691690 ),
692691 model_provider = TestModelProvider (TestResearchModel ())
693692 if use_local_model
@@ -1340,9 +1339,6 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
13401339 )
13411340
13421341
1343- guardrail_response_index : int = 0
1344-
1345-
13461342class InputGuardrailModel (OpenAIResponsesModel ):
13471343 __test__ = False
13481344 responses : list [ModelResponse ] = [
@@ -1431,11 +1427,9 @@ def __init__(
14311427 model : str ,
14321428 openai_client : AsyncOpenAI ,
14331429 ) -> None :
1434- global response_index
1435- response_index = 0
1436- global guardrail_response_index
1437- guardrail_response_index = 0
14381430 super ().__init__ (model , openai_client )
1431+ self ._responses = iter (self .responses )
1432+ self ._guardrail_responses = iter (self .guardrail_responses )
14391433
14401434 async def get_response (
14411435 self ,
@@ -1453,15 +1447,9 @@ async def get_response(
14531447 system_instructions
14541448 == "Check if the user is asking you to do their math homework."
14551449 ):
1456- global guardrail_response_index
1457- response = self .guardrail_responses [guardrail_response_index ]
1458- guardrail_response_index += 1
1459- return response
1450+ return next (self ._guardrail_responses )
14601451 else :
1461- global response_index
1462- response = self .responses [response_index ]
1463- response_index += 1
1464- return response
1452+ return next (self ._responses )
14651453
14661454
14671455### 1. An agent-based guardrail that is triggered if the user is asking to do math homework
@@ -1705,7 +1693,7 @@ class WorkflowToolModel(StaticTestModel):
17051693 id = "" ,
17061694 content = [
17071695 ResponseOutputText (
1708- text = "" ,
1696+ text = "Workflow tool was used " ,
17091697 annotations = [],
17101698 type = "output_text" ,
17111699 )
@@ -1876,3 +1864,117 @@ async def test_chat_completions_model(client: Client):
18761864 execution_timeout = timedelta (seconds = 10 ),
18771865 )
18781866 await workflow_handle .result ()
1867+
1868+
1869+ class WaitModel (Model ):
1870+ async def get_response (
1871+ self ,
1872+ system_instructions : Union [str , None ],
1873+ input : Union [str , list [TResponseInputItem ]],
1874+ model_settings : ModelSettings ,
1875+ tools : list [Tool ],
1876+ output_schema : Union [AgentOutputSchemaBase , None ],
1877+ handoffs : list [Handoff ],
1878+ tracing : ModelTracing ,
1879+ * ,
1880+ previous_response_id : Union [str , None ],
1881+ prompt : Union [ResponsePromptParam , None ] = None ,
1882+ ) -> ModelResponse :
1883+ activity .logger .info ("Waiting" )
1884+ await asyncio .sleep (1.0 )
1885+ activity .logger .info ("Returning" )
1886+ return ModelResponse (
1887+ output = [
1888+ ResponseOutputMessage (
1889+ id = "" ,
1890+ content = [
1891+ ResponseOutputText (
1892+ text = "test" , annotations = [], type = "output_text"
1893+ )
1894+ ],
1895+ role = "assistant" ,
1896+ status = "completed" ,
1897+ type = "message" ,
1898+ )
1899+ ],
1900+ usage = Usage (),
1901+ response_id = None ,
1902+ )
1903+
1904+ def stream_response (
1905+ self ,
1906+ system_instructions : Optional [str ],
1907+ input : Union [str , list [TResponseInputItem ]],
1908+ model_settings : ModelSettings ,
1909+ tools : list [Tool ],
1910+ output_schema : Optional [AgentOutputSchemaBase ],
1911+ handoffs : list [Handoff ],
1912+ tracing : ModelTracing ,
1913+ * ,
1914+ previous_response_id : Optional [str ],
1915+ prompt : Optional [ResponsePromptParam ],
1916+ ) -> AsyncIterator [TResponseStreamEvent ]:
1917+ raise NotImplementedError ()
1918+
1919+
1920+ async def test_heartbeat (client : Client , env : WorkflowEnvironment ):
1921+ if env .supports_time_skipping :
1922+ pytest .skip ("Relies on real timing, skip." )
1923+
1924+ new_config = client .config ()
1925+ new_config ["plugins" ] = [
1926+ openai_agents .OpenAIAgentsPlugin (
1927+ model_params = ModelActivityParameters (
1928+ heartbeat_timeout = timedelta (seconds = 0.5 ),
1929+ ),
1930+ model_provider = TestModelProvider (WaitModel ()),
1931+ )
1932+ ]
1933+ client = Client (** new_config )
1934+
1935+ async with new_worker (
1936+ client ,
1937+ HelloWorldAgent ,
1938+ ) as worker :
1939+ workflow_handle = await client .start_workflow (
1940+ HelloWorldAgent .run ,
1941+ "Tell me about recursion in programming." ,
1942+ id = f"workflow-tool-{ uuid .uuid4 ()} " ,
1943+ task_queue = worker .task_queue ,
1944+ execution_timeout = timedelta (seconds = 5.0 ),
1945+ )
1946+ await workflow_handle .result ()
1947+
1948+
1949+ def test_summary_extraction ():
1950+ input : list [TResponseInputItem ] = [
1951+ EasyInputMessageParam (
1952+ content = "First message" ,
1953+ role = "user" ,
1954+ )
1955+ ]
1956+
1957+ assert _extract_summary (input ) == "First message"
1958+
1959+ input .append (
1960+ Message (
1961+ content = [
1962+ ResponseInputTextParam (
1963+ text = "Second message" ,
1964+ type = "input_text" ,
1965+ )
1966+ ],
1967+ role = "user" ,
1968+ )
1969+ )
1970+ assert _extract_summary (input ) == "Second message"
1971+
1972+ input .append (
1973+ ResponseFunctionToolCallParam (
1974+ arguments = "" ,
1975+ call_id = "" ,
1976+ name = "" ,
1977+ type = "function_call" ,
1978+ )
1979+ )
1980+ assert _extract_summary (input ) == "Second message"
0 commit comments