Skip to content

Commit eee0ea2

Browse files
feat: add header forwarding logic
1 parent 9665b81 commit eee0ea2

File tree

7 files changed

+359
-6
lines changed

7 files changed

+359
-6
lines changed

src/agentex/lib/adk/_modules/acp.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ async def create_task(
5959
start_to_close_timeout: timedelta = timedelta(seconds=5),
6060
heartbeat_timeout: timedelta = timedelta(seconds=5),
6161
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
62+
extra_headers: dict[str, str] | None = None,
6263
) -> Task:
6364
"""
6465
Create a new task.
@@ -71,6 +72,7 @@ async def create_task(
7172
start_to_close_timeout: The start to close timeout for the task.
7273
heartbeat_timeout: The heartbeat timeout for the task.
7374
retry_policy: The retry policy for the task.
75+
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
7476
7577
Returns:
7678
The task entry.
@@ -85,6 +87,7 @@ async def create_task(
8587
params=params,
8688
trace_id=trace_id,
8789
parent_span_id=parent_span_id,
90+
extra_headers=extra_headers,
8891
),
8992
response_type=Task,
9093
start_to_close_timeout=start_to_close_timeout,
@@ -99,6 +102,7 @@ async def create_task(
99102
params=params,
100103
trace_id=trace_id,
101104
parent_span_id=parent_span_id,
105+
extra_headers=extra_headers,
102106
)
103107

104108
async def send_event(
@@ -112,15 +116,22 @@ async def send_event(
112116
start_to_close_timeout: timedelta = timedelta(seconds=5),
113117
heartbeat_timeout: timedelta = timedelta(seconds=5),
114118
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
119+
extra_headers: dict[str, str] | None = None,
115120
) -> Event:
116121
"""
117122
Send an event to a task.
118123
119124
Args:
120125
task_id: The ID of the task to send the event to.
121-
data: The data to send to the event.
126+
content: The content to send to the event.
122127
agent_id: The ID of the agent to send the event to.
123128
agent_name: The name of the agent to send the event to.
129+
trace_id: The trace ID for the event.
130+
parent_span_id: The parent span ID for the event.
131+
start_to_close_timeout: The start to close timeout for the event.
132+
heartbeat_timeout: The heartbeat timeout for the event.
133+
retry_policy: The retry policy for the event.
134+
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
124135
125136
Returns:
126137
The event entry.
@@ -135,6 +146,7 @@ async def send_event(
135146
content=content,
136147
trace_id=trace_id,
137148
parent_span_id=parent_span_id,
149+
extra_headers=extra_headers,
138150
),
139151
response_type=None,
140152
start_to_close_timeout=start_to_close_timeout,
@@ -149,6 +161,7 @@ async def send_event(
149161
content=content,
150162
trace_id=trace_id,
151163
parent_span_id=parent_span_id,
164+
extra_headers=extra_headers,
152165
)
153166

