99from agentex .types .task_message import TaskMessage
1010from agentex .types .task_message_content import TaskMessageContent
1111from agentex .types .task_message_content_param import TaskMessageContentParam
12+ from agentex .types .agent_rpc_params import (
13+ ParamsCancelTaskRequest as RpcParamsCancelTaskRequest ,
14+ ParamsSendEventRequest as RpcParamsSendEventRequest ,
15+ )
1216
1317logger = make_logger (__name__ )
1418
@@ -30,6 +34,7 @@ async def task_create(
3034 params : dict [str , Any ] | None = None ,
3135 trace_id : str | None = None ,
3236 parent_span_id : str | None = None ,
37+ request : dict [str , Any ] | None = None ,
3338 ) -> Task :
3439 trace = self ._tracer .trace (trace_id = trace_id )
3540 async with trace .span (
@@ -43,6 +48,11 @@ async def task_create(
4348 },
4449 ) as span :
4550 heartbeat_if_in_workflow ("task create" )
51+
52+ # Extract headers from request; pass-through to agent
53+ headers = request .get ("headers" ) if request else None
54+ filtered_headers : dict [str , str ] = headers or {}
55+
4656 if agent_name :
4757 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
4858 agent_name = agent_name ,
@@ -51,6 +61,7 @@ async def task_create(
5161 "name" : name ,
5262 "params" : params ,
5363 },
64+ extra_headers = filtered_headers ,
5465 )
5566 elif agent_id :
5667 json_rpc_response = await self ._agentex_client .agents .rpc (
@@ -60,6 +71,7 @@ async def task_create(
6071 "name" : name ,
6172 "params" : params ,
6273 },
74+ extra_headers = filtered_headers ,
6375 )
6476 else :
6577 raise ValueError ("Either agent_name or agent_id must be provided" )
@@ -78,6 +90,7 @@ async def message_send(
7890 task_name : str | None = None ,
7991 trace_id : str | None = None ,
8092 parent_span_id : str | None = None ,
93+ request : dict [str , Any ] | None = None ,
8194 ) -> List [TaskMessage ]:
8295 trace = self ._tracer .trace (trace_id = trace_id )
8396 async with trace .span (
@@ -92,6 +105,11 @@ async def message_send(
92105 },
93106 ) as span :
94107 heartbeat_if_in_workflow ("message send" )
108+
109+ # Extract headers from request; pass-through to agent
110+ headers = request .get ("headers" ) if request else None
111+ filtered_headers : dict [str , str ] = headers or {}
112+
95113 if agent_name :
96114 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
97115 agent_name = agent_name ,
@@ -101,6 +119,7 @@ async def message_send(
101119 "content" : cast (TaskMessageContentParam , content .model_dump ()),
102120 "stream" : False ,
103121 },
122+ extra_headers = filtered_headers ,
104123 )
105124 elif agent_id :
106125 json_rpc_response = await self ._agentex_client .agents .rpc (
@@ -111,12 +130,13 @@ async def message_send(
111130 "content" : cast (TaskMessageContentParam , content .model_dump ()),
112131 "stream" : False ,
113132 },
133+ extra_headers = filtered_headers ,
114134 )
115135 else :
116136 raise ValueError ("Either agent_name or agent_id must be provided" )
117137
118138 task_messages : List [TaskMessage ] = []
119- logger .info (f "json_rpc_response: { json_rpc_response } " )
139+ logger .info ("json_rpc_response: %s" , json_rpc_response )
120140 if isinstance (json_rpc_response .result , list ):
121141 for message in json_rpc_response .result :
122142 task_message = TaskMessage .model_validate (message )
@@ -137,6 +157,7 @@ async def event_send(
137157 task_name : str | None = None ,
138158 trace_id : str | None = None ,
139159 parent_span_id : str | None = None ,
160+ request : dict [str , Any ] | None = None ,
140161 ) -> Event :
141162 trace = self ._tracer .trace (trace_id = trace_id )
142163 async with trace .span (
@@ -146,27 +167,34 @@ async def event_send(
146167 "agent_id" : agent_id ,
147168 "agent_name" : agent_name ,
148169 "task_id" : task_id ,
170+ "task_name" : task_name ,
149171 "content" : content ,
150172 },
151173 ) as span :
152174 heartbeat_if_in_workflow ("event send" )
175+
176+ # Extract headers from request; pass-through to agent
177+ headers = request .get ("headers" ) if request else None
178+ filtered_headers : dict [str , str ] = headers or {}
179+
180+ rpc_event_params : RpcParamsSendEventRequest = {
181+ "task_id" : task_id ,
182+ "task_name" : task_name ,
183+ "content" : cast (TaskMessageContentParam , content .model_dump ()),
184+ }
153185 if agent_name :
154186 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
155187 agent_name = agent_name ,
156188 method = "event/send" ,
157- params = {
158- "task_id" : task_id ,
159- "content" : cast (TaskMessageContentParam , content .model_dump ()),
160- },
189+ params = rpc_event_params ,
190+ extra_headers = filtered_headers ,
161191 )
162192 elif agent_id :
163193 json_rpc_response = await self ._agentex_client .agents .rpc (
164194 agent_id = agent_id ,
165195 method = "event/send" ,
166- params = {
167- "task_id" : task_id ,
168- "content" : cast (TaskMessageContentParam , content .model_dump ()),
169- },
196+ params = rpc_event_params ,
197+ extra_headers = filtered_headers ,
170198 )
171199 else :
172200 raise ValueError ("Either agent_name or agent_id must be provided" )
@@ -184,15 +212,34 @@ async def task_cancel(
184212 agent_name : str | None = None ,
185213 trace_id : str | None = None ,
186214 parent_span_id : str | None = None ,
187- ) -> Task :
215+ request : dict [str , Any ] | None = None ,
216+ ) -> Task :
217+ """
218+ Cancel a task by sending cancel request to the agent that owns the task.
219+
220+ Args:
221+ task_id: ID of the task to cancel (passed to agent in params)
222+ task_name: Name of the task to cancel (passed to agent in params)
223+ agent_id: ID of the agent that owns the task
224+ agent_name: Name of the agent that owns the task
225+ trace_id: Trace ID for tracing
226+ parent_span_id: Parent span ID for tracing
227+ request: Additional request context including headers to forward to the agent
228+
229+ Returns:
230+ Task entry representing the cancelled task
231+
232+ Raises:
233+ ValueError: If neither agent_name nor agent_id is provided,
234+ or if neither task_name nor task_id is provided
235+ """
188236 # Require agent identification
189237 if not agent_name and not agent_id :
190238 raise ValueError ("Either agent_name or agent_id must be provided to identify the agent that owns the task" )
191239
192240 # Require task identification
193241 if not task_name and not task_id :
194242 raise ValueError ("Either task_name or task_id must be provided to identify the task to cancel" )
195-
196243 trace = self ._tracer .trace (trace_id = trace_id )
197244 async with trace .span (
198245 parent_id = parent_span_id ,
@@ -206,8 +253,12 @@ async def task_cancel(
206253 ) as span :
207254 heartbeat_if_in_workflow ("task cancel" )
208255
256+ # Extract headers from request; pass-through to agent
257+ headers = request .get ("headers" ) if request else None
258+ filtered_headers : dict [str , str ] = headers or {}
259+
209260 # Build params for the agent (task identification)
210- params = {}
261+ params : RpcParamsCancelTaskRequest = {}
211262 if task_id :
212263 params ["task_id" ] = task_id
213264 if task_name :
@@ -219,12 +270,15 @@ async def task_cancel(
219270 agent_name = agent_name ,
220271 method = "task/cancel" ,
221272 params = params ,
273+ extra_headers = filtered_headers ,
222274 )
223275 else : # agent_id is provided (validated above)
276+ assert agent_id is not None
224277 json_rpc_response = await self ._agentex_client .agents .rpc (
225278 agent_id = agent_id ,
226279 method = "task/cancel" ,
227280 params = params ,
281+ extra_headers = filtered_headers ,
228282 )
229283
230284 task_entry = Task .model_validate (json_rpc_response .result )
0 commit comments