11from typing import Any , List , cast
2+ import fnmatch
23
34from agentex import AsyncAgentex
45from agentex .lib .core .tracing .tracer import AsyncTracer
56from agentex .lib .utils .logging import make_logger
67from agentex .lib .utils .temporal import heartbeat_if_in_workflow
8+ from agentex .lib .types .agent_configs import CustomHeadersConfig
79from agentex .types .event import Event
810from agentex .types .task import Task
911from agentex .types .task_message import TaskMessage
@@ -22,6 +24,73 @@ def __init__(
2224 self ._agentex_client = agentex_client
2325 self ._tracer = tracer
2426
27+ async def _get_agent_header_config (
28+ self , agent_name : str | None , agent_id : str | None
29+ ) -> CustomHeadersConfig :
30+ """
31+ Get agent's header configuration from manifest.
32+
33+ For now, we'll return a default empty config since we need to implement
34+ manifest loading. This provides secure-by-default behavior.
35+
36+ TODO: Implement actual manifest loading to get agent's custom_headers config
37+ """
38+ # Default configuration - no headers allowed (secure by default)
39+ # This will be replaced with actual manifest loading
40+ return CustomHeadersConfig (
41+ strategy = "allowlist" ,
42+ allowed_headers = [],
43+ max_header_size = 8192 ,
44+ max_headers_count = 50
45+ )
46+
47+ def _filter_headers (
48+ self ,
49+ extra_headers : dict [str , str ] | None ,
50+ config : CustomHeadersConfig
51+ ) -> dict [str , str ]:
52+ """
53+ Filter headers based on agent's configuration.
54+
55+ Args:
56+ extra_headers: Headers to filter
57+ config: Agent's header configuration
58+
59+ Returns:
60+ Filtered headers dictionary containing only allowed headers
61+ """
62+ if not extra_headers or not config .allowed_headers :
63+ return {}
64+
65+ filtered = {}
66+ headers_added = 0
67+
68+ for header_name , header_value in extra_headers .items ():
69+ # Check if we've hit the max headers limit
70+ if headers_added >= config .max_headers_count :
71+ logger .warning ("Reached maximum header count limit (%s), ignoring remaining headers" , config .max_headers_count )
72+ break
73+
74+ # Check against allowlist patterns
75+ header_allowed = False
76+ for pattern in config .allowed_headers :
77+ if fnmatch .fnmatch (header_name .lower (), pattern .lower ()):
78+ header_allowed = True
79+ break
80+
81+ if header_allowed :
82+ # Apply size limits
83+ if len (header_value ) <= config .max_header_size :
84+ filtered [header_name ] = header_value
85+ headers_added += 1
86+ logger .debug ("Allowed header: %s" , header_name )
87+ else :
88+ logger .warning ("Header '%s' exceeds size limit (%s > %s), ignoring" , header_name , len (header_value ), config .max_header_size )
89+ else :
90+ logger .debug ("Header '%s' not in allowlist, ignoring" , header_name )
91+
92+ return filtered
93+
2594 async def task_create (
2695 self ,
2796 name : str | None = None ,
@@ -30,6 +99,7 @@ async def task_create(
3099 params : dict [str , Any ] | None = None ,
31100 trace_id : str | None = None ,
32101 parent_span_id : str | None = None ,
102+ extra_headers : dict [str , str ] | None = None ,
33103 ) -> Task :
34104 trace = self ._tracer .trace (trace_id = trace_id )
35105 async with trace .span (
@@ -43,6 +113,11 @@ async def task_create(
43113 },
44114 ) as span :
45115 heartbeat_if_in_workflow ("task create" )
116+
117+ # Get agent's header configuration and filter headers
118+ header_config = await self ._get_agent_header_config (agent_name , agent_id )
119+ filtered_headers = self ._filter_headers (extra_headers , header_config )
120+
46121 if agent_name :
47122 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
48123 agent_name = agent_name ,
@@ -51,6 +126,7 @@ async def task_create(
51126 "name" : name ,
52127 "params" : params ,
53128 },
129+ extra_headers = filtered_headers ,
54130 )
55131 elif agent_id :
56132 json_rpc_response = await self ._agentex_client .agents .rpc (
@@ -60,6 +136,7 @@ async def task_create(
60136 "name" : name ,
61137 "params" : params ,
62138 },
139+ extra_headers = filtered_headers ,
63140 )
64141 else :
65142 raise ValueError ("Either agent_name or agent_id must be provided" )
@@ -78,6 +155,7 @@ async def message_send(
78155 task_name : str | None = None ,
79156 trace_id : str | None = None ,
80157 parent_span_id : str | None = None ,
158+ extra_headers : dict [str , str ] | None = None ,
81159 ) -> List [TaskMessage ]:
82160 trace = self ._tracer .trace (trace_id = trace_id )
83161 async with trace .span (
@@ -92,6 +170,11 @@ async def message_send(
92170 },
93171 ) as span :
94172 heartbeat_if_in_workflow ("message send" )
173+
174+ # Get agent's header configuration and filter headers
175+ header_config = await self ._get_agent_header_config (agent_name , agent_id )
176+ filtered_headers = self ._filter_headers (extra_headers , header_config )
177+
95178 if agent_name :
96179 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
97180 agent_name = agent_name ,
@@ -101,6 +184,7 @@ async def message_send(
101184 "content" : cast (TaskMessageContentParam , content .model_dump ()),
102185 "stream" : False ,
103186 },
187+ extra_headers = filtered_headers ,
104188 )
105189 elif agent_id :
106190 json_rpc_response = await self ._agentex_client .agents .rpc (
@@ -111,12 +195,13 @@ async def message_send(
111195 "content" : cast (TaskMessageContentParam , content .model_dump ()),
112196 "stream" : False ,
113197 },
198+ extra_headers = filtered_headers ,
114199 )
115200 else :
116201 raise ValueError ("Either agent_name or agent_id must be provided" )
117202
118203 task_messages : List [TaskMessage ] = []
119- logger .info (f "json_rpc_response: { json_rpc_response } " )
204+ logger .info ("json_rpc_response: %s" , json_rpc_response )
120205 if isinstance (json_rpc_response .result , list ):
121206 for message in json_rpc_response .result :
122207 task_message = TaskMessage .model_validate (message )
@@ -137,6 +222,7 @@ async def event_send(
137222 task_name : str | None = None ,
138223 trace_id : str | None = None ,
139224 parent_span_id : str | None = None ,
225+ extra_headers : dict [str , str ] | None = None ,
140226 ) -> Event :
141227 trace = self ._tracer .trace (trace_id = trace_id )
142228 async with trace .span (
@@ -150,6 +236,11 @@ async def event_send(
150236 },
151237 ) as span :
152238 heartbeat_if_in_workflow ("event send" )
239+
240+ # Get agent's header configuration and filter headers
241+ header_config = await self ._get_agent_header_config (agent_name , agent_id )
242+ filtered_headers = self ._filter_headers (extra_headers , header_config )
243+
153244 if agent_name :
154245 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
155246 agent_name = agent_name ,
@@ -158,6 +249,7 @@ async def event_send(
158249 "task_id" : task_id ,
159250 "content" : cast (TaskMessageContentParam , content .model_dump ()),
160251 },
252+ extra_headers = filtered_headers ,
161253 )
162254 elif agent_id :
163255 json_rpc_response = await self ._agentex_client .agents .rpc (
@@ -167,6 +259,7 @@ async def event_send(
167259 "task_id" : task_id ,
168260 "content" : cast (TaskMessageContentParam , content .model_dump ()),
169261 },
262+ extra_headers = filtered_headers ,
170263 )
171264 else :
172265 raise ValueError ("Either agent_name or agent_id must be provided" )
@@ -182,6 +275,7 @@ async def task_cancel(
182275 task_name : str | None = None ,
183276 trace_id : str | None = None ,
184277 parent_span_id : str | None = None ,
278+ extra_headers : dict [str , str ] | None = None ,
185279 ) -> Task :
186280 trace = self ._tracer .trace (trace_id = trace_id )
187281 async with trace .span (
@@ -193,13 +287,23 @@ async def task_cancel(
193287 },
194288 ) as span :
195289 heartbeat_if_in_workflow ("task cancel" )
290+
291+ # Note: The original implementation seems to treat task_name as agent_name and task_id as agent_id
292+ # This maintains backward compatibility while adding header support
293+ # Get agent's header configuration and filter headers
294+ agent_name = task_name if task_name else None
295+ agent_id = task_id if task_id else None
296+ header_config = await self ._get_agent_header_config (agent_name , agent_id )
297+ filtered_headers = self ._filter_headers (extra_headers , header_config )
298+
196299 if task_name :
197300 json_rpc_response = await self ._agentex_client .agents .rpc_by_name (
198301 agent_name = task_name ,
199302 method = "task/cancel" ,
200303 params = {
201304 "task_name" : task_name ,
202305 },
306+ extra_headers = filtered_headers ,
203307 )
204308 elif task_id :
205309 json_rpc_response = await self ._agentex_client .agents .rpc (
@@ -208,6 +312,7 @@ async def task_cancel(
208312 params = {
209313 "task_id" : task_id ,
210314 },
315+ extra_headers = filtered_headers ,
211316 )
212317 else :
213318 raise ValueError ("Either task_name or task_id must be provided" )
0 commit comments