|
4 | 4 | import gc |
5 | 5 | import logging |
6 | 6 | import queue |
7 | | -import sys |
8 | 7 | import threading |
9 | 8 | import uuid |
10 | 9 | from concurrent.futures import ThreadPoolExecutor |
| 10 | +from contextlib import contextmanager |
11 | 11 | from dataclasses import dataclass |
12 | 12 | from datetime import timedelta |
13 | | -from typing import Iterable, List, Optional |
| 13 | +from typing import Callable, Dict, Generator, Iterable, List, Optional, cast |
14 | 14 |
|
15 | 15 | import opentelemetry.context |
16 | 16 | import pytest |
| 17 | +from opentelemetry import baggage, context |
17 | 18 | from opentelemetry.sdk.trace import ReadableSpan, TracerProvider |
18 | 19 | from opentelemetry.sdk.trace.export import SimpleSpanProcessor |
19 | 20 | from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter |
|
31 | 32 | from temporalio.testing import WorkflowEnvironment |
32 | 33 | from temporalio.worker import UnsandboxedWorkflowRunner, Worker |
33 | 34 | from tests.helpers import LogCapturer |
34 | | -from tests.helpers.cache_eviction import ( |
35 | | - CacheEvictionTearDownWorkflow, |
36 | | - WaitForeverWorkflow, |
37 | | - wait_forever_activity, |
38 | | -) |
39 | 35 |
|
40 | 36 |
|
41 | 37 | @dataclass |
@@ -558,6 +554,286 @@ async def test_opentelemetry_benign_exception(client: Client): |
558 | 554 | assert all(span.status.status_code == StatusCode.UNSET for span in spans) |
559 | 555 |
|
560 | 556 |
|
| 557 | +@contextmanager |
| 558 | +def baggage_values(values: Dict[str, str]) -> Generator[None, None, None]: |
| 559 | + ctx = context.get_current() |
| 560 | + for key, value in values.items(): |
| 561 | + ctx = baggage.set_baggage(key, value, context=ctx) |
| 562 | + |
| 563 | + token = context.attach(ctx) |
| 564 | + try: |
| 565 | + yield |
| 566 | + finally: |
| 567 | + context.detach(token) |
| 568 | + |
| 569 | + |
| 570 | +@pytest.fixture |
| 571 | +def client_with_tracing(client: Client) -> Client: |
| 572 | + tracer = get_tracer(__name__, tracer_provider=TracerProvider()) |
| 573 | + client_config = client.config() |
| 574 | + client_config["interceptors"] = [TracingInterceptor(tracer)] |
| 575 | + return Client(**client_config) |
| 576 | + |
| 577 | + |
| 578 | +def get_baggage_value(key: str) -> str: |
| 579 | + return cast("str", baggage.get_baggage(key)) |
| 580 | + |
| 581 | + |
| 582 | +@activity.defn |
| 583 | +async def read_baggage_activity() -> Dict[str, str]: |
| 584 | + return { |
| 585 | + "user_id": get_baggage_value("user.id"), |
| 586 | + "tenant_id": get_baggage_value("tenant.id"), |
| 587 | + } |
| 588 | + |
| 589 | + |
| 590 | +@workflow.defn |
| 591 | +class ReadBaggageTestWorkflow: |
| 592 | + @workflow.run |
| 593 | + async def run(self) -> Dict[str, str]: |
| 594 | + return await workflow.execute_activity( |
| 595 | + read_baggage_activity, |
| 596 | + start_to_close_timeout=timedelta(seconds=10), |
| 597 | + ) |
| 598 | + |
| 599 | + |
| 600 | +async def test_opentelemetry_baggage_propagation_basic( |
| 601 | + client_with_tracing: Client, env: WorkflowEnvironment |
| 602 | +): |
| 603 | + task_queue = f"task_queue_{uuid.uuid4()}" |
| 604 | + async with Worker( |
| 605 | + client_with_tracing, |
| 606 | + task_queue=task_queue, |
| 607 | + workflows=[ReadBaggageTestWorkflow], |
| 608 | + activities=[read_baggage_activity], |
| 609 | + ): |
| 610 | + with baggage_values({"user.id": "test-user-123", "tenant.id": "some-corp"}): |
| 611 | + result = await client_with_tracing.execute_workflow( |
| 612 | + ReadBaggageTestWorkflow.run, |
| 613 | + id=f"workflow_{uuid.uuid4()}", |
| 614 | + task_queue=task_queue, |
| 615 | + ) |
| 616 | + |
| 617 | + assert ( |
| 618 | + result["user_id"] == "test-user-123" |
| 619 | + ), "user.id baggage should propagate to activity" |
| 620 | + assert ( |
| 621 | + result["tenant_id"] == "some-corp" |
| 622 | + ), "tenant.id baggage should propagate to activity" |
| 623 | + |
| 624 | + |
| 625 | +@activity.defn |
| 626 | +async def read_baggage_local_activity() -> Dict[str, str]: |
| 627 | + return cast( |
| 628 | + Dict[str, str], |
| 629 | + { |
| 630 | + "user_id": get_baggage_value("user.id"), |
| 631 | + "tenant_id": get_baggage_value("tenant.id"), |
| 632 | + }, |
| 633 | + ) |
| 634 | + |
| 635 | + |
| 636 | +@workflow.defn |
| 637 | +class LocalActivityBaggageTestWorkflow: |
| 638 | + @workflow.run |
| 639 | + async def run(self) -> Dict[str, str]: |
| 640 | + return await workflow.execute_local_activity( |
| 641 | + read_baggage_local_activity, |
| 642 | + start_to_close_timeout=timedelta(seconds=10), |
| 643 | + ) |
| 644 | + |
| 645 | + |
| 646 | +async def test_opentelemetry_baggage_propagation_local_activity( |
| 647 | + client_with_tracing: Client, env: WorkflowEnvironment |
| 648 | +): |
| 649 | + task_queue = f"task_queue_{uuid.uuid4()}" |
| 650 | + async with Worker( |
| 651 | + client_with_tracing, |
| 652 | + task_queue=task_queue, |
| 653 | + workflows=[LocalActivityBaggageTestWorkflow], |
| 654 | + activities=[read_baggage_local_activity], |
| 655 | + ): |
| 656 | + with baggage_values( |
| 657 | + { |
| 658 | + "user.id": "test-user-456", |
| 659 | + "tenant.id": "local-corp", |
| 660 | + } |
| 661 | + ): |
| 662 | + result = await client_with_tracing.execute_workflow( |
| 663 | + LocalActivityBaggageTestWorkflow.run, |
| 664 | + id=f"workflow_{uuid.uuid4()}", |
| 665 | + task_queue=task_queue, |
| 666 | + ) |
| 667 | + |
| 668 | + assert result["user_id"] == "test-user-456" |
| 669 | + assert result["tenant_id"] == "local-corp" |
| 670 | + |
| 671 | + |
| 672 | +retry_attempt_baggage_values: List[str] = [] |
| 673 | + |
| 674 | + |
| 675 | +@activity.defn |
| 676 | +async def failing_baggage_activity() -> None: |
| 677 | + retry_attempt_baggage_values.append(get_baggage_value("user.id")) |
| 678 | + if activity.info().attempt < 2: |
| 679 | + raise RuntimeError("Intentional failure") |
| 680 | + |
| 681 | + |
| 682 | +@workflow.defn |
| 683 | +class RetryBaggageTestWorkflow: |
| 684 | + @workflow.run |
| 685 | + async def run(self) -> None: |
| 686 | + await workflow.execute_activity( |
| 687 | + failing_baggage_activity, |
| 688 | + start_to_close_timeout=timedelta(seconds=10), |
| 689 | + retry_policy=RetryPolicy(initial_interval=timedelta(milliseconds=1)), |
| 690 | + ) |
| 691 | + |
| 692 | + |
| 693 | +async def test_opentelemetry_baggage_propagation_with_retries( |
| 694 | + client_with_tracing: Client, env: WorkflowEnvironment |
| 695 | +) -> None: |
| 696 | + global retry_attempt_baggage_values |
| 697 | + retry_attempt_baggage_values = [] |
| 698 | + |
| 699 | + task_queue = f"task_queue_{uuid.uuid4()}" |
| 700 | + async with Worker( |
| 701 | + client_with_tracing, |
| 702 | + task_queue=task_queue, |
| 703 | + workflows=[RetryBaggageTestWorkflow], |
| 704 | + activities=[failing_baggage_activity], |
| 705 | + ): |
| 706 | + with baggage_values({"user.id": "test-user-retry"}): |
| 707 | + await client_with_tracing.execute_workflow( |
| 708 | + RetryBaggageTestWorkflow.run, |
| 709 | + id=f"workflow_{uuid.uuid4()}", |
| 710 | + task_queue=task_queue, |
| 711 | + ) |
| 712 | + |
| 713 | + # Verify baggage was present on all attempts |
| 714 | + assert len(retry_attempt_baggage_values) == 2 |
| 715 | + assert all(v == "test-user-retry" for v in retry_attempt_baggage_values) |
| 716 | + |
| 717 | + |
| 718 | +@activity.defn |
| 719 | +async def context_clear_noop_activity() -> None: |
| 720 | + pass |
| 721 | + |
| 722 | + |
| 723 | +@activity.defn |
| 724 | +async def context_clear_exception_activity() -> None: |
| 725 | + raise Exception("Simulated exception") |
| 726 | + |
| 727 | + |
| 728 | +@workflow.defn |
| 729 | +class ContextClearWorkflow: |
| 730 | + @workflow.run |
| 731 | + async def run(self) -> None: |
| 732 | + await workflow.execute_activity( |
| 733 | + context_clear_noop_activity, |
| 734 | + start_to_close_timeout=timedelta(seconds=10), |
| 735 | + retry_policy=RetryPolicy( |
| 736 | + maximum_attempts=1, initial_interval=timedelta(milliseconds=1) |
| 737 | + ), |
| 738 | + ) |
| 739 | + |
| 740 | + |
| 741 | +@pytest.mark.parametrize( |
| 742 | + "activity,expect_failure", |
| 743 | + [ |
| 744 | + (context_clear_noop_activity, not True), |
| 745 | + (context_clear_exception_activity, True), |
| 746 | + ], |
| 747 | +) |
| 748 | +async def test_opentelemetry_context_restored_after_activity( |
| 749 | + client_with_tracing: Client, |
| 750 | + env: WorkflowEnvironment, |
| 751 | + activity: Callable[[], None], |
| 752 | + expect_failure: bool, |
| 753 | +) -> None: |
| 754 | + attach_count = 0 |
| 755 | + detach_count = 0 |
| 756 | + original_attach = context.attach |
| 757 | + original_detach = context.detach |
| 758 | + |
| 759 | + def tracked_attach(ctx): |
| 760 | + nonlocal attach_count |
| 761 | + attach_count += 1 |
| 762 | + return original_attach(ctx) |
| 763 | + |
| 764 | + def tracked_detach(token): |
| 765 | + nonlocal detach_count |
| 766 | + detach_count += 1 |
| 767 | + return original_detach(token) |
| 768 | + |
| 769 | + context.attach = tracked_attach |
| 770 | + context.detach = tracked_detach |
| 771 | + |
| 772 | + try: |
| 773 | + task_queue = f"task_queue_{uuid.uuid4()}" |
| 774 | + async with Worker( |
| 775 | + client_with_tracing, |
| 776 | + task_queue=task_queue, |
| 777 | + workflows=[ContextClearWorkflow], |
| 778 | + activities=[activity], |
| 779 | + ): |
| 780 | + with baggage_values({"user.id": "test-123"}): |
| 781 | + try: |
| 782 | + await client_with_tracing.execute_workflow( |
| 783 | + ContextClearWorkflow.run, |
| 784 | + id=f"workflow_{uuid.uuid4()}", |
| 785 | + task_queue=task_queue, |
| 786 | + ) |
| 787 | + assert ( |
| 788 | + not expect_failure |
| 789 | + ), "This test should have raised an exception" |
| 790 | + except Exception: |
| 791 | + assert expect_failure, "This test is not expeced to raise" |
| 792 | + |
| 793 | + assert ( |
| 794 | + attach_count == detach_count |
| 795 | + ), f"Context leak detected: {attach_count} attaches vs {detach_count} detaches. " |
| 796 | + assert attach_count > 0, "Expected at least one context attach/detach" |
| 797 | + |
| 798 | + finally: |
| 799 | + context.attach = original_attach |
| 800 | + context.detach = original_detach |
| 801 | + |
| 802 | + |
| 803 | +@activity.defn |
| 804 | +async def simple_no_context_activity() -> str: |
| 805 | + return "success" |
| 806 | + |
| 807 | + |
| 808 | +@workflow.defn |
| 809 | +class SimpleNoContextWorkflow: |
| 810 | + @workflow.run |
| 811 | + async def run(self) -> str: |
| 812 | + return await workflow.execute_activity( |
| 813 | + simple_no_context_activity, |
| 814 | + start_to_close_timeout=timedelta(seconds=10), |
| 815 | + ) |
| 816 | + |
| 817 | + |
| 818 | +async def test_opentelemetry_interceptor_works_if_no_context( |
| 819 | + client_with_tracing: Client, env: WorkflowEnvironment |
| 820 | +): |
| 821 | + task_queue = f"task_queue_{uuid.uuid4()}" |
| 822 | + async with Worker( |
| 823 | + client_with_tracing, |
| 824 | + task_queue=task_queue, |
| 825 | + workflows=[SimpleNoContextWorkflow], |
| 826 | + activities=[simple_no_context_activity], |
| 827 | + ): |
| 828 | + result = await client_with_tracing.execute_workflow( |
| 829 | + SimpleNoContextWorkflow.run, |
| 830 | + id=f"workflow_{uuid.uuid4()}", |
| 831 | + task_queue=task_queue, |
| 832 | + ) |
| 833 | + |
| 834 | + assert result == "success" |
| 835 | + |
| 836 | + |
561 | 837 | # TODO(cretz): Additional tests to write |
562 | 838 | # * query without interceptor (no headers) |
563 | 839 | # * workflow without interceptor (no headers) but query with interceptor (headers) |
|
0 commit comments