Skip to content

Commit a1f559b

Browse files
Merge branch 'main' into ci/migrate-tasks-to-poe
2 parents a999a80 + e4df5e7 commit a1f559b

File tree

10 files changed

+523
-267
lines changed

10 files changed

+523
-267
lines changed

.github/workflows/build-binaries.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ on:
44
branches:
55
- main
66
- "releases/*"
7-
- fix-build-binaries
87

98
jobs:
109
# Compile the binaries and upload artifacts
@@ -66,7 +65,7 @@ jobs:
6665
if [ "$RUNNER_OS" = "Windows" ]; then
6766
bindir=Scripts
6867
fi
69-
./.venv/$bindir/pip install 'protobuf>=3.20,<6' 'types-protobuf>=3.20,<6' 'typing-extensions>=4.2.0,<5' pytest pytest_asyncio grpcio 'nexus-rpc>=1.1.0' pydantic opentelemetry-api opentelemetry-sdk python-dateutil 'openai-agents>=0.2.3,<=0.2.9'
68+
./.venv/$bindir/pip install 'protobuf>=3.20,<6' 'types-protobuf>=3.20,<6' 'typing-extensions>=4.2.0,<5' pytest pytest_asyncio grpcio 'nexus-rpc>=1.1.0' pydantic opentelemetry-api opentelemetry-sdk python-dateutil 'openai-agents>=0.2.3,<=0.2.9' 'googleapis-common-protos==1.70.0'
7069
./.venv/$bindir/pip install --no-index --find-links=../dist temporalio
7170
./.venv/$bindir/python -m pytest -s -k test_workflow_hello
7271

temporalio/client.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,13 @@
7070
WorkflowSerializationContext,
7171
)
7272
from temporalio.service import (
73+
ConnectConfig,
7374
HttpConnectProxyConfig,
7475
KeepAliveConfig,
7576
RetryConfig,
7677
RPCError,
7778
RPCStatusCode,
79+
ServiceClient,
7880
TLSConfig,
7981
)
8082

@@ -198,12 +200,14 @@ async def connect(
198200
http_connect_proxy_config=http_connect_proxy_config,
199201
)
200202

201-
root_plugin: Plugin = _RootPlugin()
203+
def make_lambda(plugin, next):
204+
return lambda config: plugin.connect_service_client(config, next)
205+
206+
next_function = ServiceClient.connect
202207
for plugin in reversed(plugins):
203-
plugin.init_client_plugin(root_plugin)
204-
root_plugin = plugin
208+
next_function = make_lambda(plugin, next_function)
205209

206-
service_client = await root_plugin.connect_service_client(connect_config)
210+
service_client = await next_function(connect_config)
207211

208212
return Client(
209213
service_client,
@@ -243,12 +247,10 @@ def __init__(
243247
plugins=plugins,
244248
)
245249

246-
root_plugin: Plugin = _RootPlugin()
247-
for plugin in reversed(plugins):
248-
plugin.init_client_plugin(root_plugin)
249-
root_plugin = plugin
250+
for plugin in plugins:
251+
config = plugin.configure_client(config)
250252

251-
self._init_from_config(root_plugin.configure_client(config))
253+
self._init_from_config(config)
252254

253255
def _init_from_config(self, config: ClientConfig):
254256
self._config = config
@@ -7541,20 +7543,6 @@ def name(self) -> str:
75417543
"""
75427544
return type(self).__module__ + "." + type(self).__qualname__
75437545

7544-
@abstractmethod
7545-
def init_client_plugin(self, next: Plugin) -> None:
7546-
"""Initialize this plugin in the plugin chain.
7547-
7548-
This method sets up the chain of responsibility pattern by providing a reference
7549-
to the next plugin in the chain. It is called during client creation to build
7550-
the plugin chain. Note, this may be called twice in the case of :py:meth:`connect`.
7551-
Implementations should store this reference and call the corresponding method
7552-
of the next plugin on method calls.
7553-
7554-
Args:
7555-
next: The next plugin in the chain to delegate to.
7556-
"""
7557-
75587546
@abstractmethod
75597547
def configure_client(self, config: ClientConfig) -> ClientConfig:
75607548
"""Hook called when creating a client to allow modification of configuration.
@@ -7572,8 +7560,10 @@ def configure_client(self, config: ClientConfig) -> ClientConfig:
75727560

75737561
@abstractmethod
75747562
async def connect_service_client(
7575-
self, config: temporalio.service.ConnectConfig
7576-
) -> temporalio.service.ServiceClient:
7563+
self,
7564+
config: ConnectConfig,
7565+
next: Callable[[ConnectConfig], Awaitable[ServiceClient]],
7566+
) -> ServiceClient:
75777567
"""Hook called when connecting to the Temporal service.
75787568
75797569
This method is called during service client connection and allows plugins
@@ -7586,16 +7576,3 @@ async def connect_service_client(
75867576
Returns:
75877577
The connected service client.
75887578
"""
7589-
7590-
7591-
class _RootPlugin(Plugin):
7592-
def init_client_plugin(self, next: Plugin) -> None:
7593-
raise NotImplementedError()
7594-
7595-
def configure_client(self, config: ClientConfig) -> ClientConfig:
7596-
return config
7597-
7598-
async def connect_service_client(
7599-
self, config: temporalio.service.ConnectConfig
7600-
) -> temporalio.service.ServiceClient:
7601-
return await temporalio.service.ServiceClient.connect(config)

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from pydantic_core import to_json
3737
from typing_extensions import Required, TypedDict
3838

