Skip to content

Commit 05ebff7

Browse files
committed
Cleanup and fix when model is provided via run_config
1 parent 14c818b commit 05ebff7

File tree

4 files changed

+69
-20
lines changed

4 files changed

+69
-20
lines changed

temporalio/contrib/openai_agents/_model_parameters.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ class ModelActivityParameters:
4747
Union[
4848
str,
4949
Callable[
50-
[Agent[Any], Optional[str], Union[str, list[TResponseInputItem]]], str
50+
[
51+
Optional[Agent[Any]],
52+
Optional[str],
53+
Union[str, list[TResponseInputItem]],
54+
],
55+
str,
5156
],
5257
]
5358
] = None

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dataclasses
12
import json
23
import typing
34
from typing import Any, Optional, Union
@@ -78,29 +79,30 @@ async def run(
7879
if run_config is None:
7980
run_config = RunConfig()
8081

81-
def model_name(agent: Agent[Any]) -> Optional[str]:
82-
name = run_config.model or agent.model
83-
if name is not None and not isinstance(name, str):
84-
print("Name: ", name, " Agent: ", agent)
82+
if run_config.model:
83+
if not isinstance(run_config.model, str):
8584
raise ValueError(
86-
"Temporal workflows require a model name to be a string in the run config and/or agent."
85+
"Temporal workflows require a model name to be a string in the run config."
8786
)
88-
return name
87+
run_config = dataclasses.replace(
88+
run_config,
89+
model=_TemporalModelStub(
90+
run_config.model, model_params=self.model_params, agent=None
91+
),
92+
)
8993

94+
# Recursively replace models in all agents
9095
def convert_agent(agent: Agent[Any]) -> None:
91-
print("Model: ", agent.model)
92-
9396
# Short circuit if this model was already replaced to prevent looping from circular handoffs
9497
if isinstance(agent.model, _TemporalModelStub):
9598
return
9699

97-
name = model_name(agent)
100+
name = _model_name(agent)
98101
agent.model = _TemporalModelStub(
99102
model_name=name,
100103
model_params=self.model_params,
101104
agent=agent,
102105
)
103-
print("Model after replace: ", agent.model)
104106

105107
for handoff in agent.handoffs:
106108
if isinstance(handoff, Agent):
@@ -159,3 +161,12 @@ def run_streamed(
159161
**kwargs,
160162
)
161163
raise RuntimeError("Temporal workflows do not support streaming.")
164+
165+
166+
def _model_name(agent: Agent[Any]) -> Optional[str]:
167+
name = agent.model
168+
if name is not None and not isinstance(name, str):
169+
raise ValueError(
170+
"Temporal workflows require a model name to be a string in the agent."
171+
)
172+
return name

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
model_name: Optional[str],
5252
*,
5353
model_params: ModelActivityParameters,
54-
agent: Agent[Any],
54+
agent: Optional[Agent[Any]],
5555
) -> None:
5656
self.model_name = model_name
5757
self.model_params = model_params
@@ -70,8 +70,6 @@ async def get_response(
7070
previous_response_id: Optional[str],
7171
prompt: Optional[ResponsePromptParam],
7272
) -> ModelResponse:
73-
print("Model stub invocation:", self.model_name)
74-
7573
def make_tool_info(tool: Tool) -> ToolInput:
7674
if isinstance(
7775
tool,
@@ -149,8 +147,10 @@ def make_tool_info(tool: Tool) -> ToolInput:
149147
)
150148
)
151149
)
152-
else:
150+
elif self.agent:
153151
summary = self.agent.name
152+
else:
153+
summary = None
154154

155155
return await workflow.execute_activity_method(
156156
ModelActivity.invoke_model_activity,

tests/contrib/openai_agents/test_openai.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
OpenAIChatCompletionsModel,
3333
OpenAIResponsesModel,
3434
OutputGuardrailTripwireTriggered,
35+
RunConfig,
3536
RunContextWrapper,
3637
Runner,
3738
SQLiteSession,
@@ -861,7 +862,6 @@ def init_agents() -> Agent[AirlineAgentContext]:
861862

862863
seat_booking_agent = Agent[AirlineAgentContext](
863864
name="Seat Booking Agent",
864-
model="gpt-4o-mini",
865865
handoff_description="A helpful agent that can update a seat on a flight.",
866866
instructions=f"""{RECOMMENDED_PROMPT_PREFIX}
867867
You are a seat booking agent. If you are speaking to a customer, you probably were transferred to from the triage agent.
@@ -988,13 +988,14 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
988988
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
989989
pytest.skip("No openai API key")
990990
new_config = client.config()
991-
provider = TestModelProvider(CustomerServiceModel())
992991
new_config["plugins"] = [
993992
openai_agents.OpenAIAgentsPlugin(
994993
model_params=ModelActivityParameters(
995994
start_to_close_timeout=timedelta(seconds=30)
996995
),
997-
model_provider=provider if use_local_model else None,
996+
model_provider=TestModelProvider(CustomerServiceModel())
997+
if use_local_model
998+
else None,
998999
)
9991000
]
10001001
client = Client(**new_config)
@@ -2066,7 +2067,7 @@ class MultipleModelsModel(StaticTestModel):
20662067
@workflow.defn
20672068
class MultipleModelWorkflow:
20682069
@workflow.run
2069-
async def run(self):
2070+
async def run(self, use_run_config: bool):
20702071
underling = Agent[None](
20712072
name="Underling",
20722073
instructions="You do all the work you are told.",
@@ -2081,6 +2082,7 @@ async def run(self):
20812082
result = await Runner.run(
20822083
starting_agent=starting_agent,
20832084
input="Have you cleaned the store room yet?",
2085+
run_config=RunConfig(model="gpt-4o") if use_run_config else None,
20842086
)
20852087
return result.final_output
20862088

@@ -2104,9 +2106,40 @@ async def test_multiple_models(client: Client):
21042106
) as worker:
21052107
workflow_handle = await client.start_workflow(
21062108
MultipleModelWorkflow.run,
2109+
False,
21072110
id=f"multiple-model-{uuid.uuid4()}",
21082111
task_queue=worker.task_queue,
2109-
execution_timeout=timedelta(seconds=120),
2112+
execution_timeout=timedelta(seconds=10),
21102113
)
21112114
result = await workflow_handle.result()
21122115
assert provider.model_names == {None, "gpt-4o-mini"}
2116+
2117+
2118+
async def test_run_config_models(client: Client):
2119+
provider = AssertDifferentModelProvider(MultipleModelsModel())
2120+
new_config = client.config()
2121+
new_config["plugins"] = [
2122+
openai_agents.OpenAIAgentsPlugin(
2123+
model_params=ModelActivityParameters(
2124+
start_to_close_timeout=timedelta(seconds=120)
2125+
),
2126+
model_provider=provider,
2127+
)
2128+
]
2129+
client = Client(**new_config)
2130+
2131+
async with new_worker(
2132+
client,
2133+
MultipleModelWorkflow,
2134+
) as worker:
2135+
workflow_handle = await client.start_workflow(
2136+
MultipleModelWorkflow.run,
2137+
True,
2138+
id=f"run-config-model-{uuid.uuid4()}",
2139+
task_queue=worker.task_queue,
2140+
execution_timeout=timedelta(seconds=10),
2141+
)
2142+
result = await workflow_handle.result()
2143+
2144+
# Only the model from the runconfig override is used
2145+
assert provider.model_names == {"gpt-4o"}

0 commit comments

Comments
 (0)