Skip to content

Commit f16d53c

Browse files
feat: add header forwarding logic
1 parent 48cdeea commit f16d53c

File tree

7 files changed

+485
-16
lines changed

7 files changed

+485
-16
lines changed

src/agentex/lib/adk/_modules/acp.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ async def create_task(
5959
start_to_close_timeout: timedelta = timedelta(seconds=5),
6060
heartbeat_timeout: timedelta = timedelta(seconds=5),
6161
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
62+
request: dict[str, Any] | None = None,
6263
) -> Task:
6364
"""
6465
Create a new task.
@@ -71,6 +72,7 @@ async def create_task(
7172
start_to_close_timeout: The start to close timeout for the task.
7273
heartbeat_timeout: The heartbeat timeout for the task.
7374
retry_policy: The retry policy for the task.
75+
request: Additional request context including headers to forward to the agent.
7476
7577
Returns:
7678
The task entry.
@@ -85,6 +87,7 @@ async def create_task(
8587
params=params,
8688
trace_id=trace_id,
8789
parent_span_id=parent_span_id,
90+
request=request,
8891
),
8992
response_type=Task,
9093
start_to_close_timeout=start_to_close_timeout,
@@ -99,6 +102,7 @@ async def create_task(
99102
params=params,
100103
trace_id=trace_id,
101104
parent_span_id=parent_span_id,
105+
request=request,
102106
)
103107

104108
async def send_event(
@@ -112,15 +116,22 @@ async def send_event(
112116
start_to_close_timeout: timedelta = timedelta(seconds=5),
113117
heartbeat_timeout: timedelta = timedelta(seconds=5),
114118
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
119+
request: dict[str, Any] | None = None,
115120
) -> Event:
116121
"""
117122
Send an event to a task.
118123
119124
Args:
120125
task_id: The ID of the task to send the event to.
121-
data: The data to send to the event.
126+
content: The content to send to the event.
122127
agent_id: The ID of the agent to send the event to.
123128
agent_name: The name of the agent to send the event to.
129+
trace_id: The trace ID for the event.
130+
parent_span_id: The parent span ID for the event.
131+
start_to_close_timeout: The start to close timeout for the event.
132+
heartbeat_timeout: The heartbeat timeout for the event.
133+
retry_policy: The retry policy for the event.
134+
request: Additional request context including headers to forward to the agent.
124135
125136
Returns:
126137
The event entry.
@@ -135,6 +146,7 @@ async def send_event(
135146
content=content,
136147
trace_id=trace_id,
137148
parent_span_id=parent_span_id,
149+
request=request,
138150
),
139151
response_type=None,
140152
start_to_close_timeout=start_to_close_timeout,
@@ -149,6 +161,7 @@ async def send_event(
149161
content=content,
150162
trace_id=trace_id,
151163
parent_span_id=parent_span_id,
164+
request=request,
152165
)
153166

