diff --git a/src/agentex/lib/core/temporal/services/temporal_task_service.py b/src/agentex/lib/core/temporal/services/temporal_task_service.py index c9521719..5551ebcb 100644 --- a/src/agentex/lib/core/temporal/services/temporal_task_service.py +++ b/src/agentex/lib/core/temporal/services/temporal_task_service.py @@ -51,7 +51,7 @@ async def get_state(self, task_id: str) -> WorkflowState: workflow_id=task_id, ) - async def send_event(self, agent: Agent, task: Task, event: Event) -> None: + async def send_event(self, agent: Agent, task: Task, event: Event, request: dict | None = None) -> None: return await self._temporal_client.send_signal( workflow_id=task.id, signal=SignalName.RECEIVE_EVENT.value, @@ -59,6 +59,7 @@ async def send_event(self, agent: Agent, task: Task, event: Event) -> None: agent=agent, task=task, event=event, + request=request, ).model_dump(), ) diff --git a/src/agentex/lib/sdk/fastacp/base/base_acp_server.py b/src/agentex/lib/sdk/fastacp/base/base_acp_server.py index 42fc1520..455f17a7 100644 --- a/src/agentex/lib/sdk/fastacp/base/base_acp_server.py +++ b/src/agentex/lib/sdk/fastacp/base/base_acp_server.py @@ -154,12 +154,14 @@ async def _handle_jsonrpc(self, request: Request): ), ) - # Extract application headers, excluding sensitive/transport headers per FASTACP_* rules + # Extract application headers using allowlist approach (only x-* headers) + # Matches gateway's security filtering rules # Forward filtered headers via params.request.headers to agent handlers custom_headers = { key: value for key, value in request.headers.items() - if key.lower() not in FASTACP_HEADER_SKIP_EXACT + if key.lower().startswith("x-") + and key.lower() not in FASTACP_HEADER_SKIP_EXACT and not any(key.lower().startswith(p) for p in FASTACP_HEADER_SKIP_PREFIXES) } @@ -168,6 +170,7 @@ async def _handle_jsonrpc(self, request: Request): params_data = dict(rpc_request.params) if rpc_request.params else {} # Add custom headers to the request structure if any headers were provided + # Gateway sends filtered headers via HTTP, SDK extracts and populates params.request if custom_headers: params_data["request"] = {"headers": custom_headers} params = params_model.model_validate(params_data) diff --git a/src/agentex/lib/sdk/fastacp/base/constants.py b/src/agentex/lib/sdk/fastacp/base/constants.py index ed83ffd3..c04287e0 100644 --- a/src/agentex/lib/sdk/fastacp/base/constants.py +++ b/src/agentex/lib/sdk/fastacp/base/constants.py @@ -1,24 +1,36 @@ from __future__ import annotations # Header filtering rules for FastACP server +# These rules match the gateway's security filtering -# Prefixes to skip (case-insensitive beginswith checks) -FASTACP_HEADER_SKIP_PREFIXES: tuple[str, ...] = ( - "content-", +# Hop-by-hop headers that should not be forwarded +HOP_BY_HOP_HEADERS: set[str] = { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade", + "content-length", + "content-encoding", "host", - "user-agent", - "x-forwarded-", - "sec-", -) +} -# Exact header names to skip (case-insensitive matching done by lowercasing keys) -FASTACP_HEADER_SKIP_EXACT: set[str] = { - "x-agent-api-key", - "connection", - "accept-encoding", +# Sensitive headers that should never be forwarded +BLOCKED_HEADERS: set[str] = { + "authorization", "cookie", - "content-length", - "transfer-encoding", + "x-agent-api-key", } +# Legacy constants for backward compatibility +FASTACP_HEADER_SKIP_EXACT: set[str] = HOP_BY_HOP_HEADERS | BLOCKED_HEADERS + +FASTACP_HEADER_SKIP_PREFIXES: tuple[str, ...] = ( + "x-forwarded-", # proxy headers + "sec-", # security headers added by browsers +) + diff --git a/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py b/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py index 698c4055..9a8ebb19 100644 --- a/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py +++ b/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py @@ -93,6 +93,7 @@ async def handle_event_send(params: SendEventParams) -> None: agent=params.agent, task=params.task, event=params.event, + request=params.request, ) except Exception as e: diff --git a/tests/test_header_forwarding.py b/tests/test_header_forwarding.py index 9faacbf6..48cd636f 100644 --- a/tests/test_header_forwarding.py +++ b/tests/test_header_forwarding.py @@ -3,6 +3,8 @@ from typing import Any, override import sys import types +from datetime import datetime, timezone +from unittest.mock import AsyncMock, Mock import pytest from fastapi.testclient import TestClient @@ -44,8 +46,14 @@ class _StubTracer(_StubAsyncTracer): from agentex.lib.core.services.adk.acp.acp import ACPService from agentex.lib.sdk.fastacp.base.base_acp_server import BaseACPServer -from agentex.lib.types.acp import RPCMethod, SendMessageParams +from agentex.lib.types.acp import RPCMethod, SendMessageParams, SendEventParams from agentex.types.task_message_content import TextContent +from agentex.lib.sdk.fastacp.impl.temporal_acp import TemporalACP +from agentex.lib.core.temporal.services.temporal_task_service import TemporalTaskService +from agentex.lib.environment_variables import EnvironmentVariables +from agentex.types.agent import Agent +from agentex.types.task import Task +from agentex.types.event import Event class DummySpan: @@ -313,3 +321,221 @@ def test_filter_headers_all_types() -> None: assert result == expected + +# ============================================================================ +# Temporal Header Forwarding Tests +# ============================================================================ + +@pytest.fixture +def mock_temporal_client(): + """Create a mock TemporalClient""" + client = AsyncMock() + client.send_signal = AsyncMock(return_value=None) + return client + + +@pytest.fixture +def mock_env_vars(): + """Create mock environment variables""" + env_vars = Mock(spec=EnvironmentVariables) + env_vars.WORKFLOW_NAME = "test-workflow" + env_vars.WORKFLOW_TASK_QUEUE = "test-queue" + return env_vars + + +@pytest.fixture +def temporal_task_service(mock_temporal_client, mock_env_vars): + """Create TemporalTaskService with mocked client""" + return TemporalTaskService( + temporal_client=mock_temporal_client, + env_vars=mock_env_vars, + ) + + +@pytest.fixture +def sample_agent(): + """Create a sample agent""" + return Agent( + id="agent-123", + name="test-agent", + description="Test agent", + acp_type="agentic", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + +@pytest.fixture +def sample_task(): + """Create a sample task""" + return Task(id="task-456") + + +@pytest.fixture +def sample_event(): + """Create a sample event""" + return Event( + id="event-789", + agent_id="agent-123", + task_id="task-456", + sequence_id=1, + content=TextContent(author="user", content="Test message") + ) + + +@pytest.mark.asyncio +async def test_temporal_task_service_send_event_with_headers( + temporal_task_service, + mock_temporal_client, + sample_agent, + sample_task, + sample_event +): + """Test that TemporalTaskService forwards request headers in signal payload""" + # Given + request_headers = { + "x-user-oauth-credentials": "test-oauth-token", + "x-custom-header": "custom-value" + } + request = {"headers": request_headers} + + # When + await temporal_task_service.send_event( + agent=sample_agent, + task=sample_task, + event=sample_event, + request=request + ) + + # Then + mock_temporal_client.send_signal.assert_called_once() + call_args = mock_temporal_client.send_signal.call_args + + # Verify the signal was sent to the correct workflow + assert call_args.kwargs["workflow_id"] == sample_task.id + assert call_args.kwargs["signal"] == "receive_event" + + # Verify the payload includes the request with headers + payload = call_args.kwargs["payload"] + assert "request" in payload + assert payload["request"] == request + assert payload["request"]["headers"] == request_headers + + +@pytest.mark.asyncio +async def test_temporal_task_service_send_event_without_headers( + temporal_task_service, + mock_temporal_client, + sample_agent, + sample_task, + sample_event +): + """Test that TemporalTaskService handles missing request gracefully""" + # When - Send event without request parameter + await temporal_task_service.send_event( + agent=sample_agent, + task=sample_task, + event=sample_event, + request=None + ) + + # Then + mock_temporal_client.send_signal.assert_called_once() + call_args = mock_temporal_client.send_signal.call_args + + # Verify the payload has request as None + payload = call_args.kwargs["payload"] + assert payload["request"] is None + + +@pytest.mark.asyncio +async def test_temporal_acp_integration_with_request_headers( + mock_temporal_client, + mock_env_vars, + sample_agent, + sample_task, + sample_event +): + """Test end-to-end integration: TemporalACP -> TemporalTaskService -> TemporalClient signal""" + # Given - Create real TemporalTaskService with mocked client + task_service = TemporalTaskService( + temporal_client=mock_temporal_client, + env_vars=mock_env_vars, + ) + + # Create TemporalACP with real task service + temporal_acp = TemporalACP( + temporal_address="localhost:7233", + temporal_task_service=task_service, + ) + temporal_acp._setup_handlers() + + request_headers = { + "x-user-id": "user-123", + "authorization": "Bearer token", + "x-tenant-id": "tenant-456" + } + request = {"headers": request_headers} + + # Create SendEventParams as TemporalACP would receive it + params = SendEventParams( + agent=sample_agent, + task=sample_task, + event=sample_event, + request=request + ) + + # When - Trigger the event handler via the decorated function + # The handler is registered via @temporal_acp.on_task_event_send + # We'll directly call the task service method as the handler does + await task_service.send_event( + agent=params.agent, + task=params.task, + event=params.event, + request=params.request + ) + + # Then - Verify the temporal client received the signal with request headers + mock_temporal_client.send_signal.assert_called_once() + call_args = mock_temporal_client.send_signal.call_args + + # Verify signal payload includes request with headers + payload = call_args.kwargs["payload"] + assert payload["request"] == request + assert payload["request"]["headers"] == request_headers + + +@pytest.mark.asyncio +async def test_temporal_task_service_preserves_all_header_types( + temporal_task_service, + mock_temporal_client, + sample_agent, + sample_task, + sample_event +): + """Test that various header types are preserved correctly""" + # Given - Headers with different patterns + request_headers = { + "x-user-oauth-credentials": "oauth-token-12345", + "authorization": "Bearer jwt-token", + "x-tenant-id": "tenant-999", + "x-custom-app-header": "custom-value" + } + request = {"headers": request_headers} + + # When + await temporal_task_service.send_event( + agent=sample_agent, + task=sample_task, + event=sample_event, + request=request + ) + + # Then - Verify all headers are preserved in the signal payload + call_args = mock_temporal_client.send_signal.call_args + payload = call_args.kwargs["payload"] + + assert payload["request"]["headers"] == request_headers + # Verify each header individually + for header_name, header_value in request_headers.items(): + assert payload["request"]["headers"][header_name] == header_value