Skip to content

Commit f9ef768

Browse files
Pass-through header forwarding with server-side exclusions (#99)
* feat: add header forwarding logic * renamed headers + filtered_headers as extra_headers
1 parent 1074a2a commit f9ef768

File tree

7 files changed

+481
-16
lines changed

7 files changed

+481
-16
lines changed

src/agentex/lib/adk/_modules/acp.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ async def create_task(
5959
start_to_close_timeout: timedelta = timedelta(seconds=5),
6060
heartbeat_timeout: timedelta = timedelta(seconds=5),
6161
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
62+
request: dict[str, Any] | None = None,
6263
) -> Task:
6364
"""
6465
Create a new task.
@@ -71,6 +72,7 @@ async def create_task(
7172
start_to_close_timeout: The start to close timeout for the task.
7273
heartbeat_timeout: The heartbeat timeout for the task.
7374
retry_policy: The retry policy for the task.
75+
request: Additional request context including headers to forward to the agent.
7476
7577
Returns:
7678
The task entry.
@@ -85,6 +87,7 @@ async def create_task(
8587
params=params,
8688
trace_id=trace_id,
8789
parent_span_id=parent_span_id,
90+
request=request,
8891
),
8992
response_type=Task,
9093
start_to_close_timeout=start_to_close_timeout,
@@ -99,6 +102,7 @@ async def create_task(
99102
params=params,
100103
trace_id=trace_id,
101104
parent_span_id=parent_span_id,
105+
request=request,
102106
)
103107

104108
async def send_event(
@@ -112,15 +116,22 @@ async def send_event(
112116
start_to_close_timeout: timedelta = timedelta(seconds=5),
113117
heartbeat_timeout: timedelta = timedelta(seconds=5),
114118
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
119+
request: dict[str, Any] | None = None,
115120
) -> Event:
116121
"""
117122
Send an event to a task.
118123
119124
Args:
120125
task_id: The ID of the task to send the event to.
121-
data: The data to send to the event.
126+
content: The content to send to the event.
122127
agent_id: The ID of the agent to send the event to.
123128
agent_name: The name of the agent to send the event to.
129+
trace_id: The trace ID for the event.
130+
parent_span_id: The parent span ID for the event.
131+
start_to_close_timeout: The start to close timeout for the event.
132+
heartbeat_timeout: The heartbeat timeout for the event.
133+
retry_policy: The retry policy for the event.
134+
request: Additional request context including headers to forward to the agent.
124135
125136
Returns:
126137
The event entry.
@@ -135,6 +146,7 @@ async def send_event(
135146
content=content,
136147
trace_id=trace_id,
137148
parent_span_id=parent_span_id,
149+
request=request,
138150
),
139151
response_type=None,
140152
start_to_close_timeout=start_to_close_timeout,
@@ -149,6 +161,7 @@ async def send_event(
149161
content=content,
150162
trace_id=trace_id,
151163
parent_span_id=parent_span_id,
164+
request=request,
152165
)
153166

154167
async def send_message(
@@ -162,15 +175,22 @@ async def send_message(
162175
start_to_close_timeout: timedelta = timedelta(seconds=5),
163176
heartbeat_timeout: timedelta = timedelta(seconds=5),
164177
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
178+
request: dict[str, Any] | None = None,
165179
) -> List[TaskMessage]:
166180
"""
167181
Send a message to a task.
168182
169183
Args:
170-
task_id: The ID of the task to send the message to.
171184
content: The task message content to send to the task.
185+
task_id: The ID of the task to send the message to.
172186
agent_id: The ID of the agent to send the message to.
173187
agent_name: The name of the agent to send the message to.
188+
trace_id: The trace ID for the message.
189+
parent_span_id: The parent span ID for the message.
190+
start_to_close_timeout: The start to close timeout for the message.
191+
heartbeat_timeout: The heartbeat timeout for the message.
192+
retry_policy: The retry policy for the message.
193+
request: Additional request context including headers to forward to the agent.
174194
175195
Returns:
176196
The message entry.
@@ -185,6 +205,7 @@ async def send_message(
185205
content=content,
186206
trace_id=trace_id,
187207
parent_span_id=parent_span_id,
208+
request=request,
188209
),
189210
response_type=TaskMessage,
190211
start_to_close_timeout=start_to_close_timeout,
@@ -199,6 +220,7 @@ async def send_message(
199220
content=content,
200221
trace_id=trace_id,
201222
parent_span_id=parent_span_id,
223+
request=request,
202224
)
203225

204226
async def cancel_task(
@@ -212,6 +234,7 @@ async def cancel_task(
212234
start_to_close_timeout: timedelta = timedelta(seconds=5),
213235
heartbeat_timeout: timedelta = timedelta(seconds=5),
214236
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
237+
request: dict[str, Any] | None = None,
215238
) -> Task:
216239
"""
217240
Cancel a task by sending cancel request to the agent that owns the task.
@@ -226,6 +249,7 @@ async def cancel_task(
226249
start_to_close_timeout: The start to close timeout for the task.
227250
heartbeat_timeout: The heartbeat timeout for the task.
228251
retry_policy: The retry policy for the task.
252+
request: Additional request context including headers to forward to the agent.
229253
230254
Returns:
231255
The task entry.
@@ -244,6 +268,7 @@ async def cancel_task(
244268
agent_name=agent_name,
245269
trace_id=trace_id,
246270
parent_span_id=parent_span_id,
271+
request=request,
247272
),
248273
response_type=None,
249274
start_to_close_timeout=start_to_close_timeout,
@@ -258,4 +283,5 @@ async def cancel_task(
258283
agent_name=agent_name,
259284
trace_id=trace_id,
260285
parent_span_id=parent_span_id,
286+
request=request,
261287
)

src/agentex/lib/core/services/adk/acp/acp.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from agentex.types.task_message import TaskMessage
1010
from agentex.types.task_message_content import TaskMessageContent
1111
from agentex.types.task_message_content_param import TaskMessageContentParam
12+
from agentex.types.agent_rpc_params import (
13+
ParamsCancelTaskRequest as RpcParamsCancelTaskRequest,
14+
ParamsSendEventRequest as RpcParamsSendEventRequest,
15+
)
1216

1317
logger = make_logger(__name__)
1418

@@ -30,6 +34,7 @@ async def task_create(
3034
params: dict[str, Any] | None = None,
3135
trace_id: str | None = None,
3236
parent_span_id: str | None = None,
37+
request: dict[str, Any] | None = None,
3338
) -> Task:
3439
trace = self._tracer.trace(trace_id=trace_id)
3540
async with trace.span(
@@ -43,6 +48,10 @@ async def task_create(
4348
},
4449
) as span:
4550
heartbeat_if_in_workflow("task create")
51+
52+
# Extract headers from request; pass-through to agent
53+
extra_headers = request.get("headers") if request else None
54+
4655
if agent_name:
4756
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
4857
agent_name=agent_name,
@@ -51,6 +60,7 @@ async def task_create(
5160
"name": name,
5261
"params": params,
5362
},
63+
extra_headers=extra_headers,
5464
)
5565
elif agent_id:
5666
json_rpc_response = await self._agentex_client.agents.rpc(
@@ -60,6 +70,7 @@ async def task_create(
6070
"name": name,
6171
"params": params,
6272
},
73+
extra_headers=extra_headers,
6374
)
6475
else:
6576
raise ValueError("Either agent_name or agent_id must be provided")
@@ -78,6 +89,7 @@ async def message_send(
7889
task_name: str | None = None,
7990
trace_id: str | None = None,
8091
parent_span_id: str | None = None,
92+
request: dict[str, Any] | None = None,
8193
) -> List[TaskMessage]:
8294
trace = self._tracer.trace(trace_id=trace_id)
8395
async with trace.span(
@@ -92,6 +104,10 @@ async def message_send(
92104
},
93105
) as span:
94106
heartbeat_if_in_workflow("message send")
107+
108+
# Extract headers from request; pass-through to agent
109+
extra_headers = request.get("headers") if request else None
110+
95111
if agent_name:
96112
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
97113
agent_name=agent_name,
@@ -101,6 +117,7 @@ async def message_send(
101117
"content": cast(TaskMessageContentParam, content.model_dump()),
102118
"stream": False,
103119
},
120+
extra_headers=extra_headers,
104121
)
105122
elif agent_id:
106123
json_rpc_response = await self._agentex_client.agents.rpc(
@@ -111,12 +128,13 @@ async def message_send(
111128
"content": cast(TaskMessageContentParam, content.model_dump()),
112129
"stream": False,
113130
},
131+
extra_headers=extra_headers,
114132
)
115133
else:
116134
raise ValueError("Either agent_name or agent_id must be provided")
117135

118136
task_messages: List[TaskMessage] = []
119-
logger.info(f"json_rpc_response: {json_rpc_response}")
137+
logger.info("json_rpc_response: %s", json_rpc_response)
120138
if isinstance(json_rpc_response.result, list):
121139
for message in json_rpc_response.result:
122140
task_message = TaskMessage.model_validate(message)
@@ -137,6 +155,7 @@ async def event_send(
137155
task_name: str | None = None,
138156
trace_id: str | None = None,
139157
parent_span_id: str | None = None,
158+
request: dict[str, Any] | None = None,
140159
) -> Event:
141160
trace = self._tracer.trace(trace_id=trace_id)
142161
async with trace.span(
@@ -146,27 +165,33 @@ async def event_send(
146165
"agent_id": agent_id,
147166
"agent_name": agent_name,
148167
"task_id": task_id,
168+
"task_name": task_name,
149169
"content": content,
150170
},
151171
) as span:
152172
heartbeat_if_in_workflow("event send")
173+
174+
# Extract headers from request; pass-through to agent
175+
extra_headers = request.get("headers") if request else None
176+
177+
rpc_event_params: RpcParamsSendEventRequest = {
178+
"task_id": task_id,
179+
"task_name": task_name,
180+
"content": cast(TaskMessageContentParam, content.model_dump()),
181+
}
153182
if agent_name:
154183
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
155184
agent_name=agent_name,
156185
method="event/send",
157-
params={
158-
"task_id": task_id,
159-
"content": cast(TaskMessageContentParam, content.model_dump()),
160-
},
186+
params=rpc_event_params,
187+
extra_headers=extra_headers,
161188
)
162189
elif agent_id:
163190
json_rpc_response = await self._agentex_client.agents.rpc(
164191
agent_id=agent_id,
165192
method="event/send",
166-
params={
167-
"task_id": task_id,
168-
"content": cast(TaskMessageContentParam, content.model_dump()),
169-
},
193+
params=rpc_event_params,
194+
extra_headers=extra_headers,
170195
)
171196
else:
172197
raise ValueError("Either agent_name or agent_id must be provided")
@@ -184,15 +209,34 @@ async def task_cancel(
184209
agent_name: str | None = None,
185210
trace_id: str | None = None,
186211
parent_span_id: str | None = None,
187-
) -> Task:
212+
request: dict[str, Any] | None = None,
213+
) -> Task:
214+
"""
215+
Cancel a task by sending cancel request to the agent that owns the task.
216+
217+
Args:
218+
task_id: ID of the task to cancel (passed to agent in params)
219+
task_name: Name of the task to cancel (passed to agent in params)
220+
agent_id: ID of the agent that owns the task
221+
agent_name: Name of the agent that owns the task
222+
trace_id: Trace ID for tracing
223+
parent_span_id: Parent span ID for tracing
224+
request: Additional request context including headers to forward to the agent
225+
226+
Returns:
227+
Task entry representing the cancelled task
228+
229+
Raises:
230+
ValueError: If neither agent_name nor agent_id is provided,
231+
or if neither task_name nor task_id is provided
232+
"""
188233
# Require agent identification
189234
if not agent_name and not agent_id:
190235
raise ValueError("Either agent_name or agent_id must be provided to identify the agent that owns the task")
191236

