Skip to content

Commit 0a9dcff

Browse files
feat: add easy access to headers
1 parent eee0ea2 commit 0a9dcff

File tree

4 files changed

+87
-53
lines changed

4 files changed

+87
-53
lines changed

src/agentex/lib/sdk/fastacp/base/base_acp_server.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,17 @@ async def _handle_jsonrpc(self, request: Request):
128128
),
129129
)
130130

131-
# Parse params into appropriate model based on method
131+
# Extract custom headers (already filtered by AgentEx based on agent manifest)
132+
custom_headers = {
133+
key: value for key, value in request.headers.items()
134+
if key.lower().startswith('x-') # Custom headers typically start with 'x-'
135+
}
136+
137+
# Parse params into appropriate model based on method and include headers
132138
params_model = PARAMS_MODEL_BY_METHOD[method]
133-
params = params_model.model_validate(rpc_request.params)
139+
params_data = dict(rpc_request.params) if rpc_request.params else {}
140+
params_data['extra_headers'] = custom_headers if custom_headers else None
141+
params = params_model.model_validate(params_data)
134142

135143
if method in RPC_SYNC_METHODS:
136144
handler = self._handlers[method]

src/agentex/lib/types/acp.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class CreateTaskParams(BaseModel):
2525
agent: The agent that the task was sent to.
2626
task: The task to be created.
2727
params: The parameters for the task as inputted by the user.
28+
extra_headers: Custom headers forwarded to this agent (filtered by manifest allowlist).
2829
"""
2930

3031
agent: Agent = Field(..., description="The agent that the task was sent to")
@@ -33,6 +34,10 @@ class CreateTaskParams(BaseModel):
3334
None,
3435
description="The parameters for the task as inputted by the user",
3536
)
37+
extra_headers: dict[str, str] | None = Field(
38+
default=None,
39+
description="Custom headers forwarded to this agent (filtered by manifest allowlist)",
40+
)
3641

3742

3843
class SendMessageParams(BaseModel):
@@ -43,6 +48,7 @@ class SendMessageParams(BaseModel):
4348
task: The task that the message was sent to.
4449
content: The message that was sent to the agent.
4550
stream: Whether to stream the message back to the agentex server from the agent.
51+
extra_headers: Custom headers forwarded to this agent (filtered by manifest allowlist).
4652
"""
4753

4854
agent: Agent = Field(..., description="The agent that the message was sent to")
@@ -54,6 +60,10 @@ class SendMessageParams(BaseModel):
5460
False,
5561
description="Whether to stream the message back to the agentex server from the agent",
5662
)
63+
extra_headers: dict[str, str] | None = Field(
64+
default=None,
65+
description="Custom headers forwarded to this agent (filtered by manifest allowlist)",
66+
)
5767

5868

5969
class SendEventParams(BaseModel):
@@ -63,11 +73,16 @@ class SendEventParams(BaseModel):
6373
agent: The agent that the event was sent to.
6474
task: The task that the message was sent to.
6575
event: The event that was sent to the agent.
76+
extra_headers: Custom headers forwarded to this agent (filtered by manifest allowlist).
6677
"""
6778

6879
agent: Agent = Field(..., description="The agent that the event was sent to")
6980
task: Task = Field(..., description="The task that the message was sent to")
7081
event: Event = Field(..., description="The event that was sent to the agent")
82+
extra_headers: dict[str, str] | None = Field(
83+
default=None,
84+
description="Custom headers forwarded to this agent (filtered by manifest allowlist)",
85+
)
7186

7287

7388
class CancelTaskParams(BaseModel):
@@ -76,10 +91,15 @@ class CancelTaskParams(BaseModel):
7691
Attributes:
7792
agent: The agent that the task was sent to.
7893
task: The task that was cancelled.
94+
extra_headers: Custom headers forwarded to this agent (filtered by manifest allowlist).
7995
"""
8096

8197
agent: Agent = Field(..., description="The agent that the task was sent to")
8298
task: Task = Field(..., description="The task that was cancelled")
99+
extra_headers: dict[str, str] | None = Field(
100+
default=None,
101+
description="Custom headers forwarded to this agent (filtered by manifest allowlist)",
102+
)
83103

84104

85105
RPC_SYNC_METHODS = [

src/agentex/lib/types/agent_configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Literal
2-
from pydantic import BaseModel, Field, model_validator, validator
2+
from pydantic import BaseModel, Field, field_validator, model_validator
33

44

55
class TemporalWorkflowConfig(BaseModel):
@@ -101,9 +101,9 @@ class TemporalConfig(BaseModel):
101101
description="List of temporal workflow configurations. Used when enabled=true.",
102102
)
103103

