diff --git a/temporalio/contrib/openai_agents/workflow.py b/temporalio/contrib/openai_agents/workflow.py index f5cc7f762..50eba0b9e 100644 --- a/temporalio/contrib/openai_agents/workflow.py +++ b/temporalio/contrib/openai_agents/workflow.py @@ -1,18 +1,27 @@ """Workflow-specific primitives for working with the OpenAI Agents SDK in a workflow context""" +import functools +import inspect import json from datetime import timedelta -from typing import Any, Callable, Optional, Type +from typing import Any, Callable, Optional, Type, Union, overload import nexusrpc from agents import ( + Agent, RunContextWrapper, Tool, ) -from agents.function_schema import function_schema +from agents.function_schema import DocstringStyle, function_schema from agents.tool import ( FunctionTool, + ToolErrorFunction, + ToolFunction, + ToolParams, + default_tool_error_function, + function_tool, ) +from agents.util._types import MaybeAwaitable from temporalio import activity from temporalio import workflow as temporal_workflow @@ -78,6 +87,25 @@ def activity_as_tool( "Bare function without tool and activity decorators is not supported", "invalid_tool", ) + if ret.name is None: + raise ApplicationError( + "Input activity must have a name to be made into a tool", + "invalid_tool", + ) + # If the provided callable has a first argument of `self`, partially apply it with the same metadata + # The actual instance will be picked up by the activity execution, the partially applied function will never actually be executed + params = list(inspect.signature(fn).parameters.keys()) + if len(params) > 0 and params[0] == "self": + partial = functools.partial(fn, None) + setattr(partial, "__name__", fn.__name__) + partial.__annotations__ = getattr(fn, "__annotations__") + setattr( + partial, + "__temporal_activity_definition", + getattr(fn, "__temporal_activity_definition"), + ) + partial.__doc__ = fn.__doc__ + fn = partial schema = function_schema(fn) async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: @@ -94,9 +122,8 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: # Add the context to the arguments if it takes that if schema.takes_context: args = [ctx] + args - result = await temporal_workflow.execute_activity( - fn, + ret.name, # type: ignore args=args, task_queue=task_queue, schedule_to_close_timeout=schedule_to_close_timeout, diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index fe8584afc..57dc5c252 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -195,6 +195,17 @@ async def get_weather_context(ctx: RunContextWrapper[str], city: str) -> Weather return Weather(city=city, temperature_range="14-20C", conditions=ctx.context) +class ActivityWeatherService: + @activity.defn + async def get_weather_method(self, city: str) -> Weather: + """ + Get the weather for a given city. + """ + return Weather( + city=city, temperature_range="14-20C", conditions="Sunny with wind." + ) + + @nexusrpc.service class WeatherService: get_weather_nexus_operation: nexusrpc.Operation[WeatherInput, Weather] @@ -269,6 +280,20 @@ class TestWeatherModel(StaticTestModel): usage=Usage(), response_id=None, ), + ModelResponse( + output=[ + ResponseFunctionToolCall( + arguments='{"city":"Tokyo"}', + call_id="call", + name="get_weather_method", + type="function_call", + id="id", + status="completed", + ) + ], + usage=Usage(), + response_id=None, + ), ModelResponse( output=[ ResponseOutputMessage( @@ -333,7 +358,7 @@ class TestNexusWeatherModel(StaticTestModel): class ToolsWorkflow: @workflow.run async def run(self, question: str) -> str: - agent = Agent( + agent: Agent = Agent( name="Tools Workflow", instructions="You are a helpful agent.", tools=[ @@ -349,8 +374,12 @@ async def run(self, question: str) -> str: openai_agents.workflow.activity_as_tool( get_weather_context, start_to_close_timeout=timedelta(seconds=10) ), + openai_agents.workflow.activity_as_tool( + ActivityWeatherService.get_weather_method, + start_to_close_timeout=timedelta(seconds=10), + ), ], - ) # type: Agent + ) result = await Runner.run( starting_agent=agent, input=question, context="Stormy" ) @@ -406,6 +435,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool): get_weather_object, get_weather_country, get_weather_context, + ActivityWeatherService().get_weather_method, ], interceptors=[OpenAIAgentsTracingInterceptor()], ) as worker: @@ -426,7 +456,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool): if e.HasField("activity_task_completed_event_attributes"): events.append(e) - assert len(events) == 9 + assert len(events) == 11 assert ( "function_call" in events[0] @@ -476,11 +506,23 @@ async def test_tool_workflow(client: Client, use_local_model: bool): .data.decode() ) assert ( - "Test weather result" + "function_call" in events[8] .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) + assert ( + "Sunny with wind" + in events[9] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Test weather result" + in events[10] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) @pytest.mark.parametrize("use_local_model", [True, False])