|
22 | 22 | from agents.run import get_default_agent_runner, set_default_agent_runner |
23 | 23 | from agents.tracing import get_trace_provider |
24 | 24 | from agents.tracing.provider import DefaultTraceProvider |
25 | | -from openai.types.responses import ResponsePromptParam |
26 | 25 |
|
27 | | -import temporalio.client |
28 | | -import temporalio.worker |
29 | | -from temporalio.client import ClientConfig |
30 | 26 | from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity |
31 | 27 | from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters |
32 | 28 | from temporalio.contrib.openai_agents._openai_runner import ( |
|
47 | 43 | DataConverter, |
48 | 44 | DefaultPayloadConverter, |
49 | 45 | ) |
50 | | -from temporalio.plugin import Plugin, create_plugin |
51 | | -from temporalio.worker import ( |
52 | | - Replayer, |
53 | | - ReplayerConfig, |
54 | | - Worker, |
55 | | - WorkerConfig, |
56 | | - WorkflowReplayResult, |
57 | | - WorkflowRunner, |
58 | | -) |
| 46 | +from temporalio.plugin import SimplePlugin |
| 47 | +from temporalio.worker import WorkflowRunner |
59 | 48 | from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner |
60 | 49 |
|
61 | 50 | # Unsupported on python 3.9 |
@@ -188,75 +177,142 @@ def _data_converter(converter: Optional[DataConverter]) -> DataConverter: |
188 | 177 | return converter |
189 | 178 |
|
190 | 179 |
|
191 | | -def OpenAIAgentsPlugin( |
192 | | - model_params: Optional[ModelActivityParameters] = None, |
193 | | - model_provider: Optional[ModelProvider] = None, |
194 | | - mcp_server_providers: Sequence[ |
195 | | - Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] |
196 | | - ] = (), |
197 | | -) -> Plugin: |
198 | | - """Create an OpenAI agents plugin. |
| 180 | +class OpenAIAgentsPlugin(SimplePlugin): |
| 181 | + """Temporal plugin for integrating OpenAI agents with Temporal workflows. |
| 182 | +
|
| 183 | + .. warning:: |
| 184 | + This class is experimental and may change in future versions. |
| 185 | + Use with caution in production environments. |
| 186 | +
|
| 187 | + This plugin provides seamless integration between the OpenAI Agents SDK and |
| 188 | + Temporal workflows. It automatically configures the necessary interceptors, |
| 189 | + activities, and data converters to enable OpenAI agents to run within |
| 190 | + Temporal workflows with proper tracing and model execution. |
| 191 | +
|
| 192 | + The plugin: |
| 193 | + 1. Configures the Pydantic data converter for type-safe serialization |
| 194 | + 2. Sets up tracing interceptors for OpenAI agent interactions |
| 195 | + 3. Registers model execution activities |
| 196 | + 4. Automatically registers MCP server activities and manages their lifecycles |
| 197 | + 5. Manages the OpenAI agent runtime overrides during worker execution |
199 | 198 |
|
200 | 199 | Args: |
201 | 200 | model_params: Configuration parameters for Temporal activity execution |
202 | 201 | of model calls. If None, default parameters will be used. |
203 | 202 | model_provider: Optional model provider for custom model implementations. |
204 | 203 | Useful for testing or custom model integrations. |
205 | 204 | mcp_server_providers: Sequence of MCP servers to automatically register with the worker. |
206 | | - Each server will be wrapped in a TemporalMCPServer if not already wrapped, |
207 | | - and their activities will be automatically registered with the worker. |
208 | | - The plugin manages the connection lifecycle of these servers. |
| 205 | + The plugin will wrap each server in a TemporalMCPServer if needed and |
| 206 | + manage their connection lifecycles tied to the worker lifetime. This is |
| 207 | + the recommended way to use MCP servers with Temporal workflows. |
| 208 | +
|
| 209 | + Example: |
| 210 | + >>> from temporalio.client import Client |
| 211 | + >>> from temporalio.worker import Worker |
| 212 | + >>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters, StatelessMCPServerProvider |
| 213 | + >>> from agents.mcp import MCPServerStdio |
| 214 | + >>> from datetime import timedelta |
| 215 | + >>> |
| 216 | + >>> # Configure model parameters |
| 217 | + >>> model_params = ModelActivityParameters( |
| 218 | + ... start_to_close_timeout=timedelta(seconds=30), |
| 219 | + ... retry_policy=RetryPolicy(maximum_attempts=3) |
| 220 | + ... ) |
| 221 | + >>> |
| 222 | + >>> # Create MCP servers |
| 223 | + >>> filesystem_server = StatelessMCPServerProvider(MCPServerStdio( |
| 224 | + ... name="Filesystem Server", |
| 225 | + ... params={"command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "."]} |
| 226 | + ... )) |
| 227 | + >>> |
| 228 | + >>> # Create plugin with MCP servers |
| 229 | + >>> plugin = OpenAIAgentsPlugin( |
| 230 | + ... model_params=model_params, |
| 231 | + ... mcp_server_providers=[filesystem_server] |
| 232 | + ... ) |
| 233 | + >>> |
| 234 | + >>> # Use with client and worker |
| 235 | + >>> client = await Client.connect( |
| 236 | + ... "localhost:7233", |
| 237 | + ... plugins=[plugin] |
| 238 | + ... ) |
| 239 | + >>> worker = Worker( |
| 240 | + ... client, |
| 241 | + ... task_queue="my-task-queue", |
| 242 | + ... workflows=[MyWorkflow], |
| 243 | + ... ) |
209 | 244 | """ |
210 | | - if model_params is None: |
211 | | - model_params = ModelActivityParameters() |
212 | | - |
213 | | - # For the default provider, we provide a default start_to_close_timeout of 60 seconds. |
214 | | - # Other providers will need to define their own. |
215 | | - if ( |
216 | | - model_params.start_to_close_timeout is None |
217 | | - and model_params.schedule_to_close_timeout is None |
| 245 | + |
| 246 | + def __init__( |
| 247 | + self, |
| 248 | + model_params: Optional[ModelActivityParameters] = None, |
| 249 | + model_provider: Optional[ModelProvider] = None, |
| 250 | + mcp_server_providers: Sequence[ |
| 251 | + Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] |
| 252 | + ] = (), |
218 | 253 | ): |
219 | | - if model_provider is None: |
220 | | - model_params.start_to_close_timeout = timedelta(seconds=60) |
221 | | - else: |
| 254 | + """Create an OpenAI agents plugin. |
| 255 | +
|
| 256 | + Args: |
| 257 | + model_params: Configuration parameters for Temporal activity execution |
| 258 | + of model calls. If None, default parameters will be used. |
| 259 | + model_provider: Optional model provider for custom model implementations. |
| 260 | + Useful for testing or custom model integrations. |
| 261 | + mcp_server_providers: Sequence of MCP servers to automatically register with the worker. |
| 262 | + Each server will be wrapped in a TemporalMCPServer if not already wrapped, |
| 263 | + and their activities will be automatically registered with the worker. |
| 264 | + The plugin manages the connection lifecycle of these servers. |
| 265 | + """ |
| 266 | + if model_params is None: |
| 267 | + model_params = ModelActivityParameters() |
| 268 | + |
| 269 | + # For the default provider, we provide a default start_to_close_timeout of 60 seconds. |
| 270 | + # Other providers will need to define their own. |
| 271 | + if ( |
| 272 | + model_params.start_to_close_timeout is None |
| 273 | + and model_params.schedule_to_close_timeout is None |
| 274 | + ): |
| 275 | + if model_provider is None: |
| 276 | + model_params.start_to_close_timeout = timedelta(seconds=60) |
| 277 | + else: |
| 278 | + raise ValueError( |
| 279 | + "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" |
| 280 | + ) |
| 281 | + |
| 282 | + new_activities = [ModelActivity(model_provider).invoke_model_activity] |
| 283 | + |
| 284 | + server_names = [server.name for server in mcp_server_providers] |
| 285 | + if len(server_names) != len(set(server_names)): |
222 | 286 | raise ValueError( |
223 | | - "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" |
| 287 | + f"More than one mcp server registered with the same name. Please provide unique names." |
224 | 288 | ) |
225 | 289 |
|
226 | | - new_activities = [ModelActivity(model_provider).invoke_model_activity] |
227 | | - |
228 | | - server_names = [server.name for server in mcp_server_providers] |
229 | | - if len(server_names) != len(set(server_names)): |
230 | | - raise ValueError( |
231 | | - f"More than one mcp server registered with the same name. Please provide unique names." |
| 290 | + for mcp_server in mcp_server_providers: |
| 291 | + new_activities.extend(mcp_server._get_activities()) |
| 292 | + |
| 293 | + def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner: |
| 294 | + if not runner: |
| 295 | + raise ValueError("No WorkflowRunner provided to the OpenAI plugin.") |
| 296 | + |
| 297 | + # If in sandbox, add additional passthrough |
| 298 | + if isinstance(runner, SandboxedWorkflowRunner): |
| 299 | + return dataclasses.replace( |
| 300 | + runner, |
| 301 | + restrictions=runner.restrictions.with_passthrough_modules("mcp"), |
| 302 | + ) |
| 303 | + return runner |
| 304 | + |
| 305 | + @asynccontextmanager |
| 306 | + async def run_context() -> AsyncIterator[None]: |
| 307 | + with set_open_ai_agent_temporal_overrides(model_params): |
| 308 | + yield |
| 309 | + |
| 310 | + super().__init__( |
| 311 | + name="OpenAIAgentsPlugin", |
| 312 | + data_converter=_data_converter, |
| 313 | + worker_interceptors=[OpenAIAgentsTracingInterceptor()], |
| 314 | + activities=new_activities, |
| 315 | + workflow_runner=workflow_runner, |
| 316 | + workflow_failure_exception_types=[AgentsWorkflowError], |
| 317 | + run_context=lambda: run_context(), |
232 | 318 | ) |
233 | | - |
234 | | - for mcp_server in mcp_server_providers: |
235 | | - new_activities.extend(mcp_server._get_activities()) |
236 | | - |
237 | | - def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner: |
238 | | - if not runner: |
239 | | - raise ValueError("No WorkflowRunner provided to the OpenAI plugin.") |
240 | | - |
241 | | - # If in sandbox, add additional passthrough |
242 | | - if isinstance(runner, SandboxedWorkflowRunner): |
243 | | - return dataclasses.replace( |
244 | | - runner, |
245 | | - restrictions=runner.restrictions.with_passthrough_modules("mcp"), |
246 | | - ) |
247 | | - return runner |
248 | | - |
249 | | - @asynccontextmanager |
250 | | - async def run_context() -> AsyncIterator[None]: |
251 | | - with set_open_ai_agent_temporal_overrides(model_params): |
252 | | - yield |
253 | | - |
254 | | - return create_plugin( |
255 | | - name="OpenAIAgentsPlugin", |
256 | | - data_converter=_data_converter, |
257 | | - worker_interceptors=[OpenAIAgentsTracingInterceptor()], |
258 | | - activities=new_activities, |
259 | | - workflow_runner=workflow_runner, |
260 | | - workflow_failure_exception_types=[AgentsWorkflowError], |
261 | | - run_context=lambda: run_context(), |
262 | | - ) |
|
0 commit comments