Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/agentex/lib/core/clients/temporal/temporal_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@


class TemporalClient:
def __init__(self, temporal_client: Client | None = None):
def __init__(self, temporal_client: Client | None = None, plugins: list[Any] = []):
self._client: Client = temporal_client
self._plugins = plugins

@classmethod
async def create(cls, temporal_address: str):
async def create(cls, temporal_address: str, plugins: list[Any] = []):
if temporal_address in [
"false",
"False",
Expand All @@ -88,8 +89,11 @@ async def create(cls, temporal_address: str):
]:
_client = None
else:
_client = await get_temporal_client(temporal_address)
return cls(_client)
_client = await get_temporal_client(
temporal_address,
plugins=plugins
)
return cls(_client, plugins)

async def setup(self, temporal_address: str):
self._client = await self._get_temporal_client(
Expand All @@ -109,7 +113,10 @@ async def _get_temporal_client(self, temporal_address: str) -> Client:
]:
return None
else:
return await get_temporal_client(temporal_address)
return await get_temporal_client(
temporal_address,
plugins=self._plugins
)

async def start_workflow(
self,
Expand Down
44 changes: 42 additions & 2 deletions src/agentex/lib/core/clients/temporal/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from temporalio.client import Client
from typing import Any
from temporalio.client import Client, Plugin as ClientPlugin
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig

Expand Down Expand Up @@ -38,12 +39,50 @@
# )


async def get_temporal_client(temporal_address: str, metrics_url: str = None) -> Client:
def validate_client_plugins(plugins: list[Any]) -> None:
"""
Validate that all items in the plugins list are valid Temporal client plugins.

Args:
plugins: List of plugins to validate

Raises:
TypeError: If any plugin is not a valid ClientPlugin instance
"""
for i, plugin in enumerate(plugins):
if not isinstance(plugin, ClientPlugin):
raise TypeError(
f"Plugin at index {i} must be an instance of temporalio.client.Plugin, "
f"got {type(plugin).__name__}. Note: WorkerPlugin is not valid for workflow clients."
)


async def get_temporal_client(
temporal_address: str,
metrics_url: str = None,
plugins: list[Any] = []
) -> Client:
"""
Create a Temporal client with plugin integration.

Args:
temporal_address: Temporal server address
metrics_url: Optional metrics endpoint URL
plugins: List of Temporal plugins to include

Returns:
Configured Temporal client
"""
# Validate plugins if any are provided
if plugins:
validate_client_plugins(plugins)

if not metrics_url:
client = await Client.connect(
target_host=temporal_address,
# data_converter=custom_data_converter,
data_converter=pydantic_data_converter,
plugins=plugins,
)
else:
runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url)))
Expand All @@ -52,5 +91,6 @@ async def get_temporal_client(temporal_address: str, metrics_url: str = None) ->
# data_converter=custom_data_converter,
data_converter=pydantic_data_converter,
runtime=runtime,
plugins=plugins,
)
return client
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
self._temporal_client = temporal_client
self._env_vars = env_vars


