Skip to content

Commit 67cb3ff

Browse files
[Fix] support header forwarding for temporal ACP (#140)
1 parent 2dd6623 commit 67cb3ff

File tree

5 files changed

+261
-18
lines changed

5 files changed

+261
-18
lines changed

src/agentex/lib/core/temporal/services/temporal_task_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,15 @@ async def get_state(self, task_id: str) -> WorkflowState:
5151
workflow_id=task_id,
5252
)
5353

54-
async def send_event(self, agent: Agent, task: Task, event: Event) -> None:
54+
async def send_event(self, agent: Agent, task: Task, event: Event, request: dict | None = None) -> None:
5555
return await self._temporal_client.send_signal(
5656
workflow_id=task.id,
5757
signal=SignalName.RECEIVE_EVENT.value,
5858
payload=SendEventParams(
5959
agent=agent,
6060
task=task,
6161
event=event,
62+
request=request,
6263
).model_dump(),
6364
)
6465

src/agentex/lib/sdk/fastacp/base/base_acp_server.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,14 @@ async def _handle_jsonrpc(self, request: Request):
154154
),
155155
)
156156

157-
# Extract application headers, excluding sensitive/transport headers per FASTACP_* rules
157+
# Extract application headers using allowlist approach (only x-* headers)
158+
# Matches gateway's security filtering rules
158159
# Forward filtered headers via params.request.headers to agent handlers
159160
custom_headers = {
160161
key: value
161162
for key, value in request.headers.items()
162-
if key.lower() not in FASTACP_HEADER_SKIP_EXACT
163+
if key.lower().startswith("x-")
164+
and key.lower() not in FASTACP_HEADER_SKIP_EXACT
163165
and not any(key.lower().startswith(p) for p in FASTACP_HEADER_SKIP_PREFIXES)
164166
}
165167

@@ -168,6 +170,7 @@ async def _handle_jsonrpc(self, request: Request):
168170
params_data = dict(rpc_request.params) if rpc_request.params else {}
169171

170172
# Add custom headers to the request structure if any headers were provided
173+
# Gateway sends filtered headers via HTTP, SDK extracts and populates params.request
171174
if custom_headers:
172175
params_data["request"] = {"headers": custom_headers}
173176
params = params_model.model_validate(params_data)
Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,36 @@
11
from __future__ import annotations
22

33
# Header filtering rules for FastACP server
4+
# These rules match the gateway's security filtering
45

5-
# Prefixes to skip (case-insensitive beginswith checks)
6-
FASTACP_HEADER_SKIP_PREFIXES: tuple[str, ...] = (
7-
"content-",
6+
# Hop-by-hop headers that should not be forwarded
7+
HOP_BY_HOP_HEADERS: set[str] = {
8+
"connection",
9+
"keep-alive",
10+
"proxy-authenticate",
11+
"proxy-authorization",
12+
"te",
13+
"trailer",
14+
"transfer-encoding",
15+
"upgrade",
16+
"content-length",
17+
"content-encoding",
818
"host",
9-
"user-agent",
10-
"x-forwarded-",
11-
"sec-",
12-
)
19+
}
1320

14-
# Exact header names to skip (case-insensitive matching done by lowercasing keys)
15-
FASTACP_HEADER_SKIP_EXACT: set[str] = {
16-
"x-agent-api-key",
17-
"connection",
18-
"accept-encoding",
21+
# Sensitive headers that should never be forwarded
22+
BLOCKED_HEADERS: set[str] = {
23+
"authorization",
1924
"cookie",
20-
"content-length",
21-
"transfer-encoding",
25+
"x-agent-api-key",
2226
}
2327

28+
# Legacy constants for backward compatibility
29+
FASTACP_HEADER_SKIP_EXACT: set[str] = HOP_BY_HOP_HEADERS | BLOCKED_HEADERS
30+
31+
FASTACP_HEADER_SKIP_PREFIXES: tuple[str, ...] = (
32+
"x-forwarded-", # proxy headers
33+
"sec-", # security headers added by browsers
34+
)
35+
2436

src/agentex/lib/sdk/fastacp/impl/temporal_acp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ async def handle_event_send(params: SendEventParams) -> None:
9393
agent=params.agent,
9494
task=params.task,
9595
event=params.event,
96+
request=params.request,
9697
)
9798

