Skip to content

Commit fc2a2e6

Browse files
committed
ensure that the context used to detach the token is the same as what was used to attach it
1 parent cde3427 commit fc2a2e6

File tree

2 files changed

+95
-14
lines changed

2 files changed

+95
-14
lines changed

temporalio/contrib/opentelemetry.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -423,21 +423,21 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
423423
# We do not put anything that happens in a query handler on the workflow
424424
# span.
425425
context_header = input.headers.get(self.header_key)
426-
context: opentelemetry.context.Context
426+
attach_context: opentelemetry.context.Context
427427
link_context_carrier: Optional[_CarrierDict] = None
428428
if context_header:
429429
context_carrier = self.payload_converter.from_payloads([context_header])[0]
430-
context = self.text_map_propagator.extract(context_carrier)
430+
attach_context = self.text_map_propagator.extract(context_carrier)
431431
# If there is a workflow span, use it as the link
432432
link_context_carrier = self._load_workflow_context_carrier()
433433
else:
434434
# Use an empty context
435-
context = opentelemetry.context.Context()
435+
attach_context = opentelemetry.context.Context()
436436

437437
# We need to put this interceptor on the context too
438-
context = self._set_on_context(context)
438+
attach_context = self._set_on_context(attach_context)
439439
# Run under context with new span
440-
token = opentelemetry.context.attach(context)
440+
token = opentelemetry.context.attach(attach_context)
441441
try:
442442
# This won't be created if there was no context header
443443
self._completed_span(
@@ -449,7 +449,9 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
449449
)
450450
return await super().handle_query(input)
451451
finally:
452-
opentelemetry.context.detach(token)
452+
detach_context = opentelemetry.context.get_current()
453+
if detach_context is attach_context:
454+
opentelemetry.context.detach(token)
453455