154167
async def send_message(
@@ -162,15 +175,22 @@ async def send_message(
162175
start_to_close_timeout: timedelta = timedelta(seconds=5),
163176
heartbeat_timeout: timedelta = timedelta(seconds=5),
164177
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
178+
extra_headers: dict[str, str] | None = None,
165179
) -> List[TaskMessage]:
166180
"""
167181
Send a message to a task.
168182
169183
Args:
170-
task_id: The ID of the task to send the message to.
171184
content: The task message content to send to the task.
185+
task_id: The ID of the task to send the message to.
172186
agent_id: The ID of the agent to send the message to.
173187
agent_name: The name of the agent to send the message to.
188+
trace_id: The trace ID for the message.
189+
parent_span_id: The parent span ID for the message.
190+
start_to_close_timeout: The start to close timeout for the message.
191+
heartbeat_timeout: The heartbeat timeout for the message.
192+
retry_policy: The retry policy for the message.
193+
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
174194
175195
Returns:
176196
The message entry.
@@ -185,6 +205,7 @@ async def send_message(
185205
content=content,
186206
trace_id=trace_id,
187207
parent_span_id=parent_span_id,
208+
extra_headers=extra_headers,
188209
),
189210
response_type=TaskMessage,
190211
start_to_close_timeout=start_to_close_timeout,
@@ -199,6 +220,7 @@ async def send_message(
199220
content=content,
200221
trace_id=trace_id,
201222
parent_span_id=parent_span_id,
223+
extra_headers=extra_headers,
202224
)
203225

204226
async def cancel_task(
@@ -210,6 +232,7 @@ async def cancel_task(
210232
start_to_close_timeout: timedelta = timedelta(seconds=5),
211233
heartbeat_timeout: timedelta = timedelta(seconds=5),
212234
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
235+
extra_headers: dict[str, str] | None = None,
213236
) -> Task:
214237
"""
215238
Cancel a task.
@@ -222,6 +245,7 @@ async def cancel_task(
222245
start_to_close_timeout: The start to close timeout for the task.
223246
heartbeat_timeout: The heartbeat timeout for the task.
224247
retry_policy: The retry policy for the task.
248+
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
225249
226250
Returns:
227251
The task entry.
@@ -234,6 +258,7 @@ async def cancel_task(
234258
task_name=task_name,
235259
trace_id=trace_id,
236260
parent_span_id=parent_span_id,
261+
extra_headers=extra_headers,
237262
),
238263
response_type=None,
239264
start_to_close_timeout=start_to_close_timeout,
@@ -246,4 +271,5 @@ async def cancel_task(
246271
task_name=task_name,
247272
trace_id=trace_id,
248273
parent_span_id=parent_span_id,
274+
extra_headers=extra_headers,
249275
)

src/agentex/lib/core/services/adk/acp/acp.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import Any, List, cast
2+
import fnmatch
23

34
from agentex import AsyncAgentex
45
from agentex.lib.core.tracing.tracer import AsyncTracer
56
from agentex.lib.utils.logging import make_logger
67
from agentex.lib.utils.temporal import heartbeat_if_in_workflow
8+
from agentex.lib.types.agent_configs import CustomHeadersConfig
79
from agentex.types.event import Event
810
from agentex.types.task import Task
911
from agentex.types.task_message import TaskMessage
@@ -22,6 +24,73 @@ def __init__(
2224
self._agentex_client = agentex_client
2325
self._tracer = tracer
2426

27+
async def _get_agent_header_config(
28+
self, agent_name: str | None, agent_id: str | None
29+
) -> CustomHeadersConfig:
30+
"""
31+
Get agent's header configuration from manifest.
32+
33+
For now, we'll return a default empty config since we need to implement
34+
manifest loading. This provides secure-by-default behavior.
35+
36+
TODO: Implement actual manifest loading to get agent's custom_headers config
37+
"""
38+
# Default configuration - no headers allowed (secure by default)
39+
# This will be replaced with actual manifest loading
40+
return CustomHeadersConfig(
41+
strategy="allowlist",
42+
allowed_headers=[],
43+
max_header_size=8192,
44+
max_headers_count=50
45+
)
46+
47+
def _filter_headers(
48+
self,
49+
extra_headers: dict[str, str] | None,
50+
config: CustomHeadersConfig
51+
) -> dict[str, str]:
52+
"""
53+
Filter headers based on agent's configuration.
54+
55+
Args:
56+
extra_headers: Headers to filter
57+
config: Agent's header configuration
58+
59+
Returns:
60+
Filtered headers dictionary containing only allowed headers
61+
"""
62+
if not extra_headers or not config.allowed_headers:
63+
return {}
64+
65+
filtered = {}
66+
headers_added = 0
67+
68+
for header_name, header_value in extra_headers.items():
69+
# Check if we've hit the max headers limit
70+
if headers_added >= config.max_headers_count:
71+
logger.warning("Reached maximum header count limit (%s), ignoring remaining headers", config.max_headers_count)
72+
break
73+
74+
# Check against allowlist patterns
75+
header_allowed = False
76+
for pattern in config.allowed_headers:
77+
if fnmatch.fnmatch(header_name.lower(), pattern.lower()):
78+
header_allowed = True
79+
break
80+
81+
if header_allowed:
82+
# Apply size limits
83+
if len(header_value) <= config.max_header_size:
84+
filtered[header_name] = header_value
85+
headers_added += 1
86+
logger.debug("Allowed header: %s", header_name)
87+
else:
88+
logger.warning("Header '%s' exceeds size limit (%s > %s), ignoring", header_name, len(header_value), config.max_header_size)
89+
else:
90+
logger.debug("Header '%s' not in allowlist, ignoring", header_name)
91+
92+
return filtered
93+
2594
async def task_create(
2695
self,
2796
name: str | None = None,
@@ -30,6 +99,7 @@ async def task_create(
3099
params: dict[str, Any] | None = None,
31100
trace_id: str | None = None,
32101
parent_span_id: str | None = None,
102+
extra_headers: dict[str, str] | None = None,
33103
) -> Task:
34104
trace = self._tracer.trace(trace_id=trace_id)
35105
async with trace.span(
@@ -43,6 +113,11 @@ async def task_create(
43113
},
44114
) as span:
45115
heartbeat_if_in_workflow("task create")
116+
117+
# Get agent's header configuration and filter headers
118+
header_config = await self._get_agent_header_config(agent_name, agent_id)
119+
filtered_headers = self._filter_headers(extra_headers, header_config)
120+
46121
if agent_name:
47122
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
48123
agent_name=agent_name,
@@ -51,6 +126,7 @@ async def task_create(
51126
"name": name,
52127
"params": params,
53128
},
129+
extra_headers=filtered_headers,
54130
)
55131
elif agent_id:
56132
json_rpc_response = await self._agentex_client.agents.rpc(
@@ -60,6 +136,7 @@ async def task_create(
60136
"name": name,
61137
"params": params,
62138
},
139+
extra_headers=filtered_headers,
63140
)
64141
else:
65142
raise ValueError("Either agent_name or agent_id must be provided")
@@ -78,6 +155,7 @@ async def message_send(
78155
task_name: str | None = None,
79156
trace_id: str | None = None,
80157
parent_span_id: str | None = None,
158+
extra_headers: dict[str, str] | None = None,
81159
) -> List[TaskMessage]:
82160
trace = self._tracer.trace(trace_id=trace_id)
83161
async with trace.span(
@@ -92,6 +170,11 @@ async def message_send(
92170
},
93171
) as span:
94172
heartbeat_if_in_workflow("message send")
173+
174+
# Get agent's header configuration and filter headers
175+
header_config = await self._get_agent_header_config(agent_name, agent_id)
176+
filtered_headers = self._filter_headers(extra_headers, header_config)
177+
95178
if agent_name:
96179
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
97180
agent_name=agent_name,
@@ -101,6 +184,7 @@ async def message_send(
101184
"content": cast(TaskMessageContentParam, content.model_dump()),
102185
"stream": False,
103186
},
187+
extra_headers=filtered_headers,
104188
)
105189
elif agent_id:
106190
json_rpc_response = await self._agentex_client.agents.rpc(
@@ -111,12 +195,13 @@ async def message_send(
111195
"content": cast(TaskMessageContentParam, content.model_dump()),
112196
"stream": False,
113197
},
198+
extra_headers=filtered_headers,
114199
)
115200
else:
116201
raise ValueError("Either agent_name or agent_id must be provided")
117202

118203
task_messages: List[TaskMessage] = []
119-
logger.info(f"json_rpc_response: {json_rpc_response}")
204+
logger.info("json_rpc_response: %s", json_rpc_response)
120205
if isinstance(json_rpc_response.result, list):
121206
for message in json_rpc_response.result:
122207
task_message = TaskMessage.model_validate(message)
@@ -137,6 +222,7 @@ async def event_send(
137222
task_name: str | None = None,
138223
trace_id: str | None = None,
139224
parent_span_id: str | None = None,
225+
extra_headers: dict[str, str] | None = None,
140226
) -> Event:
141227
trace = self._tracer.trace(trace_id=trace_id)
142228
async with trace.span(
@@ -150,6 +236,11 @@ async def event_send(
150236
},
151237
) as span:
152238
heartbeat_if_in_workflow("event send")
239+
240+
# Get agent's header configuration and filter headers
241+
header_config = await self._get_agent_header_config(agent_name, agent_id)
242+
filtered_headers = self._filter_headers(extra_headers, header_config)
243+
153244
if agent_name:
154245
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
155246
agent_name=agent_name,
@@ -158,6 +249,7 @@ async def event_send(
158249
"task_id": task_id,
159250
"content": cast(TaskMessageContentParam, content.model_dump()),
160251
},
252+
extra_headers=filtered_headers,
161253
)
162254
elif agent_id:
163255
json_rpc_response = await self._agentex_client.agents.rpc(
@@ -167,6 +259,7 @@ async def event_send(
167259
"task_id": task_id,
168260
"content": cast(TaskMessageContentParam, content.model_dump()),
169261
},
262+
extra_headers=filtered_headers,
170263
)
171264
else:
172265
raise ValueError("Either agent_name or agent_id must be provided")
@@ -182,6 +275,7 @@ async def task_cancel(
182275
task_name: str | None = None,
183276
trace_id: str | None = None,
184277
parent_span_id: str | None = None,
278+
extra_headers: dict[str, str] | None = None,
185279
) -> Task:
186280
trace = self._tracer.trace(trace_id=trace_id)
187281
async with trace.span(
@@ -193,13 +287,23 @@ async def task_cancel(
193287
},
194288
) as span:
195289
heartbeat_if_in_workflow("task cancel")
290+
291+
# Note: The original implementation seems to treat task_name as agent_name and task_id as agent_id
292+
# This maintains backward compatibility while adding header support
293+
# Get agent's header configuration and filter headers
294+
agent_name = task_name if task_name else None
295+
agent_id = task_id if task_id else None
296+
header_config = await self._get_agent_header_config(agent_name, agent_id)
297+
filtered_headers = self._filter_headers(extra_headers, header_config)
298+
196299
if task_name:
197300
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
198301
agent_name=task_name,
199302
method="task/cancel",
200303
params={
201304
"task_name": task_name,
202305
},
306+
extra_headers=filtered_headers,
203307
)
204308
elif task_id:
205309
json_rpc_response = await self._agentex_client.agents.rpc(
@@ -208,6 +312,7 @@ async def task_cancel(
208312
params={
209313
"task_id": task_id,
210314
},
315+
extra_headers=filtered_headers,
211316
)
212317
else:
213318
raise ValueError("Either task_name or task_id must be provided")

0 commit comments

Comments
 (0)