Skip to content

Commit c7731c8

Browse files
feat: add header forwarding logic
1 parent 48cdeea commit c7731c8

File tree

7 files changed

+486
-16
lines changed

7 files changed

+486
-16
lines changed

src/agentex/lib/adk/_modules/acp.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ async def create_task(
5959
start_to_close_timeout: timedelta = timedelta(seconds=5),
6060
heartbeat_timeout: timedelta = timedelta(seconds=5),
6161
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
62+
request: dict[str, Any] | None = None,
6263
) -> Task:
6364
"""
6465
Create a new task.
@@ -71,6 +72,7 @@ async def create_task(
7172
start_to_close_timeout: The start to close timeout for the task.
7273
heartbeat_timeout: The heartbeat timeout for the task.
7374
retry_policy: The retry policy for the task.
75+
request: Additional request context including headers to forward to the agent.
7476
7577
Returns:
7678
The task entry.
@@ -85,6 +87,7 @@ async def create_task(
8587
params=params,
8688
trace_id=trace_id,
8789
parent_span_id=parent_span_id,
90+
request=request,
8891
),
8992
response_type=Task,
9093
start_to_close_timeout=start_to_close_timeout,
@@ -99,6 +102,7 @@ async def create_task(
99102
params=params,
100103
trace_id=trace_id,
101104
parent_span_id=parent_span_id,
105+
request=request,
102106
)
103107

104108
async def send_event(
@@ -112,15 +116,22 @@ async def send_event(
112116
start_to_close_timeout: timedelta = timedelta(seconds=5),
113117
heartbeat_timeout: timedelta = timedelta(seconds=5),
114118
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
119+
request: dict[str, Any] | None = None,
115120
) -> Event:
116121
"""
117122
Send an event to a task.
118123
119124
Args:
120125
task_id: The ID of the task to send the event to.
121-
data: The data to send to the event.
126+
content: The content to send to the event.
122127
agent_id: The ID of the agent to send the event to.
123128
agent_name: The name of the agent to send the event to.
129+
trace_id: The trace ID for the event.
130+
parent_span_id: The parent span ID for the event.
131+
start_to_close_timeout: The start to close timeout for the event.
132+
heartbeat_timeout: The heartbeat timeout for the event.
133+
retry_policy: The retry policy for the event.
134+
request: Additional request context including headers to forward to the agent.
124135
125136
Returns:
126137
The event entry.
@@ -135,6 +146,7 @@ async def send_event(
135146
content=content,
136147
trace_id=trace_id,
137148
parent_span_id=parent_span_id,
149+
request=request,
138150
),
139151
response_type=None,
140152
start_to_close_timeout=start_to_close_timeout,
@@ -149,6 +161,7 @@ async def send_event(
149161
content=content,
150162
trace_id=trace_id,
151163
parent_span_id=parent_span_id,
164+
request=request,
152165
)
153166

