44import uuid
55from collections .abc import Callable
66from concurrent .futures import ThreadPoolExecutor
7- from typing import Any , overload
7+ from typing import Any
88
99from aiohttp import web
1010from temporalio .client import Client
2222 UnsandboxedWorkflowRunner ,
2323 Worker ,
2424)
25+ from temporalio .contrib .openai_agents import OpenAIAgentsPlugin
2526
2627from agentex .lib .utils .logging import make_logger
2728from agentex .lib .utils .registration import register_agent
@@ -66,10 +67,18 @@ def __init__(self) -> None:
6667)
6768
6869
69- async def get_temporal_client (temporal_address : str , metrics_url : str = None ) -> Client :
70+ async def get_temporal_client (temporal_address : str , metrics_url : str = None , enable_openai_agents_plugin : bool = False ) -> Client :
71+ plugins = []
72+
73+ # Add OpenAI Agents plugin if enabled
74+ if enable_openai_agents_plugin :
75+ plugins .append (OpenAIAgentsPlugin ())
76+
7077 if not metrics_url :
7178 client = await Client .connect (
72- target_host = temporal_address , data_converter = custom_data_converter
79+ target_host = temporal_address ,
80+ data_converter = custom_data_converter ,
81+ plugins = plugins if plugins else None ,
7382 )
7483 else :
7584 runtime = Runtime (
@@ -79,6 +88,7 @@ async def get_temporal_client(temporal_address: str, metrics_url: str = None) ->
7988 target_host = temporal_address ,
8089 data_converter = custom_data_converter ,
8190 runtime = runtime ,
91+ plugins = plugins if plugins else None ,
8292 )
8393 return client
8494
@@ -90,6 +100,7 @@ def __init__(
90100 max_workers : int = 10 ,
91101 max_concurrent_activities : int = 10 ,
92102 health_check_port : int = 80 ,
103+ enable_openai_agents_plugin : bool = False ,
93104 ):
94105 self .task_queue = task_queue
95106 self .activity_handles = []
@@ -98,49 +109,30 @@ def __init__(
98109 self .health_check_server_running = False
99110 self .healthy = False
100111 self .health_check_port = health_check_port
112+ self .enable_openai_agents_plugin = enable_openai_agents_plugin
101113
102- @overload
103114 async def run (
104115 self ,
105116 activities : list [Callable ],
106- * ,
107117 workflow : type ,
108- ) -> None : ...
109-
110- @overload
111- async def run (
112- self ,
113- activities : list [Callable ],
114- * ,
115- workflows : list [type ],
116- ) -> None : ...
117-
118- async def run (
119- self ,
120- activities : list [Callable ],
121- * ,
122- workflow : type | None = None ,
123- workflows : list [type ] | None = None ,
124118 ):
125119 await self .start_health_check_server ()
126120 await self ._register_agent ()
127121 temporal_client = await get_temporal_client (
128122 temporal_address = os .environ .get ("TEMPORAL_ADDRESS" , "localhost:7233" ),
123+ enable_openai_agents_plugin = self .enable_openai_agents_plugin ,
129124 )
130125
131126 # Enable debug mode if AgentEx debug is enabled (disables deadlock detection)
132127 debug_enabled = os .environ .get ("AGENTEX_DEBUG_ENABLED" , "false" ).lower () == "true"
133128 if debug_enabled :
134129 logger .info ("🐛 [WORKER] Temporal debug mode enabled - deadlock detection disabled" )
135130
136- if workflow is None and workflows is None :
137- raise ValueError ("Either workflow or workflows must be provided" )
138-
139131 worker = Worker (
140132 client = temporal_client ,
141133 task_queue = self .task_queue ,
142134 activity_executor = ThreadPoolExecutor (max_workers = self .max_workers ),
143- workflows = [workflow ] if workflows is None else workflows ,
135+ workflows = [workflow ],
144136 activities = activities ,
145137 workflow_runner = UnsandboxedWorkflowRunner (),
146138 max_concurrent_activities = self .max_concurrent_activities ,
0 commit comments