Skip to content

Commit b18763b

Browse files
committed
Add open ai agents sdk plugin to enable temporal support
1 parent a208510 commit b18763b

File tree

7 files changed

+102
-21
lines changed

7 files changed

+102
-21
lines changed

src/agentex/lib/core/clients/temporal/temporal_client.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,12 @@
7171

7272

7373
class TemporalClient:
74-
def __init__(self, temporal_client: Client | None = None):
74+
def __init__(self, temporal_client: Client | None = None, plugins: list[Any] = []):
7575
self._client: Client = temporal_client
76+
self._plugins = plugins
7677

7778
@classmethod
78-
async def create(cls, temporal_address: str):
79+
async def create(cls, temporal_address: str, plugins: list[Any] = []):
7980
if temporal_address in [
8081
"false",
8182
"False",
@@ -88,8 +89,11 @@ async def create(cls, temporal_address: str):
8889
]:
8990
_client = None
9091
else:
91-
_client = await get_temporal_client(temporal_address)
92-
return cls(_client)
92+
_client = await get_temporal_client(
93+
temporal_address,
94+
plugins=plugins
95+
)
96+
return cls(_client, plugins)
9397

9498
async def setup(self, temporal_address: str):
9599
self._client = await self._get_temporal_client(
@@ -109,7 +113,10 @@ async def _get_temporal_client(self, temporal_address: str) -> Client:
109113
]:
110114
return None
111115
else:
112-
return await get_temporal_client(temporal_address)
116+
return await get_temporal_client(
117+
temporal_address,
118+
plugins=self._plugins
119+
)
113120

114121
async def start_workflow(
115122
self,

src/agentex/lib/core/clients/temporal/utils.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from temporalio.client import Client
1+
from typing import Any
2+
from temporalio.client import Client, Plugin as ClientPlugin
23
from temporalio.contrib.pydantic import pydantic_data_converter
34
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
45

@@ -38,12 +39,50 @@
3839
# )
3940

4041

41-
async def get_temporal_client(temporal_address: str, metrics_url: str = None) -> Client:
42+
def validate_client_plugins(plugins: list[Any]) -> None:
43+
"""
44+
Validate that all items in the plugins list are valid Temporal client plugins.
45+
46+
Args:
47+
plugins: List of plugins to validate
48+
49+
Raises:
50+
TypeError: If any plugin is not a valid ClientPlugin instance
51+
"""
52+
for i, plugin in enumerate(plugins):
53+
if not isinstance(plugin, ClientPlugin):
54+
raise TypeError(
55+
f"Plugin at index {i} must be an instance of temporalio.client.Plugin, "
56+
f"got {type(plugin).__name__}. Note: WorkerPlugin is not valid for workflow clients."
57+
)
58+
59+
60+
async def get_temporal_client(
61+
temporal_address: str,
62+
metrics_url: str = None,
63+
plugins: list[Any] = []
64+
) -> Client:
65+
"""
66+
Create a Temporal client with plugin integration.
67+
68+
Args:
69+
temporal_address: Temporal server address
70+
metrics_url: Optional metrics endpoint URL
71+
plugins: List of Temporal plugins to include
72+
73+
Returns:
74+
Configured Temporal client
75+
"""
76+
# Validate plugins if any are provided
77+
if plugins:
78+
validate_client_plugins(plugins)
79+
4280
if not metrics_url:
4381
client = await Client.connect(
4482
target_host=temporal_address,
4583
# data_converter=custom_data_converter,
4684
data_converter=pydantic_data_converter,
85+
plugins=plugins,
4786
)
4887
else:
4988
runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url)))
@@ -52,5 +91,6 @@ async def get_temporal_client(temporal_address: str, metrics_url: str = None) ->
5291
# data_converter=custom_data_converter,
5392
data_converter=pydantic_data_converter,
5493
runtime=runtime,
94+
plugins=plugins,
5595
)
5696
return client

