Skip to content

Commit 8f754b4

Browse files
romank0tconley1428
andauthored
Opentelemetry baggage propagation fix (#1174)
* Add tags to gitignore * adds failing baggage propagation test * adds baggage propagation in activity interceptor Change the activity interceptor to use context.attach()/detach() pattern instead of passing context as a parameter to start_as_current_span(). The fix follows the standard OpenTelemetry pattern used by other instrumentations (django, gRPC, etc.) and ensures proper context management with try/finally for detach. * adds baggage propagation tests Add additional tests to verify baggage propagation in scenarios: - multiple values - local activity - retries in activity * adds more tests for baggage propagation Two important edge case tests: - exceptions handling - when no current context is available * moves context handling to _start_as_current_span * cleanup and improvements after review * adds test for context cleanup in the interceptor * fixes static checks * only clear context if exit on the same thread as enter * removes global constant --------- Co-authored-by: tconley1428 <[email protected]>
1 parent 231cc67 commit 8f754b4

File tree

3 files changed

+311
-29
lines changed

3 files changed

+311
-29
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ temporalio/bridge/temporal_sdk_bridge*
1010
/sdk-python.iml
1111
/.zed
1212
*.DS_Store
13+
tags

temporalio/contrib/opentelemetry.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -172,29 +172,34 @@ def _start_as_current_span(
172172
kind: opentelemetry.trace.SpanKind,
173173
context: Optional[Context] = None,
174174
) -> Iterator[None]:
175-
with self.tracer.start_as_current_span(
176-
name,
177-
attributes=attributes,
178-
kind=kind,
179-
context=context,
180-
set_status_on_exception=False,
181-
) as span:
182-
if input:
183-
input.headers = self._context_to_headers(input.headers)
184-
try:
185-
yield None
186-
except Exception as exc:
187-
if (
188-
not isinstance(exc, ApplicationError)
189-
or exc.category != ApplicationErrorCategory.BENIGN
190-
):
191-
span.set_status(
192-
Status(
193-
status_code=StatusCode.ERROR,
194-
description=f"{type(exc).__name__}: {exc}",
175+
token = opentelemetry.context.attach(context) if context else None
176+
try:
177+
with self.tracer.start_as_current_span(
178+
name,
179+
attributes=attributes,
180+
kind=kind,
181+
context=context,
182+
set_status_on_exception=False,
183+
) as span:
184+
if input:
185+
input.headers = self._context_to_headers(input.headers)
186+
try:
187+
yield None
188+
except Exception as exc:
189+
if (
190+
not isinstance(exc, ApplicationError)
191+
or exc.category != ApplicationErrorCategory.BENIGN
192+
):
193+
span.set_status(
194+
Status(
195+
status_code=StatusCode.ERROR,
196+
description=f"{type(exc).__name__}: {exc}",
197+
)
195198
)
196-
)
197-
raise
199+
raise
200+
finally:
201+
if token and context is opentelemetry.context.get_current():
202+
opentelemetry.context.detach(token)
198203

199204
def _completed_workflow_span(
200205
self, params: _CompletedWorkflowSpanParams

tests/contrib/test_opentelemetry.py

Lines changed: 283 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
import gc
55
import logging
66
import queue
7-
import sys
87
import threading
98
import uuid
109
from concurrent.futures import ThreadPoolExecutor
10+
from contextlib import contextmanager
1111
from dataclasses import dataclass
1212
from datetime import timedelta
13-
from typing import Iterable, List, Optional
13+
from typing import Callable, Dict, Generator, Iterable, List, Optional, cast
1414

1515
import opentelemetry.context
1616
import pytest
17+
from opentelemetry import baggage, context
1718
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
1819
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
1920
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
@@ -31,11 +32,6 @@
3132
from temporalio.testing import WorkflowEnvironment
3233
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
3334
from tests.helpers import LogCapturer
34-
from tests.helpers.cache_eviction import (
35-
CacheEvictionTearDownWorkflow,
36-
WaitForeverWorkflow,
37-
wait_forever_activity,
38-
)
3935

4036

4137
@dataclass
@@ -558,6 +554,286 @@ async def test_opentelemetry_benign_exception(client: Client):
558554
assert all(span.status.status_code == StatusCode.UNSET for span in spans)
559555

560556

557+
@contextmanager
558+
def baggage_values(values: Dict[str, str]) -> Generator[None, None, None]:
559+
ctx = context.get_current()
560+
for key, value in values.items():
561+
ctx = baggage.set_baggage(key, value, context=ctx)
562+
563+
token = context.attach(ctx)
564+
try:
565+
yield
566+
finally:
567+
context.detach(token)
568+
569+
570+
@pytest.fixture
571+
def client_with_tracing(client: Client) -> Client:
572+
tracer = get_tracer(__name__, tracer_provider=TracerProvider())
573+
client_config = client.config()
574+
client_config["interceptors"] = [TracingInterceptor(tracer)]
575+
return Client(**client_config)
576+
577+
578+
def get_baggage_value(key: str) -> str:
579+
return cast("str", baggage.get_baggage(key))
580+
581+
582+
@activity.defn
583+
async def read_baggage_activity() -> Dict[str, str]:
584+
return {
585+
"user_id": get_baggage_value("user.id"),
586+
"tenant_id": get_baggage_value("tenant.id"),
587+
}
588+
589+
590+
@workflow.defn
591+
class ReadBaggageTestWorkflow:
592+
@workflow.run
593+
async def run(self) -> Dict[str, str]:
594+
return await workflow.execute_activity(
595+
read_baggage_activity,
596+
start_to_close_timeout=timedelta(seconds=10),
597+
)
598+
599+
600+
async def test_opentelemetry_baggage_propagation_basic(
601+
client_with_tracing: Client, env: WorkflowEnvironment
602+
):
603+
task_queue = f"task_queue_{uuid.uuid4()}"
604+
async with Worker(
605+
client_with_tracing,
606+
task_queue=task_queue,
607+
workflows=[ReadBaggageTestWorkflow],
608+
activities=[read_baggage_activity],
609+
):
610+
with baggage_values({"user.id": "test-user-123", "tenant.id": "some-corp"}):
611+
result = await client_with_tracing.execute_workflow(
612+
ReadBaggageTestWorkflow.run,
613+
id=f"workflow_{uuid.uuid4()}",
614+
task_queue=task_queue,
615+
)
616+
617+
assert (
618+
result["user_id"] == "test-user-123"
619+
), "user.id baggage should propagate to activity"
620+
assert (
621+
result["tenant_id"] == "some-corp"
622+
), "tenant.id baggage should propagate to activity"
623+
624+
625+
@activity.defn
626+
async def read_baggage_local_activity() -> Dict[str, str]:
627+
return cast(
628+
Dict[str, str],
629+
{
630+
"user_id": get_baggage_value("user.id"),
631+
"tenant_id": get_baggage_value("tenant.id"),
632+
},
633+
)
634+
635+
636+
@workflow.defn
637+
class LocalActivityBaggageTestWorkflow:
638+
@workflow.run
639+
async def run(self) -> Dict[str, str]:
640+
return await workflow.execute_local_activity(
641+
read_baggage_local_activity,
642+
start_to_close_timeout=timedelta(seconds=10),
643+
)
644+
645+
646+
async def test_opentelemetry_baggage_propagation_local_activity(
647+
client_with_tracing: Client, env: WorkflowEnvironment
648+
):
649+
task_queue = f"task_queue_{uuid.uuid4()}"
650+
async with Worker(
651+
client_with_tracing,
652+
task_queue=task_queue,
653+
workflows=[LocalActivityBaggageTestWorkflow],
654+
activities=[read_baggage_local_activity],
655+
):
656+
with baggage_values(
657+
{
658+
"user.id": "test-user-456",
659+
"tenant.id": "local-corp",
660+
}
661+
):
662+
result = await client_with_tracing.execute_workflow(
663+
LocalActivityBaggageTestWorkflow.run,
664+
id=f"workflow_{uuid.uuid4()}",
665+
task_queue=task_queue,
666+
)
667+
668+
assert result["user_id"] == "test-user-456"
669+
assert result["tenant_id"] == "local-corp"
670+
671+
672+
retry_attempt_baggage_values: List[str] = []
673+
674+
675+
@activity.defn
676+
async def failing_baggage_activity() -> None:
677+
retry_attempt_baggage_values.append(get_baggage_value("user.id"))
678+
if activity.info().attempt < 2:
679+
raise RuntimeError("Intentional failure")
680+
681+
682+
@workflow.defn
683+
class RetryBaggageTestWorkflow:
684+
@workflow.run
685+
async def run(self) -> None:
686+
await workflow.execute_activity(
687+
failing_baggage_activity,
688+
start_to_close_timeout=timedelta(seconds=10),
689+
retry_policy=RetryPolicy(initial_interval=timedelta(milliseconds=1)),
690+
)
691+
692+
693+
async def test_opentelemetry_baggage_propagation_with_retries(
694+
client_with_tracing: Client, env: WorkflowEnvironment
695+
) -> None:
696+
global retry_attempt_baggage_values
697+
retry_attempt_baggage_values = []
698+
699+
task_queue = f"task_queue_{uuid.uuid4()}"
700+
async with Worker(
701+
client_with_tracing,
702+
task_queue=task_queue,
703+
workflows=[RetryBaggageTestWorkflow],
704+
activities=[failing_baggage_activity],
705+
):
706+
with baggage_values({"user.id": "test-user-retry"}):
707+
await client_with_tracing.execute_workflow(
708+
RetryBaggageTestWorkflow.run,
709+
id=f"workflow_{uuid.uuid4()}",
710+
task_queue=task_queue,
711+
)
712+
713+
# Verify baggage was present on all attempts
714+
assert len(retry_attempt_baggage_values) == 2
715+
assert all(v == "test-user-retry" for v in retry_attempt_baggage_values)
716+
717+
718+
@activity.defn
719+
async def context_clear_noop_activity() -> None:
720+
pass
721+
722+
723+
@activity.defn
724+
async def context_clear_exception_activity() -> None:
725+
raise Exception("Simulated exception")
726+
727+
728+
@workflow.defn
729+
class ContextClearWorkflow:
730+
@workflow.run
731+
async def run(self) -> None:
732+
await workflow.execute_activity(
733+
context_clear_noop_activity,
734+
start_to_close_timeout=timedelta(seconds=10),
735+
retry_policy=RetryPolicy(
736+
maximum_attempts=1, initial_interval=timedelta(milliseconds=1)
737+
),
738+
)
739+
740+
741+
@pytest.mark.parametrize(
742+
"activity,expect_failure",
743+
[
744+
(context_clear_noop_activity, not True),
745+
(context_clear_exception_activity, True),
746+
],
747+
)
748+
async def test_opentelemetry_context_restored_after_activity(
749+
client_with_tracing: Client,
750+
env: WorkflowEnvironment,
751+
activity: Callable[[], None],
752+
expect_failure: bool,
753+
) -> None:
754+
attach_count = 0
755+
detach_count = 0
756+
original_attach = context.attach
757+
original_detach = context.detach
758+
759+
def tracked_attach(ctx):
760+
nonlocal attach_count
761+
attach_count += 1
762+
return original_attach(ctx)
763+
764+
def tracked_detach(token):
765+
nonlocal detach_count
766+
detach_count += 1
767+
return original_detach(token)
768+
769+
context.attach = tracked_attach
770+
context.detach = tracked_detach
771+
772+
try:
773+
task_queue = f"task_queue_{uuid.uuid4()}"
774+
async with Worker(
775+
client_with_tracing,
776+
task_queue=task_queue,
777+
workflows=[ContextClearWorkflow],
778+
activities=[activity],
779+
):
780+
with baggage_values({"user.id": "test-123"}):
781+
try:
782+
await client_with_tracing.execute_workflow(
783+
ContextClearWorkflow.run,
784+
id=f"workflow_{uuid.uuid4()}",
785+
task_queue=task_queue,
786+
)
787+
assert (
788+
not expect_failure
789+
), "This test should have raised an exception"
790+
except Exception:
791+
assert expect_failure, "This test is not expeced to raise"
792+
793+
assert (
794+
attach_count == detach_count
795+
), f"Context leak detected: {attach_count} attaches vs {detach_count} detaches. "
796+
assert attach_count > 0, "Expected at least one context attach/detach"
797+
798+
finally:
799+
context.attach = original_attach
800+
context.detach = original_detach
801+
802+
803+
@activity.defn
804+
async def simple_no_context_activity() -> str:
805+
return "success"
806+
807+
808+
@workflow.defn
809+
class SimpleNoContextWorkflow:
810+
@workflow.run
811+
async def run(self) -> str:
812+
return await workflow.execute_activity(
813+
simple_no_context_activity,
814+
start_to_close_timeout=timedelta(seconds=10),
815+
)
816+
817+
818+
async def test_opentelemetry_interceptor_works_if_no_context(
819+
client_with_tracing: Client, env: WorkflowEnvironment
820+
):
821+
task_queue = f"task_queue_{uuid.uuid4()}"
822+
async with Worker(
823+
client_with_tracing,
824+
task_queue=task_queue,
825+
workflows=[SimpleNoContextWorkflow],
826+
activities=[simple_no_context_activity],
827+
):
828+
result = await client_with_tracing.execute_workflow(
829+
SimpleNoContextWorkflow.run,
830+
id=f"workflow_{uuid.uuid4()}",
831+
task_queue=task_queue,
832+
)
833+
834+
assert result == "success"
835+
836+
561837
# TODO(cretz): Additional tests to write
562838
# * query without interceptor (no headers)
563839
# * workflow without interceptor (no headers) but query with interceptor (headers)

0 commit comments

Comments
 (0)