Skip to content

Commit aeac7e2

Browse files
committed
Convert OpenAI plugin as example
1 parent 60eb145 commit aeac7e2

File tree

3 files changed

+109
-219
lines changed

3 files changed

+109
-219
lines changed

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 66 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,14 @@
4747
DataConverter,
4848
DefaultPayloadConverter,
4949
)
50+
from temporalio.plugin import Plugin, create_plugin
5051
from temporalio.worker import (
5152
Replayer,
5253
ReplayerConfig,
5354
Worker,
5455
WorkerConfig,
5556
WorkflowReplayResult,
57+
WorkflowRunner,
5658
)
5759
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
5860

@@ -172,226 +174,88 @@ def __init__(self) -> None:
172174
super().__init__(ToJsonOptions(exclude_unset=True))
173175

174176

175-
class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
176-
"""Temporal plugin for integrating OpenAI agents with Temporal workflows.
177-
178-
.. warning::
179-
This class is experimental and may change in future versions.
180-
Use with caution in production environments.
181-
182-
This plugin provides seamless integration between the OpenAI Agents SDK and
183-
Temporal workflows. It automatically configures the necessary interceptors,
184-
activities, and data converters to enable OpenAI agents to run within
185-
Temporal workflows with proper tracing and model execution.
186-
187-
The plugin:
188-
1. Configures the Pydantic data converter for type-safe serialization
189-
2. Sets up tracing interceptors for OpenAI agent interactions
190-
3. Registers model execution activities
191-
4. Automatically registers MCP server activities and manages their lifecycles
192-
5. Manages the OpenAI agent runtime overrides during worker execution
177+
def _data_converter(converter: Optional[DataConverter]) -> DataConverter:
178+
if converter is None:
179+
return DataConverter(payload_converter_class=OpenAIPayloadConverter)
180+
elif converter.payload_converter_class is DefaultPayloadConverter:
181+
return dataclasses.replace(
182+
converter, payload_converter_class=OpenAIPayloadConverter
183+
)
184+
elif not isinstance(converter.payload_converter, OpenAIPayloadConverter):
185+
raise ValueError(
186+
"The payload converter must be of type OpenAIPayloadConverter."
187+
)
188+
return converter
189+
190+
191+
def OpenAIAgentsPlugin(
192+
model_params: Optional[ModelActivityParameters] = None,
193+
model_provider: Optional[ModelProvider] = None,
194+
mcp_server_providers: Sequence[
195+
Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"]
196+
] = (),
197+
) -> Plugin:
198+
"""Create an OpenAI agents plugin.
193199
194200
Args:
195201
model_params: Configuration parameters for Temporal activity execution
196202
of model calls. If None, default parameters will be used.
197203
model_provider: Optional model provider for custom model implementations.
198204
Useful for testing or custom model integrations.
199205
mcp_server_providers: Sequence of MCP servers to automatically register with the worker.
200-
The plugin will wrap each server in a TemporalMCPServer if needed and
201-
manage their connection lifecycles tied to the worker lifetime. This is
202-
the recommended way to use MCP servers with Temporal workflows.
203-
204-
Example:
205-
>>> from temporalio.client import Client
206-
>>> from temporalio.worker import Worker
207-
>>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters, StatelessMCPServerProvider
208-
>>> from agents.mcp import MCPServerStdio
209-
>>> from datetime import timedelta
210-
>>>
211-
>>> # Configure model parameters
212-
>>> model_params = ModelActivityParameters(
213-
... start_to_close_timeout=timedelta(seconds=30),
214-
... retry_policy=RetryPolicy(maximum_attempts=3)
215-
... )
216-
>>>
217-
>>> # Create MCP servers
218-
>>> filesystem_server = StatelessMCPServerProvider(MCPServerStdio(
219-
... name="Filesystem Server",
220-
... params={"command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "."]}
221-
... ))
222-
>>>
223-
>>> # Create plugin with MCP servers
224-
>>> plugin = OpenAIAgentsPlugin(
225-
... model_params=model_params,
226-
... mcp_server_providers=[filesystem_server]
227-
... )
228-
>>>
229-
>>> # Use with client and worker
230-
>>> client = await Client.connect(
231-
... "localhost:7233",
232-
... plugins=[plugin]
233-
... )
234-
>>> worker = Worker(
235-
... client,
236-
... task_queue="my-task-queue",
237-
... workflows=[MyWorkflow],
238-
... )
206+
Each server will be wrapped in a TemporalMCPServer if not already wrapped,
207+
and their activities will be automatically registered with the worker.
208+
The plugin manages the connection lifecycle of these servers.
239209
"""
240-
241-
def __init__(
242-
self,
243-
model_params: Optional[ModelActivityParameters] = None,
244-
model_provider: Optional[ModelProvider] = None,
245-
mcp_server_providers: Sequence[
246-
Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"]
247-
] = (),
248-
) -> None:
249-
"""Initialize the OpenAI agents plugin.
250-
251-
Args:
252-
model_params: Configuration parameters for Temporal activity execution
253-
of model calls. If None, default parameters will be used.
254-
model_provider: Optional model provider for custom model implementations.
255-
Useful for testing or custom model integrations.
256-
mcp_server_providers: Sequence of MCP servers to automatically register with the worker.
257-
Each server will be wrapped in a TemporalMCPServer if not already wrapped,
258-
and their activities will be automatically registered with the worker.
259-
The plugin manages the connection lifecycle of these servers.
260-
"""
261-
if model_params is None:
262-
model_params = ModelActivityParameters()
263-
264-
# For the default provider, we provide a default start_to_close_timeout of 60 seconds.
265-
# Other providers will need to define their own.
266-
if (
267-
model_params.start_to_close_timeout is None
268-
and model_params.schedule_to_close_timeout is None
269-
):
270-
if model_provider is None:
271-
model_params.start_to_close_timeout = timedelta(seconds=60)
272-
else:
273-
raise ValueError(
274-
"When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout"
275-
)
276-
277-
self._model_params = model_params
278-
self._model_provider = model_provider
279-
self._mcp_server_providers = mcp_server_providers
280-
281-
def init_client_plugin(self, next: temporalio.client.Plugin) -> None:
282-
"""Set the next client plugin"""
283-
self.next_client_plugin = next
284-
285-
async def connect_service_client(
286-
self, config: temporalio.service.ConnectConfig
287-
) -> temporalio.service.ServiceClient:
288-
"""No modifications to service client"""
289-
return await self.next_client_plugin.connect_service_client(config)
290-
291-
def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None:
292-
"""Set the next worker plugin"""
293-
self.next_worker_plugin = next
294-
295-
@staticmethod
296-
def _data_converter(converter: Optional[DataConverter]) -> DataConverter:
297-
if converter is None:
298-
return DataConverter(payload_converter_class=OpenAIPayloadConverter)
299-
elif converter.payload_converter_class is DefaultPayloadConverter:
300-
return dataclasses.replace(
301-
converter, payload_converter_class=OpenAIPayloadConverter
302-
)
303-
elif not isinstance(converter.payload_converter, OpenAIPayloadConverter):
210+
if model_params is None:
211+
model_params = ModelActivityParameters()
212+
213+
# For the default provider, we provide a default start_to_close_timeout of 60 seconds.
214+
# Other providers will need to define their own.
215+
if (
216+
model_params.start_to_close_timeout is None
217+
and model_params.schedule_to_close_timeout is None
218+
):
219+
if model_provider is None:
220+
model_params.start_to_close_timeout = timedelta(seconds=60)
221+
else:
304222
raise ValueError(
305-
"The payload converter must be of type OpenAIPayloadConverter."
223+
"When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout"
306224
)
307-
return converter
308-
309-
def configure_client(self, config: ClientConfig) -> ClientConfig:
310-
"""Configure the Temporal client for OpenAI agents integration.
311-
312-
This method sets up the Pydantic data converter to enable proper
313-
serialization of OpenAI agent objects and responses.
314-
315-
Args:
316-
config: The client configuration to modify.
317-
318-
Returns:
319-
The modified client configuration.
320-
"""
321-
config["data_converter"] = self._data_converter(config["data_converter"])
322-
return self.next_client_plugin.configure_client(config)
323225

324-
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
325-
"""Configure the Temporal worker for OpenAI agents integration.
226+
new_activities = [ModelActivity(model_provider).invoke_model_activity]
326227

327-
This method adds the necessary interceptors and activities for OpenAI
328-
agent execution:
329-
- Adds tracing interceptors for OpenAI agent interactions
330-
- Registers model execution activities
228+
server_names = [server.name for server in mcp_server_providers]
229+
if len(server_names) != len(set(server_names)):
230+
raise ValueError(
231+
f"More than one mcp server registered with the same name. Please provide unique names."
232+
)
331233

332-
Args:
333-
config: The worker configuration to modify.
234+
for mcp_server in mcp_server_providers:
235+
new_activities.extend(mcp_server._get_activities())
334236

335-
Returns:
336-
The modified worker configuration.
337-
"""
338-
config["interceptors"] = list(config.get("interceptors") or []) + [
339-
OpenAIAgentsTracingInterceptor()
340-
]
341-
new_activities = [ModelActivity(self._model_provider).invoke_model_activity]
237+
def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner:
238+
if not runner:
239+
raise ValueError("No WorkflowRunner provided to the OpenAI plugin.")
342240

343-
server_names = [server.name for server in self._mcp_server_providers]
344-
if len(server_names) != len(set(server_names)):
345-
raise ValueError(
346-
f"More than one mcp server registered with the same name. Please provide unique names."
347-
)
348-
349-
for mcp_server in self._mcp_server_providers:
350-
new_activities.extend(mcp_server._get_activities())
351-
config["activities"] = list(config.get("activities") or []) + new_activities
352-
353-
runner = config.get("workflow_runner")
241+
# If in sandbox, add additional passthrough
354242
if isinstance(runner, SandboxedWorkflowRunner):
355-
config["workflow_runner"] = dataclasses.replace(
243+
return dataclasses.replace(
356244
runner,
357245
restrictions=runner.restrictions.with_passthrough_modules("mcp"),
358246
)
359-
360-
config["workflow_failure_exception_types"] = list(
361-
config.get("workflow_failure_exception_types") or []
362-
) + [AgentsWorkflowError]
363-
return self.next_worker_plugin.configure_worker(config)
364-
365-
async def run_worker(self, worker: Worker) -> None:
366-
"""Run the worker with OpenAI agents temporal overrides.
367-
368-
This method sets up the necessary runtime overrides for OpenAI agents
369-
to work within the Temporal worker context, including custom runners
370-
and trace providers.
371-
372-
Args:
373-
worker: The worker instance to run.
374-
"""
375-
with set_open_ai_agent_temporal_overrides(self._model_params):
376-
await self.next_worker_plugin.run_worker(worker)
377-
378-
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
379-
"""Configure the replayer for OpenAI Agents."""
380-
config["interceptors"] = list(config.get("interceptors") or []) + [
381-
OpenAIAgentsTracingInterceptor()
382-
]
383-
config["data_converter"] = self._data_converter(config.get("data_converter"))
384-
return self.next_worker_plugin.configure_replayer(config)
247+
return runner
385248

386249
@asynccontextmanager
387-
async def run_replayer(
388-
self,
389-
replayer: Replayer,
390-
histories: AsyncIterator[temporalio.client.WorkflowHistory],
391-
) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]:
392-
"""Set the OpenAI Overrides during replay"""
393-
with set_open_ai_agent_temporal_overrides(self._model_params):
394-
async with self.next_worker_plugin.run_replayer(
395-
replayer, histories
396-
) as results:
397-
yield results
250+
async def run_context() -> AsyncIterator[None]:
251+
with set_open_ai_agent_temporal_overrides(model_params):
252+
yield
253+
254+
return create_plugin(
255+
data_converter=_data_converter,
256+
worker_interceptors=[OpenAIAgentsTracingInterceptor()],
257+
activities=new_activities,
258+
workflow_runner=workflow_runner,
259+
workflow_failure_exception_types=[AgentsWorkflowError],
260+
run_context=lambda: run_context(),
261+
)

0 commit comments

Comments
 (0)