Skip to content

Commit fc1f1cf

Browse files
committed
Merge remote-tracking branch 'origin/main' into openai/summary_fixes
2 parents 241a0a9 + d863f5c commit fc1f1cf

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

temporalio/contrib/openai_agents/_heartbeat_decorator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
2424
if heartbeat_task:
2525
heartbeat_task.cancel()
2626
# Wait for heartbeat cancellation to complete
27-
await heartbeat_task
27+
try:
28+
await heartbeat_task
29+
except asyncio.CancelledError:
30+
pass
2831

2932
return cast(F, wrapper)
3033

tests/contrib/openai_agents/test_openai.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import asyncio
12
import json
23
import os
34
import uuid
45
from dataclasses import dataclass
56
from datetime import timedelta
6-
from typing import Any, Optional, Union, no_type_check
7+
from typing import Any, AsyncIterator, Optional, Union, no_type_check
78

89
import nexusrpc
910
import pytest
@@ -39,6 +40,7 @@
3940
HandoffOutputItem,
4041
ToolCallItem,
4142
ToolCallOutputItem,
43+
TResponseStreamEvent,
4244
)
4345
from openai import APIStatusError, AsyncOpenAI, BaseModel
4446
from openai.types.responses import (
@@ -1884,6 +1886,86 @@ async def test_chat_completions_model(client: Client):
18841886
await workflow_handle.result()
18851887

18861888

1889+
class WaitModel(Model):
1890+
async def get_response(
1891+
self,
1892+
system_instructions: Union[str, None],
1893+
input: Union[str, list[TResponseInputItem]],
1894+
model_settings: ModelSettings,
1895+
tools: list[Tool],
1896+
output_schema: Union[AgentOutputSchemaBase, None],
1897+
handoffs: list[Handoff],
1898+
tracing: ModelTracing,
1899+
*,
1900+
previous_response_id: Union[str, None],
1901+
prompt: Union[ResponsePromptParam, None] = None,
1902+
) -> ModelResponse:
1903+
activity.logger.info("Waiting")
1904+
await asyncio.sleep(1.0)
1905+
activity.logger.info("Returning")
1906+
return ModelResponse(
1907+
output=[
1908+
ResponseOutputMessage(
1909+
id="",
1910+
content=[
1911+
ResponseOutputText(
1912+
text="test", annotations=[], type="output_text"
1913+
)
1914+
],
1915+
role="assistant",
1916+
status="completed",
1917+
type="message",
1918+
)
1919+
],
1920+
usage=Usage(),
1921+
response_id=None,
1922+
)
1923+
1924+
def stream_response(
1925+
self,
1926+
system_instructions: Optional[str],
1927+
input: Union[str, list[TResponseInputItem]],
1928+
model_settings: ModelSettings,
1929+
tools: list[Tool],
1930+
output_schema: Optional[AgentOutputSchemaBase],
1931+
handoffs: list[Handoff],
1932+
tracing: ModelTracing,
1933+
*,
1934+
previous_response_id: Optional[str],
1935+
prompt: Optional[ResponsePromptParam],
1936+
) -> AsyncIterator[TResponseStreamEvent]:
1937+
raise NotImplementedError()
1938+
1939+
1940+
async def test_heartbeat(client: Client, env: WorkflowEnvironment):
1941+
if env.supports_time_skipping:
1942+
pytest.skip("Relies on real timing, skip.")
1943+
1944+
new_config = client.config()
1945+
new_config["plugins"] = [
1946+
openai_agents.OpenAIAgentsPlugin(
1947+
model_params=ModelActivityParameters(
1948+
heartbeat_timeout=timedelta(seconds=0.5),
1949+
),
1950+
model_provider=TestModelProvider(WaitModel()),
1951+
)
1952+
]
1953+
client = Client(**new_config)
1954+
1955+
async with new_worker(
1956+
client,
1957+
HelloWorldAgent,
1958+
) as worker:
1959+
workflow_handle = await client.start_workflow(
1960+
HelloWorldAgent.run,
1961+
"Tell me about recursion in programming.",
1962+
id=f"workflow-tool-{uuid.uuid4()}",
1963+
task_queue=worker.task_queue,
1964+
execution_timeout=timedelta(seconds=5.0),
1965+
)
1966+
await workflow_handle.result()
1967+
1968+
18871969
def test_summary_extraction():
18881970
input: list[TResponseInputItem] = [
18891971
EasyInputMessageParam(

0 commit comments

Comments
 (0)