104-
@validator("workflows")
104+
@field_validator("workflows")
105105
@classmethod
106-
def validate_workflows_not_empty(cls, v):
106+
def validate_workflows_not_empty(cls, v: list[TemporalWorkflowConfig]) -> list[TemporalWorkflowConfig]:
107107
"""Ensure workflows list is not empty when provided"""
108108
if v is not None and len(v) == 0:
109109
raise ValueError("workflows list cannot be empty when provided")

tests/test_header_filtering.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,16 @@
11
"""
22
Tests for custom header filtering functionality.
33
"""
4+
import fnmatch
45
import pytest
5-
from unittest.mock import AsyncMock, MagicMock
66

77
from agentex.lib.types.agent_configs import CustomHeadersConfig
8-
from agentex.lib.core.services.adk.acp.acp import ACPService
98

109

11-
class TestHeaderFiltering:
12-
"""Test header filtering functionality in ACPService."""
13-
14-
def setup_method(self):
15-
"""Set up test fixtures."""
16-
self.mock_client = MagicMock()
17-
self.mock_tracer = MagicMock()
18-
self.acp_service = ACPService(
19-
agentex_client=self.mock_client,
20-
tracer=self.mock_tracer
21-
)
10+
class TestCustomHeadersConfig:
11+
"""Test CustomHeadersConfig functionality."""
2212

23-
def test_custom_headers_config_creation(self):
13+
def test_custom_headers_config_creation(self) -> None:
2414
"""Test that CustomHeadersConfig can be created with defaults."""
2515
config = CustomHeadersConfig()
2616

@@ -29,7 +19,7 @@ def test_custom_headers_config_creation(self):
2919
assert config.max_header_size == 8192
3020
assert config.max_headers_count == 50
3121