154167
async def send_message(
@@ -162,15 +175,22 @@ async def send_message(
162175
start_to_close_timeout: timedelta = timedelta(seconds=5),
163176
heartbeat_timeout: timedelta = timedelta(seconds=5),
164177
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
178+
request: dict[str, Any] | None = None,
165179
) -> List[TaskMessage]:
166180
"""
167181
Send a message to a task.
168182
169183
Args:
170-
task_id: The ID of the task to send the message to.
171184
content: The task message content to send to the task.
185+
task_id: The ID of the task to send the message to.
172186
agent_id: The ID of the agent to send the message to.
173187
agent_name: The name of the agent to send the message to.
188+
trace_id: The trace ID for the message.
189+
parent_span_id: The parent span ID for the message.
190+
start_to_close_timeout: The start to close timeout for the message.
191+
heartbeat_timeout: The heartbeat timeout for the message.
192+
retry_policy: The retry policy for the message.
193+
request: Additional request context including headers to forward to the agent.
174194
175195
Returns:
176196
The message entry.
@@ -185,6 +205,7 @@ async def send_message(
185205
content=content,
186206
trace_id=trace_id,
187207
parent_span_id=parent_span_id,
208+
request=request,
188209
),
189210
response_type=TaskMessage,
190211
start_to_close_timeout=start_to_close_timeout,
@@ -199,6 +220,7 @@ async def send_message(
199220
content=content,
200221
trace_id=trace_id,
201222
parent_span_id=parent_span_id,
223+
request=request,
202224
)
203225

204226
async def cancel_task(
@@ -212,6 +234,7 @@ async def cancel_task(
212234
start_to_close_timeout: timedelta = timedelta(seconds=5),
213235
heartbeat_timeout: timedelta = timedelta(seconds=5),
214236
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
237+
request: dict[str, Any] | None = None,
215238
) -> Task:
216239
"""
217240
Cancel a task by sending cancel request to the agent that owns the task.
@@ -226,6 +249,7 @@ async def cancel_task(
226249
start_to_close_timeout: The start to close timeout for the task.
227250
heartbeat_timeout: The heartbeat timeout for the task.
228251
retry_policy: The retry policy for the task.
252+
request: Additional request context including headers to forward to the agent.
229253
230254
Returns:
231255
The task entry.
@@ -244,6 +268,7 @@ async def cancel_task(
244268
agent_name=agent_name,
245269
trace_id=trace_id,
246270
parent_span_id=parent_span_id,
271+
request=request,
247272
),
248273
response_type=None,
249274
start_to_close_timeout=start_to_close_timeout,
@@ -258,4 +283,5 @@ async def cancel_task(
258283
agent_name=agent_name,
259284
trace_id=trace_id,
260285
parent_span_id=parent_span_id,
286+
request=request,
261287
)

src/agentex/lib/core/services/adk/acp/acp.py

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from agentex.types.task_message import TaskMessage
1010
from agentex.types.task_message_content import TaskMessageContent
1111
from agentex.types.task_message_content_param import TaskMessageContentParam
12+
from agentex.types.agent_rpc_params import (
13+
ParamsCancelTaskRequest as RpcParamsCancelTaskRequest,
14+
ParamsSendEventRequest as RpcParamsSendEventRequest,
15+
)
1216

1317
logger = make_logger(__name__)
1418

@@ -30,6 +34,7 @@ async def task_create(
3034
params: dict[str, Any] | None = None,
3135
trace_id: str | None = None,
3236
parent_span_id: str | None = None,
37+
request: dict[str, Any] | None = None,
3338
) -> Task:
3439
trace = self._tracer.trace(trace_id=trace_id)
3540
async with trace.span(
@@ -43,6 +48,11 @@ async def task_create(
4348
},
4449
) as span:
4550
heartbeat_if_in_workflow("task create")
51+
52+
# Extract headers from request; pass-through to agent
53+
headers = request.get("headers") if request else None
54+
filtered_headers: dict[str, str] = headers or {}
55+
4656
if agent_name:
4757
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
4858
agent_name=agent_name,
@@ -51,6 +61,7 @@ async def task_create(
5161
"name": name,
5262
"params": params,
5363
},
64+
extra_headers=filtered_headers,
5465
)
5566
elif agent_id:
5667
json_rpc_response = await self._agentex_client.agents.rpc(
@@ -60,6 +71,7 @@ async def task_create(
6071
"name": name,
6172
"params": params,
6273
},
74+
extra_headers=filtered_headers,
6375
)
6476
else:
6577
raise ValueError("Either agent_name or agent_id must be provided")
@@ -78,6 +90,7 @@ async def message_send(
7890
task_name: str | None = None,
7991
trace_id: str | None = None,
8092
parent_span_id: str | None = None,
93+
request: dict[str, Any] | None = None,
8194
) -> List[TaskMessage]:
8295
trace = self._tracer.trace(trace_id=trace_id)
8396
async with trace.span(
@@ -92,6 +105,11 @@ async def message_send(
92105
},
93106
) as span:
94107
heartbeat_if_in_workflow("message send")
108+
109+
# Extract headers from request; pass-through to agent
110+
headers = request.get("headers") if request else None
111+
filtered_headers: dict[str, str] = headers or {}
112+
95113
if agent_name:
96114
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
97115
agent_name=agent_name,
@@ -101,6 +119,7 @@ async def message_send(
101119
"content": cast(TaskMessageContentParam, content.model_dump()),
102120
"stream": False,
103121
},
122+
extra_headers=filtered_headers,
104123
)
105124
elif agent_id:
106125
json_rpc_response = await self._agentex_client.agents.rpc(
@@ -111,12 +130,13 @@ async def message_send(
111130
"content": cast(TaskMessageContentParam, content.model_dump()),
112131
"stream": False,
113132
},
133+
extra_headers=filtered_headers,
114134
)
115135
else:
116136
raise ValueError("Either agent_name or agent_id must be provided")
117137