async def submit_task(self, agent: Agent, task: Task, params: dict[str, Any] | None) -> str:
"""
Submit a task to the async runtime for execution.
Expand Down
24 changes: 21 additions & 3 deletions src/agentex/lib/core/temporal/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, overload

from aiohttp import web
from temporalio.client import Client
from temporalio.client import Client, Plugin as ClientPlugin
from temporalio.converter import (
AdvancedJSONEncoder,
CompositePayloadConverter,
Expand All @@ -19,6 +19,7 @@
)
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
from temporalio.worker import (
Plugin as WorkerPlugin,
UnsandboxedWorkflowRunner,
Worker,
)
Expand Down Expand Up @@ -65,11 +66,25 @@ def __init__(self) -> None:
payload_converter_class=DateTimePayloadConverter,
)

def _validate_plugins(plugins: list) -> None:
"""Validate that all items in the plugins list are valid Temporal plugins."""
for i, plugin in enumerate(plugins):
if not isinstance(plugin, (ClientPlugin, WorkerPlugin)):
raise TypeError(
f"Plugin at index {i} must be an instance of temporalio.client.Plugin "
f"or temporalio.worker.Plugin, got {type(plugin).__name__}"
)



async def get_temporal_client(temporal_address: str, metrics_url: str = None, plugins: list = []) -> Client:

if plugins != []: # We don't need to validate the plugins if they are empty
_validate_plugins(plugins)

async def get_temporal_client(temporal_address: str, metrics_url: str = None) -> Client:
if not metrics_url:
client = await Client.connect(
target_host=temporal_address, data_converter=custom_data_converter
target_host=temporal_address, data_converter=custom_data_converter, plugins=plugins
)
else:
runtime = Runtime(
Expand All @@ -90,6 +105,7 @@ def __init__(
max_workers: int = 10,
max_concurrent_activities: int = 10,
health_check_port: int = 80,
plugins: list = [],
):
self.task_queue = task_queue
self.activity_handles = []
Expand All @@ -98,6 +114,7 @@ def __init__(
self.health_check_server_running = False
self.healthy = False
self.health_check_port = health_check_port
self.plugins = plugins

@overload
async def run(
Expand Down Expand Up @@ -126,6 +143,7 @@ async def run(
await self._register_agent()
temporal_client = await get_temporal_client(
temporal_address=os.environ.get("TEMPORAL_ADDRESS", "localhost:7233"),
plugins=self.plugins,
)

# Enable debug mode if AgentEx debug is enabled (disables deadlock detection)
Expand Down
11 changes: 7 additions & 4 deletions src/agentex/lib/sdk/fastacp/fastacp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import inspect
import json
import os
from pathlib import Path

from typing import Literal
from typing import Any, Literal
from agentex.lib.sdk.fastacp.base.base_acp_server import BaseACPServer
from agentex.lib.sdk.fastacp.impl.agentic_base_acp import AgenticBaseACP
from agentex.lib.sdk.fastacp.impl.sync_acp import SyncACP
Expand Down Expand Up @@ -50,10 +49,12 @@ def create_agentic_acp(config: AgenticACPConfig, **kwargs) -> BaseACPServer:
implementation_class = AGENTIC_ACP_IMPLEMENTATIONS[config.type]
# Handle temporal-specific configuration
if config.type == "temporal":
# Extract temporal_address from config if it's a TemporalACPConfig
# Extract temporal_address and plugins from config if it's a TemporalACPConfig
temporal_config = kwargs.copy()
if hasattr(config, "temporal_address"):
temporal_config["temporal_address"] = config.temporal_address
if hasattr(config, "plugins"):
temporal_config["plugins"] = config.plugins
return implementation_class.create(**temporal_config)
else:
return implementation_class.create(**kwargs)
Expand All @@ -68,7 +69,9 @@ def locate_build_info_path() -> None:

@staticmethod
def create(
acp_type: Literal["sync", "agentic"], config: BaseACPConfig | None = None, **kwargs
acp_type: Literal["sync", "agentic"],
config: BaseACPConfig | None = None,
**kwargs
) -> BaseACPServer | SyncACP | AgenticBaseACP | TemporalACP:
"""Main factory method to create any ACP type

Expand Down
12 changes: 7 additions & 5 deletions src/agentex/lib/sdk/fastacp/impl/temporal_acp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Callable
from typing import Any, AsyncGenerator, Callable

from fastapi import FastAPI

Expand All @@ -24,18 +24,19 @@ class TemporalACP(BaseACPServer):
"""

def __init__(
self, temporal_address: str, temporal_task_service: TemporalTaskService | None = None
self, temporal_address: str, temporal_task_service: TemporalTaskService | None = None, plugins: list[Any] | None = None
):
super().__init__()
self._temporal_task_service = temporal_task_service
self._temporal_address = temporal_address
self._plugins = plugins or []

@classmethod
def create(cls, temporal_address: str) -> "TemporalACP":
def create(cls, temporal_address: str, plugins: list[Any] | None = None) -> "TemporalACP":
logger.info("Initializing TemporalACP instance")

# Create instance without temporal client initially
temporal_acp = cls(temporal_address=temporal_address)
temporal_acp = cls(temporal_address=temporal_address, plugins=plugins)
temporal_acp._setup_handlers()
logger.info("TemporalACP instance initialized now")
return temporal_acp
Expand All @@ -51,7 +52,8 @@ async def lifespan(app: FastAPI):
if self._temporal_task_service is None:
env_vars = EnvironmentVariables.refresh()
temporal_client = await TemporalClient.create(
temporal_address=self._temporal_address
temporal_address=self._temporal_address,
plugins=self._plugins
)
self._temporal_task_service = TemporalTaskService(
temporal_client=temporal_client,
Expand Down
14 changes: 12 additions & 2 deletions src/agentex/lib/types/fastacp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Literal
from typing import Any, Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from agentex.lib.core.clients.temporal.utils import validate_client_plugins


class BaseACPConfig(BaseModel):
Expand Down Expand Up @@ -43,12 +44,21 @@ class TemporalACPConfig(AgenticACPConfig):
Attributes:
type: The type of ACP implementation
temporal_address: The address of the temporal server
plugins: List of Temporal client plugins
"""

type: Literal["temporal"] = Field(default="temporal", frozen=True)
temporal_address: str = Field(
default="temporal-frontend.temporal.svc.cluster.local:7233", frozen=True
)
plugins: list[Any] = Field(default=[], frozen=True)

@field_validator("plugins")
@classmethod
def validate_plugins(cls, v: list[Any]) -> list[Any]:
"""Validate that all plugins are valid Temporal client plugins."""
validate_client_plugins(v)
return v


class AgenticBaseACPConfig(AgenticACPConfig):
Expand Down
Loading