diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index c05026983..6f0d74064 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -70,14 +70,15 @@ async def run( if run_config is None: run_config = RunConfig() - if run_config.model is not None and not isinstance(run_config.model, str): + model_name = run_config.model or starting_agent.model + if model_name is not None and not isinstance(model_name, str): raise ValueError( - "Temporal workflows require a model name to be a string in the run config." + "Temporal workflows require a model name to be a string in the run config and/or agent." ) updated_run_config = replace( run_config, model=_TemporalModelStub( - run_config.model, + model_name=model_name, model_params=self.model_params, ), ) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 17603e49f..ed5e1ffa4 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -1934,6 +1934,51 @@ def stream_response( raise NotImplementedError() +@workflow.defn +class AlternateModelAgent: + @workflow.run + async def run(self, prompt: str) -> str: + agent = Agent[None]( + name="Assistant", + instructions="You only respond in haikus.", + model="test_model", + ) + result = await Runner.run(starting_agent=agent, input=prompt) + return result.final_output + + +class CheckModelNameProvider(ModelProvider): + def get_model(self, model_name: Optional[str]) -> Model: + assert model_name == "test_model" + return TestHelloModel() + + +async def test_alternative_model(client: Client): + new_config = client.config() + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=CheckModelNameProvider(), + ) + ] + client = Client(**new_config) + + async with new_worker( + client, + AlternateModelAgent, + ) as worker: + workflow_handle = await client.start_workflow( + AlternateModelAgent.run, + "Hello", + id=f"alternative-model-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() + + async def test_heartbeat(client: Client, env: WorkflowEnvironment): if env.supports_time_skipping: pytest.skip("Relies on real timing, skip.")