454456
def handle_update_validator(
455457
self, input: temporalio.worker.HandleUpdateInput
@@ -508,19 +510,20 @@ def _top_level_workflow_context(
508510
) -> Iterator[None]:
509511
# Load context only if there is a carrier, otherwise use empty context
510512
context_carrier = self._load_workflow_context_carrier()
511-
context: opentelemetry.context.Context
513+
attach_context: opentelemetry.context.Context
512514
if context_carrier:
513-
context = self.text_map_propagator.extract(context_carrier)
515+
attach_context = self.text_map_propagator.extract(context_carrier)
514516
else:
515-
context = opentelemetry.context.Context()
517+
attach_context = opentelemetry.context.Context()
516518
# We need to put this interceptor on the context too
517-
context = self._set_on_context(context)
519+
attach_context = self._set_on_context(attach_context)
518520
# Need to know whether completed and whether there was a fail-workflow
519521
# exception
520522
success = False
521523
exception: Optional[Exception] = None
522524
# Run under this context
523-
token = opentelemetry.context.attach(context)
525+
token = opentelemetry.context.attach(attach_context)
526+
524527
try:
525528
yield None
526529
success = True
@@ -537,8 +540,12 @@ def _top_level_workflow_context(
537540
exception=exception,
538541
kind=opentelemetry.trace.SpanKind.INTERNAL,
539542
)
540-
opentelemetry.context.detach(token)
541543

544+
detach_context = opentelemetry.context.get_current()
545+
if detach_context is attach_context:
546+
opentelemetry.context.detach(token)
547+
548+
#
542549
def _context_to_headers(
543550
self, headers: Mapping[str, temporalio.api.common.v1.Payload]
544551
) -> Mapping[str, temporalio.api.common.v1.Payload]:

tests/contrib/test_opentelemetry.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import logging
5+
import sys
56
import uuid
67
from concurrent.futures import ThreadPoolExecutor
78
from dataclasses import dataclass
@@ -21,6 +22,11 @@
2122
from temporalio.exceptions import ApplicationError, ApplicationErrorCategory
2223
from temporalio.testing import WorkflowEnvironment
2324
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
25+
from tests.worker.test_workflow import (
26+
CacheEvictionTearDownWorkflow,
27+
WaitForeverWorkflow,
28+
wait_forever_activity,
29+
)
2430

2531
# Passing through because Python 3.9 has an import bug at
2632
# https://github.com/python/cpython/issues/91351
@@ -321,7 +327,10 @@ def dump_spans(
321327
span_links: List[str] = []
322328
for link in span.links:
323329
for link_span in spans:
324-
if link_span.context.span_id == link.context.span_id:
330+
if (
331+
link_span.context is not None
332+
and link_span.context.span_id == link.context.span_id
333+
):
325334
span_links.append(link_span.name)
326335
span_str += f" (links: {', '.join(span_links)})"
327336
# Signals can duplicate in rare situations, so we make sure not to
@@ -331,7 +340,7 @@ def dump_spans(
331340
ret.append(span_str)
332341
ret += dump_spans(
333342
spans,
334-
parent_id=span.context.span_id,
343+
parent_id=span.context.span_id if span.context else None,
335344
with_attributes=with_attributes,
336345
indent_depth=indent_depth + 1,
337346
)
@@ -448,3 +457,68 @@ async def test_opentelemetry_benign_exception(client: Client):
448457
# * workflow failure and wft failure
449458
# * signal with start
450459
# * signal failure and wft failure from signal
460+
461+
462+
async def test_opentelemetry_safe_detach(client: Client):
463+
# This test simulates forcing eviction. This purposely raises GeneratorExit on
464+
# GC which triggers the finally which could run on any thread Python
465+
# chooses. When this occurs, we should not detach the token from the context
466+
# b/c the context no longer exists
467+
468+
# Create a tracer that has an in-memory exporter
469+
exporter = InMemorySpanExporter()
470+
provider = TracerProvider()
471+
provider.add_span_processor(SimpleSpanProcessor(exporter))
472+
tracer = get_tracer(__name__, tracer_provider=provider)
473+
474+
class _OtelLogSpy(logging.Handler):
475+
def __init__(self, level: int | str = 0) -> None:
476+
self.seenOtelFailedMessage = False
477+
super().__init__(level)
478+
479+
def emit(self, record: logging.LogRecord) -> None:
480+
if not self.seenOtelFailedMessage:
481+
self.seenOtelFailedMessage = (
482+
record.levelno == logging.ERROR
483+
and record.name == "opentelemetry.context"
484+
and record.message == "Failed to detach context"
485+
)
486+
487+
async with Worker(
488+
client,
489+
workflows=[CacheEvictionTearDownWorkflow, WaitForeverWorkflow],
490+
activities=[wait_forever_activity],
491+
max_cached_workflows=0,
492+
task_queue=f"task_queue_{uuid.uuid4()}",
493+
disable_safe_workflow_eviction=True,
494+
interceptors=[TracingInterceptor(tracer)],
495+
) as worker:
496+
# Put a hook to catch unraisable exceptions
497+
old_hook = sys.unraisablehook
498+
hook_calls: List[sys.UnraisableHookArgs] = []
499+
sys.unraisablehook = hook_calls.append
500+
log_spy = _OtelLogSpy()
501+
logging.getLogger().addHandler(log_spy)
502+
try:
503+
handle = await client.start_workflow(
504+
CacheEvictionTearDownWorkflow.run,
505+
id=f"wf-{uuid.uuid4()}",
506+
task_queue=worker.task_queue,
507+
)
508+
509+
# CacheEvictionTearDownWorkflow requires 3 signals to be sent
510+
await handle.signal(CacheEvictionTearDownWorkflow.signal)
511+
await handle.signal(CacheEvictionTearDownWorkflow.signal)
512+
await handle.signal(CacheEvictionTearDownWorkflow.signal)
513+
514+
await handle.result()
515+
finally:
516+
sys.unraisablehook = old_hook
517+
logging.getLogger().removeHandler(log_spy)
518+
519+
# Confirm at least 1 exception
520+
assert hook_calls
521+
522+
assert (
523+
not log_spy.seenOtelFailedMessage
524+
), "Detach from context message should not be logged"

0 commit comments

Comments
 (0)