Skip to content

Commit beb9c9d

Browse files
authored
Use the agent's model name if not present in runconfig (#996)
* Use the agent's model name if not present in runconfig * Format
1 parent b6b0973 commit beb9c9d

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,15 @@ async def run(
7070
if run_config is None:
7171
run_config = RunConfig()
7272

73-
if run_config.model is not None and not isinstance(run_config.model, str):
73+
model_name = run_config.model or starting_agent.model
74+
if model_name is not None and not isinstance(model_name, str):
7475
raise ValueError(
75-
"Temporal workflows require a model name to be a string in the run config."
76+
"Temporal workflows require a model name to be a string in the run config and/or agent."
7677
)
7778
updated_run_config = replace(
7879
run_config,
7980
model=_TemporalModelStub(
80-
run_config.model,
81+
model_name=model_name,
8182
model_params=self.model_params,
8283
),
8384
)

tests/contrib/openai_agents/test_openai.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,6 +1934,51 @@ def stream_response(
19341934
raise NotImplementedError()
19351935

19361936

1937+
@workflow.defn
1938+
class AlternateModelAgent:
1939+
@workflow.run
1940+
async def run(self, prompt: str) -> str:
1941+
agent = Agent[None](
1942+
name="Assistant",
1943+
instructions="You only respond in haikus.",
1944+
model="test_model",
1945+
)
1946+
result = await Runner.run(starting_agent=agent, input=prompt)
1947+
return result.final_output
1948+
1949+
1950+
class CheckModelNameProvider(ModelProvider):
1951+
def get_model(self, model_name: Optional[str]) -> Model:
1952+
assert model_name == "test_model"
1953+
return TestHelloModel()
1954+
1955+
1956+
async def test_alternative_model(client: Client):
1957+
new_config = client.config()
1958+
new_config["plugins"] = [
1959+
openai_agents.OpenAIAgentsPlugin(
1960+
model_params=ModelActivityParameters(
1961+
start_to_close_timeout=timedelta(seconds=30)
1962+
),
1963+
model_provider=CheckModelNameProvider(),
1964+
)
1965+
]
1966+
client = Client(**new_config)
1967+
1968+
async with new_worker(
1969+
client,
1970+
AlternateModelAgent,
1971+
) as worker:
1972+
workflow_handle = await client.start_workflow(
1973+
AlternateModelAgent.run,
1974+
"Hello",
1975+
id=f"alternative-model-{uuid.uuid4()}",
1976+
task_queue=worker.task_queue,
1977+
execution_timeout=timedelta(seconds=10),
1978+
)
1979+
await workflow_handle.result()
1980+
1981+
19371982
async def test_heartbeat(client: Client, env: WorkflowEnvironment):
19381983
if env.supports_time_skipping:
19391984
pytest.skip("Relies on real timing, skip.")

0 commit comments

Comments
 (0)