Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions temporalio/contrib/opentelemetry/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@
_ContextT = TypeVar("_ContextT", bound=nexusrpc.handler.OperationContext)


def _restore_context(previous: Context | None, attached: Context | None) -> None:
"""Restore ``previous`` as the current OTel context after attaching one.

Restores by re-attaching ``previous`` rather than detaching the attach
token. ``opentelemetry.context.detach`` resets a ``contextvars`` Token, which
fails (and logs "Failed to detach context") when the restore runs in a
different ``contextvars.Context`` than the attach did -- e.g. when the
workflow event loop resumes inside ``contextvars.copy_context().run(...)``.
A copy preserves the OTel context value (so the guard below still matches)
but invalidates the Token. ``attach`` goes through ``ContextVar.set``, which
is valid in any ``contextvars.Context``, so restoration never fails.

The guard restores only when our attached context is still current, so we
don't clobber a context something else has since attached. A ``None``
``previous`` means nothing was attached, so there is nothing to restore.
"""
if previous is not None and attached is opentelemetry.context.get_current():
opentelemetry.context.attach(previous)


class TracingInterceptor(temporalio.client.Interceptor, temporalio.worker.Interceptor):
"""Interceptor that supports client and worker OpenTelemetry span creation
and propagation.
Expand Down Expand Up @@ -182,7 +202,9 @@ def _start_as_current_span(
kind: opentelemetry.trace.SpanKind,
context: Context | None = None,
) -> Iterator[None]:
token = opentelemetry.context.attach(context) if context else None
previous = opentelemetry.context.get_current() if context else None
if context:
opentelemetry.context.attach(context)
try:
with self.tracer.start_as_current_span(
name,
Expand Down Expand Up @@ -219,8 +241,7 @@ def _start_as_current_span(
)
raise
finally:
if token and context is opentelemetry.context.get_current():
opentelemetry.context.detach(token)
_restore_context(previous, context)

def _completed_workflow_span(
self, params: _CompletedWorkflowSpanParams
Expand Down Expand Up @@ -551,7 +572,8 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
# We need to put this interceptor on the context too
context = self._set_on_context(context)
# Run under context with new span
token = opentelemetry.context.attach(context)
previous = opentelemetry.context.get_current()
opentelemetry.context.attach(context)
try:
# This won't be created if there was no context header
self._completed_span(
Expand All @@ -563,12 +585,7 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
)
return await super().handle_query(input)
finally:
# In some exceptional cases this finally is executed with a
# different contextvars.Context than the one the token was created
# on. As such we do a best effort detach to avoid using a mismatched
# token.
if context is opentelemetry.context.get_current():
opentelemetry.context.detach(token)
_restore_context(previous, context)

def handle_update_validator(
self, input: temporalio.worker.HandleUpdateInput
Expand Down Expand Up @@ -639,7 +656,8 @@ def _top_level_workflow_context(
success = False
exception: Exception | None = None
# Run under this context
token = opentelemetry.context.attach(context)
previous = opentelemetry.context.get_current()
opentelemetry.context.attach(context)

try:
yield None
Expand All @@ -650,20 +668,15 @@ def _top_level_workflow_context(
exception = err
raise
finally:
# Create a completed span before detaching context
# Create a completed span before restoring context
if exception or (success and success_is_complete):
self._completed_span(
f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}",
exception=exception,
kind=opentelemetry.trace.SpanKind.INTERNAL,
)

# In some exceptional cases this finally is executed with a
# different contextvars.Context than the one the token was created
# on. As such we do a best effort detach to avoid using a mismatched
# token.
if context is opentelemetry.context.get_current():
opentelemetry.context.detach(token)
_restore_context(previous, context)

def _context_to_headers(
self, headers: Mapping[str, temporalio.api.common.v1.Payload]
Expand Down
90 changes: 25 additions & 65 deletions tests/contrib/opentelemetry/test_opentelemetry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from __future__ import annotations

import asyncio
import gc
import logging
import queue
import threading
import uuid
from collections.abc import Callable, Generator, Iterable
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -14,9 +11,9 @@
from typing import Any, cast

import nexusrpc
import opentelemetry.context
import pytest
from opentelemetry import baggage, context
from opentelemetry.context import Context
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
Expand All @@ -27,7 +24,6 @@
from temporalio.common import RetryPolicy, WorkflowIDConflictPolicy
from temporalio.contrib.opentelemetry import (
TracingInterceptor,
TracingWorkflowInboundInterceptor,
)
from temporalio.contrib.opentelemetry import workflow as otel_workflow
from temporalio.exceptions import (
Expand All @@ -37,7 +33,6 @@
)
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
from tests.helpers import LogCapturer
from tests.helpers.nexus import make_nexus_endpoint_name


Expand Down Expand Up @@ -853,20 +848,31 @@ async def test_opentelemetry_context_restored_after_activity(
activity: Callable[[], None],
expect_failure: bool,
) -> None:
attach_count = 0
detach_count = 0
original_attach = context.attach
original_detach = context.detach
baseline = context.get_current()

# Model OTel's single current-context ContextVar across every attach/detach
# so we can assert it returns to the baseline. The interceptor restores by
# re-attaching the previous context rather than detaching a token, so a
# token-pairing count no longer balances; the real leak invariant is that
# the current context is restored to where it started.
attach_count = 0
modeled_current = baseline
previous_by_token: dict[int, Context] = {}

def tracked_attach(ctx): # type:ignore[reportMissingParameterType]
nonlocal attach_count
def tracked_attach(context: Context) -> Any:
nonlocal attach_count, modeled_current
attach_count += 1
return original_attach(ctx)
token = original_attach(context)
previous_by_token[id(token)] = modeled_current
modeled_current = context
return token

def tracked_detach(token): # type:ignore[reportMissingParameterType]
nonlocal detach_count
detach_count += 1
return original_detach(token)
def tracked_detach(token: Any) -> None:
nonlocal modeled_current
modeled_current = previous_by_token.pop(id(token), baseline)
original_detach(token)

context.attach = tracked_attach
context.detach = tracked_detach
Expand All @@ -892,10 +898,11 @@ def tracked_detach(token): # type:ignore[reportMissingParameterType]
except Exception:
assert expect_failure, "This test is not expeced to raise"

assert attach_count == detach_count, (
f"Context leak detected: {attach_count} attaches vs {detach_count} detaches. "
assert attach_count > 0, "Expected at least one context attach"
assert modeled_current == baseline, (
"Context leak detected: current context was not restored to baseline "
f"(modeled current={modeled_current!r}, baseline={baseline!r})"
)
assert attach_count > 0, "Expected at least one context attach/detach"

finally:
context.attach = original_attach
Expand Down Expand Up @@ -986,50 +993,3 @@ async def test_opentelemetry_standalone_activity_tracing(
assert start_activity_span.attributes is not None
assert start_activity_span.attributes["temporalActivityID"] == activity_id
assert start_activity_span.attributes["temporalActivityType"] == "tracing_activity"


def test_opentelemetry_safe_detach():
class _fake_self:
def _load_workflow_context_carrier(*_args):
return None

def _set_on_context(self, ctx: Any):
return opentelemetry.context.set_value("test-key", "test-value", ctx)

def _completed_span(*args: Any, **_kwargs: Any):
pass

# create a context manager and force enter to happen on this thread
context_manager = TracingWorkflowInboundInterceptor._top_level_workflow_context(
_fake_self(), # type: ignore
success_is_complete=True,
)
context_manager.__enter__()

# move reference to context manager into queue
q: queue.Queue = queue.Queue()
q.put(context_manager)
del context_manager

def worker():
# pull reference from queue and delete the last reference
context_manager = q.get()
del context_manager
# force gc
gc.collect()

with LogCapturer().logs_captured(opentelemetry.context.logger) as capturer:
# run forced gc on other thread so exit happens there
t = threading.Thread(target=worker)
t.start()
t.join(timeout=5)

def otel_context_error(record: logging.LogRecord) -> bool:
return (
record.name == "opentelemetry.context"
and "Failed to detach context" in record.message
)

assert capturer.find(otel_context_error) is None, (
"Detach from context message should not be logged"
)
Loading