44from agentex .lib .core .tracing .tracer import AsyncTracer
55from agentex .lib .utils .logging import make_logger
66from agentex .lib .utils .temporal import heartbeat_if_in_workflow
7+ # No longer need CustomHeadersConfig import
78from agentex .types .event import Event
89from agentex .types .task import Task
910from agentex .types .task_message import TaskMessage
1011from agentex .types .task_message_content import TaskMessageContent
1112from agentex .types .task_message_content_param import TaskMessageContentParam
13+ from agentex .types .agent_rpc_params import (
14+ ParamsCancelTaskRequest as RpcParamsCancelTaskRequest ,
15+ ParamsSendEventRequest as RpcParamsSendEventRequest ,
16+ )
1217
1318logger = make_logger (__name__ )
1419
@@ -30,6 +35,7 @@ async def task_create(
3035 params : dict [str , Any ] | None = None ,
3136 trace_id : str | None = None ,
3237 parent_span_id : str | None = None ,
38+ request : dict [str , Any ] | None = None ,
3339 ) -> Task :
3440 trace = self ._tracer .trace (trace_id = trace_id )
3541 async with trace .span (
@@ -43,6 +49,11 @@ async def task_create(
4349 },
4450 ) as span :
4551 heartbeat_if_in_workflow ("task create" )
52+
53+ # Extract headers from request; pass-through to agent
54+ headers = request .get ("headers" ) if request else None
55+ filtered_headers : dict [str , str ] = headers or {}
56+
4657 if agent_name :
4758 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
4859 agent_name = agent_name ,
@@ -51,6 +62,7 @@ async def task_create(
5162 "name" : name ,
5263 "params" : params ,
5364 },
65+ extra_headers = filtered_headers ,
5466 )
5567 elif agent_id :
5668 json_rpc_response = await self ._agentex_client .agents .rpc (
@@ -60,6 +72,7 @@ async def task_create(
6072 "name" : name ,
6173 "params" : params ,
6274 },
75+ extra_headers = filtered_headers ,
6376 )
6477 else :
6578 raise ValueError ("Either agent_name or agent_id must be provided" )
@@ -78,6 +91,7 @@ async def message_send(
7891 task_name : str | None = None ,
7992 trace_id : str | None = None ,
8093 parent_span_id : str | None = None ,
94+ request : dict [str , Any ] | None = None ,
8195 ) -> List [TaskMessage ]:
8296 trace = self ._tracer .trace (trace_id = trace_id )
8397 async with trace .span (
@@ -92,6 +106,11 @@ async def message_send(
92106 },
93107 ) as span :
94108 heartbeat_if_in_workflow ("message send" )
109+
110+ # Extract headers from request; pass-through to agent
111+ headers = request .get ("headers" ) if request else None
112+ filtered_headers : dict [str , str ] = headers or {}
113+
95114 if agent_name :
96115 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
97116 agent_name = agent_name ,
@@ -101,6 +120,7 @@ async def message_send(
101120 "content" : cast (TaskMessageContentParam , content .model_dump ()),
102121 "stream" : False ,
103122 },
123+ extra_headers = filtered_headers ,
104124 )
105125 elif agent_id :
106126 json_rpc_response = await self ._agentex_client .agents .rpc (
@@ -111,12 +131,13 @@ async def message_send(
111131 "content" : cast (TaskMessageContentParam , content .model_dump ()),
112132 "stream" : False ,
113133 },
134+ extra_headers = filtered_headers ,
114135 )
115136 else :
116137 raise ValueError ("Either agent_name or agent_id must be provided" )
117138
118139 task_messages : List [TaskMessage ] = []
119- logger .info (f "json_rpc_response: { json_rpc_response } " )
140+ logger .info ("json_rpc_response: %s" , json_rpc_response )
120141 if isinstance (json_rpc_response .result , list ):
121142 for message in json_rpc_response .result :
122143 task_message = TaskMessage .model_validate (message )
@@ -137,6 +158,7 @@ async def event_send(
137158 task_name : str | None = None ,
138159 trace_id : str | None = None ,
139160 parent_span_id : str | None = None ,
161+ request : dict [str , Any ] | None = None ,
140162 ) -> Event :
141163 trace = self ._tracer .trace (trace_id = trace_id )
142164 async with trace .span (
@@ -146,27 +168,34 @@ async def event_send(
146168 "agent_id" : agent_id ,
147169 "agent_name" : agent_name ,
148170 "task_id" : task_id ,
171+ "task_name" : task_name ,
149172 "content" : content ,
150173 },
151174 ) as span :
152175 heartbeat_if_in_workflow ("event send" )
176+
177+ # Extract headers from request; pass-through to agent
178+ headers = request .get ("headers" ) if request else None
179+ filtered_headers : dict [str , str ] = headers or {}
180+
181+ rpc_event_params : RpcParamsSendEventRequest = {
182+ "task_id" : task_id ,
183+ "task_name" : task_name ,
184+ "content" : cast (TaskMessageContentParam , content .model_dump ()),
185+ }
153186 if agent_name :
154187 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
155188 agent_name = agent_name ,
156189 method = "event/send" ,
157- params = {
158- "task_id" : task_id ,
159- "content" : cast (TaskMessageContentParam , content .model_dump ()),
160- },
190+ params = rpc_event_params ,
191+ extra_headers = filtered_headers ,
161192 )
162193 elif agent_id :
163194 json_rpc_response = await self ._agentex_client .agents .rpc (
164195 agent_id = agent_id ,
165196 method = "event/send" ,
166- params = {
167- "task_id" : task_id ,
168- "content" : cast (TaskMessageContentParam , content .model_dump ()),
169- },
197+ params = rpc_event_params ,
198+ extra_headers = filtered_headers ,
170199 )
171200 else :
172201 raise ValueError ("Either agent_name or agent_id must be provided" )
@@ -184,15 +213,34 @@ async def task_cancel(
184213 agent_name : str | None = None ,
185214 trace_id : str | None = None ,
186215 parent_span_id : str | None = None ,
187- ) -> Task :
216+ request : dict [str , Any ] | None = None ,
217+ ) -> Task :
218+ """
219+ Cancel a task by sending cancel request to the agent that owns the task.
220+
221+ Args:
222+ task_id: ID of the task to cancel (passed to agent in params)
223+ task_name: Name of the task to cancel (passed to agent in params)
224+ agent_id: ID of the agent that owns the task
225+ agent_name: Name of the agent that owns the task
226+ trace_id: Trace ID for tracing
227+ parent_span_id: Parent span ID for tracing
228+ request: Additional request context including headers to forward to the agent
229+
230+ Returns:
231+ Task entry representing the cancelled task
232+
233+ Raises:
234+ ValueError: If neither agent_name nor agent_id is provided,
235+ or if neither task_name nor task_id is provided
236+ """
188237 # Require agent identification
189238 if not agent_name and not agent_id :
190239 raise ValueError ("Either agent_name or agent_id must be provided to identify the agent that owns the task" )
191240
192241 # Require task identification
193242 if not task_name and not task_id :
194243 raise ValueError ("Either task_name or task_id must be provided to identify the task to cancel" )
195-
196244 trace = self ._tracer .trace (trace_id = trace_id )
197245 async with trace .span (
198246 parent_id = parent_span_id ,
@@ -206,8 +254,12 @@ async def task_cancel(
206254 ) as span :
207255 heartbeat_if_in_workflow ("task cancel" )
208256
257+ # Extract headers from request; pass-through to agent
258+ headers = request .get ("headers" ) if request else None
259+ filtered_headers : dict [str , str ] = headers or {}
260+
209261 # Build params for the agent (task identification)
210- params = {}
262+ params : RpcParamsCancelTaskRequest = {}
211263 if task_id :
212264 params ["task_id" ] = task_id
213265 if task_name :
@@ -219,12 +271,15 @@ async def task_cancel(
219271 agent_name = agent_name ,
220272 method = "task/cancel" ,
221273 params = params ,
274+ extra_headers = filtered_headers ,
222275 )
223276 else : # agent_id is provided (validated above)
277+ assert agent_id is not None
224278 json_rpc_response = await self ._agentex_client .agents .rpc (
225279 agent_id = agent_id ,
226280 method = "task/cancel" ,
227281 params = params ,
282+ extra_headers = filtered_headers ,
228283 )
229284
230285 task_entry = Task .model_validate (json_rpc_response .result )
0 commit comments