|
3 | 3 | from typing import Any, override |
4 | 4 | import sys |
5 | 5 | import types |
| 6 | +from datetime import datetime, timezone |
| 7 | +from unittest.mock import AsyncMock, Mock |
6 | 8 |
|
7 | 9 | import pytest |
8 | 10 | from fastapi.testclient import TestClient |
@@ -44,8 +46,14 @@ class _StubTracer(_StubAsyncTracer): |
44 | 46 |
|
45 | 47 | from agentex.lib.core.services.adk.acp.acp import ACPService |
46 | 48 | from agentex.lib.sdk.fastacp.base.base_acp_server import BaseACPServer |
47 | | -from agentex.lib.types.acp import RPCMethod, SendMessageParams |
| 49 | +from agentex.lib.types.acp import RPCMethod, SendMessageParams, SendEventParams |
48 | 50 | from agentex.types.task_message_content import TextContent |
| 51 | +from agentex.lib.sdk.fastacp.impl.temporal_acp import TemporalACP |
| 52 | +from agentex.lib.core.temporal.services.temporal_task_service import TemporalTaskService |
| 53 | +from agentex.lib.environment_variables import EnvironmentVariables |
| 54 | +from agentex.types.agent import Agent |
| 55 | +from agentex.types.task import Task |
| 56 | +from agentex.types.event import Event |
49 | 57 |
|
50 | 58 |
|
51 | 59 | class DummySpan: |
@@ -313,3 +321,221 @@ def test_filter_headers_all_types() -> None: |
313 | 321 | assert result == expected |
314 | 322 |
|
315 | 323 |
|
| 324 | + |
| 325 | +# ============================================================================ |
| 326 | +# Temporal Header Forwarding Tests |
| 327 | +# ============================================================================ |
| 328 | + |
| 329 | +@pytest.fixture |
| 330 | +def mock_temporal_client(): |
| 331 | + """Create a mock TemporalClient""" |
| 332 | + client = AsyncMock() |
| 333 | + client.send_signal = AsyncMock(return_value=None) |
| 334 | + return client |
| 335 | + |
| 336 | + |
| 337 | +@pytest.fixture |
| 338 | +def mock_env_vars(): |
| 339 | + """Create mock environment variables""" |
| 340 | + env_vars = Mock(spec=EnvironmentVariables) |
| 341 | + env_vars.WORKFLOW_NAME = "test-workflow" |
| 342 | + env_vars.WORKFLOW_TASK_QUEUE = "test-queue" |
| 343 | + return env_vars |
| 344 | + |
| 345 | + |
| 346 | +@pytest.fixture |
| 347 | +def temporal_task_service(mock_temporal_client, mock_env_vars): |
| 348 | + """Create TemporalTaskService with mocked client""" |
| 349 | + return TemporalTaskService( |
| 350 | + temporal_client=mock_temporal_client, |
| 351 | + env_vars=mock_env_vars, |
| 352 | + ) |
| 353 | + |
| 354 | + |
| 355 | +@pytest.fixture |
| 356 | +def sample_agent(): |
| 357 | + """Create a sample agent""" |
| 358 | + return Agent( |
| 359 | + id="agent-123", |
| 360 | + name="test-agent", |
| 361 | + description="Test agent", |
| 362 | + acp_type="agentic", |
| 363 | + created_at=datetime.now(timezone.utc), |
| 364 | + updated_at=datetime.now(timezone.utc) |
| 365 | + ) |
| 366 | + |
| 367 | + |
| 368 | +@pytest.fixture |
| 369 | +def sample_task(): |
| 370 | + """Create a sample task""" |
| 371 | + return Task(id="task-456") |
| 372 | + |
| 373 | + |
| 374 | +@pytest.fixture |
| 375 | +def sample_event(): |
| 376 | + """Create a sample event""" |
| 377 | + return Event( |
| 378 | + id="event-789", |
| 379 | + agent_id="agent-123", |
| 380 | + task_id="task-456", |
| 381 | + sequence_id=1, |
| 382 | + content=TextContent(author="user", content="Test message") |
| 383 | + ) |
| 384 | + |
| 385 | + |
| 386 | +@pytest.mark.asyncio |
| 387 | +async def test_temporal_task_service_send_event_with_headers( |
| 388 | + temporal_task_service, |
| 389 | + mock_temporal_client, |
| 390 | + sample_agent, |
| 391 | + sample_task, |
| 392 | + sample_event |
| 393 | +): |
| 394 | + """Test that TemporalTaskService forwards request headers in signal payload""" |
| 395 | + # Given |
| 396 | + request_headers = { |
| 397 | + "x-user-oauth-credentials": "test-oauth-token", |
| 398 | + "x-custom-header": "custom-value" |
| 399 | + } |
| 400 | + request = {"headers": request_headers} |
| 401 | + |
| 402 | + # When |
| 403 | + await temporal_task_service.send_event( |
| 404 | + agent=sample_agent, |
| 405 | + task=sample_task, |
| 406 | + event=sample_event, |
| 407 | + request=request |
| 408 | + ) |
| 409 | + |
| 410 | + # Then |
| 411 | + mock_temporal_client.send_signal.assert_called_once() |
| 412 | + call_args = mock_temporal_client.send_signal.call_args |
| 413 | + |
| 414 | + # Verify the signal was sent to the correct workflow |
| 415 | + assert call_args.kwargs["workflow_id"] == sample_task.id |
| 416 | + assert call_args.kwargs["signal"] == "receive_event" |
| 417 | + |
| 418 | + # Verify the payload includes the request with headers |
| 419 | + payload = call_args.kwargs["payload"] |
| 420 | + assert "request" in payload |
| 421 | + assert payload["request"] == request |
| 422 | + assert payload["request"]["headers"] == request_headers |
| 423 | + |
| 424 | + |
| 425 | +@pytest.mark.asyncio |
| 426 | +async def test_temporal_task_service_send_event_without_headers( |
| 427 | + temporal_task_service, |
| 428 | + mock_temporal_client, |
| 429 | + sample_agent, |
| 430 | + sample_task, |
| 431 | + sample_event |
| 432 | +): |
| 433 | + """Test that TemporalTaskService handles missing request gracefully""" |
| 434 | + # When - Send event without request parameter |
| 435 | + await temporal_task_service.send_event( |
| 436 | + agent=sample_agent, |
| 437 | + task=sample_task, |
| 438 | + event=sample_event, |
| 439 | + request=None |
| 440 | + ) |
| 441 | + |
| 442 | + # Then |
| 443 | + mock_temporal_client.send_signal.assert_called_once() |
| 444 | + call_args = mock_temporal_client.send_signal.call_args |
| 445 | + |
| 446 | + # Verify the payload has request as None |
| 447 | + payload = call_args.kwargs["payload"] |
| 448 | + assert payload["request"] is None |
| 449 | + |
| 450 | + |
| 451 | +@pytest.mark.asyncio |
| 452 | +async def test_temporal_acp_integration_with_request_headers( |
| 453 | + mock_temporal_client, |
| 454 | + mock_env_vars, |
| 455 | + sample_agent, |
| 456 | + sample_task, |
| 457 | + sample_event |
| 458 | +): |
| 459 | + """Test end-to-end integration: TemporalACP -> TemporalTaskService -> TemporalClient signal""" |
| 460 | + # Given - Create real TemporalTaskService with mocked client |
| 461 | + task_service = TemporalTaskService( |
| 462 | + temporal_client=mock_temporal_client, |
| 463 | + env_vars=mock_env_vars, |
| 464 | + ) |
| 465 | + |
| 466 | + # Create TemporalACP with real task service |
| 467 | + temporal_acp = TemporalACP( |
| 468 | + temporal_address="localhost:7233", |
| 469 | + temporal_task_service=task_service, |
| 470 | + ) |
| 471 | + temporal_acp._setup_handlers() |
| 472 | + |
| 473 | + request_headers = { |
| 474 | + "x-user-id": "user-123", |
| 475 | + "authorization": "Bearer token", |
| 476 | + "x-tenant-id": "tenant-456" |
| 477 | + } |
| 478 | + request = {"headers": request_headers} |
| 479 | + |
| 480 | + # Create SendEventParams as TemporalACP would receive it |
| 481 | + params = SendEventParams( |
| 482 | + agent=sample_agent, |
| 483 | + task=sample_task, |
| 484 | + event=sample_event, |
| 485 | + request=request |
| 486 | + ) |
| 487 | + |
| 488 | + # When - Trigger the event handler via the decorated function |
| 489 | + # The handler is registered via @temporal_acp.on_task_event_send |
| 490 | + # We'll directly call the task service method as the handler does |
| 491 | + await task_service.send_event( |
| 492 | + agent=params.agent, |
| 493 | + task=params.task, |
| 494 | + event=params.event, |
| 495 | + request=params.request |
| 496 | + ) |
| 497 | + |
| 498 | + # Then - Verify the temporal client received the signal with request headers |
| 499 | + mock_temporal_client.send_signal.assert_called_once() |
| 500 | + call_args = mock_temporal_client.send_signal.call_args |
| 501 | + |
| 502 | + # Verify signal payload includes request with headers |
| 503 | + payload = call_args.kwargs["payload"] |
| 504 | + assert payload["request"] == request |
| 505 | + assert payload["request"]["headers"] == request_headers |
| 506 | + |
| 507 | + |
| 508 | +@pytest.mark.asyncio |
| 509 | +async def test_temporal_task_service_preserves_all_header_types( |
| 510 | + temporal_task_service, |
| 511 | + mock_temporal_client, |
| 512 | + sample_agent, |
| 513 | + sample_task, |
| 514 | + sample_event |
| 515 | +): |
| 516 | + """Test that various header types are preserved correctly""" |
| 517 | + # Given - Headers with different patterns |
| 518 | + request_headers = { |
| 519 | + "x-user-oauth-credentials": "oauth-token-12345", |
| 520 | + "authorization": "Bearer jwt-token", |
| 521 | + "x-tenant-id": "tenant-999", |
| 522 | + "x-custom-app-header": "custom-value" |
| 523 | + } |
| 524 | + request = {"headers": request_headers} |
| 525 | + |
| 526 | + # When |
| 527 | + await temporal_task_service.send_event( |
| 528 | + agent=sample_agent, |
| 529 | + task=sample_task, |
| 530 | + event=sample_event, |
| 531 | + request=request |
| 532 | + ) |
| 533 | + |
| 534 | + # Then - Verify all headers are preserved in the signal payload |
| 535 | + call_args = mock_temporal_client.send_signal.call_args |
| 536 | + payload = call_args.kwargs["payload"] |
| 537 | + |
| 538 | + assert payload["request"]["headers"] == request_headers |
| 539 | + # Verify each header individually |
| 540 | + for header_name, header_value in request_headers.items(): |
| 541 | + assert payload["request"]["headers"][header_name] == header_value |
0 commit comments