From e8d8a48c2c39278245d8a5131c1efce341b386b0 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 24 Jul 2025 16:16:14 -0700 Subject: [PATCH 1/2] Use the agent's model name if not present in runconfig --- .../contrib/openai_agents/_openai_runner.py | 7 +-- tests/contrib/openai_agents/test_openai.py | 43 +++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index fb07b6062..8be1d2f12 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -65,14 +65,15 @@ async def run( if run_config is None: run_config = RunConfig() - if run_config.model is not None and not isinstance(run_config.model, str): + model_name = run_config.model or starting_agent.model + if model_name is not None and not isinstance(model_name, str): raise ValueError( - "Temporal workflows require a model name to be a string in the run config." + "Temporal workflows require a model name to be a string in the run config and/or agent." ) updated_run_config = replace( run_config, model=_TemporalModelStub( - run_config.model, + model_name=model_name, model_params=self.model_params, ), ) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 7147d3ca6..77200f0a7 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -1917,6 +1917,49 @@ def stream_response( raise NotImplementedError() +@workflow.defn +class AlternateModelAgent: + @workflow.run + async def run(self, prompt: str) -> str: + agent = Agent[None]( + name="Assistant", + instructions="You only respond in haikus.", + model="test_model", + ) + result = await Runner.run(starting_agent=agent, input=prompt) + return result.final_output + +class CheckModelNameProvider(ModelProvider): + def get_model(self, model_name: Optional[str]) -> Model: + assert model_name == "test_model" + return TestHelloModel() + +async def test_alternative_model(client: Client): + new_config = client.config() + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=CheckModelNameProvider(), + ) + ] + client = Client(**new_config) + + async with new_worker( + client, + AlternateModelAgent, + ) as worker: + workflow_handle = await client.start_workflow( + AlternateModelAgent.run, + "Hello", + id=f"alternative-model-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() + + async def test_heartbeat(client: Client, env: WorkflowEnvironment): if env.supports_time_skipping: pytest.skip("Relies on real timing, skip.") From 1d1eeb9f5e7be592be52ad468309861443edfc2b Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 24 Jul 2025 16:24:05 -0700 Subject: [PATCH 2/2] Format --- tests/contrib/openai_agents/test_openai.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 77200f0a7..51d78ac39 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -1929,11 +1929,13 @@ async def run(self, prompt: str) -> str: result = await Runner.run(starting_agent=agent, input=prompt) return result.final_output + class CheckModelNameProvider(ModelProvider): def get_model(self, model_name: Optional[str]) -> Model: assert model_name == "test_model" return TestHelloModel() + async def test_alternative_model(client: Client): new_config = client.config() new_config["plugins"] = [