9899
except Exception as e:

tests/test_header_forwarding.py

Lines changed: 227 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import Any, override
44
import sys
55
import types
6+
from datetime import datetime, timezone
7+
from unittest.mock import AsyncMock, Mock
68

79
import pytest
810
from fastapi.testclient import TestClient
@@ -44,8 +46,14 @@ class _StubTracer(_StubAsyncTracer):
4446

4547
from agentex.lib.core.services.adk.acp.acp import ACPService
4648
from agentex.lib.sdk.fastacp.base.base_acp_server import BaseACPServer
47-
from agentex.lib.types.acp import RPCMethod, SendMessageParams
49+
from agentex.lib.types.acp import RPCMethod, SendMessageParams, SendEventParams
4850
from agentex.types.task_message_content import TextContent
51+
from agentex.lib.sdk.fastacp.impl.temporal_acp import TemporalACP
52+
from agentex.lib.core.temporal.services.temporal_task_service import TemporalTaskService
53+
from agentex.lib.environment_variables import EnvironmentVariables
54+
from agentex.types.agent import Agent
55+
from agentex.types.task import Task
56+
from agentex.types.event import Event
4957

5058

5159
class DummySpan:
@@ -313,3 +321,221 @@ def test_filter_headers_all_types() -> None:
313321
assert result == expected
314322

315323

