Skip to content

Commit 110e58e

Browse files
committed
Add open ai agents sdk plugin to enable temporal support
1 parent a208510 commit 110e58e

File tree

4 files changed

+80
-9
lines changed

4 files changed

+80
-9
lines changed

src/agentex/lib/core/clients/temporal/temporal_client.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,12 @@
7171

7272

7373
class TemporalClient:
74-
def __init__(self, temporal_client: Client | None = None):
74+
def __init__(self, temporal_client: Client | None = None, enable_openai_agents_plugin: bool = False):
7575
self._client: Client = temporal_client
76+
self._enable_openai_agents_plugin = enable_openai_agents_plugin
7677

7778
@classmethod
78-
async def create(cls, temporal_address: str):
79+
async def create(cls, temporal_address: str, enable_openai_agents_plugin: bool = False):
7980
if temporal_address in [
8081
"false",
8182
"False",
@@ -88,8 +89,11 @@ async def create(cls, temporal_address: str):
8889
]:
8990
_client = None
9091
else:
91-
_client = await get_temporal_client(temporal_address)
92-
return cls(_client)
92+
_client = await get_temporal_client(
93+
temporal_address,
94+
enable_openai_agents_plugin=enable_openai_agents_plugin
95+
)
96+
return cls(_client, enable_openai_agents_plugin)
9397

9498
async def setup(self, temporal_address: str):
9599
self._client = await self._get_temporal_client(
@@ -109,7 +113,10 @@ async def _get_temporal_client(self, temporal_address: str) -> Client:
109113
]:
110114
return None
111115
else:
112-
return await get_temporal_client(temporal_address)
116+
return await get_temporal_client(
117+
temporal_address,
118+
enable_openai_agents_plugin=self._enable_openai_agents_plugin
119+
)
113120

114121
async def start_workflow(
115122
self,

src/agentex/lib/core/clients/temporal/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from temporalio.client import Client
22
from temporalio.contrib.pydantic import pydantic_data_converter
33
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
4+
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin
45

56
# class DateTimeJSONEncoder(AdvancedJSONEncoder):
67
# def default(self, o: Any) -> Any:
@@ -38,12 +39,34 @@
3839
# )
3940

4041

41-
async def get_temporal_client(temporal_address: str, metrics_url: str = None) -> Client:
42+
async def get_temporal_client(
43+
temporal_address: str,
44+
metrics_url: str = None,
45+
enable_openai_agents_plugin: bool = False
46+
) -> Client:
47+
"""
48+
Create a Temporal client with optional OpenAI Agents plugin integration.
49+
50+
Args:
51+
temporal_address: Temporal server address
52+
metrics_url: Optional metrics endpoint URL
53+
enable_openai_agents_plugin: Whether to enable OpenAI Agents plugin
54+
55+
Returns:
56+
Configured Temporal client
57+
"""
58+
plugins = []
59+
60+
# Add OpenAI Agents plugin if enabled
61+
if enable_openai_agents_plugin:
62+
plugins.append(OpenAIAgentsPlugin())
63+
4264
if not metrics_url:
4365
client = await Client.connect(
4466
target_host=temporal_address,
4567
# data_converter=custom_data_converter,
4668
data_converter=pydantic_data_converter,
69+
plugins=plugins if plugins else None,
4770
)
4871
else:
4972
runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url)))
@@ -52,5 +75,6 @@ async def get_temporal_client(temporal_address: str, metrics_url: str = None) ->
5275
# data_converter=custom_data_converter,
5376
data_converter=pydantic_data_converter,
5477
runtime=runtime,
78+
plugins=plugins if plugins else None,
5579
)
5680
return client

src/agentex/lib/core/temporal/services/temporal_task_service.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@ def __init__(
2323
self._temporal_client = temporal_client
2424
self._env_vars = env_vars
2525

26+
@classmethod
27+
async def create(
28+
cls,
29+
env_vars: EnvironmentVariables,
30+
enable_openai_agents_plugin: bool = False,
31+
):
32+
"""
33+
Create a TemporalTaskService with optional OpenAI Agents plugin integration.
34+
35+
Args:
36+
env_vars: Environment variables configuration
37+
enable_openai_agents_plugin: Whether to enable OpenAI Agents plugin
38+
39+
Returns:
40+
Configured TemporalTaskService instance
41+
"""
42+
temporal_client = await TemporalClient.create(
43+
temporal_address=env_vars.TEMPORAL_ADDRESS or "localhost:7233",
44+
enable_openai_agents_plugin=enable_openai_agents_plugin,
45+
)
46+
return cls(temporal_client, env_vars)
47+
2648
async def submit_task(self, agent: Agent, task: Task, params: dict[str, Any] | None) -> str:
2749
"""
2850
Submit a task to the async runtime for execution.

src/agentex/lib/core/temporal/workers/worker.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, overload
88

99
from aiohttp import web
10-
from temporalio.client import Client
10+
from temporalio.client import Client, Plugin as ClientPlugin
1111
from temporalio.converter import (
1212
AdvancedJSONEncoder,
1313
CompositePayloadConverter,
@@ -19,6 +19,7 @@
1919
)
2020
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
2121
from temporalio.worker import (
22+
Plugin as WorkerPlugin,
2223
UnsandboxedWorkflowRunner,
2324
Worker,
2425
)
@@ -65,11 +66,25 @@ def __init__(self) -> None:
6566
payload_converter_class=DateTimePayloadConverter,
6667
)
6768

69+
def _validate_plugins(plugins: list) -> None:
70+
"""Validate that all items in the plugins list are valid Temporal plugins."""
71+
for i, plugin in enumerate(plugins):
72+
if not isinstance(plugin, (ClientPlugin, WorkerPlugin)):
73+
raise TypeError(
74+
f"Plugin at index {i} must be an instance of temporalio.client.Plugin "
75+
f"or temporalio.worker.Plugin, got {type(plugin).__name__}"
76+
)
77+
78+
79+
80+
async def get_temporal_client(temporal_address: str, metrics_url: str = None, plugins: list = []) -> Client:
81+
82+
if plugins != []: # We don't need to validate the plugins if they are empty
83+
_validate_plugins(plugins)
6884

69-
async def get_temporal_client(temporal_address: str, metrics_url: str = None) -> Client:
7085
if not metrics_url:
7186
client = await Client.connect(
72-
target_host=temporal_address, data_converter=custom_data_converter
87+
target_host=temporal_address, data_converter=custom_data_converter, plugins=plugins
7388
)
7489
else:
7590
runtime = Runtime(
@@ -90,6 +105,7 @@ def __init__(
90105
max_workers: int = 10,
91106
max_concurrent_activities: int = 10,
92107
health_check_port: int = 80,
108+
plugins: list = [],
93109
):
94110
self.task_queue = task_queue
95111
self.activity_handles = []
@@ -98,6 +114,7 @@ def __init__(
98114
self.health_check_server_running = False
99115
self.healthy = False
100116
self.health_check_port = health_check_port
117+
self.plugins = plugins
101118

102119
@overload
103120
async def run(
@@ -126,6 +143,7 @@ async def run(
126143
await self._register_agent()
127144
temporal_client = await get_temporal_client(
128145
temporal_address=os.environ.get("TEMPORAL_ADDRESS", "localhost:7233"),
146+
plugins=self.plugins,
129147
)
130148

131149
# Enable debug mode if AgentEx debug is enabled (disables deadlock detection)

0 commit comments

Comments
 (0)