Skip to content

Commit 2b5de91

Browse files
authored
Move model stub creation from RunConfig to Agent (#1029)
* Move model override from the runconfig to the agents * Add dedicated test that models are configured correctly * Linting * Cleanup and fix when model is provided via run_config * Fix test isolation * Introduce summary provider, don't mutate user provided agents * Fixing recursion bugs * Revert timeout debugging change
1 parent a09bb85 commit 2b5de91

File tree

4 files changed

+259
-19
lines changed

4 files changed

+259
-19
lines changed

temporalio/contrib/openai_agents/_model_parameters.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,32 @@
11
"""Parameters for configuring Temporal activity execution for model calls."""
22

3+
from abc import ABC, abstractmethod
34
from dataclasses import dataclass
45
from datetime import timedelta
5-
from typing import Optional
6+
from typing import Any, Callable, Optional, Union
7+
8+
from agents import Agent, TResponseInputItem
69

710
from temporalio.common import Priority, RetryPolicy
811
from temporalio.workflow import ActivityCancellationType, VersioningIntent
912

1013

14+
class ModelSummaryProvider(ABC):
15+
"""Abstract base class for providing model summaries. Essentially just a callable,
16+
but the arguments are sufficiently complex to benefit from names.
17+
"""
18+
19+
@abstractmethod
20+
def provide(
21+
self,
22+
agent: Optional[Agent[Any]],
23+
instructions: Optional[str],
24+
input: Union[str, list[TResponseInputItem]],
25+
) -> str:
26+
"""Given the provided information, produce a summary for the model invocation activity."""
27+
pass
28+
29+
1130
@dataclass
1231
class ModelActivityParameters:
1332
"""Parameters for configuring Temporal activity execution for model calls.
@@ -41,7 +60,12 @@ class ModelActivityParameters:
4160
versioning_intent: Optional[VersioningIntent] = None
4261
"""Versioning intent for the activity."""
4362

44-
summary_override: Optional[str] = None
63+
summary_override: Optional[
64+
Union[
65+
str,
66+
ModelSummaryProvider,
67+
]
68+
] = None
4569
"""Summary for the activity execution."""
4670

4771
priority: Priority = Priority.default

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import dataclasses
12
import json
23
import typing
3-
from dataclasses import replace
4-
from typing import Any, Union
4+
from typing import Any, Optional, Union
55

66
from agents import (
77
Agent,
8+
Handoff,
89
RunConfig,
10+
RunContextWrapper,
911
RunResult,
1012
RunResultStreaming,
1113
SQLiteSession,
@@ -77,26 +79,70 @@ async def run(
7779
if run_config is None:
7880
run_config = RunConfig()
7981

80-
model_name = run_config.model or starting_agent.model
81-
if model_name is not None and not isinstance(model_name, str):
82-
raise ValueError(
83-
"Temporal workflows require a model name to be a string in the run config and/or agent."
82+
if run_config.model:
83+
if not isinstance(run_config.model, str):
84+
raise ValueError(
85+
"Temporal workflows require a model name to be a string in the run config."
86+
)
87+
run_config = dataclasses.replace(
88+
run_config,
89+
model=_TemporalModelStub(
90+
run_config.model, model_params=self.model_params, agent=None
91+
),
92+
)
93+
94+
# Recursively replace models in all agents
95+
def convert_agent(agent: Agent[Any], seen: Optional[set[int]]) -> Agent[Any]:
96+
if seen is None:
97+
seen = set()
98+
99+
# Short circuit if this model was already seen to prevent looping from circular handoffs
100+
if id(agent) in seen:
101+
return agent
102+
seen.add(id(agent))
103+
104+
# This agent has already been processed in some other run
105+
if isinstance(agent.model, _TemporalModelStub):
106+
return agent
107+
108+
name = _model_name(agent)
109+
110+
new_handoffs: list[Union[Agent, Handoff]] = []
111+
for handoff in agent.handoffs:
112+
if isinstance(handoff, Agent):
113+
new_handoffs.append(convert_agent(handoff, seen))
114+
elif isinstance(handoff, Handoff):
115+
original_invoke = handoff.on_invoke_handoff
116+
117+
async def on_invoke(
118+
context: RunContextWrapper[Any], args: str
119+
) -> Agent:
120+
handoff_agent = await original_invoke(context, args)
121+
return convert_agent(handoff_agent, seen)
122+
123+
new_handoffs.append(
124+
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
125+
)
126+
else:
127+
raise ValueError(f"Unknown handoff type: {type(handoff)}")
128+
129+
return dataclasses.replace(
130+
agent,
131+
model=_TemporalModelStub(
132+
model_name=name,
133+
model_params=self.model_params,
134+
agent=agent,
135+
),
136+
handoffs=new_handoffs,
84137
)
85-
updated_run_config = replace(
86-
run_config,
87-
model=_TemporalModelStub(
88-
model_name=model_name,
89-
model_params=self.model_params,
90-
),
91-
)
92138

93139
return await self._runner.run(
94-
starting_agent=starting_agent,
140+
starting_agent=convert_agent(starting_agent, None),
95141
input=input,
96142
context=context,
97143
max_turns=max_turns,
98144
hooks=hooks,
99-
run_config=updated_run_config,
145+
run_config=run_config,
100146
previous_response_id=previous_response_id,
101147
session=session,
102148
)
@@ -130,3 +176,12 @@ def run_streamed(
130176
**kwargs,
131177
)
132178
raise RuntimeError("Temporal workflows do not support streaming.")
179+
180+
181+
def _model_name(agent: Agent[Any]) -> Optional[str]:
182+
name = agent.model
183+
if name is not None and not isinstance(name, str):
184+
raise ValueError(
185+
"Temporal workflows require a model name to be a string in the agent."
186+
)
187+
return name

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, AsyncIterator, Union, cast
1212

1313
from agents import (
14+
Agent,
1415
AgentOutputSchema,
1516
AgentOutputSchemaBase,
1617
CodeInterpreterTool,
@@ -50,9 +51,11 @@ def __init__(
5051
model_name: Optional[str],
5152
*,
5253
model_params: ModelActivityParameters,
54+
agent: Optional[Agent[Any]],
5355
) -> None:
5456
self.model_name = model_name
5557
self.model_params = model_params
58+
self.agent = agent
5659

5760
async def get_response(
5861
self,
@@ -124,7 +127,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
124127
activity_input = ActivityModelInput(
125128
model_name=self.model_name,
126129
system_instructions=system_instructions,
127-
input=cast(Union[str, list[TResponseInputItem]], input),
130+
input=input,
128131
model_settings=model_settings,
129132
tools=tool_infos,
130133
output_schema=output_schema_input,
@@ -134,10 +137,25 @@ def make_tool_info(tool: Tool) -> ToolInput:
134137
prompt=prompt,
135138
)
136139

140+
if self.model_params.summary_override:
141+
summary = (
142+
self.model_params.summary_override
143+
if isinstance(self.model_params.summary_override, str)
144+
else (
145+
self.model_params.summary_override.provide(
146+
self.agent, system_instructions, input
147+
)
148+
)
149+
)
150+
elif self.agent:
151+
summary = self.agent.name
152+
else:
153+
summary = None
154+
137155
return await workflow.execute_activity_method(
138156
ModelActivity.invoke_model_activity,
139157
activity_input,
140-
summary=self.model_params.summary_override or _extract_summary(input),
158+
summary=summary,
141159
task_queue=self.model_params.task_queue,
142160
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
143161
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,

tests/contrib/openai_agents/test_openai.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
OpenAIChatCompletionsModel,
3333
OpenAIResponsesModel,
3434
OutputGuardrailTripwireTriggered,
35+
RunConfig,
3536
RunContextWrapper,
3637
Runner,
3738
SQLiteSession,
@@ -85,6 +86,7 @@
8586
TestModel,
8687
TestModelProvider,
8788
)
89+
from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider
8890
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
8991
from temporalio.contrib.pydantic import pydantic_data_converter
9092
from temporalio.exceptions import ApplicationError, CancelledError
@@ -2041,3 +2043,144 @@ async def test_hosted_mcp_tool(client: Client, use_local_model):
20412043
result = await workflow_handle.result()
20422044
if use_local_model:
20432045
assert result == "Some language"
2046+
2047+
2048+
class AssertDifferentModelProvider(ModelProvider):
2049+
model_names: set[Optional[str]]
2050+
2051+
def __init__(self, model: Model):
2052+
self._model = model
2053+
self.model_names = set()
2054+
2055+
def get_model(self, model_name: Union[str, None]) -> Model:
2056+
self.model_names.add(model_name)
2057+
return self._model
2058+
2059+
2060+
class MultipleModelsModel(StaticTestModel):
2061+
responses = [
2062+
ResponseBuilders.tool_call("{}", "transfer_to_underling"),
2063+
ResponseBuilders.output_message(
2064+
"I'm here to help! Was there a specific task you needed assistance with regarding the storeroom?"
2065+
),
2066+
]
2067+
2068+
2069+
@workflow.defn
2070+
class MultipleModelWorkflow:
2071+
@workflow.run
2072+
async def run(self, use_run_config: bool):
2073+
underling = Agent[None](
2074+
name="Underling",
2075+
instructions="You do all the work you are told.",
2076+
)
2077+
2078+
starting_agent = Agent[None](
2079+
name="Lazy Assistant",
2080+
model="gpt-4o-mini",
2081+
instructions="You delegate all your work to another agent.",
2082+
handoffs=[underling],
2083+
)
2084+
result = await Runner.run(
2085+
starting_agent=starting_agent,
2086+
input="Have you cleaned the store room yet?",
2087+
run_config=RunConfig(model="gpt-4o") if use_run_config else None,
2088+
)
2089+
return result.final_output
2090+
2091+
2092+
async def test_multiple_models(client: Client):
2093+
provider = AssertDifferentModelProvider(MultipleModelsModel())
2094+
new_config = client.config()
2095+
new_config["plugins"] = [
2096+
openai_agents.OpenAIAgentsPlugin(
2097+
model_params=ModelActivityParameters(
2098+
start_to_close_timeout=timedelta(seconds=120)
2099+
),
2100+
model_provider=provider,
2101+
)
2102+
]
2103+
client = Client(**new_config)
2104+
2105+
async with new_worker(
2106+
client,
2107+
MultipleModelWorkflow,
2108+
) as worker:
2109+
workflow_handle = await client.start_workflow(
2110+
MultipleModelWorkflow.run,
2111+
False,
2112+
id=f"multiple-model-{uuid.uuid4()}",
2113+
task_queue=worker.task_queue,
2114+
execution_timeout=timedelta(seconds=10),
2115+
)
2116+
result = await workflow_handle.result()
2117+
assert provider.model_names == {None, "gpt-4o-mini"}
2118+
2119+
2120+
async def test_run_config_models(client: Client):
2121+
provider = AssertDifferentModelProvider(MultipleModelsModel())
2122+
new_config = client.config()
2123+
new_config["plugins"] = [
2124+
openai_agents.OpenAIAgentsPlugin(
2125+
model_params=ModelActivityParameters(
2126+
start_to_close_timeout=timedelta(seconds=120)
2127+
),
2128+
model_provider=provider,
2129+
)
2130+
]
2131+
client = Client(**new_config)
2132+
2133+
async with new_worker(
2134+
client,
2135+
MultipleModelWorkflow,
2136+
) as worker:
2137+
workflow_handle = await client.start_workflow(
2138+
MultipleModelWorkflow.run,
2139+
True,
2140+
id=f"run-config-model-{uuid.uuid4()}",
2141+
task_queue=worker.task_queue,
2142+
execution_timeout=timedelta(seconds=10),
2143+
)
2144+
result = await workflow_handle.result()
2145+
2146+
# Only the model from the runconfig override is used
2147+
assert provider.model_names == {"gpt-4o"}
2148+
2149+
2150+
async def test_summary_provider(client: Client):
2151+
class SummaryProvider(ModelSummaryProvider):
2152+
def provide(
2153+
self,
2154+
agent: Optional[Agent[Any]],
2155+
instructions: Optional[str],
2156+
input: Union[str, list[TResponseInputItem]],
2157+
) -> str:
2158+
return "My summary"
2159+
2160+
new_config = client.config()
2161+
new_config["plugins"] = [
2162+
openai_agents.OpenAIAgentsPlugin(
2163+
model_params=ModelActivityParameters(
2164+
start_to_close_timeout=timedelta(seconds=120),
2165+
summary_override=SummaryProvider(),
2166+
),
2167+
model_provider=TestModelProvider(TestHelloModel()),
2168+
)
2169+
]
2170+
client = Client(**new_config)
2171+
2172+
async with new_worker(
2173+
client,
2174+
HelloWorldAgent,
2175+
) as worker:
2176+
workflow_handle = await client.start_workflow(
2177+
HelloWorldAgent.run,
2178+
"Prompt",
2179+
id=f"summary-provider-model-{uuid.uuid4()}",
2180+
task_queue=worker.task_queue,
2181+
execution_timeout=timedelta(seconds=10),
2182+
)
2183+
result = await workflow_handle.result()
2184+
async for e in workflow_handle.fetch_history_events():
2185+
if e.HasField("activity_task_scheduled_event_attributes"):
2186+
assert e.user_metadata.summary.data == b'"My summary"'

0 commit comments

Comments
 (0)