39-
from temporalio import activity
39+
from temporalio import activity, workflow
4040
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater
4141
from temporalio.exceptions import ApplicationError
4242

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 61 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@
2222
from agents.run import get_default_agent_runner, set_default_agent_runner
2323
from agents.tracing import get_trace_provider
2424
from agents.tracing.provider import DefaultTraceProvider
25-
from openai.types.responses import ResponsePromptParam
2625

27-
import temporalio.client
28-
import temporalio.worker
29-
from temporalio.client import ClientConfig
3026
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
3127
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
3228
from temporalio.contrib.openai_agents._openai_runner import (
@@ -47,13 +43,8 @@
4743
DataConverter,
4844
DefaultPayloadConverter,
4945
)
50-
from temporalio.worker import (
51-
Replayer,
52-
ReplayerConfig,
53-
Worker,
54-
WorkerConfig,
55-
WorkflowReplayResult,
56-
)
46+
from temporalio.plugin import SimplePlugin
47+
from temporalio.worker import WorkflowRunner
5748
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
5849

5950
# Unsupported on python 3.9
@@ -172,7 +163,21 @@ def __init__(self) -> None:
172163
super().__init__(ToJsonOptions(exclude_unset=True))
173164

174165

175-
class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
166+
def _data_converter(converter: Optional[DataConverter]) -> DataConverter:
167+
if converter is None:
168+
return DataConverter(payload_converter_class=OpenAIPayloadConverter)
169+
elif converter.payload_converter_class is DefaultPayloadConverter:
170+
return dataclasses.replace(
171+
converter, payload_converter_class=OpenAIPayloadConverter
172+
)
173+
elif not isinstance(converter.payload_converter, OpenAIPayloadConverter):
174+
raise ValueError(
175+
"The payload converter must be of type OpenAIPayloadConverter."
176+
)
177+
return converter
178+
179+
180+
class OpenAIAgentsPlugin(SimplePlugin):
176181
"""Temporal plugin for integrating OpenAI agents with Temporal workflows.
177182
178183
.. warning::
@@ -245,6 +250,7 @@ def __init__(
245250
mcp_server_providers: Sequence[
246251
Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"]
247252
] = (),
253+
register_activities: bool = True,
248254
) -> None:
249255
"""Initialize the OpenAI agents plugin.
250256
@@ -257,6 +263,9 @@ def __init__(
257263
Each server will be wrapped in a TemporalMCPServer if not already wrapped,
258264
and their activities will be automatically registered with the worker.
259265
The plugin manages the connection lifecycle of these servers.
266+
register_activities: Whether to register activities during the worker execution.
267+
This can be disabled on some workers to allow a separation of workflows and activities
268+
but should not be disabled on all workers, or agents will not be able to progress.
260269
"""
261270
if model_params is None:
262271
model_params = ModelActivityParameters()
@@ -274,124 +283,48 @@ def __init__(
274283
"When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout"
275284
)
276285

277-
self._model_params = model_params
278-
self._model_provider = model_provider
279-
self._mcp_server_providers = mcp_server_providers
280-
281-
def init_client_plugin(self, next: temporalio.client.Plugin) -> None:
282-
"""Set the next client plugin"""
283-
self.next_client_plugin = next
284-
285-
async def connect_service_client(
286-
self, config: temporalio.service.ConnectConfig
287-
) -> temporalio.service.ServiceClient:
288-
"""No modifications to service client"""
289-
return await self.next_client_plugin.connect_service_client(config)
290-
291-
def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None:
292-
"""Set the next worker plugin"""
293-
self.next_worker_plugin = next
294-
295-
@staticmethod
296-
def _data_converter(converter: Optional[DataConverter]) -> DataConverter:
297-
if converter is None:
298-
return DataConverter(payload_converter_class=OpenAIPayloadConverter)
299-
elif converter.payload_converter_class is DefaultPayloadConverter:
300-
return dataclasses.replace(
301-
converter, payload_converter_class=OpenAIPayloadConverter
302-
)
303-
elif not isinstance(converter.payload_converter, OpenAIPayloadConverter):
304-
raise ValueError(
305-
"The payload converter must be of type OpenAIPayloadConverter."
306-
)
307-
return converter
308-
309-
def configure_client(self, config: ClientConfig) -> ClientConfig:
310-
"""Configure the Temporal client for OpenAI agents integration.
311-
312-
This method sets up the Pydantic data converter to enable proper
313-
serialization of OpenAI agent objects and responses.
314-
315-
Args:
316-
config: The client configuration to modify.
317-
318-
Returns:
319-
The modified client configuration.
320-
"""
321-
config["data_converter"] = self._data_converter(config["data_converter"])
322-
return self.next_client_plugin.configure_client(config)
286+
# Delay activity construction until they are actually needed
287+
def add_activities(
288+
activities: Optional[Sequence[Callable]],
289+
) -> Sequence[Callable]:
290+
if not register_activities:
291+
return activities or []
323292

324-
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
325-
"""Configure the Temporal worker for OpenAI agents integration.
293+
new_activities = [ModelActivity(model_provider).invoke_model_activity]
326294

327-
This method adds the necessary interceptors and activities for OpenAI
328-
agent execution:
329-
- Adds tracing interceptors for OpenAI agent interactions
330-
- Registers model execution activities
295+
server_names = [server.name for server in mcp_server_providers]
296+
if len(server_names) != len(set(server_names)):
297+
raise ValueError(
298+
f"More than one mcp server registered with the same name. Please provide unique names."
299+
)
331300

332-
Args:
333-
config: The worker configuration to modify.
301+
for mcp_server in mcp_server_providers:
302+
new_activities.extend(mcp_server._get_activities())
303+
return list(activities or []) + new_activities
334304

335-
Returns:
336-
The modified worker configuration.
337-
"""
338-
config["interceptors"] = list(config.get("interceptors") or []) + [
339-
OpenAIAgentsTracingInterceptor()
340-
]
341-
new_activities = [ModelActivity(self._model_provider).invoke_model_activity]
342-
343-
server_names = [server.name for server in self._mcp_server_providers]
344-
if len(server_names) != len(set(server_names)):
345-
raise ValueError(
346-
f"More than one mcp server registered with the same name. Please provide unique names."
347-
)
348-
349-
for mcp_server in self._mcp_server_providers:
350-
new_activities.extend(mcp_server._get_activities())
351-
config["activities"] = list(config.get("activities") or []) + new_activities
352-
353-
runner = config.get("workflow_runner")
354-
if isinstance(runner, SandboxedWorkflowRunner):
355-
config["workflow_runner"] = dataclasses.replace(
356-
runner,
357-
restrictions=runner.restrictions.with_passthrough_modules("mcp"),
358-
)
359-
360-
config["workflow_failure_exception_types"] = list(
361-
config.get("workflow_failure_exception_types") or []
362-
) + [AgentsWorkflowError]
363-
return self.next_worker_plugin.configure_worker(config)
364-
365-
async def run_worker(self, worker: Worker) -> None:
366-
"""Run the worker with OpenAI agents temporal overrides.
367-
368-
This method sets up the necessary runtime overrides for OpenAI agents
369-
to work within the Temporal worker context, including custom runners
370-
and trace providers.
305+
def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner:
306+
if not runner:
307+
raise ValueError("No WorkflowRunner provided to the OpenAI plugin.")
371308

372-
Args:
373-
worker: The worker instance to run.
374-
"""
375-
with set_open_ai_agent_temporal_overrides(self._model_params):
376-
await self.next_worker_plugin.run_worker(worker)
377-
378-
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
379-
"""Configure the replayer for OpenAI Agents."""
380-
config["interceptors"] = list(config.get("interceptors") or []) + [
381-
OpenAIAgentsTracingInterceptor()
382-
]
383-
config["data_converter"] = self._data_converter(config.get("data_converter"))
384-
return self.next_worker_plugin.configure_replayer(config)
385-
386-
@asynccontextmanager
387-
async def run_replayer(
388-
self,
389-
replayer: Replayer,
390-
histories: AsyncIterator[temporalio.client.WorkflowHistory],
391-
) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]:
392-
"""Set the OpenAI Overrides during replay"""
393-
with set_open_ai_agent_temporal_overrides(self._model_params):
394-
async with self.next_worker_plugin.run_replayer(
395-
replayer, histories
396-
) as results:
397-
yield results
309+
# If in sandbox, add additional passthrough
310+
if isinstance(runner, SandboxedWorkflowRunner):
311+
return dataclasses.replace(
312+
runner,
313+
restrictions=runner.restrictions.with_passthrough_modules("mcp"),
314+
)
315+
return runner
316+
317+
@asynccontextmanager
318+
async def run_context() -> AsyncIterator[None]:
319+
with set_open_ai_agent_temporal_overrides(model_params):
320+
yield
321+
322+
super().__init__(
323+
name="OpenAIAgentsPlugin",
324+
data_converter=_data_converter,
325+
worker_interceptors=[OpenAIAgentsTracingInterceptor()],
326+
activities=add_activities,
327+
workflow_runner=workflow_runner,
328+
workflow_failure_exception_types=[AgentsWorkflowError],
329+
run_context=lambda: run_context(),
330+
)

0 commit comments

Comments
 (0)