Skip to content

Commit 6a92035

Browse files
feat: add header forwarding logic
1 parent 48cdeea commit 6a92035

File tree

6 files changed

+370
-6
lines changed

6 files changed

+370
-6
lines changed

src/agentex/lib/adk/_modules/acp.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ async def create_task(
5959
start_to_close_timeout: timedelta = timedelta(seconds=5),
6060
heartbeat_timeout: timedelta = timedelta(seconds=5),
6161
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
62+
extra_headers: dict[str, str] | None = None,
6263
) -> Task:
6364
"""
6465
Create a new task.
@@ -71,6 +72,7 @@ async def create_task(
7172
start_to_close_timeout: The start to close timeout for the task.
7273
heartbeat_timeout: The heartbeat timeout for the task.
7374
retry_policy: The retry policy for the task.
75+
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
7476
7577
Returns:
7678
The task entry.
@@ -85,6 +87,7 @@ async def create_task(
8587
params=params,
8688
trace_id=trace_id,
8789
parent_span_id=parent_span_id,
90+
extra_headers=extra_headers,
8891
),
8992
response_type=Task,
9093
start_to_close_timeout=start_to_close_timeout,
@@ -99,6 +102,7 @@ async def create_task(
99102
params=params,
100103
trace_id=trace_id,
101104
parent_span_id=parent_span_id,
105+
extra_headers=extra_headers,
102106
)
103107

104108
async def send_event(
@@ -112,15 +116,22 @@ async def send_event(
112116
start_to_close_timeout: timedelta = timedelta(seconds=5),
113117
heartbeat_timeout: timedelta = timedelta(seconds=5),
114118
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
119+
extra_headers: dict[str, str] | None = None,
115120
) -> Event:
116121
"""
117122
Send an event to a task.
118123
119124
Args:
120125
task_id: The ID of the task to send the event to.
121-
data: The data to send to the event.
126+
content: The content to send to the event.
122127
agent_id: The ID of the agent to send the event to.
123128
agent_name: The name of the agent to send the event to.
129+
trace_id: The trace ID for the event.
130+
parent_span_id: The parent span ID for the event.
131+
start_to_close_timeout: The start to close timeout for the event.
132+
heartbeat_timeout: The heartbeat timeout for the event.
133+
retry_policy: The retry policy for the event.
134+
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
124135
125136
Returns:
126137
The event entry.
@@ -135,6 +146,7 @@ async def send_event(
135146
content=content,
136147
trace_id=trace_id,
137148
parent_span_id=parent_span_id,
149+
extra_headers=extra_headers,
138150
),
139151
response_type=None,
140152
start_to_close_timeout=start_to_close_timeout,
@@ -149,6 +161,7 @@ async def send_event(
149161
content=content,
150162
trace_id=trace_id,
151163
parent_span_id=parent_span_id,
164+
extra_headers=extra_headers,
152165
)
153166

154167
async def send_message(
@@ -162,15 +175,22 @@ async def send_message(
162175
start_to_close_timeout: timedelta = timedelta(seconds=5),
163176
heartbeat_timeout: timedelta = timedelta(seconds=5),
164177
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
178+
extra_headers: dict[str, str] | None = None,
165179
) -> List[TaskMessage]:
166180
"""
167181
Send a message to a task.
168182
169183
Args:
170-
task_id: The ID of the task to send the message to.
171184
content: The task message content to send to the task.
185+
task_id: The ID of the task to send the message to.
172186
agent_id: The ID of the agent to send the message to.
173187
agent_name: The name of the agent to send the message to.
188+
trace_id: The trace ID for the message.
189+
parent_span_id: The parent span ID for the message.
190+
start_to_close_timeout: The start to close timeout for the message.
191+
heartbeat_timeout: The heartbeat timeout for the message.
192+
retry_policy: The retry policy for the message.
193+
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
174194
175195
Returns:
176196
The message entry.
@@ -185,6 +205,7 @@ async def send_message(
185205
content=content,
186206
trace_id=trace_id,
187207
parent_span_id=parent_span_id,
208+
extra_headers=extra_headers,
188209
),
189210
response_type=TaskMessage,
190211
start_to_close_timeout=start_to_close_timeout,
@@ -199,6 +220,7 @@ async def send_message(
199220
content=content,
200221
trace_id=trace_id,
201222
parent_span_id=parent_span_id,
223+
extra_headers=extra_headers,
202224
)
203225

204226
async def cancel_task(
@@ -212,6 +234,7 @@ async def cancel_task(
212234
start_to_close_timeout: timedelta = timedelta(seconds=5),
213235
heartbeat_timeout: timedelta = timedelta(seconds=5),
214236
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
237+
extra_headers: dict[str, str] | None = None,
215238
) -> Task:
216239
"""
217240
Cancel a task by sending cancel request to the agent that owns the task.
@@ -226,6 +249,7 @@ async def cancel_task(
226249
start_to_close_timeout: The start to close timeout for the task.
227250
heartbeat_timeout: The heartbeat timeout for the task.
228251
retry_policy: The retry policy for the task.
252+
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
229253
230254
Returns:
231255
The task entry.
@@ -244,6 +268,7 @@ async def cancel_task(
244268
agent_name=agent_name,
245269
trace_id=trace_id,
246270
parent_span_id=parent_span_id,
271+
extra_headers=extra_headers,
247272
),
248273
response_type=None,
249274
start_to_close_timeout=start_to_close_timeout,
@@ -258,4 +283,5 @@ async def cancel_task(
258283
agent_name=agent_name,
259284
trace_id=trace_id,
260285
parent_span_id=parent_span_id,
286+
extra_headers=extra_headers,
261287
)

0 commit comments

Comments
 (0)