diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index 7a5153141..1ccbc5f4d 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -1,3 +1,4 @@ +import typing from dataclasses import replace from typing import Any, Union @@ -7,6 +8,7 @@ RunResult, RunResultStreaming, TContext, + Tool, TResponseInputItem, ) from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner @@ -42,6 +44,13 @@ async def run( **kwargs, ) + tool_types = typing.get_args(Tool) + for t in starting_agent.tools: + if not isinstance(t, tool_types): + raise ValueError( + "Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool." + ) + context = kwargs.get("context") max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) hooks = kwargs.get("hooks") @@ -63,16 +72,15 @@ async def run( ), ) - with workflow.unsafe.imports_passed_through(): - return await self._runner.run( - starting_agent=starting_agent, - input=input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=updated_run_config, - previous_response_id=previous_response_id, - ) + return await self._runner.run( + starting_agent=starting_agent, + input=input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=updated_run_config, + previous_response_id=previous_response_id, + ) def run_sync( self,