154167
async def send_message(
@@ -162,15 +175,22 @@ async def send_message(
162175
start_to_close_timeout: timedelta = timedelta(seconds=5),
163176
heartbeat_timeout: timedelta = timedelta(seconds=5),
164177
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
178+
request: dict[str, Any] | None = None,
165179
) -> List[TaskMessage]:
166180
"""
167181
Send a message to a task.
168182
169183
Args:
170-
task_id: The ID of the task to send the message to.
171184
content: The task message content to send to the task.
185+
task_id: The ID of the task to send the message to.
172186
agent_id: The ID of the agent to send the message to.
173187
agent_name: The name of the agent to send the message to.
188+
trace_id: The trace ID for the message.
189+
parent_span_id: The parent span ID for the message.
190+
start_to_close_timeout: The start to close timeout for the message.
191+
heartbeat_timeout: The heartbeat timeout for the message.
192+
retry_policy: The retry policy for the message.
193+
request: Additional request context including headers to forward to the agent.
174194
175195
Returns:
176196
The message entry.
@@ -185,6 +205,7 @@ async def send_message(
185205
content=content,
186206
trace_id=trace_id,
187207
parent_span_id=parent_span_id,
208+
request=request,
188209
),
189210
response_type=TaskMessage,
190211
start_to_close_timeout=start_to_close_timeout,
@@ -199,6 +220,7 @@ async def send_message(
199220
content=content,
200221
trace_id=trace_id,
201222
parent_span_id=parent_span_id,
223+
request=request,
202224
)
203225

204226
async def cancel_task(
@@ -212,6 +234,7 @@ async def cancel_task(
212234
start_to_close_timeout: timedelta = timedelta(seconds=5),
213235
heartbeat_timeout: timedelta = timedelta(seconds=5),
214236
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
237+
request: dict[str, Any] | None = None,
215238
) -> Task:
216239
"""
217240
Cancel a task by sending cancel request to the agent that owns the task.
@@ -226,6 +249,7 @@ async def cancel_task(
226249
start_to_close_timeout: The start to close timeout for the task.
227250
heartbeat_timeout: The heartbeat timeout for the task.
228251
retry_policy: The retry policy for the task.
252+
request: Additional request context including headers to forward to the agent.
229253
230254
Returns:
231255
The task entry.
@@ -244,6 +268,7 @@ async def cancel_task(
244268
agent_name=agent_name,
245269
trace_id=trace_id,
246270
parent_span_id=parent_span_id,
271+
request=request,
247272
),
248273
response_type=None,
249274
start_to_close_timeout=start_to_close_timeout,
@@ -258,4 +283,5 @@ async def cancel_task(
258283
agent_name=agent_name,
259284
trace_id=trace_id,
260285
parent_span_id=parent_span_id,
286+
request=request,
261287
)

src/agentex/lib/core/services/adk/acp/acp.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
from agentex.lib.core.tracing.tracer import AsyncTracer
55
from agentex.lib.utils.logging import make_logger
66
from agentex.lib.utils.temporal import heartbeat_if_in_workflow
7+
# No longer need CustomHeadersConfig import
78
from agentex.types.event import Event
89
from agentex.types.task import Task
910
from agentex.types.task_message import TaskMessage
1011
from agentex.types.task_message_content import TaskMessageContent
1112
from agentex.types.task_message_content_param import TaskMessageContentParam
13+
from agentex.types.agent_rpc_params import (
14+
ParamsCancelTaskRequest as RpcParamsCancelTaskRequest,
15+
ParamsSendEventRequest as RpcParamsSendEventRequest,
16+
)
1217

1318
logger = make_logger(__name__)
1419

@@ -30,6 +35,7 @@ async def task_create(
3035
params: dict[str, Any] | None = None,
3136
trace_id: str | None = None,
3237
parent_span_id: str | None = None,
38+
request: dict[str, Any] | None = None,
3339
) -> Task:
3440
trace = self._tracer.trace(trace_id=trace_id)
3541
async with trace.span(
@@ -43,6 +49,11 @@ async def task_create(
4349
},
4450
) as span:
4551
heartbeat_if_in_workflow("task create")
52+
53+
# Extract headers from request; pass-through to agent
54+
headers = request.get("headers") if request else None
55+
filtered_headers: dict[str, str] = headers or {}
56+
4657
if agent_name:
4758
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
4859
agent_name=agent_name,
@@ -51,6 +62,7 @@ async def task_create(
5162
"name": name,
5263
"params": params,
5364
},
65+
extra_headers=filtered_headers,
5466
)
5567
elif agent_id:
5668
json_rpc_response = await self._agentex_client.agents.rpc(
@@ -60,6 +72,7 @@ async def task_create(
6072
"name": name,
6173
"params": params,
6274
},
75+
extra_headers=filtered_headers,
6376
)
6477
else:
6578
raise ValueError("Either agent_name or agent_id must be provided")
@@ -78,6 +91,7 @@ async def message_send(
7891
task_name: str | None = None,
7992
trace_id: str | None = None,
8093
parent_span_id: str | None = None,
94+
request: dict[str, Any] | None = None,
8195
) -> List[TaskMessage]:
8296
trace = self._tracer.trace(trace_id=trace_id)
8397
async with trace.span(
@@ -92,6 +106,11 @@ async def message_send(
92106
},
93107
) as span:
94108
heartbeat_if_in_workflow("message send")
109+
110+
# Extract headers from request; pass-through to agent
111+
headers = request.get("headers") if request else None
112+
filtered_headers: dict[str, str] = headers or {}
113+
95114
if agent_name:
96115
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
97116
agent_name=agent_name,
@@ -101,6 +120,7 @@ async def message_send(
101120
"content": cast(TaskMessageContentParam, content.model_dump()),
102121
"stream": False,
103122
},
123+
extra_headers=filtered_headers,
104124
)
105125
elif agent_id:
106126
json_rpc_response = await self._agentex_client.agents.rpc(
@@ -111,12 +131,13 @@ async def message_send(
111131
"content": cast(TaskMessageContentParam, content.model_dump()),
112132
"stream": False,
113133
},
134+
extra_headers=filtered_headers,
114135
)
115136
else:
116137
raise ValueError("Either agent_name or agent_id must be provided")
117138

