Skip to content

Commit b5a3250

Browse files
committed
Add open ai agents sdk plugin to enable temporal support
1 parent a208510 commit b5a3250

File tree

7 files changed

+69
-19
lines changed

7 files changed

+69
-19
lines changed

src/agentex/lib/core/clients/temporal/temporal_client.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,12 @@
7171

7272

7373
class TemporalClient:
74-
def __init__(self, temporal_client: Client | None = None):
74+
def __init__(self, temporal_client: Client | None = None, plugins: list[Any] = []):
7575
self._client: Client = temporal_client
76+
self._plugins = plugins
7677

7778
@classmethod
78-
async def create(cls, temporal_address: str):
79+
async def create(cls, temporal_address: str, plugins: list[Any] = []):
7980
if temporal_address in [
8081
"false",
8182
"False",
@@ -88,8 +89,11 @@ async def create(cls, temporal_address: str):
8889
]:
8990
_client = None
9091
else:
91-
_client = await get_temporal_client(temporal_address)
92-
return cls(_client)
92+
_client = await get_temporal_client(
93+
temporal_address,
94+
plugins=plugins
95+
)
96+
return cls(_client, plugins)
9397

9498
async def setup(self, temporal_address: str):
9599
self._client = await self._get_temporal_client(
@@ -109,7 +113,10 @@ async def _get_temporal_client(self, temporal_address: str) -> Client:
109113
]:
110114
return None
111115
else:
112-
return await get_temporal_client(temporal_address)
116+
return await get_temporal_client(
117+
temporal_address,
118+
plugins=self._plugins
119+
)
113120

114121
async def start_workflow(
115122
self,

src/agentex/lib/core/clients/temporal/utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Any
12
from temporalio.client import Client
23
from temporalio.contrib.pydantic import pydantic_data_converter
34
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
@@ -38,12 +39,28 @@
3839
# )
3940

4041

41-
async def get_temporal_client(temporal_address: str, metrics_url: str = None) -> Client:
42+
async def get_temporal_client(
43+
temporal_address: str,
44+
metrics_url: str = None,
45+
plugins: list[Any] = []
46+
) -> Client:
47+
"""
48+
Create a Temporal client with plugin integration.
49+
50+
Args:
51+
temporal_address: Temporal server address
52+
metrics_url: Optional metrics endpoint URL
53+
plugins: List of Temporal plugins to include
54+
55+
Returns:
56+
Configured Temporal client
57+
"""
4258
if not metrics_url:
4359
client = await Client.connect(
4460
target_host=temporal_address,
4561
# data_converter=custom_data_converter,
4662
data_converter=pydantic_data_converter,
63+
plugins=plugins,
4764
)
4865
else:
4966
runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url)))
@@ -52,5 +69,6 @@ async def get_temporal_client(temporal_address: str, metrics_url: str = None) ->
5269
# data_converter=custom_data_converter,
5370
data_converter=pydantic_data_converter,
5471
runtime=runtime,
72+
plugins=plugins,
5573
)
5674
return client

src/agentex/lib/core/temporal/services/temporal_task_service.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
self._temporal_client = temporal_client
2424
self._env_vars = env_vars
2525

