Skip to content

Commit 4f79082

Browse files
authored
Merge branch 'main' into plugins
2 parents 9238e7d + 33b4a43 commit 4f79082

File tree

3 files changed

+81
-31
lines changed

3 files changed

+81
-31
lines changed

temporalio/client.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -964,9 +964,6 @@ async def execute_update_with_start_workflow(
964964
the call will not return successfully until the update has been delivered to a
965965
worker.
966966
967-
.. warning::
968-
This API is experimental
969-
970967
Args:
971968
update: Update function or name on the workflow. arg: Single argument to the
972969
update.
@@ -5383,11 +5380,7 @@ class StartWorkflowUpdateInput:
53835380

53845381
@dataclass
53855382
class UpdateWithStartUpdateWorkflowInput:
5386-
"""Update input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`.
5387-
5388-
.. warning::
5389-
This API is experimental
5390-
"""
5383+
"""Update input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`."""
53915384

53925385
update_id: Optional[str]
53935386
update: str
@@ -5401,11 +5394,7 @@ class UpdateWithStartUpdateWorkflowInput:
54015394

54025395
@dataclass
54035396
class UpdateWithStartStartWorkflowInput:
5404-
"""StartWorkflow input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`.
5405-
5406-
.. warning::
5407-
This API is experimental
5408-
"""
5397+
"""StartWorkflow input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`."""
54095398

54105399
# Similar to StartWorkflowInput but without e.g. run_id, start_signal,
54115400
# start_signal_args, request_eager_start.
@@ -5441,11 +5430,7 @@ class UpdateWithStartStartWorkflowInput:
54415430

54425431
@dataclass
54435432
class StartWorkflowUpdateWithStartInput:
5444-
"""Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`.
5445-
5446-
.. warning::
5447-
This API is experimental
5448-
"""
5433+
"""Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`."""
54495434

54505435
start_workflow_input: UpdateWithStartStartWorkflowInput
54515436
update_workflow_input: UpdateWithStartUpdateWorkflowInput
@@ -5719,11 +5704,7 @@ async def start_workflow_update(
57195704
async def start_update_with_start_workflow(
57205705
self, input: StartWorkflowUpdateWithStartInput
57215706
) -> WorkflowUpdateHandle[Any]:
5722-
"""Called for every :py:meth:`Client.start_update_with_start_workflow` and :py:meth:`Client.execute_update_with_start_workflow` call.
5723-
5724-
.. warning::
5725-
This API is experimental
5726-
"""
5707+
"""Called for every :py:meth:`Client.start_update_with_start_workflow` and :py:meth:`Client.execute_update_with_start_workflow` call."""
57275708
return await self.next.start_update_with_start_workflow(input)
57285709

57295710
### Async activity calls

temporalio/contrib/openai_agents/workflow.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
"""Workflow-specific primitives for working with the OpenAI Agents SDK in a workflow context"""
22

3+
import functools
4+
import inspect
35
import json
46
from datetime import timedelta
5-
from typing import Any, Callable, Optional, Type
7+
from typing import Any, Callable, Optional, Type, Union, overload
68

79
import nexusrpc
810
from agents import (
11+
Agent,
912
RunContextWrapper,
1013
Tool,
1114
)
12-
from agents.function_schema import function_schema
15+
from agents.function_schema import DocstringStyle, function_schema
1316
from agents.tool import (
1417
FunctionTool,
18+
ToolErrorFunction,
19+
ToolFunction,
20+
ToolParams,
21+
default_tool_error_function,
22+
function_tool,
1523
)
24+
from agents.util._types import MaybeAwaitable
1625

1726
from temporalio import activity
1827
from temporalio import workflow as temporal_workflow
@@ -78,6 +87,25 @@ def activity_as_tool(
7887
"Bare function without tool and activity decorators is not supported",
7988
"invalid_tool",
8089
)
90+
if ret.name is None:
91+
raise ApplicationError(
92+
"Input activity must have a name to be made into a tool",
93+
"invalid_tool",
94+
)
95+
# If the provided callable has a first argument of `self`, partially apply it with the same metadata
96+
# The actual instance will be picked up by the activity execution, the partially applied function will never actually be executed
97+
params = list(inspect.signature(fn).parameters.keys())
98+
if len(params) > 0 and params[0] == "self":
99+
partial = functools.partial(fn, None)
100+
setattr(partial, "__name__", fn.__name__)
101+
partial.__annotations__ = getattr(fn, "__annotations__")
102+
setattr(
103+
partial,
104+
"__temporal_activity_definition",
105+
getattr(fn, "__temporal_activity_definition"),
106+
)
107+
partial.__doc__ = fn.__doc__
108+
fn = partial
81109
schema = function_schema(fn)
82110

83111
async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
@@ -94,9 +122,8 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
94122
# Add the context to the arguments if it takes that
95123
if schema.takes_context:
96124
args = [ctx] + args
97-
98125
result = await temporal_workflow.execute_activity(
99-
fn,
126+
ret.name, # type: ignore
100127
args=args,
101128
task_queue=task_queue,
102129
schedule_to_close_timeout=schedule_to_close_timeout,

tests/contrib/openai_agents/test_openai.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,17 @@ async def get_weather_context(ctx: RunContextWrapper[str], city: str) -> Weather
195195
return Weather(city=city, temperature_range="14-20C", conditions=ctx.context)
196196

197197

198+
class ActivityWeatherService:
199+
@activity.defn
200+
async def get_weather_method(self, city: str) -> Weather:
201+
"""
202+
Get the weather for a given city.
203+
"""
204+
return Weather(
205+
city=city, temperature_range="14-20C", conditions="Sunny with wind."
206+
)
207+
208+
198209
@nexusrpc.service
199210
class WeatherService:
200211
get_weather_nexus_operation: nexusrpc.Operation[WeatherInput, Weather]
@@ -269,6 +280,20 @@ class TestWeatherModel(StaticTestModel):
269280
usage=Usage(),
270281
response_id=None,
271282
),
283+
ModelResponse(
284+
output=[
285+
ResponseFunctionToolCall(
286+
arguments='{"city":"Tokyo"}',
287+
call_id="call",
288+
name="get_weather_method",
289+
type="function_call",
290+
id="id",
291+
status="completed",
292+
)
293+
],
294+
usage=Usage(),
295+
response_id=None,
296+
),
272297
ModelResponse(
273298
output=[
274299
ResponseOutputMessage(
@@ -333,7 +358,7 @@ class TestNexusWeatherModel(StaticTestModel):
333358
class ToolsWorkflow:
334359
@workflow.run
335360
async def run(self, question: str) -> str:
336-
agent = Agent(
361+
agent: Agent = Agent(
337362
name="Tools Workflow",
338363
instructions="You are a helpful agent.",
339364
tools=[
@@ -349,8 +374,12 @@ async def run(self, question: str) -> str:
349374
openai_agents.workflow.activity_as_tool(
350375
get_weather_context, start_to_close_timeout=timedelta(seconds=10)
351376
),
377+
openai_agents.workflow.activity_as_tool(
378+
ActivityWeatherService.get_weather_method,
379+
start_to_close_timeout=timedelta(seconds=10),
380+
),
352381
],
353-
) # type: Agent
382+
)
354383
result = await Runner.run(
355384
starting_agent=agent, input=question, context="Stormy"
356385
)
@@ -406,6 +435,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
406435
get_weather_object,
407436
get_weather_country,
408437
get_weather_context,
438+
ActivityWeatherService().get_weather_method,
409439
],
410440
interceptors=[OpenAIAgentsTracingInterceptor()],
411441
) as worker:
@@ -426,7 +456,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
426456
if e.HasField("activity_task_completed_event_attributes"):
427457
events.append(e)
428458

429-
assert len(events) == 9
459+
assert len(events) == 11
430460
assert (
431461
"function_call"
432462
in events[0]
@@ -476,11 +506,23 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
476506
.data.decode()
477507
)
478508
assert (
479-
"Test weather result"
509+
"function_call"
480510
in events[8]
481511
.activity_task_completed_event_attributes.result.payloads[0]
482512
.data.decode()
483513
)
514+
assert (
515+
"Sunny with wind"
516+
in events[9]
517+
.activity_task_completed_event_attributes.result.payloads[0]
518+
.data.decode()
519+
)
520+
assert (
521+
"Test weather result"
522+
in events[10]
523+
.activity_task_completed_event_attributes.result.payloads[0]
524+
.data.decode()
525+
)
484526

485527

486528
@pytest.mark.parametrize("use_local_model", [True, False])

0 commit comments

Comments
 (0)