118139
task_messages: List[TaskMessage] = []
119-
logger.info(f"json_rpc_response: {json_rpc_response}")
140+
logger.info("json_rpc_response: %s", json_rpc_response)
120141
if isinstance(json_rpc_response.result, list):
121142
for message in json_rpc_response.result:
122143
task_message = TaskMessage.model_validate(message)
@@ -137,6 +158,7 @@ async def event_send(
137158
task_name: str | None = None,
138159
trace_id: str | None = None,
139160
parent_span_id: str | None = None,
161+
request: dict[str, Any] | None = None,
140162
) -> Event:
141163
trace = self._tracer.trace(trace_id=trace_id)
142164
async with trace.span(
@@ -146,27 +168,34 @@ async def event_send(
146168
"agent_id": agent_id,
147169
"agent_name": agent_name,
148170
"task_id": task_id,
171+
"task_name": task_name,
149172
"content": content,
150173
},
151174
) as span:
152175
heartbeat_if_in_workflow("event send")
176+
177+
# Extract headers from request; pass-through to agent
178+
headers = request.get("headers") if request else None
179+
filtered_headers: dict[str, str] = headers or {}
180+
181+
rpc_event_params: RpcParamsSendEventRequest = {
182+
"task_id": task_id,
183+
"task_name": task_name,
184+
"content": cast(TaskMessageContentParam, content.model_dump()),
185+
}
153186
if agent_name:
154187
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
155188
agent_name=agent_name,
156189
method="event/send",
157-
params={
158-
"task_id": task_id,
159-
"content": cast(TaskMessageContentParam, content.model_dump()),
160-
},
190+
params=rpc_event_params,
191+
extra_headers=filtered_headers,
161192
)
162193
elif agent_id:
163194
json_rpc_response = await self._agentex_client.agents.rpc(
164195
agent_id=agent_id,
165196
method="event/send",
166-
params={
167-
"task_id": task_id,
168-
"content": cast(TaskMessageContentParam, content.model_dump()),
169-
},
197+
params=rpc_event_params,
198+
extra_headers=filtered_headers,
170199
)
171200
else:
172201
raise ValueError("Either agent_name or agent_id must be provided")
@@ -184,15 +213,34 @@ async def task_cancel(
184213
agent_name: str | None = None,
185214
trace_id: str | None = None,
186215
parent_span_id: str | None = None,
187-
) -> Task:
216+
request: dict[str, Any] | None = None,
217+
) -> Task:
218+
"""
219+
Cancel a task by sending cancel request to the agent that owns the task.
220+
221+
Args:
222+
task_id: ID of the task to cancel (passed to agent in params)
223+
task_name: Name of the task to cancel (passed to agent in params)
224+
agent_id: ID of the agent that owns the task
225+
agent_name: Name of the agent that owns the task
226+
trace_id: Trace ID for tracing
227+
parent_span_id: Parent span ID for tracing
228+
request: Additional request context including headers to forward to the agent
229+
230+
Returns:
231+
Task entry representing the cancelled task
232+
233+
Raises:
234+
ValueError: If neither agent_name nor agent_id is provided,
235+
or if neither task_name nor task_id is provided
236+
"""
188237
# Require agent identification
189238
if not agent_name and not agent_id:
190239
raise ValueError("Either agent_name or agent_id must be provided to identify the agent that owns the task")
191240

192241
# Require task identification
193242
if not task_name and not task_id:
194243
raise ValueError("Either task_name or task_id must be provided to identify the task to cancel")
195-
196244
trace = self._tracer.trace(trace_id=trace_id)
197245
async with trace.span(
198246
parent_id=parent_span_id,
@@ -206,8 +254,12 @@ async def task_cancel(
206254
) as span:
207255
heartbeat_if_in_workflow("task cancel")
208256

257+
# Extract headers from request; pass-through to agent
258+
headers = request.get("headers") if request else None
259+
filtered_headers: dict[str, str] = headers or {}
260+
209261
# Build params for the agent (task identification)
210-
params = {}
262+
params: RpcParamsCancelTaskRequest = {}
211263
if task_id:
212264
params["task_id"] = task_id
213265
if task_name:
@@ -219,12 +271,15 @@ async def task_cancel(
219271
agent_name=agent_name,
220272
method="task/cancel",
221273
params=params,
274+
extra_headers=filtered_headers,
222275
)
223276
else: # agent_id is provided (validated above)
277+
assert agent_id is not None
224278
json_rpc_response = await self._agentex_client.agents.rpc(
225279
agent_id=agent_id,
226280
method="task/cancel",
227281
params=params,
282+
extra_headers=filtered_headers,
228283
)
229284

230285
task_entry = Task.model_validate(json_rpc_response.result)

0 commit comments

Comments
 (0)