diff --git a/.gitignore b/.gitignore index c31f84940..c3447e5d1 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ temporalio/bridge/temporal_sdk_bridge* /sdk-python.iml /.zed *.DS_Store +tags diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 351a2e42f..7dfd920ef 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -172,29 +172,34 @@ def _start_as_current_span( kind: opentelemetry.trace.SpanKind, context: Optional[Context] = None, ) -> Iterator[None]: - with self.tracer.start_as_current_span( - name, - attributes=attributes, - kind=kind, - context=context, - set_status_on_exception=False, - ) as span: - if input: - input.headers = self._context_to_headers(input.headers) - try: - yield None - except Exception as exc: - if ( - not isinstance(exc, ApplicationError) - or exc.category != ApplicationErrorCategory.BENIGN - ): - span.set_status( - Status( - status_code=StatusCode.ERROR, - description=f"{type(exc).__name__}: {exc}", + token = opentelemetry.context.attach(context) if context else None + try: + with self.tracer.start_as_current_span( + name, + attributes=attributes, + kind=kind, + context=context, + set_status_on_exception=False, + ) as span: + if input: + input.headers = self._context_to_headers(input.headers) + try: + yield None + except Exception as exc: + if ( + not isinstance(exc, ApplicationError) + or exc.category != ApplicationErrorCategory.BENIGN + ): + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) ) - ) - raise + raise + finally: + if token and context is opentelemetry.context.get_current(): + opentelemetry.context.detach(token) def _completed_workflow_span( self, params: _CompletedWorkflowSpanParams diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 9dbdfed93..fb4759be9 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -4,16 +4,17 @@ import gc import logging import queue -import sys import threading import uuid from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta -from typing import Iterable, List, Optional +from typing import Callable, Dict, Generator, Iterable, List, Optional, cast import opentelemetry.context import pytest +from opentelemetry import baggage, context from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter @@ -31,11 +32,6 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers import LogCapturer -from tests.helpers.cache_eviction import ( - CacheEvictionTearDownWorkflow, - WaitForeverWorkflow, - wait_forever_activity, -) @dataclass @@ -558,6 +554,286 @@ async def test_opentelemetry_benign_exception(client: Client): assert all(span.status.status_code == StatusCode.UNSET for span in spans) +@contextmanager +def baggage_values(values: Dict[str, str]) -> Generator[None, None, None]: + ctx = context.get_current() + for key, value in values.items(): + ctx = baggage.set_baggage(key, value, context=ctx) + + token = context.attach(ctx) + try: + yield + finally: + context.detach(token) + + +@pytest.fixture +def client_with_tracing(client: Client) -> Client: + tracer = get_tracer(__name__, tracer_provider=TracerProvider()) + client_config = client.config() + client_config["interceptors"] = [TracingInterceptor(tracer)] + return Client(**client_config) + + +def get_baggage_value(key: str) -> str: + return cast("str", baggage.get_baggage(key)) + + +@activity.defn +async def read_baggage_activity() -> Dict[str, str]: + return { + "user_id": get_baggage_value("user.id"), + "tenant_id": get_baggage_value("tenant.id"), + } + + +@workflow.defn +class ReadBaggageTestWorkflow: + @workflow.run + async def run(self) -> Dict[str, str]: + return await workflow.execute_activity( + read_baggage_activity, + start_to_close_timeout=timedelta(seconds=10), + ) + + +async def test_opentelemetry_baggage_propagation_basic( + client_with_tracing: Client, env: WorkflowEnvironment +): + task_queue = f"task_queue_{uuid.uuid4()}" + async with Worker( + client_with_tracing, + task_queue=task_queue, + workflows=[ReadBaggageTestWorkflow], + activities=[read_baggage_activity], + ): + with baggage_values({"user.id": "test-user-123", "tenant.id": "some-corp"}): + result = await client_with_tracing.execute_workflow( + ReadBaggageTestWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=task_queue, + ) + + assert ( + result["user_id"] == "test-user-123" + ), "user.id baggage should propagate to activity" + assert ( + result["tenant_id"] == "some-corp" + ), "tenant.id baggage should propagate to activity" + + +@activity.defn +async def read_baggage_local_activity() -> Dict[str, str]: + return cast( + Dict[str, str], + { + "user_id": get_baggage_value("user.id"), + "tenant_id": get_baggage_value("tenant.id"), + }, + ) + + +@workflow.defn +class LocalActivityBaggageTestWorkflow: + @workflow.run + async def run(self) -> Dict[str, str]: + return await workflow.execute_local_activity( + read_baggage_local_activity, + start_to_close_timeout=timedelta(seconds=10), + ) + + +async def test_opentelemetry_baggage_propagation_local_activity( + client_with_tracing: Client, env: WorkflowEnvironment +): + task_queue = f"task_queue_{uuid.uuid4()}" + async with Worker( + client_with_tracing, + task_queue=task_queue, + workflows=[LocalActivityBaggageTestWorkflow], + activities=[read_baggage_local_activity], + ): + with baggage_values( + { + "user.id": "test-user-456", + "tenant.id": "local-corp", + } + ): + result = await client_with_tracing.execute_workflow( + LocalActivityBaggageTestWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=task_queue, + ) + + assert result["user_id"] == "test-user-456" + assert result["tenant_id"] == "local-corp" + + +retry_attempt_baggage_values: List[str] = [] + + +@activity.defn +async def failing_baggage_activity() -> None: + retry_attempt_baggage_values.append(get_baggage_value("user.id")) + if activity.info().attempt < 2: + raise RuntimeError("Intentional failure") + + +@workflow.defn +class RetryBaggageTestWorkflow: + @workflow.run + async def run(self) -> None: + await workflow.execute_activity( + failing_baggage_activity, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy(initial_interval=timedelta(milliseconds=1)), + ) + + +async def test_opentelemetry_baggage_propagation_with_retries( + client_with_tracing: Client, env: WorkflowEnvironment +) -> None: + global retry_attempt_baggage_values + retry_attempt_baggage_values = [] + + task_queue = f"task_queue_{uuid.uuid4()}" + async with Worker( + client_with_tracing, + task_queue=task_queue, + workflows=[RetryBaggageTestWorkflow], + activities=[failing_baggage_activity], + ): + with baggage_values({"user.id": "test-user-retry"}): + await client_with_tracing.execute_workflow( + RetryBaggageTestWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=task_queue, + ) + + # Verify baggage was present on all attempts + assert len(retry_attempt_baggage_values) == 2 + assert all(v == "test-user-retry" for v in retry_attempt_baggage_values) + + +@activity.defn +async def context_clear_noop_activity() -> None: + pass + + +@activity.defn +async def context_clear_exception_activity() -> None: + raise Exception("Simulated exception") + + +@workflow.defn +class ContextClearWorkflow: + @workflow.run + async def run(self) -> None: + await workflow.execute_activity( + context_clear_noop_activity, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + maximum_attempts=1, initial_interval=timedelta(milliseconds=1) + ), + ) + + +@pytest.mark.parametrize( + "activity,expect_failure", + [ + (context_clear_noop_activity, not True), + (context_clear_exception_activity, True), + ], +) +async def test_opentelemetry_context_restored_after_activity( + client_with_tracing: Client, + env: WorkflowEnvironment, + activity: Callable[[], None], + expect_failure: bool, +) -> None: + attach_count = 0 + detach_count = 0 + original_attach = context.attach + original_detach = context.detach + + def tracked_attach(ctx): + nonlocal attach_count + attach_count += 1 + return original_attach(ctx) + + def tracked_detach(token): + nonlocal detach_count + detach_count += 1 + return original_detach(token) + + context.attach = tracked_attach + context.detach = tracked_detach + + try: + task_queue = f"task_queue_{uuid.uuid4()}" + async with Worker( + client_with_tracing, + task_queue=task_queue, + workflows=[ContextClearWorkflow], + activities=[activity], + ): + with baggage_values({"user.id": "test-123"}): + try: + await client_with_tracing.execute_workflow( + ContextClearWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=task_queue, + ) + assert ( + not expect_failure + ), "This test should have raised an exception" + except Exception: + assert expect_failure, "This test is not expeced to raise" + + assert ( + attach_count == detach_count + ), f"Context leak detected: {attach_count} attaches vs {detach_count} detaches. " + assert attach_count > 0, "Expected at least one context attach/detach" + + finally: + context.attach = original_attach + context.detach = original_detach + + +@activity.defn +async def simple_no_context_activity() -> str: + return "success" + + +@workflow.defn +class SimpleNoContextWorkflow: + @workflow.run + async def run(self) -> str: + return await workflow.execute_activity( + simple_no_context_activity, + start_to_close_timeout=timedelta(seconds=10), + ) + + +async def test_opentelemetry_interceptor_works_if_no_context( + client_with_tracing: Client, env: WorkflowEnvironment +): + task_queue = f"task_queue_{uuid.uuid4()}" + async with Worker( + client_with_tracing, + task_queue=task_queue, + workflows=[SimpleNoContextWorkflow], + activities=[simple_no_context_activity], + ): + result = await client_with_tracing.execute_workflow( + SimpleNoContextWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=task_queue, + ) + + assert result == "success" + + # TODO(cretz): Additional tests to write # * query without interceptor (no headers) # * workflow without interceptor (no headers) but query with interceptor (headers)