Skip to content

Commit bf481ed

Browse files
authored
Merge branch 'main' into test_custom_slot_supplier
2 parents dcd74d6 + fd51efa commit bf481ed

File tree

4 files changed

+115
-71
lines changed

4 files changed

+115
-71
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ jobs:
5555
with:
5656
submodules: recursive
5757
- uses: dtolnay/rust-toolchain@stable
58+
with:
59+
components: "clippy"
5860
- uses: Swatinem/rust-cache@v2
5961
with:
6062
workspaces: temporalio/bridge -> target

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import dataclasses
2-
import json
32
import typing
43
from typing import Any, Optional, Union
54

@@ -17,14 +16,62 @@
1716
TResponseInputItem,
1817
)
1918
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner
20-
from pydantic_core import to_json
2119

2220
from temporalio import workflow
2321
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
2422
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
2523
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
2624

2725

26+
# Recursively replace models in all agents
27+
def _convert_agent(
28+
model_params: ModelActivityParameters,
29+
agent: Agent[Any],
30+
seen: Optional[dict[int, Agent]],
31+
) -> Agent[Any]:
32+
if seen is None:
33+
seen = dict()
34+
35+
# Short circuit if this model was already seen to prevent looping from circular handoffs
36+
if id(agent) in seen:
37+
return seen[id(agent)]
38+
39+
# This agent has already been processed in some other run
40+
if isinstance(agent.model, _TemporalModelStub):
41+
return agent
42+
43+
# Save the new version of the agent so that we can replace loops
44+
new_agent = dataclasses.replace(agent)
45+
seen[id(agent)] = new_agent
46+
47+
name = _model_name(agent)
48+
49+
new_handoffs: list[Union[Agent, Handoff]] = []
50+
for handoff in agent.handoffs:
51+
if isinstance(handoff, Agent):
52+
new_handoffs.append(_convert_agent(model_params, handoff, seen))
53+
elif isinstance(handoff, Handoff):
54+
original_invoke = handoff.on_invoke_handoff
55+
56+
async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent:
57+
handoff_agent = await original_invoke(context, args)
58+
return _convert_agent(model_params, handoff_agent, seen)
59+
60+
new_handoffs.append(
61+
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
62+
)
63+
else:
64+
raise TypeError(f"Unknown handoff type: {type(handoff)}")
65+
66+
new_agent.model = _TemporalModelStub(
67+
model_name=name,
68+
model_params=model_params,
69+
agent=agent,
70+
)
71+
new_agent.handoffs = new_handoffs
72+
return new_agent
73+
74+
2875
class TemporalOpenAIRunner(AgentRunner):
2976
"""Temporal Runner for OpenAI agents.
3077
@@ -101,54 +148,9 @@ async def run(
101148
),
102149
)
103150

104-
# Recursively replace models in all agents
105-
def convert_agent(agent: Agent[Any], seen: Optional[set[int]]) -> Agent[Any]:
106-
if seen is None:
107-
seen = set()
108-
109-
# Short circuit if this model was already seen to prevent looping from circular handoffs
110-
if id(agent) in seen:
111-
return agent
112-
seen.add(id(agent))
113-
114-
# This agent has already been processed in some other run
115-
if isinstance(agent.model, _TemporalModelStub):
116-
return agent
117-
118-
name = _model_name(agent)
119-
120-
new_handoffs: list[Union[Agent, Handoff]] = []
121-
for handoff in agent.handoffs:
122-
if isinstance(handoff, Agent):
123-
new_handoffs.append(convert_agent(handoff, seen))
124-
elif isinstance(handoff, Handoff):
125-
original_invoke = handoff.on_invoke_handoff
126-
127-
async def on_invoke(
128-
context: RunContextWrapper[Any], args: str
129-
) -> Agent:
130-
handoff_agent = await original_invoke(context, args)
131-
return convert_agent(handoff_agent, seen)
132-
133-
new_handoffs.append(
134-
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
135-
)
136-
else:
137-
raise ValueError(f"Unknown handoff type: {type(handoff)}")
138-
139-
return dataclasses.replace(
140-
agent,
141-
model=_TemporalModelStub(
142-
model_name=name,
143-
model_params=self.model_params,
144-
agent=agent,
145-
),
146-
handoffs=new_handoffs,
147-
)
148-
149151
try:
150152
return await self._runner.run(
151-
starting_agent=convert_agent(starting_agent, None),
153+
starting_agent=_convert_agent(self.model_params, starting_agent, None),
152154
input=input,
153155
context=context,
154156
max_turns=max_turns,

tests/contrib/openai_agents/test_openai.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,13 @@
9696
TestModelProvider,
9797
)
9898
from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider
99-
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
99+
from temporalio.contrib.openai_agents._openai_runner import _convert_agent
100+
from temporalio.contrib.openai_agents._temporal_model_stub import (
101+
_extract_summary,
102+
_TemporalModelStub,
103+
)
100104
from temporalio.contrib.pydantic import pydantic_data_converter
101-
from temporalio.exceptions import ApplicationError, CancelledError
105+
from temporalio.exceptions import ApplicationError, CancelledError, TemporalError
102106
from temporalio.testing import WorkflowEnvironment
103107
from temporalio.workflow import ActivityConfig
104108
from tests.contrib.openai_agents.research_agents.research_manager import (
@@ -897,7 +901,10 @@ async def update_seat(
897901
async def on_seat_booking_handoff(
898902
context: RunContextWrapper[AirlineAgentContext],
899903
) -> None:
900-
flight_number = f"FLT-{workflow.random().randint(100, 999)}"
904+
try:
905+
flight_number = f"FLT-{workflow.random().randint(100, 999)}"
906+
except TemporalError:
907+
flight_number = "FLT-100"
901908
context.context.flight_number = flight_number
902909

903910

@@ -975,6 +982,8 @@ class CustomerServiceModel(StaticTestModel):
975982
ResponseBuilders.output_message(
976983
"Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!"
977984
),
985+
ResponseBuilders.tool_call("{}", "transfer_to_triage_agent"),
986+
ResponseBuilders.output_message("You're welcome!"),
978987
]
979988

980989

@@ -988,10 +997,7 @@ def __init__(self, input_items: list[TResponseInputItem] = []):
988997

989998
@workflow.run
990999
async def run(self, input_items: list[TResponseInputItem] = []):
991-
await workflow.wait_condition(
992-
lambda: workflow.info().is_continue_as_new_suggested()
993-
and workflow.all_handlers_finished()
994-
)
1000+
await workflow.wait_condition(lambda: False)
9951001
workflow.continue_as_new(self.input_items)
9961002

9971003
@workflow.query
@@ -1062,7 +1068,13 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
10621068
]
10631069
client = Client(**new_config)
10641070

1065-
questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"]
1071+
questions = [
1072+
"Hello",
1073+
"Book me a flight to PDX",
1074+
"11111",
1075+
"Any window seat",
1076+
"Take me back to the triage agent to say goodbye",
1077+
]
10661078

10671079
async with new_worker(
10681080
client,
@@ -1101,7 +1113,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
11011113
if e.HasField("activity_task_completed_event_attributes"):
11021114
events.append(e)
11031115

1104-
assert len(events) == 6
1116+
assert len(events) == 8
11051117
assert (
11061118
"Hi there! How can I assist you today?"
11071119
in events[0]
@@ -1138,6 +1150,18 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
11381150
.activity_task_completed_event_attributes.result.payloads[0]
11391151
.data.decode()
11401152
)
1153+
assert (
1154+
"transfer_to_triage_agent"
1155+
in events[6]
1156+
.activity_task_completed_event_attributes.result.payloads[0]
1157+
.data.decode()
1158+
)
1159+
assert (
1160+
"You're welcome!"
1161+
in events[7]
1162+
.activity_task_completed_event_attributes.result.payloads[0]
1163+
.data.decode()
1164+
)
11411165

11421166

11431167
class InputGuardrailModel(OpenAIResponsesModel):
@@ -2571,3 +2595,17 @@ def override_get_activities() -> Sequence[Callable]:
25712595
err.value.cause.message
25722596
== "MCP Stateful Server Worker failed to schedule activity."
25732597
)
2598+
2599+
2600+
async def test_model_conversion_loops():
2601+
agent = init_agents()
2602+
converted = _convert_agent(ModelActivityParameters(), agent, None)
2603+
seat_booking_handoff = converted.handoffs[1]
2604+
assert isinstance(seat_booking_handoff, Handoff)
2605+
context: RunContextWrapper[AirlineAgentContext] = RunContextWrapper(
2606+
context=AirlineAgentContext() # type: ignore
2607+
)
2608+
seat_booking_agent = await seat_booking_handoff.on_invoke_handoff(context, "")
2609+
triage_agent = seat_booking_agent.handoffs[0]
2610+
assert isinstance(triage_agent, Agent)
2611+
assert isinstance(triage_agent.model, _TemporalModelStub)

tests/worker/test_workflow.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2946,12 +2946,21 @@ async def waiting_signal() -> bool:
29462946
task_queue=task_queue,
29472947
)
29482948

2949+
# Need to wait until it has gotten halfway through, otherwise the post_patch workflow may never complete
2950+
async def waiting_signal() -> bool:
2951+
return await post_patch_handle.query(
2952+
PatchMemoizedWorkflowPatched.waiting_signal
2953+
)
2954+
2955+
await assert_eq_eventually(True, waiting_signal)
2956+
29492957
# Send signal to both and check results
29502958
await pre_patch_handle.signal(PatchMemoizedWorkflowUnpatched.signal)
29512959
await post_patch_handle.signal(PatchMemoizedWorkflowPatched.signal)
29522960

29532961
# Confirm expected values
29542962
assert ["some-value"] == await pre_patch_handle.result()
2963+
29552964
assert [
29562965
"pre-patch",
29572966
"some-value",
@@ -6091,22 +6100,21 @@ def __init__(
60916100
self.main_workflow_returns_before_signal_completions = (
60926101
main_workflow_returns_before_signal_completions
60936102
)
6094-
self.ping_pong_val = 1
6095-
self.ping_pong_counter = 0
6096-
self.ping_pong_max_count = 4
6103+
self.run_finished = False
60976104

60986105
@workflow.run
60996106
async def run(self) -> str:
61006107
await workflow.wait_condition(
61016108
lambda: self.seen_first_signal and self.seen_second_signal
61026109
)
6110+
self.run_finished = True
61036111
return "workflow-result"
61046112

61056113
@workflow.signal
61066114
async def this_signal_executes_first(self):
61076115
self.seen_first_signal = True
61086116
if self.main_workflow_returns_before_signal_completions:
6109-
await self.ping_pong(lambda: self.ping_pong_val > 0)
6117+
await workflow.wait_condition(lambda: self.run_finished)
61106118
raise ApplicationError(
61116119
"Client should see this error unless doing ping-pong "
61126120
"(in which case main coroutine returns first)"
@@ -6117,18 +6125,12 @@ async def this_signal_executes_second(self):
61176125
await workflow.wait_condition(lambda: self.seen_first_signal)
61186126
self.seen_second_signal = True
61196127
if self.main_workflow_returns_before_signal_completions:
6120-
await self.ping_pong(lambda: self.ping_pong_val < 0)
6128+
await workflow.wait_condition(lambda: self.run_finished)
61216129
raise ApplicationError("Client should never see this error!")
61226130

6123-
async def ping_pong(self, cond: Callable[[], bool]):
6124-
while self.ping_pong_counter < self.ping_pong_max_count:
6125-
await workflow.wait_condition(cond)
6126-
self.ping_pong_val = -self.ping_pong_val
6127-
self.ping_pong_counter += 1
6128-
61296131

61306132
@workflow.defn
6131-
class FirstCompletionCommandIsHonoredPingPongWorkflow(
6133+
class FirstCompletionCommandIsHonoredSignalWaitWorkflow(
61326134
FirstCompletionCommandIsHonoredWorkflow
61336135
):
61346136
def __init__(self) -> None:
@@ -6157,10 +6159,10 @@ async def _do_first_completion_command_is_honored_test(
61576159
client: Client, main_workflow_returns_before_signal_completions: bool
61586160
):
61596161
workflow_cls: Union[
6160-
Type[FirstCompletionCommandIsHonoredPingPongWorkflow],
6162+
Type[FirstCompletionCommandIsHonoredSignalWaitWorkflow],
61616163
Type[FirstCompletionCommandIsHonoredWorkflow],
61626164
] = (
6163-
FirstCompletionCommandIsHonoredPingPongWorkflow
6165+
FirstCompletionCommandIsHonoredSignalWaitWorkflow
61646166
if main_workflow_returns_before_signal_completions
61656167
else FirstCompletionCommandIsHonoredWorkflow
61666168
)

0 commit comments

Comments
 (0)