Skip to content

Commit 6d6aba4

Browse files
committed
Move model override from the runconfig to the agents
1 parent e1016bc commit 6d6aba4

File tree

4 files changed

+67
-19
lines changed

4 files changed

+67
-19
lines changed

temporalio/contrib/openai_agents/_model_parameters.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from dataclasses import dataclass
44
from datetime import timedelta
5-
from typing import Optional
5+
from typing import Optional, Union, Callable, Any
6+
7+
from agents import Agent, TResponseInputItem
68

79
from temporalio.common import Priority, RetryPolicy
810
from temporalio.workflow import ActivityCancellationType, VersioningIntent
@@ -41,7 +43,7 @@ class ModelActivityParameters:
4143
versioning_intent: Optional[VersioningIntent] = None
4244
"""Versioning intent for the activity."""
4345

44-
summary_override: Optional[str] = None
46+
summary_override: Optional[Union[str, Callable[[Agent[Any], Optional[str], Union[str, list[TResponseInputItem]]], str]]] = None
4547
"""Summary for the activity execution."""
4648

4749
priority: Priority = Priority.default

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
SQLiteSession,
1212
TContext,
1313
Tool,
14-
TResponseInputItem,
14+
TResponseInputItem, Handoff, RunContextWrapper,
1515
)
1616
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner
1717
from pydantic_core import to_json
@@ -77,26 +77,50 @@ async def run(
7777
if run_config is None:
7878
run_config = RunConfig()
7979

80-
model_name = run_config.model or starting_agent.model
81-
if model_name is not None and not isinstance(model_name, str):
82-
raise ValueError(
83-
"Temporal workflows require a model name to be a string in the run config and/or agent."
84-
)
85-
updated_run_config = replace(
86-
run_config,
87-
model=_TemporalModelStub(
88-
model_name=model_name,
80+
def model_name(agent: Agent[Any]) -> str:
81+
name = run_config.model or agent.model
82+
if name is not None and not isinstance(name, str):
83+
print("Name: ", name, " Agent: ", agent)
84+
raise ValueError(
85+
"Temporal workflows require a model name to be a string in the run config and/or agent."
86+
)
87+
return name
88+
89+
def convert_agent(agent: Agent[Any]) -> None:
90+
print("Model: ", agent.model)
91+
92+
# Short circuit if this model was already replaced to prevent looping from circular handoffs
93+
if isinstance(agent.model, _TemporalModelStub):
94+
return
95+
96+
name = model_name(agent)
97+
agent.model = _TemporalModelStub(
98+
model_name=name,
8999
model_params=self.model_params,
90-
),
91-
)
100+
agent=agent,
101+
)
102+
print("Model after replace: ", agent.model)
103+
104+
for handoff in agent.handoffs:
105+
if isinstance(handoff, Agent):
106+
convert_agent(handoff)
107+
elif isinstance(handoff, Handoff):
108+
original_invoke = handoff.on_invoke_handoff
109+
async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent[Any]:
110+
handoff_agent = await original_invoke(context, args)
111+
convert_agent(handoff_agent)
112+
return handoff_agent
113+
handoff.on_invoke_handoff = on_invoke
114+
115+
convert_agent(starting_agent)
92116

93117
return await self._runner.run(
94118
starting_agent=starting_agent,
95119
input=input,
96120
context=context,
97121
max_turns=max_turns,
98122
hooks=hooks,
99-
run_config=updated_run_config,
123+
run_config=run_config,
100124
previous_response_id=previous_response_id,
101125
session=session,
102126
)

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
ModelTracing,
2626
Tool,
2727
TResponseInputItem,
28-
WebSearchTool,
28+
WebSearchTool, Agent,
2929
)
3030
from agents.items import TResponseStreamEvent
3131
from openai.types.responses.response_prompt_param import ResponsePromptParam
@@ -50,9 +50,11 @@ def __init__(
5050
model_name: Optional[str],
5151
*,
5252
model_params: ModelActivityParameters,
53+
agent: Agent[Any],
5354
) -> None:
5455
self.model_name = model_name
5556
self.model_params = model_params
57+
self.agent = agent
5658

5759
async def get_response(
5860
self,
@@ -124,7 +126,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
124126
activity_input = ActivityModelInput(
125127
model_name=self.model_name,
126128
system_instructions=system_instructions,
127-
input=cast(Union[str, list[TResponseInputItem]], input),
129+
input=input,
128130
model_settings=model_settings,
129131
tools=tool_infos,
130132
output_schema=output_schema_input,
@@ -134,10 +136,16 @@ def make_tool_info(tool: Tool) -> ToolInput:
134136
prompt=prompt,
135137
)
136138

139+
if self.model_params.summary_override:
140+
summary = self.model_params.summary_override if isinstance(self.model_params.summary_override, str) else (
141+
self.model_params.summary_override(self.agent, system_instructions, input))
142+
else:
143+
summary = self.agent.name
144+
137145
return await workflow.execute_activity_method(
138146
ModelActivity.invoke_model_activity,
139147
activity_input,
140-
summary=self.model_params.summary_override or _extract_summary(input),
148+
summary=summary,
141149
task_queue=self.model_params.task_queue,
142150
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
143151
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,

tests/contrib/openai_agents/test_openai.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,7 @@ def init_agents() -> Agent[AirlineAgentContext]:
861861

862862
seat_booking_agent = Agent[AirlineAgentContext](
863863
name="Seat Booking Agent",
864+
model="gpt-4o-mini",
864865
handoff_description="A helpful agent that can update a seat on a flight.",
865866
instructions=f"""{RECOMMENDED_PROMPT_PREFIX}
866867
You are a seat booking agent. If you are speaking to a customer, you probably were transferred to from the triage agent.
@@ -915,6 +916,17 @@ class CustomerServiceModel(StaticTestModel):
915916
]
916917

917918

919+
class AssertDifferentModelProvider(ModelProvider):
920+
model_names = set()
921+
922+
def __init__(self, model: Model):
923+
self._model = model
924+
925+
def get_model(self, model_name: Union[str, None]) -> Model:
926+
self.model_names.add(model_name)
927+
return self._model
928+
929+
918930
@workflow.defn
919931
class CustomerServiceWorkflow:
920932
def __init__(self, input_items: list[TResponseInputItem] = []):
@@ -987,12 +999,13 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
987999
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
9881000
pytest.skip("No openai API key")
9891001
new_config = client.config()
1002+
provider = AssertDifferentModelProvider(CustomerServiceModel())
9901003
new_config["plugins"] = [
9911004
openai_agents.OpenAIAgentsPlugin(
9921005
model_params=ModelActivityParameters(
9931006
start_to_close_timeout=timedelta(seconds=30)
9941007
),
995-
model_provider=TestModelProvider(CustomerServiceModel())
1008+
model_provider= provider
9961009
if use_local_model
9971010
else None,
9981011
)
@@ -1076,6 +1089,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
10761089
.data.decode()
10771090
)
10781091

1092+
assert len(provider.model_names) == 2
10791093

10801094
class InputGuardrailModel(OpenAIResponsesModel):
10811095
__test__ = False

0 commit comments

Comments
 (0)