Skip to content

Commit 5a8503e

Browse files
committed
Introduce summary provider, don't mutate user provided agents
1 parent a0f78be commit 5a8503e

File tree

4 files changed

+88
-25
lines changed

4 files changed

+88
-25
lines changed

temporalio/contrib/openai_agents/_model_parameters.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Parameters for configuring Temporal activity execution for model calls."""
22

3+
from abc import ABC, abstractmethod
34
from dataclasses import dataclass
45
from datetime import timedelta
56
from typing import Any, Callable, Optional, Union
@@ -10,6 +11,22 @@
1011
from temporalio.workflow import ActivityCancellationType, VersioningIntent
1112

1213

14+
class ModelSummaryProvider(ABC):
15+
"""Abstract base class for providing model summaries. Essentially just a callable,
16+
but the arguments are sufficiently complex to benefit from names.
17+
"""
18+
19+
@abstractmethod
20+
def provide(
21+
self,
22+
agent: Optional[Agent[Any]],
23+
instructions: Optional[str],
24+
input: Union[str, list[TResponseInputItem]],
25+
) -> str:
26+
"""Given the provided information, produce a summary for the model invocation activity."""
27+
pass
28+
29+
1330
@dataclass
1431
class ModelActivityParameters:
1532
"""Parameters for configuring Temporal activity execution for model calls.
@@ -46,14 +63,7 @@ class ModelActivityParameters:
4663
summary_override: Optional[
4764
Union[
4865
str,
49-
Callable[
50-
[
51-
Optional[Agent[Any]],
52-
Optional[str],
53-
Union[str, list[TResponseInputItem]],
54-
],
55-
str,
56-
],
66+
ModelSummaryProvider,
5767
]
5868
] = None
5969
"""Summary for the activity execution."""

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,37 +92,50 @@ async def run(
9292
)
9393

9494
# Recursively replace models in all agents
95-
def convert_agent(agent: Agent[Any]) -> None:
96-
# Short circuit if this model was already replaced to prevent looping from circular handoffs
97-
if isinstance(agent.model, _TemporalModelStub):
98-
return
95+
def convert_agent(
96+
agent: Agent[Any], seen: Optional[set[int]] = None
97+
) -> Agent[Any]:
98+
if seen is None:
99+
seen = set()
100+
101+
# Short circuit if this model was already seen to prevent looping from circular handoffs
102+
if id(agent) in seen:
103+
return agent
104+
seen.add(id(agent))
99105

100106
name = _model_name(agent)
101-
agent.model = _TemporalModelStub(
102-
model_name=name,
103-
model_params=self.model_params,
104-
agent=agent,
105-
)
106107

108+
new_handoffs: list[Union[Agent, Handoff]] = []
107109
for handoff in agent.handoffs:
108110
if isinstance(handoff, Agent):
109-
convert_agent(handoff)
111+
new_handoffs.append(convert_agent(handoff))
110112
elif isinstance(handoff, Handoff):
111113
original_invoke = handoff.on_invoke_handoff
112114

113115
async def on_invoke(
114116
context: RunContextWrapper[Any], args: str
115-
) -> Agent[Any]:
117+
) -> Agent:
116118
handoff_agent = await original_invoke(context, args)
117-
convert_agent(handoff_agent)
118-
return handoff_agent
119+
return convert_agent(handoff_agent, seen)
119120

120-
handoff.on_invoke_handoff = on_invoke
121+
new_handoffs.append(
122+
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
123+
)
124+
else:
125+
raise ValueError(f"Unknown handoff type: {type(handoff)}")
121126

122-
convert_agent(starting_agent)
127+
return dataclasses.replace(
128+
agent,
129+
model=_TemporalModelStub(
130+
model_name=name,
131+
model_params=self.model_params,
132+
agent=agent,
133+
),
134+
handoffs=new_handoffs,
135+
)
123136

124137
return await self._runner.run(
125-
starting_agent=starting_agent,
138+
starting_agent=convert_agent(starting_agent),
126139
input=input,
127140
context=context,
128141
max_turns=max_turns,

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
142142
self.model_params.summary_override
143143
if isinstance(self.model_params.summary_override, str)
144144
else (
145-
self.model_params.summary_override(
145+
self.model_params.summary_override.provide(
146146
self.agent, system_instructions, input
147147
)
148148
)

tests/contrib/openai_agents/test_openai.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
TestModel,
8787
TestModelProvider,
8888
)
89+
from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider
8990
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
9091
from temporalio.contrib.pydantic import pydantic_data_converter
9192
from temporalio.exceptions import ApplicationError, CancelledError
@@ -2144,3 +2145,42 @@ async def test_run_config_models(client: Client):
21442145

21452146
# Only the model from the runconfig override is used
21462147
assert provider.model_names == {"gpt-4o"}
2148+
2149+
2150+
async def test_summary_provider(client: Client):
2151+
class SummaryProvider(ModelSummaryProvider):
2152+
def provide(
2153+
self,
2154+
agent: Optional[Agent[Any]],
2155+
instructions: Optional[str],
2156+
input: Union[str, list[TResponseInputItem]],
2157+
) -> str:
2158+
return "My summary"
2159+
2160+
new_config = client.config()
2161+
new_config["plugins"] = [
2162+
openai_agents.OpenAIAgentsPlugin(
2163+
model_params=ModelActivityParameters(
2164+
start_to_close_timeout=timedelta(seconds=120),
2165+
summary_override=SummaryProvider(),
2166+
),
2167+
model_provider=TestModelProvider(TestHelloModel()),
2168+
)
2169+
]
2170+
client = Client(**new_config)
2171+
2172+
async with new_worker(
2173+
client,
2174+
HelloWorldAgent,
2175+
) as worker:
2176+
workflow_handle = await client.start_workflow(
2177+
HelloWorldAgent.run,
2178+
"Prompt",
2179+
id=f"summary-provider-model-{uuid.uuid4()}",
2180+
task_queue=worker.task_queue,
2181+
execution_timeout=timedelta(seconds=10),
2182+
)
2183+
result = await workflow_handle.result()
2184+
async for e in workflow_handle.fetch_history_events():
2185+
if e.HasField("activity_task_scheduled_event_attributes"):
2186+
assert e.user_metadata.summary.data == b'"My summary"'

0 commit comments

Comments
 (0)