From fc2a2e61c93c245b7a3084e2551abe853768adae Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 10 Oct 2025 19:30:27 -0700 Subject: [PATCH 01/21] ensure that the context used to detach the token is the same as what was used to attach it --- temporalio/contrib/opentelemetry.py | 31 +++++++----- tests/contrib/test_opentelemetry.py | 78 ++++++++++++++++++++++++++++- 2 files changed, 95 insertions(+), 14 deletions(-) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 04d40d544..a8abf39d3 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -423,21 +423,21 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: # We do not put anything that happens in a query handler on the workflow # span. context_header = input.headers.get(self.header_key) - context: opentelemetry.context.Context + attach_context: opentelemetry.context.Context link_context_carrier: Optional[_CarrierDict] = None if context_header: context_carrier = self.payload_converter.from_payloads([context_header])[0] - context = self.text_map_propagator.extract(context_carrier) + attach_context = self.text_map_propagator.extract(context_carrier) # If there is a workflow span, use it as the link link_context_carrier = self._load_workflow_context_carrier() else: # Use an empty context - context = opentelemetry.context.Context() + attach_context = opentelemetry.context.Context() # We need to put this interceptor on the context too - context = self._set_on_context(context) + attach_context = self._set_on_context(attach_context) # Run under context with new span - token = opentelemetry.context.attach(context) + token = opentelemetry.context.attach(attach_context) try: # This won't be created if there was no context header self._completed_span( @@ -449,7 +449,9 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: ) return await super().handle_query(input) finally: - opentelemetry.context.detach(token) + detach_context = opentelemetry.context.get_current() + if detach_context is attach_context: + opentelemetry.context.detach(token) def handle_update_validator( self, input: temporalio.worker.HandleUpdateInput @@ -508,19 +510,20 @@ def _top_level_workflow_context( ) -> Iterator[None]: # Load context only if there is a carrier, otherwise use empty context context_carrier = self._load_workflow_context_carrier() - context: opentelemetry.context.Context + attach_context: opentelemetry.context.Context if context_carrier: - context = self.text_map_propagator.extract(context_carrier) + attach_context = self.text_map_propagator.extract(context_carrier) else: - context = opentelemetry.context.Context() + attach_context = opentelemetry.context.Context() # We need to put this interceptor on the context too - context = self._set_on_context(context) + attach_context = self._set_on_context(attach_context) # Need to know whether completed and whether there was a fail-workflow # exception success = False exception: Optional[Exception] = None # Run under this context - token = opentelemetry.context.attach(context) + token = opentelemetry.context.attach(attach_context) + try: yield None success = True @@ -537,8 +540,12 @@ def _top_level_workflow_context( exception=exception, kind=opentelemetry.trace.SpanKind.INTERNAL, ) - opentelemetry.context.detach(token) + detach_context = opentelemetry.context.get_current() + if detach_context is attach_context: + opentelemetry.context.detach(token) + + # def _context_to_headers( self, headers: Mapping[str, temporalio.api.common.v1.Payload] ) -> Mapping[str, temporalio.api.common.v1.Payload]: diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 0b797f606..a787458aa 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -2,6 +2,7 @@ import asyncio import logging +import sys import uuid from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -21,6 +22,11 @@ from temporalio.exceptions import ApplicationError, ApplicationErrorCategory from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker +from tests.worker.test_workflow import ( + CacheEvictionTearDownWorkflow, + WaitForeverWorkflow, + wait_forever_activity, +) # Passing through because Python 3.9 has an import bug at # https://github.com/python/cpython/issues/91351 @@ -321,7 +327,10 @@ def dump_spans( span_links: List[str] = [] for link in span.links: for link_span in spans: - if link_span.context.span_id == link.context.span_id: + if ( + link_span.context is not None + and link_span.context.span_id == link.context.span_id + ): span_links.append(link_span.name) span_str += f" (links: {', '.join(span_links)})" # Signals can duplicate in rare situations, so we make sure not to @@ -331,7 +340,7 @@ def dump_spans( ret.append(span_str) ret += dump_spans( spans, - parent_id=span.context.span_id, + parent_id=span.context.span_id if span.context else None, with_attributes=with_attributes, indent_depth=indent_depth + 1, ) @@ -448,3 +457,68 @@ async def test_opentelemetry_benign_exception(client: Client): # * workflow failure and wft failure # * signal with start # * signal failure and wft failure from signal + + +async def test_opentelemetry_safe_detach(client: Client): + # This test simulates forcing eviction. This purposely raises GeneratorExit on + # GC which triggers the finally which could run on any thread Python + # chooses. When this occurs, we should not detach the token from the context + # b/c the context no longer exists + + # Create a tracer that has an in-memory exporter + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = get_tracer(__name__, tracer_provider=provider) + + class _OtelLogSpy(logging.Handler): + def __init__(self, level: int | str = 0) -> None: + self.seenOtelFailedMessage = False + super().__init__(level) + + def emit(self, record: logging.LogRecord) -> None: + if not self.seenOtelFailedMessage: + self.seenOtelFailedMessage = ( + record.levelno == logging.ERROR + and record.name == "opentelemetry.context" + and record.message == "Failed to detach context" + ) + + async with Worker( + client, + workflows=[CacheEvictionTearDownWorkflow, WaitForeverWorkflow], + activities=[wait_forever_activity], + max_cached_workflows=0, + task_queue=f"task_queue_{uuid.uuid4()}", + disable_safe_workflow_eviction=True, + interceptors=[TracingInterceptor(tracer)], + ) as worker: + # Put a hook to catch unraisable exceptions + old_hook = sys.unraisablehook + hook_calls: List[sys.UnraisableHookArgs] = [] + sys.unraisablehook = hook_calls.append + log_spy = _OtelLogSpy() + logging.getLogger().addHandler(log_spy) + try: + handle = await client.start_workflow( + CacheEvictionTearDownWorkflow.run, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # CacheEvictionTearDownWorkflow requires 3 signals to be sent + await handle.signal(CacheEvictionTearDownWorkflow.signal) + await handle.signal(CacheEvictionTearDownWorkflow.signal) + await handle.signal(CacheEvictionTearDownWorkflow.signal) + + await handle.result() + finally: + sys.unraisablehook = old_hook + logging.getLogger().removeHandler(log_spy) + + # Confirm at least 1 exception + assert hook_calls + + assert ( + not log_spy.seenOtelFailedMessage + ), "Detach from context message should not be logged" From 7a515299d3def37664e01547530a04d69d4adcb9 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 13 Oct 2025 10:34:32 -0700 Subject: [PATCH 02/21] Update test to use LogCapturer helper. Move shared test workflow and helpers to test/helpers --- temporalio/contrib/opentelemetry.py | 9 +-- tests/contrib/test_opentelemetry.py | 65 +++++++++--------- tests/helpers/__init__.py | 49 +++++++++++++- tests/helpers/cache_evitction.py | 67 ++++++++++++++++++ tests/worker/test_workflow.py | 101 ++-------------------------- 5 files changed, 153 insertions(+), 138 deletions(-) create mode 100644 tests/helpers/cache_evitction.py diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index a8abf39d3..6bb6f5c79 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -26,8 +26,7 @@ import opentelemetry.trace.propagation.tracecontext import opentelemetry.util.types from opentelemetry.context import Context -from opentelemetry.trace import Span, SpanKind, Status, StatusCode, _Links -from opentelemetry.util import types +from opentelemetry.trace import Status, StatusCode from typing_extensions import Protocol, TypeAlias, TypedDict import temporalio.activity @@ -449,8 +448,7 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: ) return await super().handle_query(input) finally: - detach_context = opentelemetry.context.get_current() - if detach_context is attach_context: + if attach_context == opentelemetry.context.get_current(): opentelemetry.context.detach(token) def handle_update_validator( @@ -541,8 +539,7 @@ def _top_level_workflow_context( kind=opentelemetry.trace.SpanKind.INTERNAL, ) - detach_context = opentelemetry.context.get_current() - if detach_context is attach_context: + if attach_context == opentelemetry.context.get_current(): opentelemetry.context.detach(token) # diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index a787458aa..802c6b466 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -9,6 +9,7 @@ from datetime import timedelta from typing import Iterable, List, Optional +import opentelemetry.context from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter @@ -22,11 +23,12 @@ from temporalio.exceptions import ApplicationError, ApplicationErrorCategory from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker -from tests.worker.test_workflow import ( +from tests.helpers.cache_evitction import ( CacheEvictionTearDownWorkflow, WaitForeverWorkflow, wait_forever_activity, ) +from tests.helpers import LogCapturer # Passing through because Python 3.9 has an import bug at # https://github.com/python/cpython/issues/91351 @@ -471,19 +473,6 @@ async def test_opentelemetry_safe_detach(client: Client): provider.add_span_processor(SimpleSpanProcessor(exporter)) tracer = get_tracer(__name__, tracer_provider=provider) - class _OtelLogSpy(logging.Handler): - def __init__(self, level: int | str = 0) -> None: - self.seenOtelFailedMessage = False - super().__init__(level) - - def emit(self, record: logging.LogRecord) -> None: - if not self.seenOtelFailedMessage: - self.seenOtelFailedMessage = ( - record.levelno == logging.ERROR - and record.name == "opentelemetry.context" - and record.message == "Failed to detach context" - ) - async with Worker( client, workflows=[CacheEvictionTearDownWorkflow, WaitForeverWorkflow], @@ -497,28 +486,34 @@ def emit(self, record: logging.LogRecord) -> None: old_hook = sys.unraisablehook hook_calls: List[sys.UnraisableHookArgs] = [] sys.unraisablehook = hook_calls.append - log_spy = _OtelLogSpy() - logging.getLogger().addHandler(log_spy) - try: - handle = await client.start_workflow( - CacheEvictionTearDownWorkflow.run, - id=f"wf-{uuid.uuid4()}", - task_queue=worker.task_queue, - ) - # CacheEvictionTearDownWorkflow requires 3 signals to be sent - await handle.signal(CacheEvictionTearDownWorkflow.signal) - await handle.signal(CacheEvictionTearDownWorkflow.signal) - await handle.signal(CacheEvictionTearDownWorkflow.signal) + with LogCapturer().logs_captured(opentelemetry.context.logger) as capturer: + try: + handle = await client.start_workflow( + CacheEvictionTearDownWorkflow.run, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # CacheEvictionTearDownWorkflow requires 3 signals to be sent + await handle.signal(CacheEvictionTearDownWorkflow.signal) + await handle.signal(CacheEvictionTearDownWorkflow.signal) + await handle.signal(CacheEvictionTearDownWorkflow.signal) - await handle.result() - finally: - sys.unraisablehook = old_hook - logging.getLogger().removeHandler(log_spy) + await handle.result() + finally: + sys.unraisablehook = old_hook - # Confirm at least 1 exception - assert hook_calls + # Confirm at least 1 exception + assert hook_calls + + def otel_context_error(record: logging.LogRecord) -> bool: + return ( + record.levelno == logging.ERROR + and record.name == "opentelemetry.context" + and record.message == "Failed to detach context" + ) - assert ( - not log_spy.seenOtelFailedMessage - ), "Detach from context message should not be logged" + assert ( + capturer.find(otel_context_error) is None + ), "Detach from context message should not be logged" diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index 79d3687fd..4a3850024 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -1,11 +1,25 @@ import asyncio +import logging +import logging.handlers +import queue import socket import time import uuid -from contextlib import closing +from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar, Union +from typing import ( + Any, + Awaitable, + Callable, + List, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, +) from temporalio.api.common.v1 import WorkflowExecution from temporalio.api.enums.v1 import EventType as EventType @@ -401,3 +415,34 @@ def _format_row(items: list[str], truncate: bool = False) -> str: padding = len(f" *: {elapsed_ms:>4} ") summary_row[col_idx] = f"{' ' * padding}[{summary}]"[: col_width - 3] print(_format_row(summary_row)) + + +class LogCapturer: + def __init__(self) -> None: + self.log_queue: queue.Queue[logging.LogRecord] = queue.Queue() + + @contextmanager + def logs_captured(self, *loggers: logging.Logger): + handler = logging.handlers.QueueHandler(self.log_queue) + + prev_levels = [l.level for l in loggers] + for l in loggers: + l.setLevel(logging.INFO) + l.addHandler(handler) + try: + yield self + finally: + for i, l in enumerate(loggers): + l.removeHandler(handler) + l.setLevel(prev_levels[i]) + + def find_log(self, starts_with: str) -> Optional[logging.LogRecord]: + return self.find(lambda l: l.message.startswith(starts_with)) + + def find( + self, pred: Callable[[logging.LogRecord], bool] + ) -> Optional[logging.LogRecord]: + for record in cast(List[logging.LogRecord], self.log_queue.queue): + if pred(record): + return record + return None diff --git a/tests/helpers/cache_evitction.py b/tests/helpers/cache_evitction.py new file mode 100644 index 000000000..654655ab2 --- /dev/null +++ b/tests/helpers/cache_evitction.py @@ -0,0 +1,67 @@ +import asyncio +from datetime import timedelta +from temporalio import activity, workflow + + +@activity.defn +async def wait_forever_activity() -> None: + await asyncio.Future() + + +@workflow.defn +class WaitForeverWorkflow: + @workflow.run + async def run(self) -> None: + await asyncio.Future() + + +@workflow.defn +class CacheEvictionTearDownWorkflow: + def __init__(self) -> None: + self._signal_count = 0 + + @workflow.run + async def run(self) -> None: + # Start several things in background. This is just to show that eviction + # can work even with these things running. + tasks = [ + asyncio.create_task( + workflow.execute_activity( + wait_forever_activity, start_to_close_timeout=timedelta(hours=1) + ) + ), + asyncio.create_task( + workflow.execute_child_workflow(WaitForeverWorkflow.run) + ), + asyncio.create_task(asyncio.sleep(1000)), + asyncio.shield( + workflow.execute_activity( + wait_forever_activity, start_to_close_timeout=timedelta(hours=1) + ) + ), + asyncio.create_task(workflow.wait_condition(lambda: False)), + ] + gather_fut = asyncio.gather(*tasks, return_exceptions=True) + # Let's also start something in the background that we never wait on + asyncio.create_task(asyncio.sleep(1000)) + try: + # Wait for signal count to reach 2 + await asyncio.sleep(0.01) + await workflow.wait_condition(lambda: self._signal_count > 1) + finally: + # This finally, on eviction, is actually called but the command + # should be ignored + await asyncio.sleep(0.01) + await workflow.wait_condition(lambda: self._signal_count > 2) + # Cancel gather tasks and wait on them, but ignore the errors + for task in tasks: + task.cancel() + await gather_fut + + @workflow.signal + async def signal(self) -> None: + self._signal_count += 1 + + @workflow.query + def signal_count(self) -> int: + return self._signal_count diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index a987d1b34..acb8c8012 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -143,11 +143,17 @@ pause_and_assert, unpause_and_assert, workflow_update_exists, + LogCapturer, ) from tests.helpers.external_stack_trace import ( ExternalStackTraceWorkflow, external_wait_cancel, ) +from tests.helpers.cache_evitction import ( + CacheEvictionTearDownWorkflow, + WaitForeverWorkflow, + wait_forever_activity, +) # Passing through because Python 3.9 has an import bug at # https://github.com/python/cpython/issues/91351 @@ -1996,37 +2002,6 @@ def last_signal(self) -> str: return self._last_signal -class LogCapturer: - def __init__(self) -> None: - self.log_queue: queue.Queue[logging.LogRecord] = queue.Queue() - - @contextmanager - def logs_captured(self, *loggers: logging.Logger): - handler = logging.handlers.QueueHandler(self.log_queue) - - prev_levels = [l.level for l in loggers] - for l in loggers: - l.setLevel(logging.INFO) - l.addHandler(handler) - try: - yield self - finally: - for i, l in enumerate(loggers): - l.removeHandler(handler) - l.setLevel(prev_levels[i]) - - def find_log(self, starts_with: str) -> Optional[logging.LogRecord]: - return self.find(lambda l: l.message.startswith(starts_with)) - - def find( - self, pred: Callable[[logging.LogRecord], bool] - ) -> Optional[logging.LogRecord]: - for record in cast(List[logging.LogRecord], self.log_queue.queue): - if pred(record): - return record - return None - - async def test_workflow_logging(client: Client, env: WorkflowEnvironment): workflow.logger.full_workflow_info_on_extra = True with LogCapturer().logs_captured( @@ -3739,70 +3714,6 @@ async def test_manual_result_type(client: Client): assert res4 == ManualResultType(some_string="from-query") -@activity.defn -async def wait_forever_activity() -> None: - await asyncio.Future() - - -@workflow.defn -class WaitForeverWorkflow: - @workflow.run - async def run(self) -> None: - await asyncio.Future() - - -@workflow.defn -class CacheEvictionTearDownWorkflow: - def __init__(self) -> None: - self._signal_count = 0 - - @workflow.run - async def run(self) -> None: - # Start several things in background. This is just to show that eviction - # can work even with these things running. - tasks = [ - asyncio.create_task( - workflow.execute_activity( - wait_forever_activity, start_to_close_timeout=timedelta(hours=1) - ) - ), - asyncio.create_task( - workflow.execute_child_workflow(WaitForeverWorkflow.run) - ), - asyncio.create_task(asyncio.sleep(1000)), - asyncio.shield( - workflow.execute_activity( - wait_forever_activity, start_to_close_timeout=timedelta(hours=1) - ) - ), - asyncio.create_task(workflow.wait_condition(lambda: False)), - ] - gather_fut = asyncio.gather(*tasks, return_exceptions=True) - # Let's also start something in the background that we never wait on - asyncio.create_task(asyncio.sleep(1000)) - try: - # Wait for signal count to reach 2 - await asyncio.sleep(0.01) - await workflow.wait_condition(lambda: self._signal_count > 1) - finally: - # This finally, on eviction, is actually called but the command - # should be ignored - await asyncio.sleep(0.01) - await workflow.wait_condition(lambda: self._signal_count > 2) - # Cancel gather tasks and wait on them, but ignore the errors - for task in tasks: - task.cancel() - await gather_fut - - @workflow.signal - async def signal(self) -> None: - self._signal_count += 1 - - @workflow.query - def signal_count(self) -> int: - return self._signal_count - - async def test_cache_eviction_tear_down(client: Client): # This test simulates forcing eviction. This used to raise GeneratorExit on # GC which triggered the finally which could run on any thread Python From 1b8d9e17a70f2aefe02ab364ba8d0fece65de59b Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 13 Oct 2025 10:47:27 -0700 Subject: [PATCH 03/21] run formatter --- tests/contrib/test_opentelemetry.py | 2 +- tests/helpers/cache_evitction.py | 1 + tests/worker/test_workflow.py | 10 +++++----- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 802c6b466..87bc2b1b7 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -23,12 +23,12 @@ from temporalio.exceptions import ApplicationError, ApplicationErrorCategory from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker +from tests.helpers import LogCapturer from tests.helpers.cache_evitction import ( CacheEvictionTearDownWorkflow, WaitForeverWorkflow, wait_forever_activity, ) -from tests.helpers import LogCapturer # Passing through because Python 3.9 has an import bug at # https://github.com/python/cpython/issues/91351 diff --git a/tests/helpers/cache_evitction.py b/tests/helpers/cache_evitction.py index 654655ab2..191d51078 100644 --- a/tests/helpers/cache_evitction.py +++ b/tests/helpers/cache_evitction.py @@ -1,5 +1,6 @@ import asyncio from datetime import timedelta + from temporalio import activity, workflow diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index acb8c8012..b76451d00 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -130,6 +130,7 @@ ) from tests import DEV_SERVER_DOWNLOAD_VERSION from tests.helpers import ( + LogCapturer, admitted_update_task, assert_eq_eventually, assert_eventually, @@ -143,17 +144,16 @@ pause_and_assert, unpause_and_assert, workflow_update_exists, - LogCapturer, -) -from tests.helpers.external_stack_trace import ( - ExternalStackTraceWorkflow, - external_wait_cancel, ) from tests.helpers.cache_evitction import ( CacheEvictionTearDownWorkflow, WaitForeverWorkflow, wait_forever_activity, ) +from tests.helpers.external_stack_trace import ( + ExternalStackTraceWorkflow, + external_wait_cancel, +) # Passing through because Python 3.9 has an import bug at # https://github.com/python/cpython/issues/91351 From b656e8435df0f6403ebd80963da8777cfc0e99b7 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 13 Oct 2025 16:06:32 -0700 Subject: [PATCH 04/21] Fix up test with log capturer. Only log warnings if there are no hook calls --- tests/contrib/test_opentelemetry.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 87bc2b1b7..c6e161f3e 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -505,15 +505,17 @@ async def test_opentelemetry_safe_detach(client: Client): sys.unraisablehook = old_hook # Confirm at least 1 exception - assert hook_calls + if len(hook_calls) < 1: + logging.warning( + "Expected at least 1 exception. Unable to properly verify context detachment" + ) def otel_context_error(record: logging.LogRecord) -> bool: return ( - record.levelno == logging.ERROR - and record.name == "opentelemetry.context" - and record.message == "Failed to detach context" + record.name == "opentelemetry.context" + and "Failed to detach context" in record.message ) - assert ( - capturer.find(otel_context_error) is None - ), "Detach from context message should not be logged" + assert capturer.find(otel_context_error) is None, ( + "Detach from context message should not be logged" + ) From 5abc5a439daa0f324c9d6e14d131293c6414e2ac Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 13 Oct 2025 16:08:36 -0700 Subject: [PATCH 05/21] run formatter --- tests/contrib/test_opentelemetry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index c6e161f3e..bae1323f0 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -516,6 +516,6 @@ def otel_context_error(record: logging.LogRecord) -> bool: and "Failed to detach context" in record.message ) - assert capturer.find(otel_context_error) is None, ( - "Detach from context message should not be logged" - ) + assert ( + capturer.find(otel_context_error) is None + ), "Detach from context message should not be logged" From d1bf879e9f50c7d8e172c5a329ea68b13948365d Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 14 Oct 2025 11:44:08 -0700 Subject: [PATCH 06/21] Move to a class based context manager to manually manage context --- temporalio/contrib/opentelemetry.py | 194 +++++++++++++++++++--------- temporalio/worker/_workflow.py | 3 + 2 files changed, 134 insertions(+), 63 deletions(-) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 6bb6f5c79..fae1470ce 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -2,11 +2,18 @@ from __future__ import annotations -from contextlib import contextmanager +import contextvars +import asyncio +from contextlib import ( + AbstractContextManager, + contextmanager, +) from dataclasses import dataclass +from types import TracebackType from typing import ( Any, Callable, + ContextManager, Dict, Iterator, Mapping, @@ -386,15 +393,30 @@ async def execute_workflow( """Implementation of :py:meth:`temporalio.worker.WorkflowInboundInterceptor.execute_workflow`. """ - with self._top_level_workflow_context(success_is_complete=True): - # Entrypoint of workflow should be `server` in OTel - self._completed_span( - f"RunWorkflow:{temporalio.workflow.info().workflow_type}", - kind=opentelemetry.trace.SpanKind.SERVER, - ) - return await super().execute_workflow(input) + with self._top_level_workflow_context(success_is_complete=True) as ctx: + return await ctx.run(asyncio.create_task, self._execute_workflow(input)) + + async def _execute_workflow( + self, input: temporalio.worker.ExecuteWorkflowInput + ) -> Any: + # Entrypoint of workflow should be `server` in OTel + self._completed_span( + f"RunWorkflow:{temporalio.workflow.info().workflow_type}", + kind=opentelemetry.trace.SpanKind.SERVER, + ) + # with self._with_complete_span(success_is_complete=True): + return await super().execute_workflow(input) async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: + """Implementation of + :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`. + """ + # Create a span in the current context for the signal and link any + # header given + with self._top_level_workflow_context(success_is_complete=False) as ctx: + return await ctx.run(asyncio.create_task, self._handle_signal(input)) + + async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: """Implementation of :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`. """ @@ -406,18 +428,22 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non link_context_carrier = self.payload_converter.from_payloads( [link_context_header] )[0] - with self._top_level_workflow_context(success_is_complete=False): - self._completed_span( - f"HandleSignal:{input.signal}", - link_context_carrier=link_context_carrier, - kind=opentelemetry.trace.SpanKind.SERVER, - ) - await super().handle_signal(input) + self._completed_span( + f"HandleSignal:{input.signal}", + link_context_carrier=link_context_carrier, + kind=opentelemetry.trace.SpanKind.SERVER, + ) + await super().handle_signal(input) async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: """Implementation of :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_query`. """ + # handle_query does not manage the contextvars.Context itself because + # scheduling an asyncio task in a read only operation is not allowed. + # The operation is synchronous which makes default contextvars.Context + # safe. + # Only trace this if there is a header, and make that span the parent. # We do not put anything that happens in a query handler on the workflow # span. @@ -448,11 +474,19 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: ) return await super().handle_query(input) finally: - if attach_context == opentelemetry.context.get_current(): - opentelemetry.context.detach(token) + opentelemetry.context.detach(token) def handle_update_validator( self, input: temporalio.worker.HandleUpdateInput + ) -> None: + """Implementation of + :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`. + """ + with self._top_level_workflow_context(success_is_complete=False) as ctx: + ctx.run(self._handle_update_validator, input) + + def _handle_update_validator( + self, input: temporalio.worker.HandleUpdateInput ) -> None: """Implementation of :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`. @@ -463,7 +497,6 @@ def handle_update_validator( link_context_carrier = self.payload_converter.from_payloads( [link_context_header] )[0] - with self._top_level_workflow_context(success_is_complete=False): self._completed_span( f"ValidateUpdate:{input.update}", link_context_carrier=link_context_carrier, @@ -473,6 +506,17 @@ def handle_update_validator( async def handle_update_handler( self, input: temporalio.worker.HandleUpdateInput + ) -> Any: + """Implementation of + :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`. + """ + with self._top_level_workflow_context(success_is_complete=False) as ctx: + return await ctx.run( + asyncio.create_task, self._handle_update_handler(input) + ) + + async def _handle_update_handler( + self, input: temporalio.worker.HandleUpdateInput ) -> Any: """Implementation of :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`. @@ -483,13 +527,12 @@ async def handle_update_handler( link_context_carrier = self.payload_converter.from_payloads( [link_context_header] )[0] - with self._top_level_workflow_context(success_is_complete=False): - self._completed_span( - f"HandleUpdate:{input.update}", - link_context_carrier=link_context_carrier, - kind=opentelemetry.trace.SpanKind.SERVER, - ) - return await super().handle_update_handler(input) + self._completed_span( + f"HandleUpdate:{input.update}", + link_context_carrier=link_context_carrier, + kind=opentelemetry.trace.SpanKind.SERVER, + ) + return await super().handle_update_handler(input) def _load_workflow_context_carrier(self) -> Optional[_CarrierDict]: if self._workflow_context_carrier: @@ -502,47 +545,13 @@ def _load_workflow_context_carrier(self) -> Optional[_CarrierDict]: )[0] return self._workflow_context_carrier - @contextmanager def _top_level_workflow_context( self, *, success_is_complete: bool - ) -> Iterator[None]: - # Load context only if there is a carrier, otherwise use empty context - context_carrier = self._load_workflow_context_carrier() - attach_context: opentelemetry.context.Context - if context_carrier: - attach_context = self.text_map_propagator.extract(context_carrier) - else: - attach_context = opentelemetry.context.Context() - # We need to put this interceptor on the context too - attach_context = self._set_on_context(attach_context) - # Need to know whether completed and whether there was a fail-workflow - # exception - success = False - exception: Optional[Exception] = None - # Run under this context - token = opentelemetry.context.attach(attach_context) - - try: - yield None - success = True - except temporalio.exceptions.FailureError as err: - # We only record the failure errors since those are the only ones - # that lead to workflow completions - exception = err - raise - finally: - # Create a completed span before detaching context - if exception or (success and success_is_complete): - self._completed_span( - f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}", - exception=exception, - kind=opentelemetry.trace.SpanKind.INTERNAL, - ) - - if attach_context == opentelemetry.context.get_current(): - opentelemetry.context.detach(token) + ) -> ContextManager[contextvars.Context]: + return self._TopLevelWorkflowContextManager( + self, success_is_complete=success_is_complete + ) - # def _context_to_headers( self, headers: Mapping[str, temporalio.api.common.v1.Payload] ) -> Mapping[str, temporalio.api.common.v1.Payload]: @@ -616,6 +625,65 @@ def _set_on_context( ) -> opentelemetry.context.Context: return opentelemetry.context.set_value(_interceptor_context_key, self, context) + class _TopLevelWorkflowContextManager(AbstractContextManager): + def __init__( + self, + interceptor: TracingWorkflowInboundInterceptor, + *, + success_is_complete: bool, + ): + self._ctx = contextvars.copy_context() + self._token = None + self._owner = interceptor + self._success_is_complete = success_is_complete + + def __enter__(self): + self._ctx.run(self._start) + return self._ctx + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], # noqa: F811 + ) -> bool | None: + self._ctx.run(self._end, exc_type, exc_value, traceback) + + def _start(self): + # Load context only if there is a carrier, otherwise use empty context + context_carrier = self._owner._load_workflow_context_carrier() + attach_context: opentelemetry.context.Context + if context_carrier: + attach_context = self._owner.text_map_propagator.extract( + context_carrier + ) + else: + attach_context = opentelemetry.context.Context() + # We need to put this interceptor on the context too + attach_context = self._owner._set_on_context(attach_context) + self._token = opentelemetry.context.attach(attach_context) + + def _end( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + _traceback: Optional[TracebackType], + ): + success = exc_type is None + exception: Optional[temporalio.exceptions.FailureError] = None + if isinstance(exc_value, temporalio.exceptions.FailureError): + exception = exc_value + + if (success and self._success_is_complete) or exception is not None: + self._owner._completed_span( + f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}", + exception=exception, + kind=opentelemetry.trace.SpanKind.INTERNAL, + ) + + if self._token: + opentelemetry.context.detach(self._token) + class _TracingWorkflowOutboundInterceptor( temporalio.worker.WorkflowOutboundInterceptor diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 6e7c254aa..75b0c9059 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -337,6 +337,9 @@ async def _handle_activation( raise deadlock_exc from None except Exception as err: + if isinstance(err, GeneratorExit): + print("###############generator exit#############") + if isinstance(err, _DeadlockError): err.swap_traceback() From 6ca781cdbdaeba7f9ea2251004f319c9cc1495bd Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 16 Oct 2025 11:07:01 -0700 Subject: [PATCH 07/21] apply context management to openai_agents tracing --- .../openai_agents/_trace_interceptor.py | 68 +++++++++++-------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/temporalio/contrib/openai_agents/_trace_interceptor.py b/temporalio/contrib/openai_agents/_trace_interceptor.py index 20d489b65..9f1874819 100644 --- a/temporalio/contrib/openai_agents/_trace_interceptor.py +++ b/temporalio/contrib/openai_agents/_trace_interceptor.py @@ -2,6 +2,8 @@ from __future__ import annotations +import asyncio +import contextvars import random import uuid from contextlib import contextmanager @@ -400,48 +402,58 @@ async def signal_external_workflow( def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: - trace = get_trace_provider().get_current_trace() - span: Optional[Span] = None - if trace: - span = custom_span( - name="temporal:startActivity", data={"activity": input.activity} - ) - span.start(mark_as_current=True) - - set_header_from_context(input, temporalio.workflow.payload_converter()) - handle = self.next.start_activity(input) + ctx = contextvars.copy_context() + span = ctx.run( + self._create_span, + name="temporal:startActivity", + data={"activity": input.activity}, + input=input, + ) + handle = ctx.run(self.next.start_activity, input) if span: - handle.add_done_callback(lambda _: span.finish()) # type: ignore + handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore return handle async def start_child_workflow( self, input: temporalio.worker.StartChildWorkflowInput ) -> temporalio.workflow.ChildWorkflowHandle: - trace = get_trace_provider().get_current_trace() - span: Optional[Span] = None - if trace: - span = custom_span( - name="temporal:startChildWorkflow", data={"workflow": input.workflow} - ) - span.start(mark_as_current=True) - set_header_from_context(input, temporalio.workflow.payload_converter()) - handle = await self.next.start_child_workflow(input) + ctx = contextvars.copy_context() + span = ctx.run( + self._create_span, + name="temporal:startChildWorkflow", + data={"workflow": input.workflow}, + input=input, + ) + handle = await ctx.run( + asyncio.create_task, self.next.start_child_workflow(input) + ) if span: - handle.add_done_callback(lambda _: span.finish()) # type: ignore + handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore return handle def start_local_activity( self, input: temporalio.worker.StartLocalActivityInput ) -> temporalio.workflow.ActivityHandle: + ctx = contextvars.copy_context() + span = ctx.run( + self._create_span, + name="temporal:startLocalActivity", + data={"activity": input.activity}, + input=input, + ) + handle = ctx.run(self.next.start_local_activity, input) + if span: + handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore + return handle + + def _create_span( + self, name: str, data: dict[str, Any], input: _InputWithHeaders + ) -> Optional[Span]: trace = get_trace_provider().get_current_trace() span: Optional[Span] = None if trace: - span = custom_span( - name="temporal:startLocalActivity", data={"activity": input.activity} - ) + span = custom_span(name=name, data=data) span.start(mark_as_current=True) + set_header_from_context(input, temporalio.workflow.payload_converter()) - handle = self.next.start_local_activity(input) - if span: - handle.add_done_callback(lambda _: span.finish()) # type: ignore - return handle + return span From 6060b3542e3d8f2c3ca2fba23e3ff3e5182af512 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 16 Oct 2025 11:30:48 -0700 Subject: [PATCH 08/21] use 3.11 create_task when possible to avoid an extra context copy --- .../openai_agents/_trace_interceptor.py | 12 ++++++--- temporalio/contrib/opentelemetry.py | 26 +++++++++++++++---- tests/nexus/test_handler.py | 10 ++++--- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/temporalio/contrib/openai_agents/_trace_interceptor.py b/temporalio/contrib/openai_agents/_trace_interceptor.py index 9f1874819..e8f357daf 100644 --- a/temporalio/contrib/openai_agents/_trace_interceptor.py +++ b/temporalio/contrib/openai_agents/_trace_interceptor.py @@ -5,6 +5,7 @@ import asyncio import contextvars import random +import sys import uuid from contextlib import contextmanager from typing import Any, Mapping, Optional, Protocol, Type @@ -424,9 +425,14 @@ async def start_child_workflow( data={"workflow": input.workflow}, input=input, ) - handle = await ctx.run( - asyncio.create_task, self.next.start_child_workflow(input) - ) + if sys.version_info >= (3, 11): + handle = await asyncio.create_task( + self.next.start_child_workflow(input), context=ctx + ) + else: + handle = await ctx.run( + asyncio.create_task, self.next.start_child_workflow(input) + ) if span: handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore return handle diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index fae1470ce..7e0883c97 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -9,6 +9,7 @@ contextmanager, ) from dataclasses import dataclass +import sys from types import TracebackType from typing import ( Any, @@ -394,7 +395,12 @@ async def execute_workflow( :py:meth:`temporalio.worker.WorkflowInboundInterceptor.execute_workflow`. """ with self._top_level_workflow_context(success_is_complete=True) as ctx: - return await ctx.run(asyncio.create_task, self._execute_workflow(input)) + if sys.version_info >= (3, 11): + return await asyncio.create_task( + self._execute_workflow(input), context=ctx + ) + else: + return await ctx.run(asyncio.create_task, self._execute_workflow(input)) async def _execute_workflow( self, input: temporalio.worker.ExecuteWorkflowInput @@ -414,7 +420,12 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non # Create a span in the current context for the signal and link any # header given with self._top_level_workflow_context(success_is_complete=False) as ctx: - return await ctx.run(asyncio.create_task, self._handle_signal(input)) + if sys.version_info >= (3, 11): + return await asyncio.create_task( + self._handle_signal(input), context=ctx + ) + else: + return await ctx.run(asyncio.create_task, self._handle_signal(input)) async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: """Implementation of @@ -511,9 +522,14 @@ async def handle_update_handler( :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`. """ with self._top_level_workflow_context(success_is_complete=False) as ctx: - return await ctx.run( - asyncio.create_task, self._handle_update_handler(input) - ) + if sys.version_info >= (3, 11): + return await asyncio.create_task( + self._handle_update_handler(input), context=ctx + ) + else: + return await ctx.run( + asyncio.create_task, self._handle_update_handler(input) + ) async def _handle_update_handler( self, input: temporalio.worker.HandleUpdateInput diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index c805a967c..1f3420da3 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -20,7 +20,7 @@ import uuid from collections.abc import Mapping from concurrent.futures.thread import ThreadPoolExecutor -from dataclasses import dataclass +from dataclasses import dataclass, field from types import MappingProxyType from typing import Any, Callable, Optional, Union @@ -313,7 +313,9 @@ async def non_serializable_output( class SuccessfulResponse: status_code: int body_json: Optional[Union[dict[str, Any], Callable[[dict[str, Any]], bool]]] = None - headers: Mapping[str, str] = SUCCESSFUL_RESPONSE_HEADERS + headers: Mapping[str, str] = field( + default_factory=lambda: SUCCESSFUL_RESPONSE_HEADERS + ) @dataclass @@ -325,7 +327,9 @@ class UnsuccessfulResponse: # Expected value of inverse of non_retryable attribute of exception. retryable_exception: bool = True body_json: Optional[Callable[[dict[str, Any]], bool]] = None - headers: Mapping[str, str] = UNSUCCESSFUL_RESPONSE_HEADERS + headers: Mapping[str, str] = field( + default_factory=lambda: UNSUCCESSFUL_RESPONSE_HEADERS + ) class _TestCase: From 13e99a3c0e0706dc97e8fee2f5135886ec363d28 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 16 Oct 2025 11:40:12 -0700 Subject: [PATCH 09/21] run formatter --- temporalio/contrib/opentelemetry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 7e0883c97..92a41cb13 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -2,14 +2,14 @@ from __future__ import annotations -import contextvars import asyncio +import contextvars +import sys from contextlib import ( AbstractContextManager, contextmanager, ) from dataclasses import dataclass -import sys from types import TracebackType from typing import ( Any, From 853f603d75456d2cec14391d9de17a67044e3853 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 16 Oct 2025 11:49:46 -0700 Subject: [PATCH 10/21] fix typing lint errors --- temporalio/contrib/opentelemetry.py | 4 ++-- temporalio/worker/_workflow.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 92a41cb13..013cb53ab 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -649,7 +649,7 @@ def __init__( success_is_complete: bool, ): self._ctx = contextvars.copy_context() - self._token = None + self._token: Optional[contextvars.Token] = None self._owner = interceptor self._success_is_complete = success_is_complete @@ -662,7 +662,7 @@ def __exit__( exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], # noqa: F811 - ) -> bool | None: + ): self._ctx.run(self._end, exc_type, exc_value, traceback) def _start(self): diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 75b0c9059..6e7c254aa 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -337,9 +337,6 @@ async def _handle_activation( raise deadlock_exc from None except Exception as err: - if isinstance(err, GeneratorExit): - print("###############generator exit#############") - if isinstance(err, _DeadlockError): err.swap_traceback() From 7ae97d0e12bf3c77f2117f334f93e18d2b98212c Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 16 Oct 2025 11:56:27 -0700 Subject: [PATCH 11/21] Fix a few more typing errors --- temporalio/contrib/openai_agents/_trace_interceptor.py | 4 ++-- temporalio/contrib/opentelemetry.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/temporalio/contrib/openai_agents/_trace_interceptor.py b/temporalio/contrib/openai_agents/_trace_interceptor.py index e8f357daf..bfe6994d7 100644 --- a/temporalio/contrib/openai_agents/_trace_interceptor.py +++ b/temporalio/contrib/openai_agents/_trace_interceptor.py @@ -426,11 +426,11 @@ async def start_child_workflow( input=input, ) if sys.version_info >= (3, 11): - handle = await asyncio.create_task( + handle: temporalio.workflow.ChildWorkflowHandle = await asyncio.create_task( self.next.start_child_workflow(input), context=ctx ) else: - handle = await ctx.run( + handle: temporalio.workflow.ChildWorkflowHandle = await ctx.run( asyncio.create_task, self.next.start_child_workflow(input) ) if span: diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 013cb53ab..b6626400c 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -421,11 +421,9 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non # header given with self._top_level_workflow_context(success_is_complete=False) as ctx: if sys.version_info >= (3, 11): - return await asyncio.create_task( - self._handle_signal(input), context=ctx - ) + await asyncio.create_task(self._handle_signal(input), context=ctx) else: - return await ctx.run(asyncio.create_task, self._handle_signal(input)) + await ctx.run(asyncio.create_task, self._handle_signal(input)) async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: """Implementation of From aaccb0053c565103a8066c0f22b80a77b22c7ee6 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 16 Oct 2025 12:23:44 -0700 Subject: [PATCH 12/21] Use a lambda to wrap task creation to appease the type linter --- temporalio/contrib/openai_agents/_trace_interceptor.py | 2 +- temporalio/contrib/opentelemetry.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/temporalio/contrib/openai_agents/_trace_interceptor.py b/temporalio/contrib/openai_agents/_trace_interceptor.py index bfe6994d7..b9deda1a9 100644 --- a/temporalio/contrib/openai_agents/_trace_interceptor.py +++ b/temporalio/contrib/openai_agents/_trace_interceptor.py @@ -431,7 +431,7 @@ async def start_child_workflow( ) else: handle: temporalio.workflow.ChildWorkflowHandle = await ctx.run( - asyncio.create_task, self.next.start_child_workflow(input) + lambda: asyncio.create_task(self.next.start_child_workflow(input)) ) if span: handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index b6626400c..44d44a7b5 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -15,6 +15,7 @@ Any, Callable, ContextManager, + Coroutine, Dict, Iterator, Mapping, @@ -423,7 +424,7 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non if sys.version_info >= (3, 11): await asyncio.create_task(self._handle_signal(input), context=ctx) else: - await ctx.run(asyncio.create_task, self._handle_signal(input)) + await ctx.run(lambda: asyncio.create_task(self._handle_signal(input))) async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: """Implementation of From fe8b0a55c124d51e37516a4bdedfc26879323352 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 17 Oct 2025 13:41:34 -0700 Subject: [PATCH 13/21] revert manual context management explorations --- .../openai_agents/_trace_interceptor.py | 72 +++--- temporalio/contrib/opentelemetry.py | 209 ++++++------------ tests/nexus/test_handler.py | 10 +- 3 files changed, 93 insertions(+), 198 deletions(-) diff --git a/temporalio/contrib/openai_agents/_trace_interceptor.py b/temporalio/contrib/openai_agents/_trace_interceptor.py index b9deda1a9..20d489b65 100644 --- a/temporalio/contrib/openai_agents/_trace_interceptor.py +++ b/temporalio/contrib/openai_agents/_trace_interceptor.py @@ -2,10 +2,7 @@ from __future__ import annotations -import asyncio -import contextvars import random -import sys import uuid from contextlib import contextmanager from typing import Any, Mapping, Optional, Protocol, Type @@ -403,63 +400,48 @@ async def signal_external_workflow( def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: - ctx = contextvars.copy_context() - span = ctx.run( - self._create_span, - name="temporal:startActivity", - data={"activity": input.activity}, - input=input, - ) - handle = ctx.run(self.next.start_activity, input) + trace = get_trace_provider().get_current_trace() + span: Optional[Span] = None + if trace: + span = custom_span( + name="temporal:startActivity", data={"activity": input.activity} + ) + span.start(mark_as_current=True) + + set_header_from_context(input, temporalio.workflow.payload_converter()) + handle = self.next.start_activity(input) if span: - handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore + handle.add_done_callback(lambda _: span.finish()) # type: ignore return handle async def start_child_workflow( self, input: temporalio.worker.StartChildWorkflowInput ) -> temporalio.workflow.ChildWorkflowHandle: - ctx = contextvars.copy_context() - span = ctx.run( - self._create_span, - name="temporal:startChildWorkflow", - data={"workflow": input.workflow}, - input=input, - ) - if sys.version_info >= (3, 11): - handle: temporalio.workflow.ChildWorkflowHandle = await asyncio.create_task( - self.next.start_child_workflow(input), context=ctx - ) - else: - handle: temporalio.workflow.ChildWorkflowHandle = await ctx.run( - lambda: asyncio.create_task(self.next.start_child_workflow(input)) + trace = get_trace_provider().get_current_trace() + span: Optional[Span] = None + if trace: + span = custom_span( + name="temporal:startChildWorkflow", data={"workflow": input.workflow} ) + span.start(mark_as_current=True) + set_header_from_context(input, temporalio.workflow.payload_converter()) + handle = await self.next.start_child_workflow(input) if span: - handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore + handle.add_done_callback(lambda _: span.finish()) # type: ignore return handle def start_local_activity( self, input: temporalio.worker.StartLocalActivityInput ) -> temporalio.workflow.ActivityHandle: - ctx = contextvars.copy_context() - span = ctx.run( - self._create_span, - name="temporal:startLocalActivity", - data={"activity": input.activity}, - input=input, - ) - handle = ctx.run(self.next.start_local_activity, input) - if span: - handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore - return handle - - def _create_span( - self, name: str, data: dict[str, Any], input: _InputWithHeaders - ) -> Optional[Span]: trace = get_trace_provider().get_current_trace() span: Optional[Span] = None if trace: - span = custom_span(name=name, data=data) + span = custom_span( + name="temporal:startLocalActivity", data={"activity": input.activity} + ) span.start(mark_as_current=True) - set_header_from_context(input, temporalio.workflow.payload_converter()) - return span + handle = self.next.start_local_activity(input) + if span: + handle.add_done_callback(lambda _: span.finish()) # type: ignore + return handle diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 9f4b42266..136320c60 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -2,20 +2,11 @@ from __future__ import annotations -import asyncio -import contextvars -import sys -from contextlib import ( - AbstractContextManager, - contextmanager, -) +from contextlib import contextmanager from dataclasses import dataclass -from types import TracebackType from typing import ( Any, Callable, - ContextManager, - Coroutine, Dict, Iterator, Mapping, @@ -419,38 +410,15 @@ async def execute_workflow( """Implementation of :py:meth:`temporalio.worker.WorkflowInboundInterceptor.execute_workflow`. """ - with self._top_level_workflow_context(success_is_complete=True) as ctx: - if sys.version_info >= (3, 11): - return await asyncio.create_task( - self._execute_workflow(input), context=ctx - ) - else: - return await ctx.run(asyncio.create_task, self._execute_workflow(input)) - - async def _execute_workflow( - self, input: temporalio.worker.ExecuteWorkflowInput - ) -> Any: - # Entrypoint of workflow should be `server` in OTel - self._completed_span( - f"RunWorkflow:{temporalio.workflow.info().workflow_type}", - kind=opentelemetry.trace.SpanKind.SERVER, - ) - # with self._with_complete_span(success_is_complete=True): - return await super().execute_workflow(input) + with self._top_level_workflow_context(success_is_complete=True): + # Entrypoint of workflow should be `server` in OTel + self._completed_span( + f"RunWorkflow:{temporalio.workflow.info().workflow_type}", + kind=opentelemetry.trace.SpanKind.SERVER, + ) + return await super().execute_workflow(input) async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: - """Implementation of - :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`. - """ - # Create a span in the current context for the signal and link any - # header given - with self._top_level_workflow_context(success_is_complete=False) as ctx: - if sys.version_info >= (3, 11): - await asyncio.create_task(self._handle_signal(input), context=ctx) - else: - await ctx.run(lambda: asyncio.create_task(self._handle_signal(input))) - - async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: """Implementation of :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_signal`. """ @@ -462,22 +430,18 @@ async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> No link_context_carrier = self.payload_converter.from_payloads( [link_context_header] )[0] - self._completed_span( - f"HandleSignal:{input.signal}", - link_context_carrier=link_context_carrier, - kind=opentelemetry.trace.SpanKind.SERVER, - ) - await super().handle_signal(input) + with self._top_level_workflow_context(success_is_complete=False): + self._completed_span( + f"HandleSignal:{input.signal}", + link_context_carrier=link_context_carrier, + kind=opentelemetry.trace.SpanKind.SERVER, + ) + await super().handle_signal(input) async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: """Implementation of :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_query`. """ - # handle_query does not manage the contextvars.Context itself because - # scheduling an asyncio task in a read only operation is not allowed. - # The operation is synchronous which makes default contextvars.Context - # safe. - # Only trace this if there is a header, and make that span the parent. # We do not put anything that happens in a query handler on the workflow # span. @@ -508,19 +472,11 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: ) return await super().handle_query(input) finally: - opentelemetry.context.detach(token) + if attach_context == opentelemetry.context.get_current(): + opentelemetry.context.detach(token) def handle_update_validator( self, input: temporalio.worker.HandleUpdateInput - ) -> None: - """Implementation of - :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`. - """ - with self._top_level_workflow_context(success_is_complete=False) as ctx: - ctx.run(self._handle_update_validator, input) - - def _handle_update_validator( - self, input: temporalio.worker.HandleUpdateInput ) -> None: """Implementation of :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_validator`. @@ -531,6 +487,7 @@ def _handle_update_validator( link_context_carrier = self.payload_converter.from_payloads( [link_context_header] )[0] + with self._top_level_workflow_context(success_is_complete=False): self._completed_span( f"ValidateUpdate:{input.update}", link_context_carrier=link_context_carrier, @@ -540,22 +497,6 @@ def _handle_update_validator( async def handle_update_handler( self, input: temporalio.worker.HandleUpdateInput - ) -> Any: - """Implementation of - :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`. - """ - with self._top_level_workflow_context(success_is_complete=False) as ctx: - if sys.version_info >= (3, 11): - return await asyncio.create_task( - self._handle_update_handler(input), context=ctx - ) - else: - return await ctx.run( - asyncio.create_task, self._handle_update_handler(input) - ) - - async def _handle_update_handler( - self, input: temporalio.worker.HandleUpdateInput ) -> Any: """Implementation of :py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`. @@ -566,12 +507,13 @@ async def _handle_update_handler( link_context_carrier = self.payload_converter.from_payloads( [link_context_header] )[0] - self._completed_span( - f"HandleUpdate:{input.update}", - link_context_carrier=link_context_carrier, - kind=opentelemetry.trace.SpanKind.SERVER, - ) - return await super().handle_update_handler(input) + with self._top_level_workflow_context(success_is_complete=False): + self._completed_span( + f"HandleUpdate:{input.update}", + link_context_carrier=link_context_carrier, + kind=opentelemetry.trace.SpanKind.SERVER, + ) + return await super().handle_update_handler(input) def _load_workflow_context_carrier(self) -> Optional[_CarrierDict]: if self._workflow_context_carrier: @@ -584,13 +526,47 @@ def _load_workflow_context_carrier(self) -> Optional[_CarrierDict]: )[0] return self._workflow_context_carrier + @contextmanager def _top_level_workflow_context( self, *, success_is_complete: bool - ) -> ContextManager[contextvars.Context]: - return self._TopLevelWorkflowContextManager( - self, success_is_complete=success_is_complete - ) + ) -> Iterator[None]: + # Load context only if there is a carrier, otherwise use empty context + context_carrier = self._load_workflow_context_carrier() + attach_context: opentelemetry.context.Context + if context_carrier: + attach_context = self.text_map_propagator.extract(context_carrier) + else: + attach_context = opentelemetry.context.Context() + # We need to put this interceptor on the context too + attach_context = self._set_on_context(attach_context) + # Need to know whether completed and whether there was a fail-workflow + # exception + success = False + exception: Optional[Exception] = None + # Run under this context + token = opentelemetry.context.attach(attach_context) + + try: + yield None + success = True + except temporalio.exceptions.FailureError as err: + # We only record the failure errors since those are the only ones + # that lead to workflow completions + exception = err + raise + finally: + # Create a completed span before detaching context + if exception or (success and success_is_complete): + self._completed_span( + f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}", + exception=exception, + kind=opentelemetry.trace.SpanKind.INTERNAL, + ) + + if attach_context == opentelemetry.context.get_current(): + opentelemetry.context.detach(token) + # def _context_to_headers( self, headers: Mapping[str, temporalio.api.common.v1.Payload] ) -> Mapping[str, temporalio.api.common.v1.Payload]: @@ -664,65 +640,6 @@ def _set_on_context( ) -> opentelemetry.context.Context: return opentelemetry.context.set_value(_interceptor_context_key, self, context) - class _TopLevelWorkflowContextManager(AbstractContextManager): - def __init__( - self, - interceptor: TracingWorkflowInboundInterceptor, - *, - success_is_complete: bool, - ): - self._ctx = contextvars.copy_context() - self._token: Optional[contextvars.Token] = None - self._owner = interceptor - self._success_is_complete = success_is_complete - - def __enter__(self): - self._ctx.run(self._start) - return self._ctx - - def __exit__( - self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], # noqa: F811 - ): - self._ctx.run(self._end, exc_type, exc_value, traceback) - - def _start(self): - # Load context only if there is a carrier, otherwise use empty context - context_carrier = self._owner._load_workflow_context_carrier() - attach_context: opentelemetry.context.Context - if context_carrier: - attach_context = self._owner.text_map_propagator.extract( - context_carrier - ) - else: - attach_context = opentelemetry.context.Context() - # We need to put this interceptor on the context too - attach_context = self._owner._set_on_context(attach_context) - self._token = opentelemetry.context.attach(attach_context) - - def _end( - self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - _traceback: Optional[TracebackType], - ): - success = exc_type is None - exception: Optional[temporalio.exceptions.FailureError] = None - if isinstance(exc_value, temporalio.exceptions.FailureError): - exception = exc_value - - if (success and self._success_is_complete) or exception is not None: - self._owner._completed_span( - f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}", - exception=exception, - kind=opentelemetry.trace.SpanKind.INTERNAL, - ) - - if self._token: - opentelemetry.context.detach(self._token) - class _TracingWorkflowOutboundInterceptor( temporalio.worker.WorkflowOutboundInterceptor diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 1f3420da3..c805a967c 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -20,7 +20,7 @@ import uuid from collections.abc import Mapping from concurrent.futures.thread import ThreadPoolExecutor -from dataclasses import dataclass, field +from dataclasses import dataclass from types import MappingProxyType from typing import Any, Callable, Optional, Union @@ -313,9 +313,7 @@ async def non_serializable_output( class SuccessfulResponse: status_code: int body_json: Optional[Union[dict[str, Any], Callable[[dict[str, Any]], bool]]] = None - headers: Mapping[str, str] = field( - default_factory=lambda: SUCCESSFUL_RESPONSE_HEADERS - ) + headers: Mapping[str, str] = SUCCESSFUL_RESPONSE_HEADERS @dataclass @@ -327,9 +325,7 @@ class UnsuccessfulResponse: # Expected value of inverse of non_retryable attribute of exception. retryable_exception: bool = True body_json: Optional[Callable[[dict[str, Any]], bool]] = None - headers: Mapping[str, str] = field( - default_factory=lambda: UNSUCCESSFUL_RESPONSE_HEADERS - ) + headers: Mapping[str, str] = UNSUCCESSFUL_RESPONSE_HEADERS class _TestCase: From 2274ba495f11b0d37ebe434c5bb1b61145200dc0 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 17 Oct 2025 13:42:29 -0700 Subject: [PATCH 14/21] Add comment explaining the check. Use to ensure that the context is not just equal but is the same object --- temporalio/contrib/opentelemetry.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 136320c60..1cf96282b 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -472,7 +472,11 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: ) return await super().handle_query(input) finally: - if attach_context == opentelemetry.context.get_current(): + # In some exceptional cases this finally is executed with a + # different contextvars.Context than the one the token was created + # on. As such we do a best effort detach to avoid using a mismatched + # token. + if attach_context is opentelemetry.context.get_current(): opentelemetry.context.detach(token) def handle_update_validator( @@ -563,10 +567,13 @@ def _top_level_workflow_context( kind=opentelemetry.trace.SpanKind.INTERNAL, ) - if attach_context == opentelemetry.context.get_current(): + # In some exceptional cases this finally is executed with a + # different contextvars.Context than the one the token was created + # on. As such we do a best effort detach to avoid using a mismatched + # token. + if attach_context is opentelemetry.context.get_current(): opentelemetry.context.detach(token) - # def _context_to_headers( self, headers: Mapping[str, temporalio.api.common.v1.Payload] ) -> Mapping[str, temporalio.api.common.v1.Payload]: From 460ed71cde120a596589ae8dc10f87749e734727 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 17 Oct 2025 13:45:10 -0700 Subject: [PATCH 15/21] use original variable name --- temporalio/contrib/opentelemetry.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 1cf96282b..351a2e42f 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -446,21 +446,21 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: # We do not put anything that happens in a query handler on the workflow # span. context_header = input.headers.get(self.header_key) - attach_context: opentelemetry.context.Context + context: opentelemetry.context.Context link_context_carrier: Optional[_CarrierDict] = None if context_header: context_carrier = self.payload_converter.from_payloads([context_header])[0] - attach_context = self.text_map_propagator.extract(context_carrier) + context = self.text_map_propagator.extract(context_carrier) # If there is a workflow span, use it as the link link_context_carrier = self._load_workflow_context_carrier() else: # Use an empty context - attach_context = opentelemetry.context.Context() + context = opentelemetry.context.Context() # We need to put this interceptor on the context too - attach_context = self._set_on_context(attach_context) + context = self._set_on_context(context) # Run under context with new span - token = opentelemetry.context.attach(attach_context) + token = opentelemetry.context.attach(context) try: # This won't be created if there was no context header self._completed_span( @@ -476,7 +476,7 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: # different contextvars.Context than the one the token was created # on. As such we do a best effort detach to avoid using a mismatched # token. - if attach_context is opentelemetry.context.get_current(): + if context is opentelemetry.context.get_current(): opentelemetry.context.detach(token) def handle_update_validator( @@ -536,19 +536,19 @@ def _top_level_workflow_context( ) -> Iterator[None]: # Load context only if there is a carrier, otherwise use empty context context_carrier = self._load_workflow_context_carrier() - attach_context: opentelemetry.context.Context + context: opentelemetry.context.Context if context_carrier: - attach_context = self.text_map_propagator.extract(context_carrier) + context = self.text_map_propagator.extract(context_carrier) else: - attach_context = opentelemetry.context.Context() + context = opentelemetry.context.Context() # We need to put this interceptor on the context too - attach_context = self._set_on_context(attach_context) + context = self._set_on_context(context) # Need to know whether completed and whether there was a fail-workflow # exception success = False exception: Optional[Exception] = None # Run under this context - token = opentelemetry.context.attach(attach_context) + token = opentelemetry.context.attach(context) try: yield None @@ -571,7 +571,7 @@ def _top_level_workflow_context( # different contextvars.Context than the one the token was created # on. As such we do a best effort detach to avoid using a mismatched # token. - if attach_context is opentelemetry.context.get_current(): + if context is opentelemetry.context.get_current(): opentelemetry.context.detach(token) def _context_to_headers( From 83ae1fb396223277e1471b50eeb614e43e3ab97e Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 17 Oct 2025 14:52:08 -0700 Subject: [PATCH 16/21] Fix typo --- tests/contrib/test_opentelemetry.py | 2 +- tests/helpers/{cache_evitction.py => cache_eviction.py} | 0 tests/worker/test_workflow.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename tests/helpers/{cache_evitction.py => cache_eviction.py} (100%) diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 4c7360506..2ffebc8c1 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -24,7 +24,7 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers import LogCapturer -from tests.helpers.cache_evitction import ( +from tests.helpers.cache_eviction import ( CacheEvictionTearDownWorkflow, WaitForeverWorkflow, wait_forever_activity, diff --git a/tests/helpers/cache_evitction.py b/tests/helpers/cache_eviction.py similarity index 100% rename from tests/helpers/cache_evitction.py rename to tests/helpers/cache_eviction.py diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index b76451d00..671752544 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -145,7 +145,7 @@ unpause_and_assert, workflow_update_exists, ) -from tests.helpers.cache_evitction import ( +from tests.helpers.cache_eviction import ( CacheEvictionTearDownWorkflow, WaitForeverWorkflow, wait_forever_activity, From c053aa6afff9c4dff58df959b3ed50b2ad588776 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 21 Oct 2025 14:19:12 -0700 Subject: [PATCH 17/21] Add some logs to help debug test flaking with timeouts --- tests/contrib/test_opentelemetry.py | 13 ++++++++++--- tests/helpers/cache_eviction.py | 6 ++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index f419bb541..023a6dd6d 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -588,6 +588,7 @@ async def test_opentelemetry_safe_detach(client: Client): with LogCapturer().logs_captured(opentelemetry.context.logger) as capturer: try: + print("===== in detach test") handle = await client.start_workflow( CacheEvictionTearDownWorkflow.run, id=f"wf-{uuid.uuid4()}", @@ -595,14 +596,20 @@ async def test_opentelemetry_safe_detach(client: Client): ) # CacheEvictionTearDownWorkflow requires 3 signals to be sent + print("===== signal 1") await handle.signal(CacheEvictionTearDownWorkflow.signal) + print("===== signal 2") await handle.signal(CacheEvictionTearDownWorkflow.signal) + print("===== signal 3") await handle.signal(CacheEvictionTearDownWorkflow.signal) + print("===== awaiting result") await handle.result() finally: sys.unraisablehook = old_hook + print("===== inspecting logs") + # Confirm at least 1 exception if len(hook_calls) < 1: logging.warning( @@ -615,6 +622,6 @@ def otel_context_error(record: logging.LogRecord) -> bool: and "Failed to detach context" in record.message ) - assert ( - capturer.find(otel_context_error) is None - ), "Detach from context message should not be logged" + assert capturer.find(otel_context_error) is None, ( + "Detach from context message should not be logged" + ) diff --git a/tests/helpers/cache_eviction.py b/tests/helpers/cache_eviction.py index 191d51078..8adebc91e 100644 --- a/tests/helpers/cache_eviction.py +++ b/tests/helpers/cache_eviction.py @@ -46,17 +46,23 @@ async def run(self) -> None: # Let's also start something in the background that we never wait on asyncio.create_task(asyncio.sleep(1000)) try: + print("----- in evict workflow") # Wait for signal count to reach 2 await asyncio.sleep(0.01) + print("----- waiting for signal") await workflow.wait_condition(lambda: self._signal_count > 1) finally: # This finally, on eviction, is actually called but the command # should be ignored + print("----- sleeping") await asyncio.sleep(0.01) + print("----- waiting for signals 2 & 3") await workflow.wait_condition(lambda: self._signal_count > 2) # Cancel gather tasks and wait on them, but ignore the errors for task in tasks: task.cancel() + + print("----- evict workflow ending") await gather_fut @workflow.signal From 6e233658762f8ef18adc3159cb2a12bcf34ca32e Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 21 Oct 2025 14:21:54 -0700 Subject: [PATCH 18/21] apply formatting --- tests/contrib/test_opentelemetry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 023a6dd6d..fc046bf18 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -622,6 +622,6 @@ def otel_context_error(record: logging.LogRecord) -> bool: and "Failed to detach context" in record.message ) - assert capturer.find(otel_context_error) is None, ( - "Detach from context message should not be logged" - ) + assert ( + capturer.find(otel_context_error) is None + ), "Detach from context message should not be logged" From db3bb0e212601efa45b8fd2160ab0372d58bf039 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 21 Oct 2025 16:48:58 -0700 Subject: [PATCH 19/21] Revert "apply formatting" This reverts commit 6e233658762f8ef18adc3159cb2a12bcf34ca32e. --- tests/contrib/test_opentelemetry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index fc046bf18..023a6dd6d 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -622,6 +622,6 @@ def otel_context_error(record: logging.LogRecord) -> bool: and "Failed to detach context" in record.message ) - assert ( - capturer.find(otel_context_error) is None - ), "Detach from context message should not be logged" + assert capturer.find(otel_context_error) is None, ( + "Detach from context message should not be logged" + ) From 50cbd42a39440fda2f94a054e311ff678704ef9d Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 21 Oct 2025 16:49:00 -0700 Subject: [PATCH 20/21] Revert "Add some logs to help debug test flaking with timeouts" This reverts commit c053aa6afff9c4dff58df959b3ed50b2ad588776. --- tests/contrib/test_opentelemetry.py | 13 +++---------- tests/helpers/cache_eviction.py | 6 ------ 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 023a6dd6d..f419bb541 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -588,7 +588,6 @@ async def test_opentelemetry_safe_detach(client: Client): with LogCapturer().logs_captured(opentelemetry.context.logger) as capturer: try: - print("===== in detach test") handle = await client.start_workflow( CacheEvictionTearDownWorkflow.run, id=f"wf-{uuid.uuid4()}", @@ -596,20 +595,14 @@ async def test_opentelemetry_safe_detach(client: Client): ) # CacheEvictionTearDownWorkflow requires 3 signals to be sent - print("===== signal 1") await handle.signal(CacheEvictionTearDownWorkflow.signal) - print("===== signal 2") await handle.signal(CacheEvictionTearDownWorkflow.signal) - print("===== signal 3") await handle.signal(CacheEvictionTearDownWorkflow.signal) - print("===== awaiting result") await handle.result() finally: sys.unraisablehook = old_hook - print("===== inspecting logs") - # Confirm at least 1 exception if len(hook_calls) < 1: logging.warning( @@ -622,6 +615,6 @@ def otel_context_error(record: logging.LogRecord) -> bool: and "Failed to detach context" in record.message ) - assert capturer.find(otel_context_error) is None, ( - "Detach from context message should not be logged" - ) + assert ( + capturer.find(otel_context_error) is None + ), "Detach from context message should not be logged" diff --git a/tests/helpers/cache_eviction.py b/tests/helpers/cache_eviction.py index 8adebc91e..191d51078 100644 --- a/tests/helpers/cache_eviction.py +++ b/tests/helpers/cache_eviction.py @@ -46,23 +46,17 @@ async def run(self) -> None: # Let's also start something in the background that we never wait on asyncio.create_task(asyncio.sleep(1000)) try: - print("----- in evict workflow") # Wait for signal count to reach 2 await asyncio.sleep(0.01) - print("----- waiting for signal") await workflow.wait_condition(lambda: self._signal_count > 1) finally: # This finally, on eviction, is actually called but the command # should be ignored - print("----- sleeping") await asyncio.sleep(0.01) - print("----- waiting for signals 2 & 3") await workflow.wait_condition(lambda: self._signal_count > 2) # Cancel gather tasks and wait on them, but ignore the errors for task in tasks: task.cancel() - - print("----- evict workflow ending") await gather_fut @workflow.signal From c806d6b3e3b73ee3d1cd5ef8c617d9ce9bbf9c1f Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 21 Oct 2025 16:58:40 -0700 Subject: [PATCH 21/21] move safe detach test to a model that forces __exit__ on a different thread. --- tests/contrib/test_opentelemetry.py | 103 +++++++++++++--------------- 1 file changed, 48 insertions(+), 55 deletions(-) diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index f419bb541..9dbdfed93 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -1,8 +1,11 @@ from __future__ import annotations import asyncio +import gc import logging +import queue import sys +import threading import uuid from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -19,7 +22,10 @@ from temporalio import activity, workflow from temporalio.client import Client, WithStartWorkflowOperation, WorkflowUpdateStage from temporalio.common import RetryPolicy, WorkflowIDConflictPolicy -from temporalio.contrib.opentelemetry import TracingInterceptor +from temporalio.contrib.opentelemetry import ( + TracingInterceptor, + TracingWorkflowInboundInterceptor, +) from temporalio.contrib.opentelemetry import workflow as otel_workflow from temporalio.exceptions import ApplicationError, ApplicationErrorCategory from temporalio.testing import WorkflowEnvironment @@ -560,61 +566,48 @@ async def test_opentelemetry_benign_exception(client: Client): # * signal failure and wft failure from signal -async def test_opentelemetry_safe_detach(client: Client): - # This test simulates forcing eviction. This purposely raises GeneratorExit on - # GC which triggers the finally which could run on any thread Python - # chooses. When this occurs, we should not detach the token from the context - # b/c the context no longer exists +def test_opentelemetry_safe_detach(): + class _fake_self: + def _load_workflow_context_carrier(*args): + return None - # Create a tracer that has an in-memory exporter - exporter = InMemorySpanExporter() - provider = TracerProvider() - provider.add_span_processor(SimpleSpanProcessor(exporter)) - tracer = get_tracer(__name__, tracer_provider=provider) + def _set_on_context(self, ctx): + return opentelemetry.context.set_value("test-key", "test-value", ctx) - async with Worker( - client, - workflows=[CacheEvictionTearDownWorkflow, WaitForeverWorkflow], - activities=[wait_forever_activity], - max_cached_workflows=0, - task_queue=f"task_queue_{uuid.uuid4()}", - disable_safe_workflow_eviction=True, - interceptors=[TracingInterceptor(tracer)], - ) as worker: - # Put a hook to catch unraisable exceptions - old_hook = sys.unraisablehook - hook_calls: List[sys.UnraisableHookArgs] = [] - sys.unraisablehook = hook_calls.append - - with LogCapturer().logs_captured(opentelemetry.context.logger) as capturer: - try: - handle = await client.start_workflow( - CacheEvictionTearDownWorkflow.run, - id=f"wf-{uuid.uuid4()}", - task_queue=worker.task_queue, - ) - - # CacheEvictionTearDownWorkflow requires 3 signals to be sent - await handle.signal(CacheEvictionTearDownWorkflow.signal) - await handle.signal(CacheEvictionTearDownWorkflow.signal) - await handle.signal(CacheEvictionTearDownWorkflow.signal) + def _completed_span(*args, **kwargs): + pass - await handle.result() - finally: - sys.unraisablehook = old_hook - - # Confirm at least 1 exception - if len(hook_calls) < 1: - logging.warning( - "Expected at least 1 exception. Unable to properly verify context detachment" - ) - - def otel_context_error(record: logging.LogRecord) -> bool: - return ( - record.name == "opentelemetry.context" - and "Failed to detach context" in record.message - ) + # create a context manager and force enter to happen on this thread + context_manager = TracingWorkflowInboundInterceptor._top_level_workflow_context( + _fake_self(), # type: ignore + success_is_complete=True, + ) + context_manager.__enter__() + + # move reference to context manager into queue + q: queue.Queue = queue.Queue() + q.put(context_manager) + del context_manager + + def worker(): + # pull reference from queue and delete the last reference + context_manager = q.get() + del context_manager + # force gc + gc.collect() + + with LogCapturer().logs_captured(opentelemetry.context.logger) as capturer: + # run forced gc on other thread so exit happens there + t = threading.Thread(target=worker) + t.start() + t.join(timeout=5) + + def otel_context_error(record: logging.LogRecord) -> bool: + return ( + record.name == "opentelemetry.context" + and "Failed to detach context" in record.message + ) - assert ( - capturer.find(otel_context_error) is None - ), "Detach from context message should not be logged" + assert ( + capturer.find(otel_context_error) is None + ), "Detach from context message should not be logged"