diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 9e1542814..351a2e42f 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -26,8 +26,7 @@ import opentelemetry.trace.propagation.tracecontext import opentelemetry.util.types from opentelemetry.context import Context -from opentelemetry.trace import Span, SpanKind, Status, StatusCode, _Links -from opentelemetry.util import types +from opentelemetry.trace import Status, StatusCode from typing_extensions import Protocol, TypeAlias, TypedDict import temporalio.activity @@ -473,7 +472,12 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: ) return await super().handle_query(input) finally: - opentelemetry.context.detach(token) + # In some exceptional cases this finally is executed with a + # different contextvars.Context than the one the token was created + # on. As such we do a best effort detach to avoid using a mismatched + # token. + if context is opentelemetry.context.get_current(): + opentelemetry.context.detach(token) def handle_update_validator( self, input: temporalio.worker.HandleUpdateInput @@ -545,6 +549,7 @@ def _top_level_workflow_context( exception: Optional[Exception] = None # Run under this context token = opentelemetry.context.attach(context) + try: yield None success = True @@ -561,7 +566,13 @@ def _top_level_workflow_context( exception=exception, kind=opentelemetry.trace.SpanKind.INTERNAL, ) - opentelemetry.context.detach(token) + + # In some exceptional cases this finally is executed with a + # different contextvars.Context than the one the token was created + # on. As such we do a best effort detach to avoid using a mismatched + # token. + if context is opentelemetry.context.get_current(): + opentelemetry.context.detach(token) def _context_to_headers( self, headers: Mapping[str, temporalio.api.common.v1.Payload] diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index be6b17707..9dbdfed93 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -1,13 +1,18 @@ from __future__ import annotations import asyncio +import gc import logging +import queue +import sys +import threading import uuid from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import timedelta from typing import Iterable, List, Optional +import opentelemetry.context import pytest from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor @@ -17,11 +22,20 @@ from temporalio import activity, workflow from temporalio.client import Client, WithStartWorkflowOperation, WorkflowUpdateStage from temporalio.common import RetryPolicy, WorkflowIDConflictPolicy -from temporalio.contrib.opentelemetry import TracingInterceptor +from temporalio.contrib.opentelemetry import ( + TracingInterceptor, + TracingWorkflowInboundInterceptor, +) from temporalio.contrib.opentelemetry import workflow as otel_workflow from temporalio.exceptions import ApplicationError, ApplicationErrorCategory from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker +from tests.helpers import LogCapturer +from tests.helpers.cache_eviction import ( + CacheEvictionTearDownWorkflow, + WaitForeverWorkflow, + wait_forever_activity, +) @dataclass @@ -420,7 +434,10 @@ def dump_spans( span_links: List[str] = [] for link in span.links: for link_span in spans: - if link_span.context.span_id == link.context.span_id: + if ( + link_span.context is not None + and link_span.context.span_id == link.context.span_id + ): span_links.append(link_span.name) span_str += f" (links: {', '.join(span_links)})" # Signals can duplicate in rare situations, so we make sure not to @@ -430,7 +447,7 @@ def dump_spans( ret.append(span_str) ret += dump_spans( spans, - parent_id=span.context.span_id, + parent_id=span.context.span_id if span.context else None, with_attributes=with_attributes, indent_depth=indent_depth + 1, ) @@ -547,3 +564,50 @@ async def test_opentelemetry_benign_exception(client: Client): # * workflow failure and wft failure # * signal with start # * signal failure and wft failure from signal + + +def test_opentelemetry_safe_detach(): + class _fake_self: + def _load_workflow_context_carrier(*args): + return None + + def _set_on_context(self, ctx): + return opentelemetry.context.set_value("test-key", "test-value", ctx) + + def _completed_span(*args, **kwargs): + pass + + # create a context manager and force enter to happen on this thread + context_manager = TracingWorkflowInboundInterceptor._top_level_workflow_context( + _fake_self(), # type: ignore + success_is_complete=True, + ) + context_manager.__enter__() + + # move reference to context manager into queue + q: queue.Queue = queue.Queue() + q.put(context_manager) + del context_manager + + def worker(): + # pull reference from queue and delete the last reference + context_manager = q.get() + del context_manager + # force gc + gc.collect() + + with LogCapturer().logs_captured(opentelemetry.context.logger) as capturer: + # run forced gc on other thread so exit happens there + t = threading.Thread(target=worker) + t.start() + t.join(timeout=5) + + def otel_context_error(record: logging.LogRecord) -> bool: + return ( + record.name == "opentelemetry.context" + and "Failed to detach context" in record.message + ) + + assert ( + capturer.find(otel_context_error) is None + ), "Detach from context message should not be logged" diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index 79d3687fd..4a3850024 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -1,11 +1,25 @@ import asyncio +import logging +import logging.handlers +import queue import socket import time import uuid -from contextlib import closing +from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar, Union +from typing import ( + Any, + Awaitable, + Callable, + List, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, +) from temporalio.api.common.v1 import WorkflowExecution from temporalio.api.enums.v1 import EventType as EventType @@ -401,3 +415,34 @@ def _format_row(items: list[str], truncate: bool = False) -> str: padding = len(f" *: {elapsed_ms:>4} ") summary_row[col_idx] = f"{' ' * padding}[{summary}]"[: col_width - 3] print(_format_row(summary_row)) + + +class LogCapturer: + def __init__(self) -> None: + self.log_queue: queue.Queue[logging.LogRecord] = queue.Queue() + + @contextmanager + def logs_captured(self, *loggers: logging.Logger): + handler = logging.handlers.QueueHandler(self.log_queue) + + prev_levels = [l.level for l in loggers] + for l in loggers: + l.setLevel(logging.INFO) + l.addHandler(handler) + try: + yield self + finally: + for i, l in enumerate(loggers): + l.removeHandler(handler) + l.setLevel(prev_levels[i]) + + def find_log(self, starts_with: str) -> Optional[logging.LogRecord]: + return self.find(lambda l: l.message.startswith(starts_with)) + + def find( + self, pred: Callable[[logging.LogRecord], bool] + ) -> Optional[logging.LogRecord]: + for record in cast(List[logging.LogRecord], self.log_queue.queue): + if pred(record): + return record + return None diff --git a/tests/helpers/cache_eviction.py b/tests/helpers/cache_eviction.py new file mode 100644 index 000000000..191d51078 --- /dev/null +++ b/tests/helpers/cache_eviction.py @@ -0,0 +1,68 @@ +import asyncio +from datetime import timedelta + +from temporalio import activity, workflow + + +@activity.defn +async def wait_forever_activity() -> None: + await asyncio.Future() + + +@workflow.defn +class WaitForeverWorkflow: + @workflow.run + async def run(self) -> None: + await asyncio.Future() + + +@workflow.defn +class CacheEvictionTearDownWorkflow: + def __init__(self) -> None: + self._signal_count = 0 + + @workflow.run + async def run(self) -> None: + # Start several things in background. This is just to show that eviction + # can work even with these things running. + tasks = [ + asyncio.create_task( + workflow.execute_activity( + wait_forever_activity, start_to_close_timeout=timedelta(hours=1) + ) + ), + asyncio.create_task( + workflow.execute_child_workflow(WaitForeverWorkflow.run) + ), + asyncio.create_task(asyncio.sleep(1000)), + asyncio.shield( + workflow.execute_activity( + wait_forever_activity, start_to_close_timeout=timedelta(hours=1) + ) + ), + asyncio.create_task(workflow.wait_condition(lambda: False)), + ] + gather_fut = asyncio.gather(*tasks, return_exceptions=True) + # Let's also start something in the background that we never wait on + asyncio.create_task(asyncio.sleep(1000)) + try: + # Wait for signal count to reach 2 + await asyncio.sleep(0.01) + await workflow.wait_condition(lambda: self._signal_count > 1) + finally: + # This finally, on eviction, is actually called but the command + # should be ignored + await asyncio.sleep(0.01) + await workflow.wait_condition(lambda: self._signal_count > 2) + # Cancel gather tasks and wait on them, but ignore the errors + for task in tasks: + task.cancel() + await gather_fut + + @workflow.signal + async def signal(self) -> None: + self._signal_count += 1 + + @workflow.query + def signal_count(self) -> int: + return self._signal_count diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index da335635b..f7735db01 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -131,6 +131,7 @@ ) from tests import DEV_SERVER_DOWNLOAD_VERSION from tests.helpers import ( + LogCapturer, admitted_update_task, assert_eq_eventually, assert_eventually, @@ -145,6 +146,11 @@ unpause_and_assert, workflow_update_exists, ) +from tests.helpers.cache_eviction import ( + CacheEvictionTearDownWorkflow, + WaitForeverWorkflow, + wait_forever_activity, +) from tests.helpers.external_stack_trace import ( ExternalStackTraceWorkflow, external_wait_cancel, @@ -1992,37 +1998,6 @@ def last_signal(self) -> str: return self._last_signal -class LogCapturer: - def __init__(self) -> None: - self.log_queue: queue.Queue[logging.LogRecord] = queue.Queue() - - @contextmanager - def logs_captured(self, *loggers: logging.Logger): - handler = logging.handlers.QueueHandler(self.log_queue) - - prev_levels = [l.level for l in loggers] - for l in loggers: - l.setLevel(logging.INFO) - l.addHandler(handler) - try: - yield self - finally: - for i, l in enumerate(loggers): - l.removeHandler(handler) - l.setLevel(prev_levels[i]) - - def find_log(self, starts_with: str) -> Optional[logging.LogRecord]: - return self.find(lambda l: l.message.startswith(starts_with)) - - def find( - self, pred: Callable[[logging.LogRecord], bool] - ) -> Optional[logging.LogRecord]: - for record in cast(List[logging.LogRecord], self.log_queue.queue): - if pred(record): - return record - return None - - async def test_workflow_logging(client: Client, env: WorkflowEnvironment): workflow.logger.full_workflow_info_on_extra = True with LogCapturer().logs_captured( @@ -3738,70 +3713,6 @@ async def test_manual_result_type(client: Client): assert res4 == ManualResultType(some_string="from-query") -@activity.defn -async def wait_forever_activity() -> None: - await asyncio.Future() - - -@workflow.defn -class WaitForeverWorkflow: - @workflow.run - async def run(self) -> None: - await asyncio.Future() - - -@workflow.defn -class CacheEvictionTearDownWorkflow: - def __init__(self) -> None: - self._signal_count = 0 - - @workflow.run - async def run(self) -> None: - # Start several things in background. This is just to show that eviction - # can work even with these things running. - tasks = [ - asyncio.create_task( - workflow.execute_activity( - wait_forever_activity, start_to_close_timeout=timedelta(hours=1) - ) - ), - asyncio.create_task( - workflow.execute_child_workflow(WaitForeverWorkflow.run) - ), - asyncio.create_task(asyncio.sleep(1000)), - asyncio.shield( - workflow.execute_activity( - wait_forever_activity, start_to_close_timeout=timedelta(hours=1) - ) - ), - asyncio.create_task(workflow.wait_condition(lambda: False)), - ] - gather_fut = asyncio.gather(*tasks, return_exceptions=True) - # Let's also start something in the background that we never wait on - asyncio.create_task(asyncio.sleep(1000)) - try: - # Wait for signal count to reach 2 - await asyncio.sleep(0.01) - await workflow.wait_condition(lambda: self._signal_count > 1) - finally: - # This finally, on eviction, is actually called but the command - # should be ignored - await asyncio.sleep(0.01) - await workflow.wait_condition(lambda: self._signal_count > 2) - # Cancel gather tasks and wait on them, but ignore the errors - for task in tasks: - task.cancel() - await gather_fut - - @workflow.signal - async def signal(self) -> None: - self._signal_count += 1 - - @workflow.query - def signal_count(self) -> int: - return self._signal_count - - async def test_cache_eviction_tear_down(client: Client): # This test simulates forcing eviction. This used to raise GeneratorExit on # GC which triggered the finally which could run on any thread Python