118138
task_messages: List[TaskMessage] = []
119-
logger.info(f"json_rpc_response: {json_rpc_response}")
139+
logger.info("json_rpc_response: %s", json_rpc_response)
120140
if isinstance(json_rpc_response.result, list):
121141
for message in json_rpc_response.result:
122142
task_message = TaskMessage.model_validate(message)
@@ -137,6 +157,7 @@ async def event_send(
137157
task_name: str | None = None,
138158
trace_id: str | None = None,
139159
parent_span_id: str | None = None,
160+
request: dict[str, Any] | None = None,
140161
) -> Event:
141162
trace = self._tracer.trace(trace_id=trace_id)
142163
async with trace.span(
@@ -146,27 +167,34 @@ async def event_send(
146167
"agent_id": agent_id,
147168
"agent_name": agent_name,
148169
"task_id": task_id,
170+
"task_name": task_name,
149171
"content": content,
150172
},
151173
) as span:
152174
heartbeat_if_in_workflow("event send")
175+
176+
# Extract headers from request; pass-through to agent
177+
headers = request.get("headers") if request else None
178+
filtered_headers: dict[str, str] = headers or {}
179+
180+
rpc_event_params: RpcParamsSendEventRequest = {
181+
"task_id": task_id,
182+
"task_name": task_name,
183+
"content": cast(TaskMessageContentParam, content.model_dump()),
184+
}
153185
if agent_name:
154186
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
155187
agent_name=agent_name,
156188
method="event/send",
157-
params={
158-
"task_id": task_id,
159-
"content": cast(TaskMessageContentParam, content.model_dump()),
160-
},
189+
params=rpc_event_params,
190+
extra_headers=filtered_headers,
161191
)
162192
elif agent_id:
163193
json_rpc_response = await self._agentex_client.agents.rpc(
164194
agent_id=agent_id,
165195
method="event/send",
166-
params={
167-
"task_id": task_id,
168-
"content": cast(TaskMessageContentParam, content.model_dump()),
169-
},
196+
params=rpc_event_params,
197+
extra_headers=filtered_headers,
170198
)
171199
else:
172200
raise ValueError("Either agent_name or agent_id must be provided")
@@ -184,15 +212,34 @@ async def task_cancel(
184212
agent_name: str | None = None,
185213
trace_id: str | None = None,
186214
parent_span_id: str | None = None,
187-
) -> Task:
215+
request: dict[str, Any] | None = None,
216+
) -> Task:
217+
"""
218+
Cancel a task by sending cancel request to the agent that owns the task.
219+
220+
Args:
221+
task_id: ID of the task to cancel (passed to agent in params)
222+
task_name: Name of the task to cancel (passed to agent in params)
223+
agent_id: ID of the agent that owns the task
224+
agent_name: Name of the agent that owns the task
225+
trace_id: Trace ID for tracing
226+
parent_span_id: Parent span ID for tracing
227+
request: Additional request context including headers to forward to the agent
228+
229+
Returns:
230+
Task entry representing the cancelled task
231+
232+
Raises:
233+
ValueError: If neither agent_name nor agent_id is provided,
234+
or if neither task_name nor task_id is provided
235+
"""
188236
# Require agent identification
189237
if not agent_name and not agent_id:
190238
raise ValueError("Either agent_name or agent_id must be provided to identify the agent that owns the task")
191239

192240
# Require task identification
193241
if not task_name and not task_id:
194242
raise ValueError("Either task_name or task_id must be provided to identify the task to cancel")
195-
196243
trace = self._tracer.trace(trace_id=trace_id)
197244
async with trace.span(
198245
parent_id=parent_span_id,
@@ -206,8 +253,12 @@ async def task_cancel(
206253
) as span:
207254
heartbeat_if_in_workflow("task cancel")
208255

256+
# Extract headers from request; pass-through to agent
257+
headers = request.get("headers") if request else None
258+
filtered_headers: dict[str, str] = headers or {}
259+
209260
# Build params for the agent (task identification)
210-
params = {}
261+
params: RpcParamsCancelTaskRequest = {}
211262
if task_id:
212263
params["task_id"] = task_id
213264
if task_name:
@@ -219,12 +270,15 @@ async def task_cancel(
219270
agent_name=agent_name,
220271
method="task/cancel",
221272
params=params,
273+
extra_headers=filtered_headers,
222274
)
223275
else: # agent_id is provided (validated above)
276+
assert agent_id is not None
224277
json_rpc_response = await self._agentex_client.agents.rpc(
225278
agent_id=agent_id,
226279
method="task/cancel",
227280
params=params,
281+
extra_headers=filtered_headers,
228282
)
229283

230284
task_entry = Task.model_validate(json_rpc_response.result)

0 commit comments

Comments
 (0)