|
47 | 47 | DataConverter, |
48 | 48 | DefaultPayloadConverter, |
49 | 49 | ) |
| 50 | +from temporalio.plugin import Plugin, create_plugin |
50 | 51 | from temporalio.worker import ( |
51 | 52 | Replayer, |
52 | 53 | ReplayerConfig, |
53 | 54 | Worker, |
54 | 55 | WorkerConfig, |
55 | 56 | WorkflowReplayResult, |
| 57 | + WorkflowRunner, |
56 | 58 | ) |
57 | 59 | from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner |
58 | 60 |
|
@@ -172,226 +174,88 @@ def __init__(self) -> None: |
172 | 174 | super().__init__(ToJsonOptions(exclude_unset=True)) |
173 | 175 |
|
174 | 176 |
|
175 | | -class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): |
176 | | - """Temporal plugin for integrating OpenAI agents with Temporal workflows. |
177 | | -
|
178 | | - .. warning:: |
179 | | - This class is experimental and may change in future versions. |
180 | | - Use with caution in production environments. |
181 | | -
|
182 | | - This plugin provides seamless integration between the OpenAI Agents SDK and |
183 | | - Temporal workflows. It automatically configures the necessary interceptors, |
184 | | - activities, and data converters to enable OpenAI agents to run within |
185 | | - Temporal workflows with proper tracing and model execution. |
186 | | -
|
187 | | - The plugin: |
188 | | - 1. Configures the Pydantic data converter for type-safe serialization |
189 | | - 2. Sets up tracing interceptors for OpenAI agent interactions |
190 | | - 3. Registers model execution activities |
191 | | - 4. Automatically registers MCP server activities and manages their lifecycles |
192 | | - 5. Manages the OpenAI agent runtime overrides during worker execution |
| 177 | +def _data_converter(converter: Optional[DataConverter]) -> DataConverter: |
| 178 | + if converter is None: |
| 179 | + return DataConverter(payload_converter_class=OpenAIPayloadConverter) |
| 180 | + elif converter.payload_converter_class is DefaultPayloadConverter: |
| 181 | + return dataclasses.replace( |
| 182 | + converter, payload_converter_class=OpenAIPayloadConverter |
| 183 | + ) |
| 184 | + elif not isinstance(converter.payload_converter, OpenAIPayloadConverter): |
| 185 | + raise ValueError( |
| 186 | + "The payload converter must be of type OpenAIPayloadConverter." |
| 187 | + ) |
| 188 | + return converter |
| 189 | + |
| 190 | + |
| 191 | +def OpenAIAgentsPlugin( |
| 192 | + model_params: Optional[ModelActivityParameters] = None, |
| 193 | + model_provider: Optional[ModelProvider] = None, |
| 194 | + mcp_server_providers: Sequence[ |
| 195 | + Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] |
| 196 | + ] = (), |
| 197 | +) -> Plugin: |
| 198 | + """Create an OpenAI agents plugin. |
193 | 199 |
|
194 | 200 | Args: |
195 | 201 | model_params: Configuration parameters for Temporal activity execution |
196 | 202 | of model calls. If None, default parameters will be used. |
197 | 203 | model_provider: Optional model provider for custom model implementations. |
198 | 204 | Useful for testing or custom model integrations. |
199 | 205 | mcp_server_providers: Sequence of MCP servers to automatically register with the worker. |
200 | | - The plugin will wrap each server in a TemporalMCPServer if needed and |
201 | | - manage their connection lifecycles tied to the worker lifetime. This is |
202 | | - the recommended way to use MCP servers with Temporal workflows. |
203 | | -
|
204 | | - Example: |
205 | | - >>> from temporalio.client import Client |
206 | | - >>> from temporalio.worker import Worker |
207 | | - >>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters, StatelessMCPServerProvider |
208 | | - >>> from agents.mcp import MCPServerStdio |
209 | | - >>> from datetime import timedelta |
210 | | - >>> |
211 | | - >>> # Configure model parameters |
212 | | - >>> model_params = ModelActivityParameters( |
213 | | - ... start_to_close_timeout=timedelta(seconds=30), |
214 | | - ... retry_policy=RetryPolicy(maximum_attempts=3) |
215 | | - ... ) |
216 | | - >>> |
217 | | - >>> # Create MCP servers |
218 | | - >>> filesystem_server = StatelessMCPServerProvider(MCPServerStdio( |
219 | | - ... name="Filesystem Server", |
220 | | - ... params={"command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "."]} |
221 | | - ... )) |
222 | | - >>> |
223 | | - >>> # Create plugin with MCP servers |
224 | | - >>> plugin = OpenAIAgentsPlugin( |
225 | | - ... model_params=model_params, |
226 | | - ... mcp_server_providers=[filesystem_server] |
227 | | - ... ) |
228 | | - >>> |
229 | | - >>> # Use with client and worker |
230 | | - >>> client = await Client.connect( |
231 | | - ... "localhost:7233", |
232 | | - ... plugins=[plugin] |
233 | | - ... ) |
234 | | - >>> worker = Worker( |
235 | | - ... client, |
236 | | - ... task_queue="my-task-queue", |
237 | | - ... workflows=[MyWorkflow], |
238 | | - ... ) |
| 206 | + Each server will be wrapped in a TemporalMCPServer if not already wrapped, |
| 207 | + and their activities will be automatically registered with the worker. |
| 208 | + The plugin manages the connection lifecycle of these servers. |
239 | 209 | """ |
240 | | - |
241 | | - def __init__( |
242 | | - self, |
243 | | - model_params: Optional[ModelActivityParameters] = None, |
244 | | - model_provider: Optional[ModelProvider] = None, |
245 | | - mcp_server_providers: Sequence[ |
246 | | - Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] |
247 | | - ] = (), |
248 | | - ) -> None: |
249 | | - """Initialize the OpenAI agents plugin. |
250 | | -
|
251 | | - Args: |
252 | | - model_params: Configuration parameters for Temporal activity execution |
253 | | - of model calls. If None, default parameters will be used. |
254 | | - model_provider: Optional model provider for custom model implementations. |
255 | | - Useful for testing or custom model integrations. |
256 | | - mcp_server_providers: Sequence of MCP servers to automatically register with the worker. |
257 | | - Each server will be wrapped in a TemporalMCPServer if not already wrapped, |
258 | | - and their activities will be automatically registered with the worker. |
259 | | - The plugin manages the connection lifecycle of these servers. |
260 | | - """ |
261 | | - if model_params is None: |
262 | | - model_params = ModelActivityParameters() |
263 | | - |
264 | | - # For the default provider, we provide a default start_to_close_timeout of 60 seconds. |
265 | | - # Other providers will need to define their own. |
266 | | - if ( |
267 | | - model_params.start_to_close_timeout is None |
268 | | - and model_params.schedule_to_close_timeout is None |
269 | | - ): |
270 | | - if model_provider is None: |
271 | | - model_params.start_to_close_timeout = timedelta(seconds=60) |
272 | | - else: |
273 | | - raise ValueError( |
274 | | - "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" |
275 | | - ) |
276 | | - |
277 | | - self._model_params = model_params |
278 | | - self._model_provider = model_provider |
279 | | - self._mcp_server_providers = mcp_server_providers |
280 | | - |
281 | | - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: |
282 | | - """Set the next client plugin""" |
283 | | - self.next_client_plugin = next |
284 | | - |
285 | | - async def connect_service_client( |
286 | | - self, config: temporalio.service.ConnectConfig |
287 | | - ) -> temporalio.service.ServiceClient: |
288 | | - """No modifications to service client""" |
289 | | - return await self.next_client_plugin.connect_service_client(config) |
290 | | - |
291 | | - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: |
292 | | - """Set the next worker plugin""" |
293 | | - self.next_worker_plugin = next |
294 | | - |
295 | | - @staticmethod |
296 | | - def _data_converter(converter: Optional[DataConverter]) -> DataConverter: |
297 | | - if converter is None: |
298 | | - return DataConverter(payload_converter_class=OpenAIPayloadConverter) |
299 | | - elif converter.payload_converter_class is DefaultPayloadConverter: |
300 | | - return dataclasses.replace( |
301 | | - converter, payload_converter_class=OpenAIPayloadConverter |
302 | | - ) |
303 | | - elif not isinstance(converter.payload_converter, OpenAIPayloadConverter): |
| 210 | + if model_params is None: |
| 211 | + model_params = ModelActivityParameters() |
| 212 | + |
| 213 | + # For the default provider, we provide a default start_to_close_timeout of 60 seconds. |
| 214 | + # Other providers will need to define their own. |
| 215 | + if ( |
| 216 | + model_params.start_to_close_timeout is None |
| 217 | + and model_params.schedule_to_close_timeout is None |
| 218 | + ): |
| 219 | + if model_provider is None: |
| 220 | + model_params.start_to_close_timeout = timedelta(seconds=60) |
| 221 | + else: |
304 | 222 | raise ValueError( |
305 | | - "The payload converter must be of type OpenAIPayloadConverter." |
| 223 | + "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" |
306 | 224 | ) |
307 | | - return converter |
308 | | - |
309 | | - def configure_client(self, config: ClientConfig) -> ClientConfig: |
310 | | - """Configure the Temporal client for OpenAI agents integration. |
311 | | -
|
312 | | - This method sets up the Pydantic data converter to enable proper |
313 | | - serialization of OpenAI agent objects and responses. |
314 | | -
|
315 | | - Args: |
316 | | - config: The client configuration to modify. |
317 | | -
|
318 | | - Returns: |
319 | | - The modified client configuration. |
320 | | - """ |
321 | | - config["data_converter"] = self._data_converter(config["data_converter"]) |
322 | | - return self.next_client_plugin.configure_client(config) |
323 | 225 |
|
324 | | - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: |
325 | | - """Configure the Temporal worker for OpenAI agents integration. |
| 226 | + new_activities = [ModelActivity(model_provider).invoke_model_activity] |
326 | 227 |
|
327 | | - This method adds the necessary interceptors and activities for OpenAI |
328 | | - agent execution: |
329 | | - - Adds tracing interceptors for OpenAI agent interactions |
330 | | - - Registers model execution activities |
| 228 | + server_names = [server.name for server in mcp_server_providers] |
| 229 | + if len(server_names) != len(set(server_names)): |
| 230 | + raise ValueError( |
| 231 | + f"More than one mcp server registered with the same name. Please provide unique names." |
| 232 | + ) |
331 | 233 |
|
332 | | - Args: |
333 | | - config: The worker configuration to modify. |
| 234 | + for mcp_server in mcp_server_providers: |
| 235 | + new_activities.extend(mcp_server._get_activities()) |
334 | 236 |
|
335 | | - Returns: |
336 | | - The modified worker configuration. |
337 | | - """ |
338 | | - config["interceptors"] = list(config.get("interceptors") or []) + [ |
339 | | - OpenAIAgentsTracingInterceptor() |
340 | | - ] |
341 | | - new_activities = [ModelActivity(self._model_provider).invoke_model_activity] |
| 237 | + def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner: |
| 238 | + if not runner: |
| 239 | + raise ValueError("No WorkflowRunner provided to the OpenAI plugin.") |
342 | 240 |
|
343 | | - server_names = [server.name for server in self._mcp_server_providers] |
344 | | - if len(server_names) != len(set(server_names)): |
345 | | - raise ValueError( |
346 | | - f"More than one mcp server registered with the same name. Please provide unique names." |
347 | | - ) |
348 | | - |
349 | | - for mcp_server in self._mcp_server_providers: |
350 | | - new_activities.extend(mcp_server._get_activities()) |
351 | | - config["activities"] = list(config.get("activities") or []) + new_activities |
352 | | - |
353 | | - runner = config.get("workflow_runner") |
| 241 | + # If in sandbox, add additional passthrough |
354 | 242 | if isinstance(runner, SandboxedWorkflowRunner): |
355 | | - config["workflow_runner"] = dataclasses.replace( |
| 243 | + return dataclasses.replace( |
356 | 244 | runner, |
357 | 245 | restrictions=runner.restrictions.with_passthrough_modules("mcp"), |
358 | 246 | ) |
359 | | - |
360 | | - config["workflow_failure_exception_types"] = list( |
361 | | - config.get("workflow_failure_exception_types") or [] |
362 | | - ) + [AgentsWorkflowError] |
363 | | - return self.next_worker_plugin.configure_worker(config) |
364 | | - |
365 | | - async def run_worker(self, worker: Worker) -> None: |
366 | | - """Run the worker with OpenAI agents temporal overrides. |
367 | | -
|
368 | | - This method sets up the necessary runtime overrides for OpenAI agents |
369 | | - to work within the Temporal worker context, including custom runners |
370 | | - and trace providers. |
371 | | -
|
372 | | - Args: |
373 | | - worker: The worker instance to run. |
374 | | - """ |
375 | | - with set_open_ai_agent_temporal_overrides(self._model_params): |
376 | | - await self.next_worker_plugin.run_worker(worker) |
377 | | - |
378 | | - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: |
379 | | - """Configure the replayer for OpenAI Agents.""" |
380 | | - config["interceptors"] = list(config.get("interceptors") or []) + [ |
381 | | - OpenAIAgentsTracingInterceptor() |
382 | | - ] |
383 | | - config["data_converter"] = self._data_converter(config.get("data_converter")) |
384 | | - return self.next_worker_plugin.configure_replayer(config) |
| 247 | + return runner |
385 | 248 |
|
386 | 249 | @asynccontextmanager |
387 | | - async def run_replayer( |
388 | | - self, |
389 | | - replayer: Replayer, |
390 | | - histories: AsyncIterator[temporalio.client.WorkflowHistory], |
391 | | - ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: |
392 | | - """Set the OpenAI Overrides during replay""" |
393 | | - with set_open_ai_agent_temporal_overrides(self._model_params): |
394 | | - async with self.next_worker_plugin.run_replayer( |
395 | | - replayer, histories |
396 | | - ) as results: |
397 | | - yield results |
| 250 | + async def run_context() -> AsyncIterator[None]: |
| 251 | + with set_open_ai_agent_temporal_overrides(model_params): |
| 252 | + yield |
| 253 | + |
| 254 | + return create_plugin( |
| 255 | + data_converter=_data_converter, |
| 256 | + worker_interceptors=[OpenAIAgentsTracingInterceptor()], |
| 257 | + activities=new_activities, |
| 258 | + workflow_runner=workflow_runner, |
| 259 | + workflow_failure_exception_types=[AgentsWorkflowError], |
| 260 | + run_context=lambda: run_context(), |
| 261 | + ) |
0 commit comments