99from agentex .types .task_message import TaskMessage
1010from agentex .types .task_message_content import TaskMessageContent
1111from agentex .types .task_message_content_param import TaskMessageContentParam
12+ from agentex .types .agent_rpc_params import (
13+ ParamsCancelTaskRequest as RpcParamsCancelTaskRequest ,
14+ ParamsSendEventRequest as RpcParamsSendEventRequest ,
15+ )
1216
1317logger = make_logger (__name__ )
1418
@@ -30,6 +34,7 @@ async def task_create(
3034 params : dict [str , Any ] | None = None ,
3135 trace_id : str | None = None ,
3236 parent_span_id : str | None = None ,
37+ request : dict [str , Any ] | None = None ,
3338 ) -> Task :
3439 trace = self ._tracer .trace (trace_id = trace_id )
3540 async with trace .span (
@@ -43,6 +48,10 @@ async def task_create(
4348 },
4449 ) as span :
4550 heartbeat_if_in_workflow ("task create" )
51+
52+ # Extract headers from request; pass-through to agent
53+ extra_headers = request .get ("headers" ) if request else None
54+
4655 if agent_name :
4756 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
4857 agent_name = agent_name ,
@@ -51,6 +60,7 @@ async def task_create(
5160 "name" : name ,
5261 "params" : params ,
5362 },
63+ extra_headers = extra_headers ,
5464 )
5565 elif agent_id :
5666 json_rpc_response = await self ._agentex_client .agents .rpc (
@@ -60,6 +70,7 @@ async def task_create(
6070 "name" : name ,
6171 "params" : params ,
6272 },
73+ extra_headers = extra_headers ,
6374 )
6475 else :
6576 raise ValueError ("Either agent_name or agent_id must be provided" )
@@ -78,6 +89,7 @@ async def message_send(
7889 task_name : str | None = None ,
7990 trace_id : str | None = None ,
8091 parent_span_id : str | None = None ,
92+ request : dict [str , Any ] | None = None ,
8193 ) -> List [TaskMessage ]:
8294 trace = self ._tracer .trace (trace_id = trace_id )
8395 async with trace .span (
@@ -92,6 +104,10 @@ async def message_send(
92104 },
93105 ) as span :
94106 heartbeat_if_in_workflow ("message send" )
107+
108+ # Extract headers from request; pass-through to agent
109+ extra_headers = request .get ("headers" ) if request else None
110+
95111 if agent_name :
96112 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
97113 agent_name = agent_name ,
@@ -101,6 +117,7 @@ async def message_send(
101117 "content" : cast (TaskMessageContentParam , content .model_dump ()),
102118 "stream" : False ,
103119 },
120+ extra_headers = extra_headers ,
104121 )
105122 elif agent_id :
106123 json_rpc_response = await self ._agentex_client .agents .rpc (
@@ -111,12 +128,13 @@ async def message_send(
111128 "content" : cast (TaskMessageContentParam , content .model_dump ()),
112129 "stream" : False ,
113130 },
131+ extra_headers = extra_headers ,
114132 )
115133 else :
116134 raise ValueError ("Either agent_name or agent_id must be provided" )
117135
118136 task_messages : List [TaskMessage ] = []
119- logger .info (f "json_rpc_response: { json_rpc_response } " )
137+ logger .info ("json_rpc_response: %s" , json_rpc_response )
120138 if isinstance (json_rpc_response .result , list ):
121139 for message in json_rpc_response .result :
122140 task_message = TaskMessage .model_validate (message )
@@ -137,6 +155,7 @@ async def event_send(
137155 task_name : str | None = None ,
138156 trace_id : str | None = None ,
139157 parent_span_id : str | None = None ,
158+ request : dict [str , Any ] | None = None ,
140159 ) -> Event :
141160 trace = self ._tracer .trace (trace_id = trace_id )
142161 async with trace .span (
@@ -146,27 +165,33 @@ async def event_send(
146165 "agent_id" : agent_id ,
147166 "agent_name" : agent_name ,
148167 "task_id" : task_id ,
168+ "task_name" : task_name ,
149169 "content" : content ,
150170 },
151171 ) as span :
152172 heartbeat_if_in_workflow ("event send" )
173+
174+ # Extract headers from request; pass-through to agent
175+ extra_headers = request .get ("headers" ) if request else None
176+
177+ rpc_event_params : RpcParamsSendEventRequest = {
178+ "task_id" : task_id ,
179+ "task_name" : task_name ,
180+ "content" : cast (TaskMessageContentParam , content .model_dump ()),
181+ }
153182 if agent_name :
154183 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
155184 agent_name = agent_name ,
156185 method = "event/send" ,
157- params = {
158- "task_id" : task_id ,
159- "content" : cast (TaskMessageContentParam , content .model_dump ()),
160- },
186+ params = rpc_event_params ,
187+ extra_headers = extra_headers ,
161188 )
162189 elif agent_id :
163190 json_rpc_response = await self ._agentex_client .agents .rpc (
164191 agent_id = agent_id ,
165192 method = "event/send" ,
166- params = {
167- "task_id" : task_id ,
168- "content" : cast (TaskMessageContentParam , content .model_dump ()),
169- },
193+ params = rpc_event_params ,
194+ extra_headers = extra_headers ,
170195 )
171196 else :
172197 raise ValueError ("Either agent_name or agent_id must be provided" )
@@ -184,15 +209,34 @@ async def task_cancel(
184209 agent_name : str | None = None ,
185210 trace_id : str | None = None ,
186211 parent_span_id : str | None = None ,
187- ) -> Task :
212+ request : dict [str , Any ] | None = None ,
213+ ) -> Task :
214+ """
215+ Cancel a task by sending cancel request to the agent that owns the task.
216+
217+ Args:
218+ task_id: ID of the task to cancel (passed to agent in params)
219+ task_name: Name of the task to cancel (passed to agent in params)
220+ agent_id: ID of the agent that owns the task
221+ agent_name: Name of the agent that owns the task
222+ trace_id: Trace ID for tracing
223+ parent_span_id: Parent span ID for tracing
224+ request: Additional request context including headers to forward to the agent
225+
226+ Returns:
227+ Task entry representing the cancelled task
228+
229+ Raises:
230+ ValueError: If neither agent_name nor agent_id is provided,
231+ or if neither task_name nor task_id is provided
232+ """
188233 # Require agent identification
189234 if not agent_name and not agent_id :
190235 raise ValueError ("Either agent_name or agent_id must be provided to identify the agent that owns the task" )
191236
192237 # Require task identification
193238 if not task_name and not task_id :
194239 raise ValueError ("Either task_name or task_id must be provided to identify the task to cancel" )
195-
196240 trace = self ._tracer .trace (trace_id = trace_id )
197241 async with trace .span (
198242 parent_id = parent_span_id ,
@@ -206,8 +250,11 @@ async def task_cancel(
206250 ) as span :
207251 heartbeat_if_in_workflow ("task cancel" )
208252
253+ # Extract headers from request; pass-through to agent
254+ extra_headers = request .get ("headers" ) if request else None
255+
209256 # Build params for the agent (task identification)
210- params = {}
257+ params : RpcParamsCancelTaskRequest = {}
211258 if task_id :
212259 params ["task_id" ] = task_id
213260 if task_name :
@@ -219,12 +266,15 @@ async def task_cancel(
219266 agent_name = agent_name ,
220267 method = "task/cancel" ,
221268 params = params ,
269+ extra_headers = extra_headers ,
222270 )
223271 else : # agent_id is provided (validated above)
272+ assert agent_id is not None
224273 json_rpc_response = await self ._agentex_client .agents .rpc (
225274 agent_id = agent_id ,
226275 method = "task/cancel" ,
227276 params = params ,
277+ extra_headers = extra_headers ,
228278 )
229279
230280 task_entry = Task .model_validate (json_rpc_response .result )
0 commit comments