diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index de9164716..b2897b053 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -17,4 +17,5 @@ from ._operation_context import in_operation as in_operation from ._operation_context import info as info from ._operation_context import logger as logger +from ._operation_context import metric_meter as metric_meter from ._token import WorkflowHandle as WorkflowHandle diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 098bba8a1..90cb03788 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -12,12 +12,18 @@ Any, Callable, Generator, + Generic, Optional, + TypeVar, Union, overload, ) -from nexusrpc.handler import CancelOperationContext, StartOperationContext +from nexusrpc.handler import ( + CancelOperationContext, + OperationContext, + StartOperationContext, +) from typing_extensions import Concatenate import temporalio.api.common.v1 @@ -87,8 +93,13 @@ def client() -> temporalio.client.Client: return _temporal_context().client +def metric_meter() -> temporalio.common.MetricMeter: + """Get the metric meter for the current Nexus operation.""" + return _temporal_context().metric_meter + + def _temporal_context() -> ( - Union[_TemporalStartOperationContext, _TemporalCancelOperationContext] + _TemporalStartOperationContext | _TemporalCancelOperationContext ): ctx = _try_temporal_context() if ctx is None: @@ -97,7 +108,7 @@ def _temporal_context() -> ( def _try_temporal_context() -> ( - Optional[Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]] + _TemporalStartOperationContext | _TemporalCancelOperationContext | None ): start_ctx = _temporal_start_operation_context.get(None) cancel_ctx = _temporal_cancel_operation_context.get(None) @@ -119,18 +130,39 @@ def _in_nexus_backing_workflow_start_context() -> bool: return _temporal_nexus_backing_workflow_start_context.get(False) -@dataclass -class _TemporalStartOperationContext: - """Context for a Nexus start operation being handled by a Temporal Nexus Worker.""" +_OperationCtxT = TypeVar("_OperationCtxT", bound=OperationContext) - nexus_context: StartOperationContext - """Nexus-specific start operation context.""" + +@dataclass(kw_only=True) +class _TemporalOperationCtx(Generic[_OperationCtxT]): + client: temporalio.client.Client + """The Temporal client in use by the worker handling the current Nexus operation.""" info: Callable[[], Info] """Temporal information about the running Nexus operation.""" - client: temporalio.client.Client - """The Temporal client in use by the worker handling this Nexus operation.""" + nexus_context: _OperationCtxT + """Nexus-specific start operation context.""" + + runtime_metric_meter: temporalio.common.MetricMeter + _metric_meter: temporalio.common.MetricMeter | None = None + + @property + def metric_meter(self) -> temporalio.common.MetricMeter: + if not self._metric_meter: + self._metric_meter = self.runtime_metric_meter.with_additional_attributes( + { + "nexus_service": self.nexus_context.service, + "nexus_operation": self.nexus_context.operation, + "task_queue": self.info().task_queue, + } + ) + return self._metric_meter + + +@dataclass +class _TemporalStartOperationContext(_TemporalOperationCtx[StartOperationContext]): + """Context for a Nexus start operation being handled by a Temporal Nexus Worker.""" @classmethod def get(cls) -> _TemporalStartOperationContext: @@ -218,6 +250,11 @@ def _from_start_operation_context( **{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)}, ) + @property + def metric_meter(self) -> temporalio.common.MetricMeter: + """The metric meter""" + return self._temporal_context.metric_meter + # Overload for no-param workflow @overload async def start_workflow( @@ -481,19 +518,10 @@ class NexusCallback: """Header to attach to callback request.""" -@dataclass(frozen=True) -class _TemporalCancelOperationContext: +@dataclass +class _TemporalCancelOperationContext(_TemporalOperationCtx[CancelOperationContext]): """Context for a Nexus cancel operation being handled by a Temporal Nexus Worker.""" - nexus_context: CancelOperationContext - """Nexus-specific cancel operation context.""" - - info: Callable[[], Info] - """Temporal information about the running Nexus cancel operation.""" - - client: temporalio.client.Client - """The Temporal client in use by the worker handling the current Nexus operation.""" - @classmethod def get(cls) -> _TemporalCancelOperationContext: ctx = _temporal_cancel_operation_context.get(None) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 1083cc620..d7db6291c 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -4,6 +4,7 @@ import asyncio import concurrent.futures +import contextvars import json import threading from dataclasses import dataclass @@ -13,8 +14,10 @@ Mapping, NoReturn, Optional, + ParamSpec, Sequence, Type, + TypeVar, Union, ) @@ -66,19 +69,25 @@ def __init__( data_converter: temporalio.converter.DataConverter, interceptors: Sequence[Interceptor], metric_meter: temporalio.common.MetricMeter, - executor: Optional[concurrent.futures.Executor], + executor: concurrent.futures.ThreadPoolExecutor | None, ) -> None: # TODO: make it possible to query task queue of bridge worker instead of passing # unused task_queue into _NexusWorker, _ActivityWorker, etc? self._bridge_worker = bridge_worker self._client = client self._task_queue = task_queue - self._handler = Handler(service_handlers, executor) + + self._metric_meter = metric_meter + + # If an executor is provided, we wrap the executor with one that will + # copy the contextvars.Context to the thread on submit + handler_executor = _ContextPropagatingExecutor(executor) if executor else None + + self._handler = Handler(service_handlers, handler_executor) self._data_converter = data_converter # TODO(nexus-preview): interceptors self._interceptors = interceptors - # TODO(nexus-preview): metric_meter - self._metric_meter = metric_meter + self._running_tasks: dict[bytes, _RunningNexusTask] = {} self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue() @@ -206,6 +215,7 @@ async def _handle_cancel_operation_task( info=lambda: Info(task_queue=self._task_queue), nexus_context=ctx, client=self._client, + runtime_metric_meter=self._metric_meter, ).set() try: try: @@ -323,6 +333,7 @@ async def _start_operation( nexus_context=ctx, client=self._client, info=lambda: Info(task_queue=self._task_queue), + runtime_metric_meter=self._metric_meter, ).set() input = LazyValue( serializer=_DummyPayloadSerializer( @@ -597,3 +608,25 @@ def cancel(self, reason: str) -> bool: self._thread_evt.set() self._async_evt.set() return True + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +class _ContextPropagatingExecutor(concurrent.futures.Executor): + def __init__(self, executor: concurrent.futures.ThreadPoolExecutor) -> None: + self._executor = executor + + def submit( + self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs + ) -> concurrent.futures.Future[_T]: + ctx = contextvars.copy_context() + + def wrapped(*a: _P.args, **k: _P.kwargs) -> _T: + return ctx.run(fn, *a, **k) + + return self._executor.submit(wrapped, *args, **kwargs) + + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + return self._executor.shutdown(wait=wait, cancel_futures=cancel_futures) diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 1afae2c78..0795c6e66 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -108,7 +108,7 @@ def __init__( workflows: Sequence[Type] = [], activity_executor: Optional[concurrent.futures.Executor] = None, workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None, - nexus_task_executor: Optional[concurrent.futures.Executor] = None, + nexus_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None, workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(), unsandboxed_workflow_runner: WorkflowRunner = UnsandboxedWorkflowRunner(), plugins: Sequence[Plugin] = [], @@ -187,8 +187,7 @@ def __init__( the worker is shut down. nexus_task_executor: Executor to use for non-async Nexus operations. This is required if any operation start methods - are non-`async def`. :py:class:`concurrent.futures.ThreadPoolExecutor` - is recommended. + are non-`async def`. .. warning:: This parameter is experimental and unstable. @@ -884,7 +883,7 @@ class WorkerConfig(TypedDict, total=False): workflows: Sequence[Type] activity_executor: Optional[concurrent.futures.Executor] workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] - nexus_task_executor: Optional[concurrent.futures.Executor] + nexus_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] workflow_runner: WorkflowRunner unsandboxed_workflow_runner: WorkflowRunner plugins: Sequence[Plugin] diff --git a/tests/helpers/metrics.py b/tests/helpers/metrics.py new file mode 100644 index 000000000..d5869d46b --- /dev/null +++ b/tests/helpers/metrics.py @@ -0,0 +1,30 @@ +from collections.abc import Mapping + + +class PromMetricMatcher: + def __init__(self, prom_lines: list[str]) -> None: + self._prom_lines = prom_lines + + # Intentionally naive metric checker + def matches_metric_line( + self, line: str, name: str, at_least_labels: Mapping[str, str], value: int + ) -> bool: + # Must have metric name + if not line.startswith(name + "{"): + return False + # Must have labels (don't escape for this test) + for k, v in at_least_labels.items(): + if f'{k}="{v}"' not in line: + return False + return line.endswith(f" {value}") + + def assert_metric_exists( + self, name: str, at_least_labels: Mapping[str, str], value: int + ) -> None: + assert any( + self.matches_metric_line(line, name, at_least_labels, value) + for line in self._prom_lines + ) + + def assert_description_exists(self, name: str, description: str) -> None: + assert f"# HELP {name} {description}" in self._prom_lines diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 0c0dd988a..d120d711c 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -1,10 +1,12 @@ from __future__ import annotations import asyncio +import concurrent.futures import uuid from dataclasses import dataclass from enum import IntEnum from typing import Any, Awaitable, Callable, Union +from urllib.request import urlopen import nexusrpc import nexusrpc.handler @@ -37,9 +39,18 @@ from temporalio.converter import PayloadConverter from temporalio.exceptions import ApplicationError, CancelledError, NexusOperationError from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.runtime import ( + BUFFERED_METRIC_KIND_COUNTER, + MetricBuffer, + PrometheusConfig, + Runtime, + TelemetryConfig, +) from temporalio.service import RPCError, RPCStatusCode from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker +from tests.helpers import find_free_port, new_worker +from tests.helpers.metrics import PromMetricMatcher from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name # TODO(nexus-prerelease): test availability of Temporal client etc in async context set by worker @@ -239,7 +250,7 @@ def __init__( request_cancel: bool, task_queue: str, ) -> None: - self.nexus_client: workflow.NexusClient[ServiceInterface] = ( + self.nexus_client: workflow.NexusClient[ServiceInterface | ServiceImpl] = ( workflow.create_nexus_client( service={ CallerReference.IMPL_WITH_INTERFACE: ServiceImpl, @@ -890,7 +901,7 @@ async def run( f"Invalid combination of caller_reference ({caller_reference}) and name_override ({name_override})" ) - nexus_client = workflow.create_nexus_client( + nexus_client: workflow.NexusClient[Any] = workflow.create_nexus_client( service=service_cls, endpoint=make_nexus_endpoint_name(task_queue), ) @@ -1409,3 +1420,198 @@ async def test_workflow_run_operation_overloads( if op != "no_param" else OverloadTestValue(value=0) ) + + +@nexusrpc.handler.service_handler +class CustomMetricsService: + @nexusrpc.handler.sync_operation + async def custom_metric_op( + self, ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> None: + counter = nexus.metric_meter().create_counter( + "my-operation-counter", "my-operation-description", "my-operation-unit" + ) + counter.add(12) + counter.add(30, {"my-operation-extra-attr": 12.34}) + + @nexusrpc.handler.sync_operation + def custom_metric_op_executor( + self, ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> None: + counter = nexus.metric_meter().create_counter( + "my-executor-operation-counter", + "my-executor-operation-description", + "my-executor-operation-unit", + ) + counter.add(12) + counter.add(30, {"my-executor-operation-extra-attr": 12.34}) + + +@workflow.defn +class CustomMetricsWorkflow: + @workflow.run + async def run(self, task_queue: str) -> None: + nexus_client = workflow.create_nexus_client( + service=CustomMetricsService, endpoint=make_nexus_endpoint_name(task_queue) + ) + + await nexus_client.execute_operation( + CustomMetricsService.custom_metric_op, None + ) + await nexus_client.execute_operation( + CustomMetricsService.custom_metric_op_executor, None + ) + + +async def test_workflow_caller_custom_metrics(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + + task_queue = str(uuid.uuid4()) + await create_nexus_endpoint(task_queue, client) + + # Create new runtime with Prom server + prom_addr = f"127.0.0.1:{find_free_port()}" + runtime = Runtime( + telemetry=TelemetryConfig( + metrics=PrometheusConfig(bind_address=prom_addr), metric_prefix="foo_" + ) + ) + + # New client with the runtime + client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=runtime, + ) + + async with new_worker( + client, + CustomMetricsWorkflow, + task_queue=task_queue, + nexus_service_handlers=[CustomMetricsService()], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ) as worker: + # Run workflow + await client.execute_workflow( + CustomMetricsWorkflow.run, + worker.task_queue, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Get Prom dump + with urlopen(url=f"http://{prom_addr}/metrics") as f: + prom_str: str = f.read().decode("utf-8") + prom_lines = prom_str.splitlines() + + prom_matcher = PromMetricMatcher(prom_lines) + + prom_matcher.assert_description_exists( + "my_operation_counter", "my-operation-description" + ) + prom_matcher.assert_metric_exists("my_operation_counter", {}, 12) + prom_matcher.assert_metric_exists( + "my_operation_counter", + { + "my_operation_extra_attr": "12.34", + # Also confirm some nexus operation labels + "nexus_service": CustomMetricsService.__name__, + "nexus_operation": CustomMetricsService.custom_metric_op.__name__, + "task_queue": worker.task_queue, + }, + 30, + ) + prom_matcher.assert_description_exists( + "my_executor_operation_counter", "my-executor-operation-description" + ) + prom_matcher.assert_metric_exists("my_executor_operation_counter", {}, 12) + prom_matcher.assert_metric_exists( + "my_executor_operation_counter", + { + "my_executor_operation_extra_attr": "12.34", + # Also confirm some nexus operation labels + "nexus_service": CustomMetricsService.__name__, + "nexus_operation": CustomMetricsService.custom_metric_op_executor.__name__, + "task_queue": worker.task_queue, + }, + 30, + ) + + +async def test_workflow_caller_buffered_metrics( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + + # Create runtime with metric buffer + buffer = MetricBuffer(10000) + runtime = Runtime( + telemetry=TelemetryConfig(metrics=buffer, metric_prefix="some_prefix_") + ) + + # Confirm no updates yet + assert not buffer.retrieve_updates() + + # Create a new client on the runtime and execute the custom metric workflow + client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=runtime, + ) + task_queue = str(uuid.uuid4()) + await create_nexus_endpoint(task_queue, client) + async with new_worker( + client, + CustomMetricsWorkflow, + task_queue=task_queue, + nexus_service_handlers=[CustomMetricsService()], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ) as worker: + await client.execute_workflow( + CustomMetricsWorkflow.run, + worker.task_queue, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Drain updates and confirm updates exist as expected + updates = buffer.retrieve_updates() + # Check for Nexus metrics + assert any( + update.metric.name == "my-operation-counter" + and update.metric.kind == BUFFERED_METRIC_KIND_COUNTER + and update.metric.description == "my-operation-description" + and update.attributes["nexus_service"] == CustomMetricsService.__name__ + and update.attributes["nexus_operation"] + == CustomMetricsService.custom_metric_op.__name__ + and update.attributes["task_queue"] == worker.task_queue + and "my-operation-extra-attr" not in update.attributes + and update.value == 12 + for update in updates + ) + assert any( + update.metric.name == "my-operation-counter" + and update.attributes.get("my-operation-extra-attr") == 12.34 + and update.value == 30 + for update in updates + ) + assert any( + update.metric.name == "my-executor-operation-counter" + and update.metric.description == "my-executor-operation-description" + and update.metric.kind == BUFFERED_METRIC_KIND_COUNTER + and update.attributes["nexus_service"] == CustomMetricsService.__name__ + and update.attributes["nexus_operation"] + == CustomMetricsService.custom_metric_op_executor.__name__ + and update.attributes["task_queue"] == worker.task_queue + and "my-executor-operation-extra-attr" not in update.attributes + and update.value == 12 + for update in updates + ) + assert any( + update.metric.name == "my-executor-operation-counter" + and update.attributes.get("my-executor-operation-extra-attr") == 12.34 + and update.value == 30 + for update in updates + ) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 4356bc34e..2b3d139fd 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -153,6 +153,7 @@ ExternalStackTraceWorkflow, external_wait_cancel, ) +from tests.helpers.metrics import PromMetricMatcher @workflow.defn @@ -4000,33 +4001,13 @@ async def test_workflow_custom_metrics(client: Client): prom_str: str = f.read().decode("utf-8") prom_lines = prom_str.splitlines() - # Intentionally naive metric checker - def matches_metric_line( - line: str, name: str, at_least_labels: Mapping[str, str], value: int - ) -> bool: - # Must have metric name - if not line.startswith(name + "{"): - return False - # Must have labels (don't escape for this test) - for k, v in at_least_labels.items(): - if f'{k}="{v}"' not in line: - return False - return line.endswith(f" {value}") - - def assert_metric_exists( - name: str, at_least_labels: Mapping[str, str], value: int - ) -> None: - assert any( - matches_metric_line(line, name, at_least_labels, value) - for line in prom_lines - ) - - def assert_description_exists(name: str, description: str) -> None: - assert f"# HELP {name} {description}" in prom_lines + prom_matcher = PromMetricMatcher(prom_lines) # Check some metrics are as we expect - assert_description_exists("my_runtime_gauge", "my-runtime-description") - assert_metric_exists( + prom_matcher.assert_description_exists( + "my_runtime_gauge", "my-runtime-description" + ) + prom_matcher.assert_metric_exists( "my_runtime_gauge", { "my_runtime_extra_attr1": "val1", @@ -4036,9 +4017,11 @@ def assert_description_exists(name: str, description: str) -> None: }, 90, ) - assert_description_exists("my_workflow_histogram", "my-workflow-description") - assert_metric_exists("my_workflow_histogram_sum", {}, 56) - assert_metric_exists( + prom_matcher.assert_description_exists( + "my_workflow_histogram", "my-workflow-description" + ) + prom_matcher.assert_metric_exists("my_workflow_histogram_sum", {}, 56) + prom_matcher.assert_metric_exists( "my_workflow_histogram_sum", { "my_workflow_extra_attr": "1234", @@ -4049,9 +4032,11 @@ def assert_description_exists(name: str, description: str) -> None: }, 78, ) - assert_description_exists("my_activity_counter", "my-activity-description") - assert_metric_exists("my_activity_counter", {}, 12) - assert_metric_exists( + prom_matcher.assert_description_exists( + "my_activity_counter", "my-activity-description" + ) + prom_matcher.assert_metric_exists("my_activity_counter", {}, 12) + prom_matcher.assert_metric_exists( "my_activity_counter", { "my_activity_extra_attr": "12.34", @@ -4063,7 +4048,7 @@ def assert_description_exists(name: str, description: str) -> None: 34, ) # Also check Temporal metric got its prefix - assert_metric_exists( + prom_matcher.assert_metric_exists( "foo_workflow_completed", {"workflow_type": "CustomMetricsWorkflow"}, 1 )