26+
2627
async def submit_task(self, agent: Agent, task: Task, params: dict[str, Any] | None) -> str:
2728
"""
2829
Submit a task to the async runtime for execution.

src/agentex/lib/core/temporal/workers/worker.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, overload
88

99
from aiohttp import web
10-
from temporalio.client import Client
10+
from temporalio.client import Client, Plugin as ClientPlugin
1111
from temporalio.converter import (
1212
AdvancedJSONEncoder,
1313
CompositePayloadConverter,
@@ -19,6 +19,7 @@
1919
)
2020
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
2121
from temporalio.worker import (
22+
Plugin as WorkerPlugin,
2223
UnsandboxedWorkflowRunner,
2324
Worker,
2425
)
@@ -65,11 +66,25 @@ def __init__(self) -> None:
6566
payload_converter_class=DateTimePayloadConverter,
6667
)
6768

69+
def _validate_plugins(plugins: list) -> None:
70+
"""Validate that all items in the plugins list are valid Temporal plugins."""
71+
for i, plugin in enumerate(plugins):
72+
if not isinstance(plugin, (ClientPlugin, WorkerPlugin)):
73+
raise TypeError(
74+
f"Plugin at index {i} must be an instance of temporalio.client.Plugin "
75+
f"or temporalio.worker.Plugin, got {type(plugin).__name__}"
76+
)
77+
78+
79+
80+
async def get_temporal_client(temporal_address: str, metrics_url: str = None, plugins: list = []) -> Client:
81+
82+
if plugins != []: # We don't need to validate the plugins if they are empty
83+
_validate_plugins(plugins)
6884

69-
async def get_temporal_client(temporal_address: str, metrics_url: str = None) -> Client:
7085
if not metrics_url:
7186
client = await Client.connect(
72-
target_host=temporal_address, data_converter=custom_data_converter
87+
target_host=temporal_address, data_converter=custom_data_converter, plugins=plugins
7388
)
7489
else:
7590
runtime = Runtime(
@@ -90,6 +105,7 @@ def __init__(
90105
max_workers: int = 10,
91106
max_concurrent_activities: int = 10,
92107
health_check_port: int = 80,
108+
plugins: list = [],
93109
):
94110
self.task_queue = task_queue
95111
self.activity_handles = []
@@ -98,6 +114,7 @@ def __init__(
98114
self.health_check_server_running = False
99115
self.healthy = False
100116
self.health_check_port = health_check_port
117+
self.plugins = plugins
101118

102119
@overload
103120
async def run(
@@ -126,6 +143,7 @@ async def run(
126143
await self._register_agent()
127144
temporal_client = await get_temporal_client(
128145
temporal_address=os.environ.get("TEMPORAL_ADDRESS", "localhost:7233"),
146+
plugins=self.plugins,
129147
)
130148

131149
# Enable debug mode if AgentEx debug is enabled (disables deadlock detection)

src/agentex/lib/sdk/fastacp/fastacp.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import inspect
2-
import json
32
import os
43
from pathlib import Path
54

6-
from typing import Literal
5+
from typing import Any, Literal
76
from agentex.lib.sdk.fastacp.base.base_acp_server import BaseACPServer
87
from agentex.lib.sdk.fastacp.impl.agentic_base_acp import AgenticBaseACP
98
from agentex.lib.sdk.fastacp.impl.sync_acp import SyncACP
@@ -50,10 +49,12 @@ def create_agentic_acp(config: AgenticACPConfig, **kwargs) -> BaseACPServer:
5049
implementation_class = AGENTIC_ACP_IMPLEMENTATIONS[config.type]
5150
# Handle temporal-specific configuration
5251
if config.type == "temporal":
53-
# Extract temporal_address from config if it's a TemporalACPConfig
52+
# Extract temporal_address and plugins from config if it's a TemporalACPConfig
5453
temporal_config = kwargs.copy()
5554
if hasattr(config, "temporal_address"):
5655
temporal_config["temporal_address"] = config.temporal_address
56+
if hasattr(config, "plugins"):
57+
temporal_config["plugins"] = config.plugins
5758
return implementation_class.create(**temporal_config)
5859
else:
5960
return implementation_class.create(**kwargs)
@@ -68,7 +69,9 @@ def locate_build_info_path() -> None:
6869

6970
@staticmethod
7071
def create(
71-
acp_type: Literal["sync", "agentic"], config: BaseACPConfig | None = None, **kwargs
72+
acp_type: Literal["sync", "agentic"],
73+
config: BaseACPConfig | None = None,
74+
**kwargs
7275
) -> BaseACPServer | SyncACP | AgenticBaseACP | TemporalACP:
7376
"""Main factory method to create any ACP type
7477

src/agentex/lib/sdk/fastacp/impl/temporal_acp.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from contextlib import asynccontextmanager
2-
from typing import AsyncGenerator, Callable
2+
from typing import Any, AsyncGenerator, Callable
33

44
from fastapi import FastAPI
55

@@ -24,18 +24,19 @@ class TemporalACP(BaseACPServer):
2424
"""
2525

2626
def __init__(
27-
self, temporal_address: str, temporal_task_service: TemporalTaskService | None = None
27+
self, temporal_address: str, temporal_task_service: TemporalTaskService | None = None, plugins: list[Any] | None = None
2828
):
2929
super().__init__()
3030
self._temporal_task_service = temporal_task_service
3131
self._temporal_address = temporal_address
32+
self._plugins = plugins or []
3233

3334
@classmethod
34-
def create(cls, temporal_address: str) -> "TemporalACP":
35+
def create(cls, temporal_address: str, plugins: list[Any] | None = None) -> "TemporalACP":
3536
logger.info("Initializing TemporalACP instance")
3637

3738
# Create instance without temporal client initially
38-
temporal_acp = cls(temporal_address=temporal_address)
39+
temporal_acp = cls(temporal_address=temporal_address, plugins=plugins)
3940
temporal_acp._setup_handlers()
4041
logger.info("TemporalACP instance initialized now")
4142
return temporal_acp
@@ -51,7 +52,8 @@ async def lifespan(app: FastAPI):
5152
if self._temporal_task_service is None:
5253
env_vars = EnvironmentVariables.refresh()
5354
temporal_client = await TemporalClient.create(
54-
temporal_address=self._temporal_address
55+
temporal_address=self._temporal_address,
56+
plugins=self._plugins
5557
)
5658
self._temporal_task_service = TemporalTaskService(
5759
temporal_client=temporal_client,

src/agentex/lib/types/fastacp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal
1+
from typing import Any, Literal
22

33
from pydantic import BaseModel, Field
44

@@ -49,6 +49,7 @@ class TemporalACPConfig(AgenticACPConfig):
4949
temporal_address: str = Field(
5050
default="temporal-frontend.temporal.svc.cluster.local:7233", frozen=True
5151
)
52+
plugins: list[Any] = Field(default=[], frozen=True)
5253

5354

5455
class AgenticBaseACPConfig(AgenticACPConfig):

0 commit comments

Comments
 (0)