Skip to content

Commit 243a108

Browse files
feat: restructure header forwarding API from extra_headers to request.headers
- Update CustomHeadersConfig schema to remove max_header_size and max_headers_count - Change API from flat extra_headers parameter to nested request.headers structure - Implement pass-through by default behavior (headers forwarded unless filtered) - Update all ACP service methods, temporal activities, and ADK modules - Remove size/count limits from header filtering logic Breaking changes: extra_headers parameter replaced with request parameter across all ACP APIs
1 parent cdbd837 commit 243a108

File tree

4 files changed

+74
-92
lines changed

4 files changed

+74
-92
lines changed

src/agentex/lib/adk/_modules/acp.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def create_task(
5959
start_to_close_timeout: timedelta = timedelta(seconds=5),
6060
heartbeat_timeout: timedelta = timedelta(seconds=5),
6161
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
62-
extra_headers: dict[str, str] | None = None,
62+
request: dict[str, Any] | None = None,
6363
) -> Task:
6464
"""
6565
Create a new task.
@@ -72,7 +72,7 @@ async def create_task(
7272
start_to_close_timeout: The start to close timeout for the task.
7373
heartbeat_timeout: The heartbeat timeout for the task.
7474
retry_policy: The retry policy for the task.
75-
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
75+
request: Additional request context including headers to forward to the agent.
7676
7777
Returns:
7878
The task entry.
@@ -87,7 +87,7 @@ async def create_task(
8787
params=params,
8888
trace_id=trace_id,
8989
parent_span_id=parent_span_id,
90-
extra_headers=extra_headers,
90+
request=request,
9191
),
9292
response_type=Task,
9393
start_to_close_timeout=start_to_close_timeout,
@@ -102,7 +102,7 @@ async def create_task(
102102
params=params,
103103
trace_id=trace_id,
104104
parent_span_id=parent_span_id,
105-
extra_headers=extra_headers,
105+
request=request,
106106
)
107107

108108
async def send_event(
@@ -116,7 +116,7 @@ async def send_event(
116116
start_to_close_timeout: timedelta = timedelta(seconds=5),
117117
heartbeat_timeout: timedelta = timedelta(seconds=5),
118118
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
119-
extra_headers: dict[str, str] | None = None,
119+
request: dict[str, Any] | None = None,
120120
) -> Event:
121121
"""
122122
Send an event to a task.
@@ -131,7 +131,7 @@ async def send_event(
131131
start_to_close_timeout: The start to close timeout for the event.
132132
heartbeat_timeout: The heartbeat timeout for the event.
133133
retry_policy: The retry policy for the event.
134-
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
134+
request: Additional request context including headers to forward to the agent.
135135
136136
Returns:
137137
The event entry.
@@ -146,7 +146,7 @@ async def send_event(
146146
content=content,
147147
trace_id=trace_id,
148148
parent_span_id=parent_span_id,
149-
extra_headers=extra_headers,
149+
request=request,
150150
),
151151
response_type=None,
152152
start_to_close_timeout=start_to_close_timeout,
@@ -161,7 +161,7 @@ async def send_event(
161161
content=content,
162162
trace_id=trace_id,
163163
parent_span_id=parent_span_id,
164-
extra_headers=extra_headers,
164+
request=request,
165165
)
166166

167167
async def send_message(
@@ -175,7 +175,7 @@ async def send_message(
175175
start_to_close_timeout: timedelta = timedelta(seconds=5),
176176
heartbeat_timeout: timedelta = timedelta(seconds=5),
177177
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
178-
extra_headers: dict[str, str] | None = None,
178+
request: dict[str, Any] | None = None,
179179
) -> List[TaskMessage]:
180180
"""
181181
Send a message to a task.
@@ -190,7 +190,7 @@ async def send_message(
190190
start_to_close_timeout: The start to close timeout for the message.
191191
heartbeat_timeout: The heartbeat timeout for the message.
192192
retry_policy: The retry policy for the message.
193-
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
193+
request: Additional request context including headers to forward to the agent.
194194
195195
Returns:
196196
The message entry.
@@ -205,7 +205,7 @@ async def send_message(
205205
content=content,
206206
trace_id=trace_id,
207207
parent_span_id=parent_span_id,
208-
extra_headers=extra_headers,
208+
request=request,
209209
),
210210
response_type=TaskMessage,
211211
start_to_close_timeout=start_to_close_timeout,
@@ -220,7 +220,7 @@ async def send_message(
220220
content=content,
221221
trace_id=trace_id,
222222
parent_span_id=parent_span_id,
223-
extra_headers=extra_headers,
223+
request=request,
224224
)
225225

226226
async def cancel_task(
@@ -234,7 +234,7 @@ async def cancel_task(
234234
start_to_close_timeout: timedelta = timedelta(seconds=5),
235235
heartbeat_timeout: timedelta = timedelta(seconds=5),
236236
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
237-
extra_headers: dict[str, str] | None = None,
237+
request: dict[str, Any] | None = None,
238238
) -> Task:
239239
"""
240240
Cancel a task by sending cancel request to the agent that owns the task.
@@ -249,7 +249,7 @@ async def cancel_task(
249249
start_to_close_timeout: The start to close timeout for the task.
250250
heartbeat_timeout: The heartbeat timeout for the task.
251251
retry_policy: The retry policy for the task.
252-
extra_headers: Additional HTTP headers to forward to the agent (filtered by agent's allowlist).
252+
request: Additional request context including headers to forward to the agent.
253253
254254
Returns:
255255
The task entry.
@@ -268,7 +268,7 @@ async def cancel_task(
268268
agent_name=agent_name,
269269
trace_id=trace_id,
270270
parent_span_id=parent_span_id,
271-
extra_headers=extra_headers,
271+
request=request,
272272
),
273273
response_type=None,
274274
start_to_close_timeout=start_to_close_timeout,
@@ -283,5 +283,5 @@ async def cancel_task(
283283
agent_name=agent_name,
284284
trace_id=trace_id,
285285
parent_span_id=parent_span_id,
286-
extra_headers=extra_headers,
286+
request=request,
287287
)

src/agentex/lib/core/services/adk/acp/acp.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,69 +26,66 @@ def __init__(
2626

2727
async def _get_agent_header_config(
2828
self, agent_name: str | None, agent_id: str | None
29-
) -> CustomHeadersConfig:
29+
) -> CustomHeadersConfig | None:
3030
"""
3131
Get agent's header configuration from manifest.
3232
33-
For now, we'll return a default empty config since we need to implement
34-
manifest loading. This provides secure-by-default behavior.
33+
Returns None for pass-through behavior (all headers forwarded).
34+
Returns CustomHeadersConfig only if agent has specific filtering requirements.
3535
3636
TODO: Implement actual manifest loading to get agent's custom_headers config
3737
"""
38-
# Default configuration - no headers allowed (secure by default)
38+
# For now, return None to enable pass-through behavior
3939
# This will be replaced with actual manifest loading
40-
return CustomHeadersConfig(
41-
strategy="allowlist",
42-
allowed_headers=[],
43-
max_header_size=8192,
44-
max_headers_count=50
45-
)
40+
return None
4641

4742
def _filter_headers(
4843
self,
49-
extra_headers: dict[str, str] | None,
50-
config: CustomHeadersConfig
44+
headers: dict[str, str] | None,
45+
config: CustomHeadersConfig | None
5146
) -> dict[str, str]:
5247
"""
5348
Filter headers based on agent's configuration.
5449
50+
Pass-through by default: if no config provided, all headers are forwarded.
51+
If config exists, only allowlisted headers are forwarded.
52+
5553
Args:
56-
extra_headers: Headers to filter
57-
config: Agent's header configuration
54+
headers: Headers to filter
55+
config: Agent's header configuration (None = pass-through all)
5856
5957
Returns:
60-
Filtered headers dictionary containing only allowed headers
58+
Filtered headers dictionary
6159
"""
62-
if not extra_headers or not config.allowed_headers:
60+
if not headers:
6361
return {}
62+
63+
# Pass-through behavior: if no config, forward all headers
64+
if config is None:
65+
logger.debug("No header filtering config found, passing through all %d headers", len(headers))
66+
return headers
6467

65-
filtered = {}
66-
headers_added = 0
67-
68-
for header_name, header_value in extra_headers.items():
69-
# Check if we've hit the max headers limit
70-
if headers_added >= config.max_headers_count:
71-
logger.warning("Reached maximum header count limit (%s), ignoring remaining headers", config.max_headers_count)
72-
break
68+
# Apply filtering based on allowlist
69+
if not config.allowed_headers:
70+
logger.debug("Empty allowlist in config, blocking all headers")
71+
return {}
7372

74-
# Check against allowlist patterns
73+
filtered = {}
74+
for header_name, header_value in headers.items():
75+
# Check against allowlist patterns (case-insensitive)
7576
header_allowed = False
7677
for pattern in config.allowed_headers:
7778
if fnmatch.fnmatch(header_name.lower(), pattern.lower()):
7879
header_allowed = True
7980
break
8081

8182
if header_allowed:
82-
# Apply size limits
83-
if len(header_value) <= config.max_header_size:
84-
filtered[header_name] = header_value
85-
headers_added += 1
86-
logger.debug("Allowed header: %s", header_name)
87-
else:
88-
logger.warning("Header '%s' exceeds size limit (%s > %s), ignoring", header_name, len(header_value), config.max_header_size)
83+
filtered[header_name] = header_value
84+
logger.debug("Allowed header: %s", header_name)
8985
else:
9086
logger.debug("Header '%s' not in allowlist, ignoring", header_name)
9187

88+
logger.debug("Filtered %d headers from %d based on allowlist", len(filtered), len(headers))
9289
return filtered
9390

9491
async def task_create(
@@ -99,7 +96,7 @@ async def task_create(
9996
params: dict[str, Any] | None = None,
10097
trace_id: str | None = None,
10198
parent_span_id: str | None = None,
102-
extra_headers: dict[str, str] | None = None,
99+
request: dict[str, Any] | None = None,
103100
) -> Task:
104101
trace = self._tracer.trace(trace_id=trace_id)
105102
async with trace.span(
@@ -114,9 +111,10 @@ async def task_create(
114111
) as span:
115112
heartbeat_if_in_workflow("task create")
116113

117-
# Get agent's header configuration and filter headers
114+
# Extract headers from request and filter them
115+
headers = request.get("headers") if request else None
118116
header_config = await self._get_agent_header_config(agent_name, agent_id)
119-
filtered_headers = self._filter_headers(extra_headers, header_config)
117+
filtered_headers = self._filter_headers(headers, header_config)
120118

121119
if agent_name:
122120
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
@@ -155,7 +153,7 @@ async def message_send(
155153
task_name: str | None = None,
156154
trace_id: str | None = None,
157155
parent_span_id: str | None = None,
158-
extra_headers: dict[str, str] | None = None,
156+
request: dict[str, Any] | None = None,
159157
) -> List[TaskMessage]:
160158
trace = self._tracer.trace(trace_id=trace_id)
161159
async with trace.span(
@@ -171,9 +169,10 @@ async def message_send(
171169
) as span:
172170
heartbeat_if_in_workflow("message send")
173171

174-
# Get agent's header configuration and filter headers
172+
# Extract headers from request and filter them
173+
headers = request.get("headers") if request else None
175174
header_config = await self._get_agent_header_config(agent_name, agent_id)
176-
filtered_headers = self._filter_headers(extra_headers, header_config)
175+
filtered_headers = self._filter_headers(headers, header_config)
177176

178177
if agent_name:
179178
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
@@ -222,7 +221,7 @@ async def event_send(
222221
task_name: str | None = None,
223222
trace_id: str | None = None,
224223
parent_span_id: str | None = None,
225-
extra_headers: dict[str, str] | None = None,
224+
request: dict[str, Any] | None = None,
226225
) -> Event:
227226
trace = self._tracer.trace(trace_id=trace_id)
228227
async with trace.span(
@@ -237,9 +236,10 @@ async def event_send(
237236
) as span:
238237
heartbeat_if_in_workflow("event send")
239238

240-
# Get agent's header configuration and filter headers
239+
# Extract headers from request and filter them
240+
headers = request.get("headers") if request else None
241241
header_config = await self._get_agent_header_config(agent_name, agent_id)
242-
filtered_headers = self._filter_headers(extra_headers, header_config)
242+
filtered_headers = self._filter_headers(headers, header_config)
243243

244244
if agent_name:
245245
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
@@ -277,7 +277,7 @@ async def task_cancel(
277277
agent_name: str | None = None,
278278
trace_id: str | None = None,
279279
parent_span_id: str | None = None,
280-
extra_headers: dict[str, str] | None = None,
280+
request: dict[str, Any] | None = None,
281281
) -> Task:
282282
"""
283283
Cancel a task by sending cancel request to the agent that owns the task.
@@ -318,9 +318,10 @@ async def task_cancel(
318318
) as span:
319319
heartbeat_if_in_workflow("task cancel")
320320

321-
# Get agent's header configuration and filter headers
321+
# Extract headers from request and filter them
322+
headers = request.get("headers") if request else None
322323
header_config = await self._get_agent_header_config(agent_name, agent_id)
323-
filtered_headers = self._filter_headers(extra_headers, header_config)
324+
filtered_headers = self._filter_headers(headers, header_config)
324325

325326
# Build params for the agent (task identification)
326327
params = {}

src/agentex/lib/core/temporal/activities/adk/acp/acp_activities.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,31 +26,31 @@ class TaskCreateParams(BaseModelWithTraceParams):
2626
agent_id: str | None = None
2727
agent_name: str | None = None
2828
params: dict[str, Any] | None = None
29-
extra_headers: dict[str, str] | None = None
29+
request: dict[str, Any] | None = None
3030

3131

3232
class MessageSendParams(BaseModelWithTraceParams):
3333
agent_id: str | None = None
3434
agent_name: str | None = None
3535
task_id: str | None = None
3636
content: TaskMessageContent
37-
extra_headers: dict[str, str] | None = None
37+
request: dict[str, Any] | None = None
3838

3939

4040
class EventSendParams(BaseModelWithTraceParams):
4141
agent_id: str | None = None
4242
agent_name: str | None = None
4343
task_id: str | None = None
4444
content: TaskMessageContent
45-
extra_headers: dict[str, str] | None = None
45+
request: dict[str, Any] | None = None
4646

4747

4848
class TaskCancelParams(BaseModelWithTraceParams):
4949
task_id: str | None = None
5050
task_name: str | None = None
5151
agent_id: str | None = None
5252
agent_name: str | None = None
53-
extra_headers: dict[str, str] | None = None
53+
request: dict[str, Any] | None = None
5454

5555

5656
class ACPActivities:
@@ -66,7 +66,7 @@ async def task_create(self, params: TaskCreateParams) -> Task:
6666
params=params.params,
6767
trace_id=params.trace_id,
6868
parent_span_id=params.parent_span_id,
69-
extra_headers=params.extra_headers,
69+
request=params.request,
7070
)
7171

7272
@activity.defn(name=ACPActivityName.MESSAGE_SEND)
@@ -78,7 +78,7 @@ async def message_send(self, params: MessageSendParams) -> List[TaskMessage]:
7878
content=params.content,
7979
trace_id=params.trace_id,
8080
parent_span_id=params.parent_span_id,
81-
extra_headers=params.extra_headers,
81+
request=params.request,
8282
)
8383

8484
@activity.defn(name=ACPActivityName.EVENT_SEND)
@@ -90,7 +90,7 @@ async def event_send(self, params: EventSendParams) -> Event:
9090
content=params.content,
9191
trace_id=params.trace_id,
9292
parent_span_id=params.parent_span_id,
93-
extra_headers=params.extra_headers,
93+
request=params.request,
9494
)
9595

9696
@activity.defn(name=ACPActivityName.TASK_CANCEL)
@@ -102,5 +102,5 @@ async def task_cancel(self, params: TaskCancelParams) -> Task:
102102
agent_name=params.agent_name,
103103
trace_id=params.trace_id,
104104
parent_span_id=params.parent_span_id,
105-
extra_headers=params.extra_headers,
105+
request=params.request,
106106
)

0 commit comments

Comments
 (0)