Skip to content

Commit f303f76

Browse files
committed
Add open ai agents sdk plugin to enable temporal support
1 parent a208510 commit f303f76

File tree

4 files changed

+76
-31
lines changed

4 files changed

+76
-31
lines changed

src/agentex/lib/core/clients/temporal/temporal_client.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,12 @@
7171

7272

7373
class TemporalClient:
74-
def __init__(self, temporal_client: Client | None = None):
74+
def __init__(self, temporal_client: Client | None = None, enable_openai_agents_plugin: bool = False):
7575
self._client: Client = temporal_client
76+
self._enable_openai_agents_plugin = enable_openai_agents_plugin
7677

7778
@classmethod
78-
async def create(cls, temporal_address: str):
79+
async def create(cls, temporal_address: str, enable_openai_agents_plugin: bool = False):
7980
if temporal_address in [
8081
"false",
8182
"False",
@@ -88,8 +89,11 @@ async def create(cls, temporal_address: str):
8889
]:
8990
_client = None
9091
else:
91-
_client = await get_temporal_client(temporal_address)
92-
return cls(_client)
92+
_client = await get_temporal_client(
93+
temporal_address,
94+
enable_openai_agents_plugin=enable_openai_agents_plugin
95+
)
96+
return cls(_client, enable_openai_agents_plugin)
9397

9498
async def setup(self, temporal_address: str):
9599
self._client = await self._get_temporal_client(
@@ -109,7 +113,10 @@ async def _get_temporal_client(self, temporal_address: str) -> Client:
109113
]:
110114
return None
111115
else:
112-
return await get_temporal_client(temporal_address)
116+
return await get_temporal_client(
117+
temporal_address,
118+
enable_openai_agents_plugin=self._enable_openai_agents_plugin
119+
)
113120

114121
async def start_workflow(
115122
self,

src/agentex/lib/core/clients/temporal/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from temporalio.client import Client
22
from temporalio.contrib.pydantic import pydantic_data_converter
33
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
4+
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin
45

56
# class DateTimeJSONEncoder(AdvancedJSONEncoder):
67
# def default(self, o: Any) -> Any:
@@ -38,12 +39,34 @@
3839
# )
3940

4041

41-
async def get_temporal_client(temporal_address: str, metrics_url: str = None) -> Client:
42+
async def get_temporal_client(
43+
temporal_address: str,
44+
metrics_url: str = None,
45+
enable_openai_agents_plugin: bool = False
46+
) -> Client:
47+
"""
48+
Create a Temporal client with optional OpenAI Agents plugin integration.
49+
50+
Args:
51+
temporal_address: Temporal server address
52+
metrics_url: Optional metrics endpoint URL
53+
enable_openai_agents_plugin: Whether to enable OpenAI Agents plugin
54+
55+
Returns:
56+
Configured Temporal client
57+
"""
58+
plugins = []
59+
60+
# Add OpenAI Agents plugin if enabled
61+
if enable_openai_agents_plugin:
62+
plugins.append(OpenAIAgentsPlugin())
63+
4264
if not metrics_url:
4365
client = await Client.connect(
4466
target_host=temporal_address,
4567
# data_converter=custom_data_converter,
4668
data_converter=pydantic_data_converter,
69+
plugins=plugins if plugins else None,
4770
)
4871
else:
4972
runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url)))
@@ -52,5 +75,6 @@ async def get_temporal_client(temporal_address: str, metrics_url: str = None) ->
5275
# data_converter=custom_data_converter,
5376
data_converter=pydantic_data_converter,
5477
runtime=runtime,
78+
plugins=plugins if plugins else None,
5579
)
5680
return client

