Skip to content

Commit 8428da6

Browse files
authored
Merge branch 'main' into openai/early_failure
2 parents 7c417f2 + 79f2900 commit 8428da6

File tree

4 files changed

+170
-56
lines changed

4 files changed

+170
-56
lines changed

temporalio/contrib/openai_agents/_heartbeat_decorator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
2424
if heartbeat_task:
2525
heartbeat_task.cancel()
2626
# Wait for heartbeat cancellation to complete
27-
await heartbeat_task
27+
try:
28+
await heartbeat_task
29+
except asyncio.CancelledError:
30+
pass
2831

2932
return cast(F, wrapper)
3033

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
logger = logging.getLogger(__name__)
1010

11-
from typing import Any, AsyncIterator, Sequence, Union, cast
11+
from typing import Any, AsyncIterator, Union, cast
1212

1313
from agents import (
1414
AgentOutputSchema,
@@ -54,7 +54,7 @@ def __init__(
5454
async def get_response(
5555
self,
5656
system_instructions: Optional[str],
57-
input: Union[str, list[TResponseInputItem], dict[str, str]],
57+
input: Union[str, list[TResponseInputItem]],
5858
model_settings: ModelSettings,
5959
tools: list[Tool],
6060
output_schema: Optional[AgentOutputSchemaBase],
@@ -64,28 +64,6 @@ async def get_response(
6464
previous_response_id: Optional[str],
6565
prompt: Optional[ResponsePromptParam],
6666
) -> ModelResponse:
67-
def get_summary(
68-
input: Union[str, list[TResponseInputItem], dict[str, str]],
69-
) -> str:
70-
### Activity summary shown in the UI
71-
try:
72-
max_size = 100
73-
if isinstance(input, str):
74-
return input[:max_size]
75-
elif isinstance(input, list):
76-
seq_input = cast(Sequence[Any], input)
77-
last_item = seq_input[-1]
78-
if isinstance(last_item, dict):
79-
return last_item.get("content", "")[:max_size]
80-
elif hasattr(last_item, "content"):
81-
return str(getattr(last_item, "content"))[:max_size]
82-
return str(last_item)[:max_size]
83-
elif isinstance(input, dict):
84-
return input.get("content", "")[:max_size]
85-
except Exception as e:
86-
logger.error(f"Error getting summary: {e}")
87-
return ""
88-
8967
def make_tool_info(tool: Tool) -> ToolInput:
9068
if isinstance(tool, (FileSearchTool, WebSearchTool)):
9169
return tool
@@ -150,7 +128,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
150128
return await workflow.execute_activity_method(
151129
ModelActivity.invoke_model_activity,
152130
activity_input,
153-
summary=self.model_params.summary_override or get_summary(input),
131+
summary=self.model_params.summary_override or _extract_summary(input),
154132
task_queue=self.model_params.task_queue,
155133
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
156134
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
@@ -176,3 +154,34 @@ def stream_response(
176154
prompt: ResponsePromptParam | None,
177155
) -> AsyncIterator[TResponseStreamEvent]:
178156
raise NotImplementedError("Temporal model doesn't support streams yet")
157+
158+
159+
def _extract_summary(input: Union[str, list[TResponseInputItem]]) -> str:
160+
### Activity summary shown in the UI
161+
try:
162+
max_size = 100
163+
if isinstance(input, str):
164+
return input[:max_size]
165+
elif isinstance(input, list):
166+
# Find all message inputs, which are reasonably summarizable
167+
messages: list[TResponseInputItem] = [
168+
item for item in input if item.get("type", "message") == "message"
169+
]
170+
if not messages:
171+
return ""
172+
173+
content: Any = messages[-1].get("content", "")
174+
175+
# In the case of multiple contents, take the last one
176+
if isinstance(content, list):
177+
if not content:
178+
return ""
179+
content = content[-1]
180+
181+
# Take the text field from the content if present
182+
if isinstance(content, dict) and content.get("text") is not None:
183+
content = content.get("text")
184+
return str(content)[:max_size]
185+
except Exception as e:
186+
logger.error(f"Error getting summary: {e}")
187+
return ""

temporalio/contrib/openai_agents/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
134134
cancellation_type=cancellation_type,
135135
activity_id=activity_id,
136136
versioning_intent=versioning_intent,
137-
summary=summary,
137+
summary=summary or schema.description,
138138
priority=priority,
139139
)
140140
try:

tests/contrib/openai_agents/test_openai.py

Lines changed: 131 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import asyncio
12
import json
23
import os
34
import uuid
45
from dataclasses import dataclass
56
from datetime import timedelta
6-
from typing import Any, Optional, Union, no_type_check
7+
from typing import Any, AsyncIterator, Optional, Union, no_type_check
78

89
import nexusrpc
910
import pytest
@@ -39,15 +40,20 @@
3940
HandoffOutputItem,
4041
ToolCallItem,
4142
ToolCallOutputItem,
43+
TResponseStreamEvent,
4244
)
4345
from openai import APIStatusError, AsyncOpenAI, BaseModel
4446
from openai.types.responses import (
47+
EasyInputMessageParam,
4548
ResponseFunctionToolCall,
49+
ResponseFunctionToolCallParam,
4650
ResponseFunctionWebSearch,
51+
ResponseInputTextParam,
4752
ResponseOutputMessage,
4853
ResponseOutputText,
4954
)
5055
from openai.types.responses.response_function_web_search import ActionSearch
56+
from openai.types.responses.response_input_item_param import Message
5157
from openai.types.responses.response_prompt_param import ResponsePromptParam
5258
from pydantic import ConfigDict, Field, TypeAdapter
5359

@@ -61,6 +67,7 @@
6167
TestModel,
6268
TestModelProvider,
6369
)
70+
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
6471
from temporalio.contrib.pydantic import pydantic_data_converter
6572
from temporalio.exceptions import ApplicationError, CancelledError
6673
from temporalio.testing import WorkflowEnvironment
@@ -70,25 +77,16 @@
7077
from tests.helpers import new_worker
7178
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
7279

73-
response_index: int = 0
74-
7580

7681
class StaticTestModel(TestModel):
7782
__test__ = False
7883
responses: list[ModelResponse] = []
7984

80-
def response(self):
81-
global response_index
82-
response = self.responses[response_index]
83-
response_index += 1
84-
return response
85-
8685
def __init__(
8786
self,
8887
) -> None:
89-
global response_index
90-
response_index = 0
91-
super().__init__(self.response)
88+
self._responses = iter(self.responses)
89+
super().__init__(lambda: next(self._responses))
9290

9391

9492
class TestHelloModel(StaticTestModel):
@@ -687,7 +685,8 @@ async def test_research_workflow(client: Client, use_local_model: bool):
687685
new_config["plugins"] = [
688686
openai_agents.OpenAIAgentsPlugin(
689687
model_params=ModelActivityParameters(
690-
start_to_close_timeout=timedelta(seconds=30)
688+
start_to_close_timeout=timedelta(seconds=120),
689+
schedule_to_close_timeout=timedelta(seconds=120),
691690
),
692691
model_provider=TestModelProvider(TestResearchModel())
693692
if use_local_model
@@ -1340,9 +1339,6 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
13401339
)
13411340

13421341

1343-
guardrail_response_index: int = 0
1344-
1345-
13461342
class InputGuardrailModel(OpenAIResponsesModel):
13471343
__test__ = False
13481344
responses: list[ModelResponse] = [
@@ -1431,11 +1427,9 @@ def __init__(
14311427
model: str,
14321428
openai_client: AsyncOpenAI,
14331429
) -> None:
1434-
global response_index
1435-
response_index = 0
1436-
global guardrail_response_index
1437-
guardrail_response_index = 0
14381430
super().__init__(model, openai_client)
1431+
self._responses = iter(self.responses)
1432+
self._guardrail_responses = iter(self.guardrail_responses)
14391433

14401434
async def get_response(
14411435
self,
@@ -1453,15 +1447,9 @@ async def get_response(
14531447
system_instructions
14541448
== "Check if the user is asking you to do their math homework."
14551449
):
1456-
global guardrail_response_index
1457-
response = self.guardrail_responses[guardrail_response_index]
1458-
guardrail_response_index += 1
1459-
return response
1450+
return next(self._guardrail_responses)
14601451
else:
1461-
global response_index
1462-
response = self.responses[response_index]
1463-
response_index += 1
1464-
return response
1452+
return next(self._responses)
14651453

14661454

14671455
### 1. An agent-based guardrail that is triggered if the user is asking to do math homework
@@ -1705,7 +1693,7 @@ class WorkflowToolModel(StaticTestModel):
17051693
id="",
17061694
content=[
17071695
ResponseOutputText(
1708-
text="",
1696+
text="Workflow tool was used",
17091697
annotations=[],
17101698
type="output_text",
17111699
)
@@ -1876,3 +1864,117 @@ async def test_chat_completions_model(client: Client):
18761864
execution_timeout=timedelta(seconds=10),
18771865
)
18781866
await workflow_handle.result()
1867+
1868+
1869+
class WaitModel(Model):
1870+
async def get_response(
1871+
self,
1872+
system_instructions: Union[str, None],
1873+
input: Union[str, list[TResponseInputItem]],
1874+
model_settings: ModelSettings,
1875+
tools: list[Tool],
1876+
output_schema: Union[AgentOutputSchemaBase, None],
1877+
handoffs: list[Handoff],
1878+
tracing: ModelTracing,
1879+
*,
1880+
previous_response_id: Union[str, None],
1881+
prompt: Union[ResponsePromptParam, None] = None,
1882+
) -> ModelResponse:
1883+
activity.logger.info("Waiting")
1884+
await asyncio.sleep(1.0)
1885+
activity.logger.info("Returning")
1886+
return ModelResponse(
1887+
output=[
1888+
ResponseOutputMessage(
1889+
id="",
1890+
content=[
1891+
ResponseOutputText(
1892+
text="test", annotations=[], type="output_text"
1893+
)
1894+
],
1895+
role="assistant",
1896+
status="completed",
1897+
type="message",
1898+
)
1899+
],
1900+
usage=Usage(),
1901+
response_id=None,
1902+
)
1903+
1904+
def stream_response(
1905+
self,
1906+
system_instructions: Optional[str],
1907+
input: Union[str, list[TResponseInputItem]],
1908+
model_settings: ModelSettings,
1909+
tools: list[Tool],
1910+
output_schema: Optional[AgentOutputSchemaBase],
1911+
handoffs: list[Handoff],
1912+
tracing: ModelTracing,
1913+
*,
1914+
previous_response_id: Optional[str],
1915+
prompt: Optional[ResponsePromptParam],
1916+
) -> AsyncIterator[TResponseStreamEvent]:
1917+
raise NotImplementedError()
1918+
1919+
1920+
async def test_heartbeat(client: Client, env: WorkflowEnvironment):
1921+
if env.supports_time_skipping:
1922+
pytest.skip("Relies on real timing, skip.")
1923+
1924+
new_config = client.config()
1925+
new_config["plugins"] = [
1926+
openai_agents.OpenAIAgentsPlugin(
1927+
model_params=ModelActivityParameters(
1928+
heartbeat_timeout=timedelta(seconds=0.5),
1929+
),
1930+
model_provider=TestModelProvider(WaitModel()),
1931+
)
1932+
]
1933+
client = Client(**new_config)
1934+
1935+
async with new_worker(
1936+
client,
1937+
HelloWorldAgent,
1938+
) as worker:
1939+
workflow_handle = await client.start_workflow(
1940+
HelloWorldAgent.run,
1941+
"Tell me about recursion in programming.",
1942+
id=f"workflow-tool-{uuid.uuid4()}",
1943+
task_queue=worker.task_queue,
1944+
execution_timeout=timedelta(seconds=5.0),
1945+
)
1946+
await workflow_handle.result()
1947+
1948+
1949+
def test_summary_extraction():
1950+
input: list[TResponseInputItem] = [
1951+
EasyInputMessageParam(
1952+
content="First message",
1953+
role="user",
1954+
)
1955+
]
1956+
1957+
assert _extract_summary(input) == "First message"
1958+
1959+
input.append(
1960+
Message(
1961+
content=[
1962+
ResponseInputTextParam(
1963+
text="Second message",
1964+
type="input_text",
1965+
)
1966+
],
1967+
role="user",
1968+
)
1969+
)
1970+
assert _extract_summary(input) == "Second message"
1971+
1972+
input.append(
1973+
ResponseFunctionToolCallParam(
1974+
arguments="",
1975+
call_id="",
1976+
name="",
1977+
type="function_call",
1978+
)
1979+
)
1980+
assert _extract_summary(input) == "Second message"

0 commit comments

Comments
 (0)