192237
# Require task identification
193238
if not task_name and not task_id:
194239
raise ValueError("Either task_name or task_id must be provided to identify the task to cancel")
195-
196240
trace = self._tracer.trace(trace_id=trace_id)
197241
async with trace.span(
198242
parent_id=parent_span_id,
@@ -206,8 +250,11 @@ async def task_cancel(
206250
) as span:
207251
heartbeat_if_in_workflow("task cancel")
208252

253+
# Extract headers from request; pass-through to agent
254+
extra_headers = request.get("headers") if request else None
255+
209256
# Build params for the agent (task identification)
210-
params = {}
257+
params: RpcParamsCancelTaskRequest = {}
211258
if task_id:
212259
params["task_id"] = task_id
213260
if task_name:
@@ -219,12 +266,15 @@ async def task_cancel(
219266
agent_name=agent_name,
220267
method="task/cancel",
221268
params=params,
269+
extra_headers=extra_headers,
222270
)
223271
else: # agent_id is provided (validated above)
272+
assert agent_id is not None
224273
json_rpc_response = await self._agentex_client.agents.rpc(
225274
agent_id=agent_id,
226275
method="task/cancel",
227276
params=params,
277+
extra_headers=extra_headers,
228278
)
229279

230280
task_entry = Task.model_validate(json_rpc_response.result)

0 commit comments

Comments
 (0)