Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def make_tool(tool: ToolInput) -> Tool:
raise UserError(f"Unknown tool type: {tool.name}")

tools = [make_tool(x) for x in input.get("tools", [])]
handoffs = [
handoffs: list[Handoff[Any, Any]] = [
Handoff(
tool_name=x.tool_name,
tool_description=x.tool_description,
Expand Down
34 changes: 19 additions & 15 deletions temporalio/contrib/openai_agents/_openai_runner.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import typing
from dataclasses import replace
from datetime import timedelta
from typing import Optional, Union
from typing import Union

from agents import (
Agent,
RunConfig,
RunHooks,
RunResult,
RunResultStreaming,
TContext,
Tool,
TResponseInputItem,
)
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner

from temporalio import workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
from temporalio.workflow import ActivityCancellationType, VersioningIntent


class TemporalOpenAIRunner(AgentRunner):
Expand Down Expand Up @@ -46,6 +44,13 @@ async def run(
**kwargs,
)

tool_types = typing.get_args(Tool)
for t in starting_agent.tools:
if isinstance(t, tool_types):
raise ValueError(
"Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool."
)

context = kwargs.get("context")
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
hooks = kwargs.get("hooks")
Expand All @@ -67,16 +72,15 @@ async def run(
),
)

with workflow.unsafe.imports_passed_through():
return await self._runner.run(
starting_agent=starting_agent,
input=input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=updated_run_config,
previous_response_id=previous_response_id,
)
return await self._runner.run(
starting_agent=starting_agent,
input=input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=updated_run_config,
previous_response_id=previous_response_id,
)

def run_sync(
self,
Expand Down
Loading