src/agentex/lib/core/temporal/services/temporal_task_service.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@ def __init__(
2323
self._temporal_client = temporal_client
2424
self._env_vars = env_vars
2525

26+
@classmethod
27+
async def create(
28+
cls,
29+
env_vars: EnvironmentVariables,
30+
enable_openai_agents_plugin: bool = False,
31+
):
32+
"""
33+
Create a TemporalTaskService with optional OpenAI Agents plugin integration.
34+
35+
Args:
36+
env_vars: Environment variables configuration
37+
enable_openai_agents_plugin: Whether to enable OpenAI Agents plugin
38+
39+
Returns:
40+
Configured TemporalTaskService instance
41+
"""
42+
temporal_client = await TemporalClient.create(
43+
temporal_address=env_vars.TEMPORAL_ADDRESS or "localhost:7233",
44+
enable_openai_agents_plugin=enable_openai_agents_plugin,
45+
)
46+
return cls(temporal_client, env_vars)
47+
2648
async def submit_task(self, agent: Agent, task: Task, params: dict[str, Any] | None) -> str:
2749
"""
2850
Submit a task to the async runtime for execution.

src/agentex/lib/core/temporal/workers/worker.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from collections.abc import Callable
66
from concurrent.futures import ThreadPoolExecutor
7-
from typing import Any, overload
7+
from typing import Any
88

99
from aiohttp import web
1010
from temporalio.client import Client
@@ -22,6 +22,7 @@
2222
UnsandboxedWorkflowRunner,
2323
Worker,
2424
)
25+
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin
2526

2627
from agentex.lib.utils.logging import make_logger
2728
from agentex.lib.utils.registration import register_agent
@@ -66,10 +67,18 @@ def __init__(self) -> None:
6667
)
6768

6869

69-
async def get_temporal_client(temporal_address: str, metrics_url: str = None) -> Client:
70+
async def get_temporal_client(temporal_address: str, metrics_url: str = None, enable_openai_agents_plugin: bool = False) -> Client:
71+
plugins = []
72+
73+
# Add OpenAI Agents plugin if enabled
74+
if enable_openai_agents_plugin:
75+
plugins.append(OpenAIAgentsPlugin())
76+
7077
if not metrics_url:
7178
client = await Client.connect(
72-
target_host=temporal_address, data_converter=custom_data_converter
79+
target_host=temporal_address,
80+
data_converter=custom_data_converter,
81+
plugins=plugins if plugins else None,
7382
)
7483
else:
7584
runtime = Runtime(
@@ -79,6 +88,7 @@ async def get_temporal_client(temporal_address: str, metrics_url: str = None) ->
7988
target_host=temporal_address,
8089
data_converter=custom_data_converter,
8190
runtime=runtime,
91+
plugins=plugins if plugins else None,
8292
)
8393
return client
8494

@@ -90,6 +100,7 @@ def __init__(
90100
max_workers: int = 10,
91101
max_concurrent_activities: int = 10,
92102
health_check_port: int = 80,
103+
enable_openai_agents_plugin: bool = False,
93104
):
94105
self.task_queue = task_queue
95106
self.activity_handles = []
@@ -98,49 +109,30 @@ def __init__(
98109
self.health_check_server_running = False
99110
self.healthy = False
100111
self.health_check_port = health_check_port
112+
self.enable_openai_agents_plugin = enable_openai_agents_plugin
101113

102-
@overload
103114
async def run(
104115
self,
105116
activities: list[Callable],
106-
*,
107117
workflow: type,
108-
) -> None: ...
109-
110-
@overload
111-
async def run(
112-
self,
113-
activities: list[Callable],
114-
*,
115-
workflows: list[type],
116-
) -> None: ...
117-
118-
async def run(
119-
self,
120-
activities: list[Callable],
121-
*,
122-
workflow: type | None = None,
123-
workflows: list[type] | None = None,
124118
):
125119
await self.start_health_check_server()
126120
await self._register_agent()
127121
temporal_client = await get_temporal_client(
128122
temporal_address=os.environ.get("TEMPORAL_ADDRESS", "localhost:7233"),
123+
enable_openai_agents_plugin=self.enable_openai_agents_plugin,
129124
)
130125

131126
# Enable debug mode if AgentEx debug is enabled (disables deadlock detection)
132127
debug_enabled = os.environ.get("AGENTEX_DEBUG_ENABLED", "false").lower() == "true"
133128
if debug_enabled:
134129
logger.info("🐛 [WORKER] Temporal debug mode enabled - deadlock detection disabled")
135130

136-
if workflow is None and workflows is None:
137-
raise ValueError("Either workflow or workflows must be provided")
138-
139131
worker = Worker(
140132
client=temporal_client,
141133
task_queue=self.task_queue,
142134
activity_executor=ThreadPoolExecutor(max_workers=self.max_workers),
143-
workflows=[workflow] if workflows is None else workflows,
135+
workflows=[workflow],
144136
activities=activities,
145137
workflow_runner=UnsandboxedWorkflowRunner(),
146138
max_concurrent_activities=self.max_concurrent_activities,

0 commit comments

Comments
 (0)