src/agentex/lib/core/temporal/services/temporal_task_service.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
self._temporal_client = temporal_client
2424
self._env_vars = env_vars
2525

26+
2627
async def submit_task(self, agent: Agent, task: Task, params: dict[str, Any] | None) -> str:
2728
"""
2829
Submit a task to the async runtime for execution.

src/agentex/lib/core/temporal/workers/worker.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, overload
88

99
from aiohttp import web
10-
from temporalio.client import Client
10+
from temporalio.client import Client, Plugin as ClientPlugin
1111
from temporalio.converter import (
1212
AdvancedJSONEncoder,
1313
CompositePayloadConverter,
@@ -19,6 +19,7 @@
1919
)
2020
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
2121
from temporalio.worker import (
22+
Plugin as WorkerPlugin,
2223
UnsandboxedWorkflowRunner,
2324
Worker,
2425
)
@@ -65,11 +66,25 @@ def __init__(self) -> None:
6566
payload_converter_class=DateTimePayloadConverter,
6667
)
6768

69+
def _validate_plugins(plugins: list) -> None:
70+
"""Validate that all items in the plugins list are valid Temporal plugins."""
71+
for i, plugin in enumerate(plugins):
72+
if not isinstance(plugin, (ClientPlugin, WorkerPlugin)):
73+
raise TypeError(
74+
f"Plugin at index {i} must be an instance of temporalio.client.Plugin "
75+
f"or temporalio.worker.Plugin, got {type(plugin).__name__}"
76+
)
77+
78+
79+
80+
async def get_temporal_client(temporal_address: str, metrics_url: str = None, plugins: list = []) -> Client:
81+
82+
if plugins != []: # We don't need to validate the plugins if they are empty
83+
_validate_plugins(plugins)
6884

69-
async def get_temporal_client(temporal_address: str, metrics_url: str = None) -> Client:
7085
if not metrics_url:
7186
client = await Client.connect(
72-
target_host=temporal_address, data_converter=custom_data_converter
87+
target_host=temporal_address, data_converter=custom_data_converter, plugins=plugins
7388
)
7489
else:
7590
runtime = Runtime(
@@ -90,6 +105,7 @@ def __init__(
90105
max_workers: int = 10,
91106
max_concurrent_activities: int = 10,
92107
health_check_port: int = 80,
108+
plugins: list = [],
93109
):
94110
self.task_queue = task_queue
95111
self.activity_handles = []
@@ -98,6 +114,7 @@ def __init__(
98114
self.health_check_server_running = False
99115
self.healthy = False
100116
self.health_check_port = health_check_port
117+
self.plugins = plugins
101118

102119
@overload
103120
async def run(
@@ -126,6 +143,7 @@ async def run(
126143
await self._register_agent()
127144
temporal_client = await get_temporal_client(
128145
temporal_address=os.environ.get("TEMPORAL_ADDRESS", "localhost:7233"),
146+
plugins=self.plugins,
129147
)
130148

131149
# Enable debug mode if AgentEx debug is enabled (disables deadlock detection)

src/agentex/lib/sdk/fastacp/fastacp.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import inspect
2-
import json
32
import os
43
from pathlib import Path
54

6-
from typing import Literal
5+
from typing import Any, Literal
76
from agentex.lib.sdk.fastacp.base.base_acp_server import BaseACPServer
87
from agentex.lib.sdk.fastacp.impl.agentic_base_acp import AgenticBaseACP
98
from agentex.lib.sdk.fastacp.impl.sync_acp import SyncACP
@@ -50,10 +49,12 @@ def create_agentic_acp(config: AgenticACPConfig, **kwargs) -> BaseACPServer:
5049
implementation_class = AGENTIC_ACP_IMPLEMENTATIONS[config.type]
5150
# Handle temporal-specific configuration
5251
if config.type == "temporal":
53-
# Extract temporal_address from config if it's a TemporalACPConfig
52+
# Extract temporal_address and plugins from config if it's a TemporalACPConfig
5453
temporal_config = kwargs.copy()
5554
if hasattr(config, "temporal_address"):
5655
temporal_config["temporal_address"] = config.temporal_address
56+
if hasattr(config, "plugins"):
57+
temporal_config["plugins"] = config.plugins
5758
return implementation_class.create(**temporal_config)
5859
else:
5960
return implementation_class.create(**kwargs)
@@ -68,7 +69,9 @@ def locate_build_info_path() -> None:
6869

6970
@staticmethod
7071
def create(
71-
acp_type: Literal["sync", "agentic"], config: BaseACPConfig | None = None, **kwargs
72+
acp_type: Literal["sync", "agentic"],
73+
config: BaseACPConfig | None = None,
74+
**kwargs
7275
) -> BaseACPServer | SyncACP | AgenticBaseACP | TemporalACP:
7376
"""Main factory method to create any ACP type
7477