32-
def test_custom_headers_config_with_headers(self):
22+
def test_custom_headers_config_with_headers(self) -> None:
3323
"""Test CustomHeadersConfig with allowed headers."""
3424
config = CustomHeadersConfig(
3525
allowed_headers=["x-user-email", "x-tenant-id"],
@@ -42,7 +32,7 @@ def test_custom_headers_config_with_headers(self):
4232
assert config.max_header_size == 4096
4333
assert config.max_headers_count == 10
4434

45-
def test_custom_headers_config_validation(self):
35+
def test_custom_headers_config_validation(self) -> None:
4636
"""Test CustomHeadersConfig validation for positive values."""
4737
# Test negative max_header_size
4838
with pytest.raises(ValueError, match="max_header_size must be greater than 0"):
@@ -52,23 +42,57 @@ def test_custom_headers_config_validation(self):
5242
with pytest.raises(ValueError, match="max_headers_count must be greater than 0"):
5343
CustomHeadersConfig(max_headers_count=0)
5444

55-
def test_filter_headers_no_headers(self):
45+
46+
def filter_headers_standalone(
47+
headers: dict[str, str] | None,
48+
config: CustomHeadersConfig
49+
) -> dict[str, str]:
50+
"""Standalone header filtering function for testing."""
51+
if not headers:
52+
return {}
53+
54+
if not config.allowed_headers:
55+
return {}
56+
57+
filtered = {}
58+
for header_name, header_value in headers.items():
59+
# Check size limit
60+
if len(header_value) > config.max_header_size:
61+
continue
62+
63+
# Check if header matches any allowed pattern (case-insensitive)
64+
for allowed_pattern in config.allowed_headers:
65+
if fnmatch.fnmatch(header_name.lower(), allowed_pattern.lower()):
66+
filtered[header_name] = header_value
67+
break
68+
69+
# Check count limit
70+
if len(filtered) >= config.max_headers_count:
71+
break
72+
73+
return filtered
74+
75+
76+
class TestHeaderFiltering:
77+
"""Test header filtering logic."""
78+
79+
def test_filter_headers_no_headers(self) -> None:
5680
"""Test header filtering with no input headers."""
5781
config = CustomHeadersConfig(allowed_headers=["x-user-email"])
58-
result = self.acp_service._filter_headers(None, config)
82+
result = filter_headers_standalone(None, config)
5983
assert result == {}
6084

61-
result = self.acp_service._filter_headers({}, config)
85+
result = filter_headers_standalone({}, config)
6286
assert result == {}
6387

64-
def test_filter_headers_no_allowed_headers(self):
88+
def test_filter_headers_no_allowed_headers(self) -> None:
6589
"""Test header filtering with no allowed headers (secure by default)."""
6690
config = CustomHeadersConfig(allowed_headers=[])
6791
headers = {"x-user-email": "[email protected]", "x-admin-token": "secret"}
68-
result = self.acp_service._filter_headers(headers, config)
92+
result = filter_headers_standalone(headers, config)
6993
assert result == {}
7094

71-
def test_filter_headers_allowed_headers(self):
95+
def test_filter_headers_allowed_headers(self) -> None:
7296
"""Test header filtering with allowed headers."""
7397
config = CustomHeadersConfig(allowed_headers=["x-user-email", "x-tenant-id"])
7498
headers = {
@@ -77,15 +101,15 @@ def test_filter_headers_allowed_headers(self):
77101
"x-admin-token": "secret", # Should be filtered out
78102
"content-type": "application/json" # Should be filtered out
79103
}
80-
result = self.acp_service._filter_headers(headers, config)
104+
result = filter_headers_standalone(headers, config)
81105

82106
expected = {
83107
"x-user-email": "[email protected]",
84108
"x-tenant-id": "tenant123"
85109
}
86110
assert result == expected
87111

88-
def test_filter_headers_case_insensitive_patterns(self):
112+
def test_filter_headers_case_insensitive_patterns(self) -> None:
89113
"""Test header filtering with case-insensitive pattern matching."""
90114
config = CustomHeadersConfig(allowed_headers=["X-User-Email", "x-tenant-*"])
91115
headers = {
@@ -94,7 +118,7 @@ def test_filter_headers_case_insensitive_patterns(self):
94118
"x-tenant-name": "acme", # Should match x-tenant-*
95119
"x-admin-token": "secret" # Should be filtered out
96120
}
97-
result = self.acp_service._filter_headers(headers, config)
121+
result = filter_headers_standalone(headers, config)
98122

99123
expected = {
100124
"x-user-email": "[email protected]",
@@ -103,29 +127,22 @@ def test_filter_headers_case_insensitive_patterns(self):
103127
}
104128
assert result == expected
105129

106-
def test_filter_headers_size_limit(self):
130+
def test_filter_headers_size_limit(self) -> None:
107131
"""Test header filtering with size limits."""
108132
config = CustomHeadersConfig(
109-
allowed_headers=["x-data"],
133+
allowed_headers=["x-data", "x-large-data"],
110134
max_header_size=10 # Very small limit for testing
111135
)
112136
headers = {
113137
"x-data": "small", # 5 chars - should pass
114-
"x-data-large": "this is a very long header value that exceeds the limit" # Should be rejected
115-
}
116-
# Note: x-data-large would be rejected anyway since it's not in allowed_headers
117-
# Let's test with allowed header that's too long
118-
config.allowed_headers = ["x-data", "x-large-data"]
119-
headers = {
120-
"x-data": "small", # Should pass
121138
"x-large-data": "this header value is way too long for the configured limit" # Should be rejected due to size
122139
}
123-
result = self.acp_service._filter_headers(headers, config)
140+
result = filter_headers_standalone(headers, config)
124141

125142
expected = {"x-data": "small"}
126143
assert result == expected
127144

128-
def test_filter_headers_count_limit(self):
145+
def test_filter_headers_count_limit(self) -> None:
129146
"""Test header filtering with count limits."""
130147
config = CustomHeadersConfig(
131148
allowed_headers=["x-header-*"],
@@ -137,24 +154,13 @@ def test_filter_headers_count_limit(self):
137154
"x-header-3": "value3", # Should be ignored due to count limit
138155
"x-header-4": "value4" # Should be ignored due to count limit
139156
}
140-
result = self.acp_service._filter_headers(headers, config)
157+
result = filter_headers_standalone(headers, config)
141158

142159
# Should only get first 2 headers that match
143160
assert len(result) == 2
144161
assert "x-header-1" in result
145162
assert "x-header-2" in result
146163

147-
@pytest.mark.asyncio
148-
async def test_get_agent_header_config_default(self):
149-
"""Test getting default agent header configuration."""
150-
config = await self.acp_service._get_agent_header_config("test-agent", None)
151-
152-
# Should return secure default (no headers allowed)
153-
assert config.strategy == "allowlist"
154-
assert config.allowed_headers == []
155-
assert config.max_header_size == 8192
156-
assert config.max_headers_count == 50
157-
158164

159165
if __name__ == "__main__":
160166
pytest.main([__file__])

0 commit comments

Comments
 (0)