@@ -26,69 +26,66 @@ def __init__(
2626
2727 async def _get_agent_header_config (
2828 self , agent_name : str | None , agent_id : str | None
29- ) -> CustomHeadersConfig :
29+ ) -> CustomHeadersConfig | None :
3030 """
3131 Get agent's header configuration from manifest.
3232
33- For now, we'll return a default empty config since we need to implement
34- manifest loading. This provides secure-by-default behavior .
33+ Returns None for pass-through behavior (all headers forwarded).
34+ Returns CustomHeadersConfig only if agent has specific filtering requirements .
3535
3636 TODO: Implement actual manifest loading to get agent's custom_headers config
3737 """
38- # Default configuration - no headers allowed (secure by default)
38+ # For now, return None to enable pass-through behavior
3939 # This will be replaced with actual manifest loading
40- return CustomHeadersConfig (
41- strategy = "allowlist" ,
42- allowed_headers = [],
43- max_header_size = 8192 ,
44- max_headers_count = 50
45- )
40+ return None
4641
4742 def _filter_headers (
4843 self ,
49- extra_headers : dict [str , str ] | None ,
50- config : CustomHeadersConfig
44+ headers : dict [str , str ] | None ,
45+ config : CustomHeadersConfig | None
5146 ) -> dict [str , str ]:
5247 """
5348 Filter headers based on agent's configuration.
5449
50+ Pass-through by default: if no config provided, all headers are forwarded.
51+ If config exists, only allowlisted headers are forwarded.
52+
5553 Args:
56- extra_headers : Headers to filter
57- config: Agent's header configuration
54+ headers : Headers to filter
55+ config: Agent's header configuration (None = pass-through all)
5856
5957 Returns:
60- Filtered headers dictionary containing only allowed headers
58+ Filtered headers dictionary
6159 """
62- if not extra_headers or not config . allowed_headers :
60+ if not headers :
6361 return {}
62+
63+ # Pass-through behavior: if no config, forward all headers
64+ if config is None :
65+ logger .debug ("No header filtering config found, passing through all %d headers" , len (headers ))
66+ return headers
6467
65- filtered = {}
66- headers_added = 0
67-
68- for header_name , header_value in extra_headers .items ():
69- # Check if we've hit the max headers limit
70- if headers_added >= config .max_headers_count :
71- logger .warning ("Reached maximum header count limit (%s), ignoring remaining headers" , config .max_headers_count )
72- break
68+ # Apply filtering based on allowlist
69+ if not config .allowed_headers :
70+ logger .debug ("Empty allowlist in config, blocking all headers" )
71+ return {}
7372
74- # Check against allowlist patterns
73+ filtered = {}
74+ for header_name , header_value in headers .items ():
75+ # Check against allowlist patterns (case-insensitive)
7576 header_allowed = False
7677 for pattern in config .allowed_headers :
7778 if fnmatch .fnmatch (header_name .lower (), pattern .lower ()):
7879 header_allowed = True
7980 break
8081
8182 if header_allowed :
82- # Apply size limits
83- if len (header_value ) <= config .max_header_size :
84- filtered [header_name ] = header_value
85- headers_added += 1
86- logger .debug ("Allowed header: %s" , header_name )
87- else :
88- logger .warning ("Header '%s' exceeds size limit (%s > %s), ignoring" , header_name , len (header_value ), config .max_header_size )
83+ filtered [header_name ] = header_value
84+ logger .debug ("Allowed header: %s" , header_name )
8985 else :
9086 logger .debug ("Header '%s' not in allowlist, ignoring" , header_name )
9187
88+ logger .debug ("Filtered %d headers from %d based on allowlist" , len (filtered ), len (headers ))
9289 return filtered
9390
9491 async def task_create (
@@ -99,7 +96,7 @@ async def task_create(
9996 params : dict [str , Any ] | None = None ,
10097 trace_id : str | None = None ,
10198 parent_span_id : str | None = None ,
102- extra_headers : dict [str , str ] | None = None ,
99+ request : dict [str , Any ] | None = None ,
103100 ) -> Task :
104101 trace = self ._tracer .trace (trace_id = trace_id )
105102 async with trace .span (
@@ -114,9 +111,10 @@ async def task_create(
114111 ) as span :
115112 heartbeat_if_in_workflow ("task create" )
116113
117- # Get agent's header configuration and filter headers
114+ # Extract headers from request and filter them
115+ headers = request .get ("headers" ) if request else None
118116 header_config = await self ._get_agent_header_config (agent_name , agent_id )
119- filtered_headers = self ._filter_headers (extra_headers , header_config )
117+ filtered_headers = self ._filter_headers (headers , header_config )
120118
121119 if agent_name :
122120 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
@@ -155,7 +153,7 @@ async def message_send(
155153 task_name : str | None = None ,
156154 trace_id : str | None = None ,
157155 parent_span_id : str | None = None ,
158- extra_headers : dict [str , str ] | None = None ,
156+ request : dict [str , Any ] | None = None ,
159157 ) -> List [TaskMessage ]:
160158 trace = self ._tracer .trace (trace_id = trace_id )
161159 async with trace .span (
@@ -171,9 +169,10 @@ async def message_send(
171169 ) as span :
172170 heartbeat_if_in_workflow ("message send" )
173171
174- # Get agent's header configuration and filter headers
172+ # Extract headers from request and filter them
173+ headers = request .get ("headers" ) if request else None
175174 header_config = await self ._get_agent_header_config (agent_name , agent_id )
176- filtered_headers = self ._filter_headers (extra_headers , header_config )
175+ filtered_headers = self ._filter_headers (headers , header_config )
177176
178177 if agent_name :
179178 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
@@ -222,7 +221,7 @@ async def event_send(
222221 task_name : str | None = None ,
223222 trace_id : str | None = None ,
224223 parent_span_id : str | None = None ,
225- extra_headers : dict [str , str ] | None = None ,
224+ request : dict [str , Any ] | None = None ,
226225 ) -> Event :
227226 trace = self ._tracer .trace (trace_id = trace_id )
228227 async with trace .span (
@@ -237,9 +236,10 @@ async def event_send(
237236 ) as span :
238237 heartbeat_if_in_workflow ("event send" )
239238
240- # Get agent's header configuration and filter headers
239+ # Extract headers from request and filter them
240+ headers = request .get ("headers" ) if request else None
241241 header_config = await self ._get_agent_header_config (agent_name , agent_id )
242- filtered_headers = self ._filter_headers (extra_headers , header_config )
242+ filtered_headers = self ._filter_headers (headers , header_config )
243243
244244 if agent_name :
245245 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
@@ -277,7 +277,7 @@ async def task_cancel(
277277 agent_name : str | None = None ,
278278 trace_id : str | None = None ,
279279 parent_span_id : str | None = None ,
280- extra_headers : dict [str , str ] | None = None ,
280+ request : dict [str , Any ] | None = None ,
281281 ) -> Task :
282282 """
283283 Cancel a task by sending cancel request to the agent that owns the task.
@@ -318,9 +318,10 @@ async def task_cancel(
318318 ) as span :
319319 heartbeat_if_in_workflow ("task cancel" )
320320
321- # Get agent's header configuration and filter headers
321+ # Extract headers from request and filter them
322+ headers = request .get ("headers" ) if request else None
322323 header_config = await self ._get_agent_header_config (agent_name , agent_id )
323- filtered_headers = self ._filter_headers (extra_headers , header_config )
324+ filtered_headers = self ._filter_headers (headers , header_config )
324325
325326 # Build params for the agent (task identification)
326327 params = {}
0 commit comments