324+
325+
# ============================================================================
326+
# Temporal Header Forwarding Tests
327+
# ============================================================================
328+
329+
@pytest.fixture
330+
def mock_temporal_client():
331+
"""Create a mock TemporalClient"""
332+
client = AsyncMock()
333+
client.send_signal = AsyncMock(return_value=None)
334+
return client
335+
336+
337+
@pytest.fixture
338+
def mock_env_vars():
339+
"""Create mock environment variables"""
340+
env_vars = Mock(spec=EnvironmentVariables)
341+
env_vars.WORKFLOW_NAME = "test-workflow"
342+
env_vars.WORKFLOW_TASK_QUEUE = "test-queue"
343+
return env_vars
344+
345+
346+
@pytest.fixture
347+
def temporal_task_service(mock_temporal_client, mock_env_vars):
348+
"""Create TemporalTaskService with mocked client"""
349+
return TemporalTaskService(
350+
temporal_client=mock_temporal_client,
351+
env_vars=mock_env_vars,
352+
)
353+
354+
355+
@pytest.fixture
356+
def sample_agent():
357+
"""Create a sample agent"""
358+
return Agent(
359+
id="agent-123",
360+
name="test-agent",
361+
description="Test agent",
362+
acp_type="agentic",
363+
created_at=datetime.now(timezone.utc),
364+
updated_at=datetime.now(timezone.utc)
365+
)
366+
367+
368+
@pytest.fixture
369+
def sample_task():
370+
"""Create a sample task"""
371+
return Task(id="task-456")
372+
373+
374+
@pytest.fixture
375+
def sample_event():
376+
"""Create a sample event"""
377+
return Event(
378+
id="event-789",
379+
agent_id="agent-123",
380+
task_id="task-456",
381+
sequence_id=1,
382+
content=TextContent(author="user", content="Test message")
383+
)
384+
385+
386+
@pytest.mark.asyncio
387+
async def test_temporal_task_service_send_event_with_headers(
388+
temporal_task_service,
389+
mock_temporal_client,
390+
sample_agent,
391+
sample_task,
392+
sample_event
393+
):
394+
"""Test that TemporalTaskService forwards request headers in signal payload"""
395+
# Given
396+
request_headers = {
397+
"x-user-oauth-credentials": "test-oauth-token",
398+
"x-custom-header": "custom-value"
399+
}
400+
request = {"headers": request_headers}
401+
402+
# When
403+
await temporal_task_service.send_event(
404+
agent=sample_agent,
405+
task=sample_task,
406+
event=sample_event,
407+
request=request
408+
)
409+
410+
# Then
411+
mock_temporal_client.send_signal.assert_called_once()
412+
call_args = mock_temporal_client.send_signal.call_args
413+
414+
# Verify the signal was sent to the correct workflow
415+
assert call_args.kwargs["workflow_id"] == sample_task.id
416+
assert call_args.kwargs["signal"] == "receive_event"
417+
418+
# Verify the payload includes the request with headers
419+
payload = call_args.kwargs["payload"]
420+
assert "request" in payload
421+
assert payload["request"] == request
422+
assert payload["request"]["headers"] == request_headers
423+
424+
425+
@pytest.mark.asyncio
426+
async def test_temporal_task_service_send_event_without_headers(
427+
temporal_task_service,
428+
mock_temporal_client,
429+
sample_agent,
430+
sample_task,
431+
sample_event
432+
):
433+
"""Test that TemporalTaskService handles missing request gracefully"""
434+
# When - Send event without request parameter
435+
await temporal_task_service.send_event(
436+
agent=sample_agent,
437+
task=sample_task,
438+
event=sample_event,
439+
request=None
440+
)
441+
442+
# Then
443+
mock_temporal_client.send_signal.assert_called_once()
444+
call_args = mock_temporal_client.send_signal.call_args
445+
446+
# Verify the payload has request as None
447+
payload = call_args.kwargs["payload"]
448+
assert payload["request"] is None
449+
450+
451+
@pytest.mark.asyncio
452+
async def test_temporal_acp_integration_with_request_headers(
453+
mock_temporal_client,
454+
mock_env_vars,
455+
sample_agent,
456+
sample_task,
457+
sample_event
458+
):
459+
"""Test end-to-end integration: TemporalACP -> TemporalTaskService -> TemporalClient signal"""
460+
# Given - Create real TemporalTaskService with mocked client
461+
task_service = TemporalTaskService(
462+
temporal_client=mock_temporal_client,
463+
env_vars=mock_env_vars,
464+
)
465+
466+
# Create TemporalACP with real task service
467+
temporal_acp = TemporalACP(
468+
temporal_address="localhost:7233",
469+
temporal_task_service=task_service,
470+
)
471+
temporal_acp._setup_handlers()
472+
473+
request_headers = {
474+
"x-user-id": "user-123",
475+
"authorization": "Bearer token",
476+
"x-tenant-id": "tenant-456"
477+
}
478+
request = {"headers": request_headers}
479+
480+
# Create SendEventParams as TemporalACP would receive it
481+
params = SendEventParams(
482+
agent=sample_agent,
483+
task=sample_task,
484+
event=sample_event,
485+
request=request
486+
)
487+
488+
# When - Trigger the event handler via the decorated function
489+
# The handler is registered via @temporal_acp.on_task_event_send
490+
# We'll directly call the task service method as the handler does
491+
await task_service.send_event(
492+
agent=params.agent,
493+
task=params.task,
494+
event=params.event,
495+
request=params.request
496+
)
497+
498+
# Then - Verify the temporal client received the signal with request headers
499+
mock_temporal_client.send_signal.assert_called_once()
500+
call_args = mock_temporal_client.send_signal.call_args
501+
502+
# Verify signal payload includes request with headers
503+
payload = call_args.kwargs["payload"]
504+
assert payload["request"] == request
505+
assert payload["request"]["headers"] == request_headers
506+
507+
508+
@pytest.mark.asyncio
509+
async def test_temporal_task_service_preserves_all_header_types(
510+
temporal_task_service,
511+
mock_temporal_client,
512+
sample_agent,
513+
sample_task,
514+
sample_event
515+
):
516+
"""Test that various header types are preserved correctly"""
517+
# Given - Headers with different patterns
518+
request_headers = {
519+
"x-user-oauth-credentials": "oauth-token-12345",
520+
"authorization": "Bearer jwt-token",
521+
"x-tenant-id": "tenant-999",
522+
"x-custom-app-header": "custom-value"
523+
}
524+
request = {"headers": request_headers}
525+
526+
# When
527+
await temporal_task_service.send_event(
528+
agent=sample_agent,
529+
task=sample_task,
530+
event=sample_event,
531+
request=request
532+
)
533+
534+
# Then - Verify all headers are preserved in the signal payload
535+
call_args = mock_temporal_client.send_signal.call_args
536+
payload = call_args.kwargs["payload"]
537+
538+
assert payload["request"]["headers"] == request_headers
539+
# Verify each header individually
540+
for header_name, header_value in request_headers.items():
541+
assert payload["request"]["headers"][header_name] == header_value

0 commit comments

Comments
 (0)