2222from agents .run import get_default_agent_runner , set_default_agent_runner
2323from agents .tracing import get_trace_provider
2424from agents .tracing .provider import DefaultTraceProvider
25- from openai .types .responses import ResponsePromptParam
2625
27- import temporalio .client
28- import temporalio .worker
29- from temporalio .client import ClientConfig
3026from temporalio .contrib .openai_agents ._invoke_model_activity import ModelActivity
3127from temporalio .contrib .openai_agents ._model_parameters import ModelActivityParameters
3228from temporalio .contrib .openai_agents ._openai_runner import (
4743 DataConverter ,
4844 DefaultPayloadConverter ,
4945)
50- from temporalio .worker import (
51- Replayer ,
52- ReplayerConfig ,
53- Worker ,
54- WorkerConfig ,
55- WorkflowReplayResult ,
56- )
46+ from temporalio .plugin import SimplePlugin
47+ from temporalio .worker import WorkflowRunner
5748from temporalio .worker .workflow_sandbox import SandboxedWorkflowRunner
5849
5950# Unsupported on python 3.9
@@ -172,7 +163,21 @@ def __init__(self) -> None:
172163 super ().__init__ (ToJsonOptions (exclude_unset = True ))
173164
174165
175- class OpenAIAgentsPlugin (temporalio .client .Plugin , temporalio .worker .Plugin ):
166+ def _data_converter (converter : Optional [DataConverter ]) -> DataConverter :
167+ if converter is None :
168+ return DataConverter (payload_converter_class = OpenAIPayloadConverter )
169+ elif converter .payload_converter_class is DefaultPayloadConverter :
170+ return dataclasses .replace (
171+ converter , payload_converter_class = OpenAIPayloadConverter
172+ )
173+ elif not isinstance (converter .payload_converter , OpenAIPayloadConverter ):
174+ raise ValueError (
175+ "The payload converter must be of type OpenAIPayloadConverter."
176+ )
177+ return converter
178+
179+
180+ class OpenAIAgentsPlugin (SimplePlugin ):
176181 """Temporal plugin for integrating OpenAI agents with Temporal workflows.
177182
178183 .. warning::
@@ -245,6 +250,7 @@ def __init__(
245250 mcp_server_providers : Sequence [
246251 Union ["StatelessMCPServerProvider" , "StatefulMCPServerProvider" ]
247252 ] = (),
253+ register_activities : bool = True ,
248254 ) -> None :
249255 """Initialize the OpenAI agents plugin.
250256
@@ -257,6 +263,9 @@ def __init__(
257263 Each server will be wrapped in a TemporalMCPServer if not already wrapped,
258264 and their activities will be automatically registered with the worker.
259265 The plugin manages the connection lifecycle of these servers.
266+ register_activities: Whether to register activities during the worker execution.
267+ This can be disabled on some workers to allow a separation of workflows and activities
268+ but should not be disabled on all workers, or agents will not be able to progress.
260269 """
261270 if model_params is None :
262271 model_params = ModelActivityParameters ()
@@ -274,124 +283,48 @@ def __init__(
274283 "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout"
275284 )
276285
277- self ._model_params = model_params
278- self ._model_provider = model_provider
279- self ._mcp_server_providers = mcp_server_providers
280-
281- def init_client_plugin (self , next : temporalio .client .Plugin ) -> None :
282- """Set the next client plugin"""
283- self .next_client_plugin = next
284-
285- async def connect_service_client (
286- self , config : temporalio .service .ConnectConfig
287- ) -> temporalio .service .ServiceClient :
288- """No modifications to service client"""
289- return await self .next_client_plugin .connect_service_client (config )
290-
291- def init_worker_plugin (self , next : temporalio .worker .Plugin ) -> None :
292- """Set the next worker plugin"""
293- self .next_worker_plugin = next
294-
295- @staticmethod
296- def _data_converter (converter : Optional [DataConverter ]) -> DataConverter :
297- if converter is None :
298- return DataConverter (payload_converter_class = OpenAIPayloadConverter )
299- elif converter .payload_converter_class is DefaultPayloadConverter :
300- return dataclasses .replace (
301- converter , payload_converter_class = OpenAIPayloadConverter
302- )
303- elif not isinstance (converter .payload_converter , OpenAIPayloadConverter ):
304- raise ValueError (
305- "The payload converter must be of type OpenAIPayloadConverter."
306- )
307- return converter
308-
309- def configure_client (self , config : ClientConfig ) -> ClientConfig :
310- """Configure the Temporal client for OpenAI agents integration.
311-
312- This method sets up the Pydantic data converter to enable proper
313- serialization of OpenAI agent objects and responses.
314-
315- Args:
316- config: The client configuration to modify.
317-
318- Returns:
319- The modified client configuration.
320- """
321- config ["data_converter" ] = self ._data_converter (config ["data_converter" ])
322- return self .next_client_plugin .configure_client (config )
286+ # Delay activity construction until they are actually needed
287+ def add_activities (
288+ activities : Optional [Sequence [Callable ]],
289+ ) -> Sequence [Callable ]:
290+ if not register_activities :
291+ return activities or []
323292
324- def configure_worker (self , config : WorkerConfig ) -> WorkerConfig :
325- """Configure the Temporal worker for OpenAI agents integration.
293+ new_activities = [ModelActivity (model_provider ).invoke_model_activity ]
326294
327- This method adds the necessary interceptors and activities for OpenAI
328- agent execution:
329- - Adds tracing interceptors for OpenAI agent interactions
330- - Registers model execution activities
295+ server_names = [server .name for server in mcp_server_providers ]
296+ if len (server_names ) != len (set (server_names )):
297+ raise ValueError (
298+ f"More than one mcp server registered with the same name. Please provide unique names."
299+ )
331300
332- Args:
333- config: The worker configuration to modify.
301+ for mcp_server in mcp_server_providers :
302+ new_activities .extend (mcp_server ._get_activities ())
303+ return list (activities or []) + new_activities
334304
335- Returns:
336- The modified worker configuration.
337- """
338- config ["interceptors" ] = list (config .get ("interceptors" ) or []) + [
339- OpenAIAgentsTracingInterceptor ()
340- ]
341- new_activities = [ModelActivity (self ._model_provider ).invoke_model_activity ]
342-
343- server_names = [server .name for server in self ._mcp_server_providers ]
344- if len (server_names ) != len (set (server_names )):
345- raise ValueError (
346- f"More than one mcp server registered with the same name. Please provide unique names."
347- )
348-
349- for mcp_server in self ._mcp_server_providers :
350- new_activities .extend (mcp_server ._get_activities ())
351- config ["activities" ] = list (config .get ("activities" ) or []) + new_activities
352-
353- runner = config .get ("workflow_runner" )
354- if isinstance (runner , SandboxedWorkflowRunner ):
355- config ["workflow_runner" ] = dataclasses .replace (
356- runner ,
357- restrictions = runner .restrictions .with_passthrough_modules ("mcp" ),
358- )
359-
360- config ["workflow_failure_exception_types" ] = list (
361- config .get ("workflow_failure_exception_types" ) or []
362- ) + [AgentsWorkflowError ]
363- return self .next_worker_plugin .configure_worker (config )
364-
365- async def run_worker (self , worker : Worker ) -> None :
366- """Run the worker with OpenAI agents temporal overrides.
367-
368- This method sets up the necessary runtime overrides for OpenAI agents
369- to work within the Temporal worker context, including custom runners
370- and trace providers.
305+ def workflow_runner (runner : Optional [WorkflowRunner ]) -> WorkflowRunner :
306+ if not runner :
307+ raise ValueError ("No WorkflowRunner provided to the OpenAI plugin." )
371308
372- Args:
373- worker: The worker instance to run.
374- """
375- with set_open_ai_agent_temporal_overrides (self ._model_params ):
376- await self .next_worker_plugin .run_worker (worker )
377-
378- def configure_replayer (self , config : ReplayerConfig ) -> ReplayerConfig :
379- """Configure the replayer for OpenAI Agents."""
380- config ["interceptors" ] = list (config .get ("interceptors" ) or []) + [
381- OpenAIAgentsTracingInterceptor ()
382- ]
383- config ["data_converter" ] = self ._data_converter (config .get ("data_converter" ))
384- return self .next_worker_plugin .configure_replayer (config )
385-
386- @asynccontextmanager
387- async def run_replayer (
388- self ,
389- replayer : Replayer ,
390- histories : AsyncIterator [temporalio .client .WorkflowHistory ],
391- ) -> AsyncIterator [AsyncIterator [WorkflowReplayResult ]]:
392- """Set the OpenAI Overrides during replay"""
393- with set_open_ai_agent_temporal_overrides (self ._model_params ):
394- async with self .next_worker_plugin .run_replayer (
395- replayer , histories
396- ) as results :
397- yield results
309+ # If in sandbox, add additional passthrough
310+ if isinstance (runner , SandboxedWorkflowRunner ):
311+ return dataclasses .replace (
312+ runner ,
313+ restrictions = runner .restrictions .with_passthrough_modules ("mcp" ),
314+ )
315+ return runner
316+
317+ @asynccontextmanager
318+ async def run_context () -> AsyncIterator [None ]:
319+ with set_open_ai_agent_temporal_overrides (model_params ):
320+ yield
321+
322+ super ().__init__ (
323+ name = "OpenAIAgentsPlugin" ,
324+ data_converter = _data_converter ,
325+ worker_interceptors = [OpenAIAgentsTracingInterceptor ()],
326+ activities = add_activities ,
327+ workflow_runner = workflow_runner ,
328+ workflow_failure_exception_types = [AgentsWorkflowError ],
329+ run_context = lambda : run_context (),
330+ )
0 commit comments