Skip to content

Commit e8d8a48

Browse files
committed
Use the agent's model name if not present in runconfig
1 parent da6616a commit e8d8a48

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,15 @@ async def run(
6565
if run_config is None:
6666
run_config = RunConfig()
6767

68-
if run_config.model is not None and not isinstance(run_config.model, str):
68+
model_name = run_config.model or starting_agent.model
69+
if model_name is not None and not isinstance(model_name, str):
6970
raise ValueError(
70-
"Temporal workflows require a model name to be a string in the run config."
71+
"Temporal workflows require a model name to be a string in the run config and/or agent."
7172
)
7273
updated_run_config = replace(
7374
run_config,
7475
model=_TemporalModelStub(
75-
run_config.model,
76+
model_name=model_name,
7677
model_params=self.model_params,
7778
),
7879
)

tests/contrib/openai_agents/test_openai.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,6 +1917,49 @@ def stream_response(
19171917
raise NotImplementedError()
19181918

19191919

1920+
@workflow.defn
1921+
class AlternateModelAgent:
1922+
@workflow.run
1923+
async def run(self, prompt: str) -> str:
1924+
agent = Agent[None](
1925+
name="Assistant",
1926+
instructions="You only respond in haikus.",
1927+
model="test_model",
1928+
)
1929+
result = await Runner.run(starting_agent=agent, input=prompt)
1930+
return result.final_output
1931+
1932+
class CheckModelNameProvider(ModelProvider):
1933+
def get_model(self, model_name: Optional[str]) -> Model:
1934+
assert model_name == "test_model"
1935+
return TestHelloModel()
1936+
1937+
async def test_alternative_model(client: Client):
1938+
new_config = client.config()
1939+
new_config["plugins"] = [
1940+
openai_agents.OpenAIAgentsPlugin(
1941+
model_params=ModelActivityParameters(
1942+
start_to_close_timeout=timedelta(seconds=30)
1943+
),
1944+
model_provider=CheckModelNameProvider(),
1945+
)
1946+
]
1947+
client = Client(**new_config)
1948+
1949+
async with new_worker(
1950+
client,
1951+
AlternateModelAgent,
1952+
) as worker:
1953+
workflow_handle = await client.start_workflow(
1954+
AlternateModelAgent.run,
1955+
"Hello",
1956+
id=f"alternative-model-{uuid.uuid4()}",
1957+
task_queue=worker.task_queue,
1958+
execution_timeout=timedelta(seconds=10),
1959+
)
1960+
await workflow_handle.result()
1961+
1962+
19201963
async def test_heartbeat(client: Client, env: WorkflowEnvironment):
19211964
if env.supports_time_skipping:
19221965
pytest.skip("Relies on real timing, skip.")

0 commit comments

Comments
 (0)