src/agentex/lib/sdk/fastacp/impl/temporal_acp.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from contextlib import asynccontextmanager
2-
from typing import AsyncGenerator, Callable
2+
from typing import Any, AsyncGenerator, Callable
33

44
from fastapi import FastAPI
55

@@ -24,18 +24,19 @@ class TemporalACP(BaseACPServer):
2424
"""
2525

2626
def __init__(
27-
self, temporal_address: str, temporal_task_service: TemporalTaskService | None = None
27+
self, temporal_address: str, temporal_task_service: TemporalTaskService | None = None, plugins: list[Any] | None = None
2828
):
2929
super().__init__()
3030
self._temporal_task_service = temporal_task_service
3131
self._temporal_address = temporal_address
32+
self._plugins = plugins or []
3233

3334
@classmethod
34-
def create(cls, temporal_address: str) -> "TemporalACP":
35+
def create(cls, temporal_address: str, plugins: list[Any] | None = None) -> "TemporalACP":
3536
logger.info("Initializing TemporalACP instance")
3637

3738
# Create instance without temporal client initially
38-
temporal_acp = cls(temporal_address=temporal_address)
39+
temporal_acp = cls(temporal_address=temporal_address, plugins=plugins)
3940
temporal_acp._setup_handlers()
4041
logger.info("TemporalACP instance initialized now")
4142
return temporal_acp
@@ -51,7 +52,8 @@ async def lifespan(app: FastAPI):
5152
if self._temporal_task_service is None:
5253
env_vars = EnvironmentVariables.refresh()
5354
temporal_client = await TemporalClient.create(
54-
temporal_address=self._temporal_address
55+
temporal_address=self._temporal_address,
56+
plugins=self._plugins
5557
)
5658
self._temporal_task_service = TemporalTaskService(
5759
temporal_client=temporal_client,

src/agentex/lib/types/fastacp.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Literal
1+
from typing import Any, Literal
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel, Field, field_validator
4+
from agentex.lib.core.clients.temporal.utils import validate_client_plugins
45

56

67
class BaseACPConfig(BaseModel):
@@ -43,12 +44,21 @@ class TemporalACPConfig(AgenticACPConfig):
4344
Attributes:
4445
type: The type of ACP implementation
4546
temporal_address: The address of the temporal server
47+
plugins: List of Temporal client plugins
4648
"""
4749

4850
type: Literal["temporal"] = Field(default="temporal", frozen=True)
4951
temporal_address: str = Field(
5052
default="temporal-frontend.temporal.svc.cluster.local:7233", frozen=True
5153
)
54+
plugins: list[Any] = Field(default=[], frozen=True)
55+
56+
@field_validator("plugins")
57+
@classmethod
58+
def validate_plugins(cls, v: list[Any]) -> list[Any]:
59+
"""Validate that all plugins are valid Temporal client plugins."""
60+
validate_client_plugins(v)
61+
return v
5262

5363

5464
class AgenticBaseACPConfig(AgenticACPConfig):

0 commit comments

Comments
 (0)