77from typing import Any , overload
88
99from aiohttp import web
10- from temporalio .client import Client
10+ from temporalio .client import Client , Plugin as ClientPlugin
1111from temporalio .converter import (
1212 AdvancedJSONEncoder ,
1313 CompositePayloadConverter ,
1919)
2020from temporalio .runtime import OpenTelemetryConfig , Runtime , TelemetryConfig
2121from temporalio .worker import (
22+ Plugin as WorkerPlugin ,
2223 UnsandboxedWorkflowRunner ,
2324 Worker ,
2425)
@@ -65,11 +66,25 @@ def __init__(self) -> None:
6566 payload_converter_class = DateTimePayloadConverter ,
6667)
6768
69+ def _validate_plugins (plugins : list ) -> None :
70+ """Validate that all items in the plugins list are valid Temporal plugins."""
71+ for i , plugin in enumerate (plugins ):
72+ if not isinstance (plugin , (ClientPlugin , WorkerPlugin )):
73+ raise TypeError (
74+ f"Plugin at index { i } must be an instance of temporalio.client.Plugin "
75+ f"or temporalio.worker.Plugin, got { type (plugin ).__name__ } "
76+ )
77+
78+
79+
80+ async def get_temporal_client (temporal_address : str , metrics_url : str = None , plugins : list = []) -> Client :
81+
82+ if plugins != []: # We don't need to validate the plugins if they are empty
83+ _validate_plugins (plugins )
6884
69- async def get_temporal_client (temporal_address : str , metrics_url : str = None ) -> Client :
7085 if not metrics_url :
7186 client = await Client .connect (
72- target_host = temporal_address , data_converter = custom_data_converter
87+ target_host = temporal_address , data_converter = custom_data_converter , plugins = plugins
7388 )
7489 else :
7590 runtime = Runtime (
@@ -90,6 +105,7 @@ def __init__(
90105 max_workers : int = 10 ,
91106 max_concurrent_activities : int = 10 ,
92107 health_check_port : int = 80 ,
108+ plugins : list = [],
93109 ):
94110 self .task_queue = task_queue
95111 self .activity_handles = []
@@ -98,6 +114,7 @@ def __init__(
98114 self .health_check_server_running = False
99115 self .healthy = False
100116 self .health_check_port = health_check_port
117+ self .plugins = plugins
101118
102119 @overload
103120 async def run (
@@ -126,6 +143,7 @@ async def run(
126143 await self ._register_agent ()
127144 temporal_client = await get_temporal_client (
128145 temporal_address = os .environ .get ("TEMPORAL_ADDRESS" , "localhost:7233" ),
146+ plugins = self .plugins ,
129147 )
130148
131149 # Enable debug mode if AgentEx debug is enabled (disables deadlock detection)
0 commit comments