Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions temporalio/contrib/openai_agents/_openai_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,15 @@ async def run(
if run_config is None:
run_config = RunConfig()

if run_config.model is not None and not isinstance(run_config.model, str):
model_name = run_config.model or starting_agent.model
if model_name is not None and not isinstance(model_name, str):
raise ValueError(
"Temporal workflows require a model name to be a string in the run config."
"Temporal workflows require a model name to be a string in the run config and/or agent."
)
updated_run_config = replace(
run_config,
model=_TemporalModelStub(
run_config.model,
model_name=model_name,
model_params=self.model_params,
),
)
Expand Down
45 changes: 45 additions & 0 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,51 @@ def stream_response(
raise NotImplementedError()


@workflow.defn
class AlternateModelAgent:
@workflow.run
async def run(self, prompt: str) -> str:
agent = Agent[None](
name="Assistant",
instructions="You only respond in haikus.",
model="test_model",
)
result = await Runner.run(starting_agent=agent, input=prompt)
return result.final_output


class CheckModelNameProvider(ModelProvider):
def get_model(self, model_name: Optional[str]) -> Model:
assert model_name == "test_model"
return TestHelloModel()


async def test_alternative_model(client: Client):
new_config = client.config()
new_config["plugins"] = [
openai_agents.OpenAIAgentsPlugin(
model_params=ModelActivityParameters(
start_to_close_timeout=timedelta(seconds=30)
),
model_provider=CheckModelNameProvider(),
)
]
client = Client(**new_config)

async with new_worker(
client,
AlternateModelAgent,
) as worker:
workflow_handle = await client.start_workflow(
AlternateModelAgent.run,
"Hello",
id=f"alternative-model-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=10),
)
await workflow_handle.result()


async def test_heartbeat(client: Client, env: WorkflowEnvironment):
if env.supports_time_skipping:
pytest.skip("Relies on real timing, skip.")
Expand Down