Skip to content

Commit 89fb17f

Browse files
💥 Nexus MetricMeter Support (#1233)
* Add metric meter support for nexus operations * remove an unused import * Create nexus endpoint before creating the worker. Fix up buffered metrics test to properly assert both counter updates * move nexus tests out of overall metric tests b/c nexus doesn't work with time skipping * remove comment that doesn't apply * re-export metric_meter in nexus/__init__.py after bad merge. Make runtime_metric_meter private to help avoid confusion * fix up docstring in worker
1 parent 4439675 commit 89fb17f

File tree

7 files changed

+346
-68
lines changed

7 files changed

+346
-68
lines changed

‎temporalio/nexus/__init__.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
in_operation,
1717
info,
1818
logger,
19+
metric_meter,
1920
)
2021
from ._token import WorkflowHandle
2122

@@ -29,5 +30,6 @@
2930
"in_operation",
3031
"info",
3132
"logger",
33+
"metric_meter",
3234
"WorkflowHandle",
3335
)

‎temporalio/nexus/_operation_context.py‎

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@
1818
TYPE_CHECKING,
1919
Any,
2020
Concatenate,
21-
Optional,
22-
Union,
21+
Generic,
22+
TypeVar,
2323
overload,
2424
)
2525

26-
from nexusrpc.handler import CancelOperationContext, StartOperationContext
26+
from nexusrpc.handler import (
27+
CancelOperationContext,
28+
OperationContext,
29+
StartOperationContext,
30+
)
2731

2832
import temporalio.api.common.v1
2933
import temporalio.api.workflowservice.v1
@@ -97,6 +101,11 @@ def client() -> temporalio.client.Client:
97101
return _temporal_context().client
98102

99103

104+
def metric_meter() -> temporalio.common.MetricMeter:
105+
"""Get the metric meter for the current Nexus operation."""
106+
return _temporal_context().metric_meter
107+
108+
100109
def _temporal_context() -> (
101110
_TemporalStartOperationContext | _TemporalCancelOperationContext
102111
):
@@ -129,18 +138,39 @@ def _in_nexus_backing_workflow_start_context() -> bool:
129138
return _temporal_nexus_backing_workflow_start_context.get(False)
130139

131140

132-
@dataclass
133-
class _TemporalStartOperationContext:
134-
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""
141+
_OperationCtxT = TypeVar("_OperationCtxT", bound=OperationContext)
135142

136-
nexus_context: StartOperationContext
137-
"""Nexus-specific start operation context."""
143+
144+
@dataclass(kw_only=True)
145+
class _TemporalOperationCtx(Generic[_OperationCtxT]):
146+
client: temporalio.client.Client
147+
"""The Temporal client in use by the worker handling the current Nexus operation."""
138148

139149
info: Callable[[], Info]
140150
"""Temporal information about the running Nexus operation."""
141151

142-
client: temporalio.client.Client
143-
"""The Temporal client in use by the worker handling this Nexus operation."""
152+
nexus_context: _OperationCtxT
153+
"""Nexus-specific start operation context."""
154+
155+
_runtime_metric_meter: temporalio.common.MetricMeter
156+
_metric_meter: temporalio.common.MetricMeter | None = None
157+
158+
@property
159+
def metric_meter(self) -> temporalio.common.MetricMeter:
160+
if not self._metric_meter:
161+
self._metric_meter = self._runtime_metric_meter.with_additional_attributes(
162+
{
163+
"nexus_service": self.nexus_context.service,
164+
"nexus_operation": self.nexus_context.operation,
165+
"task_queue": self.info().task_queue,
166+
}
167+
)
168+
return self._metric_meter
169+
170+
171+
@dataclass
172+
class _TemporalStartOperationContext(_TemporalOperationCtx[StartOperationContext]):
173+
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""
144174

145175
@classmethod
146176
def get(cls) -> _TemporalStartOperationContext:
@@ -227,6 +257,11 @@ def _from_start_operation_context(
227257
**{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)},
228258
)
229259

260+
@property
261+
def metric_meter(self) -> temporalio.common.MetricMeter:
262+
"""The metric meter"""
263+
return self._temporal_context.metric_meter
264+
230265
# Overload for no-param workflow
231266
@overload
232267
async def start_workflow(
@@ -480,19 +515,10 @@ class NexusCallback:
480515
"""Header to attach to callback request."""
481516

482517

483-
@dataclass(frozen=True)
484-
class _TemporalCancelOperationContext:
518+
@dataclass
519+
class _TemporalCancelOperationContext(_TemporalOperationCtx[CancelOperationContext]):
485520
"""Context for a Nexus cancel operation being handled by a Temporal Nexus Worker."""
486521

487-
nexus_context: CancelOperationContext
488-
"""Nexus-specific cancel operation context."""
489-
490-
info: Callable[[], Info]
491-
"""Temporal information about the running Nexus cancel operation."""
492-
493-
client: temporalio.client.Client
494-
"""The Temporal client in use by the worker handling the current Nexus operation."""
495-
496522
@classmethod
497523
def get(cls) -> _TemporalCancelOperationContext:
498524
ctx = _temporal_cancel_operation_context.get(None)

‎temporalio/worker/_nexus.py‎

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44

55
import asyncio
66
import concurrent.futures
7+
import contextvars
78
import json
89
import threading
910
from collections.abc import Callable, Mapping, Sequence
1011
from dataclasses import dataclass
1112
from typing import (
1213
Any,
1314
NoReturn,
14-
Optional,
15-
Type,
16-
Union,
15+
ParamSpec,
16+
TypeVar,
1717
)
1818

1919
import google.protobuf.json_format
@@ -64,19 +64,25 @@ def __init__(
6464
data_converter: temporalio.converter.DataConverter,
6565
interceptors: Sequence[Interceptor],
6666
metric_meter: temporalio.common.MetricMeter,
67-
executor: concurrent.futures.Executor | None,
67+
executor: concurrent.futures.ThreadPoolExecutor | None,
6868
) -> None:
6969
# TODO: make it possible to query task queue of bridge worker instead of passing
7070
# unused task_queue into _NexusWorker, _ActivityWorker, etc?
7171
self._bridge_worker = bridge_worker
7272
self._client = client
7373
self._task_queue = task_queue
74-
self._handler = Handler(service_handlers, executor)
74+
75+
self._metric_meter = metric_meter
76+
77+
# If an executor is provided, we wrap the executor with one that will
78+
# copy the contextvars.Context to the thread on submit
79+
handler_executor = _ContextPropagatingExecutor(executor) if executor else None
80+
81+
self._handler = Handler(service_handlers, handler_executor)
7582
self._data_converter = data_converter
7683
# TODO(nexus-preview): interceptors
7784
self._interceptors = interceptors
78-
# TODO(nexus-preview): metric_meter
79-
self._metric_meter = metric_meter
85+
8086
self._running_tasks: dict[bytes, _RunningNexusTask] = {}
8187
self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue()
8288

@@ -204,6 +210,7 @@ async def _handle_cancel_operation_task(
204210
info=lambda: Info(task_queue=self._task_queue),
205211
nexus_context=ctx,
206212
client=self._client,
213+
_runtime_metric_meter=self._metric_meter,
207214
).set()
208215
try:
209216
try:
@@ -321,6 +328,7 @@ async def _start_operation(
321328
nexus_context=ctx,
322329
client=self._client,
323330
info=lambda: Info(task_queue=self._task_queue),
331+
_runtime_metric_meter=self._metric_meter,
324332
).set()
325333
input = LazyValue(
326334
serializer=_DummyPayloadSerializer(
@@ -595,3 +603,25 @@ def cancel(self, reason: str) -> bool:
595603
self._thread_evt.set()
596604
self._async_evt.set()
597605
return True
606+
607+
608+
_P = ParamSpec("_P")
609+
_T = TypeVar("_T")
610+
611+
612+
class _ContextPropagatingExecutor(concurrent.futures.Executor):
613+
def __init__(self, executor: concurrent.futures.ThreadPoolExecutor) -> None:
614+
self._executor = executor
615+
616+
def submit(
617+
self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs
618+
) -> concurrent.futures.Future[_T]:
619+
ctx = contextvars.copy_context()
620+
621+
def wrapped(*a: _P.args, **k: _P.kwargs) -> _T:
622+
return ctx.run(fn, *a, **k)
623+
624+
return self._executor.submit(wrapped, *args, **kwargs)
625+
626+
def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
627+
return self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)

‎temporalio/worker/_worker.py‎

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
workflows: Sequence[type] = [],
108108
activity_executor: concurrent.futures.Executor | None = None,
109109
workflow_task_executor: concurrent.futures.ThreadPoolExecutor | None = None,
110-
nexus_task_executor: concurrent.futures.Executor | None = None,
110+
nexus_task_executor: concurrent.futures.ThreadPoolExecutor | None = None,
111111
workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(),
112112
unsandboxed_workflow_runner: WorkflowRunner = UnsandboxedWorkflowRunner(),
113113
plugins: Sequence[Plugin] = [],
@@ -186,8 +186,7 @@ def __init__(
186186
the worker is shut down.
187187
nexus_task_executor: Executor to use for non-async
188188
Nexus operations. This is required if any operation start methods
189-
are non-``async def``. :py:class:`concurrent.futures.ThreadPoolExecutor`
190-
is recommended.
189+
are non-``async def``.
191190
192191
.. warning::
193192
This parameter is experimental and unstable.
@@ -893,7 +892,7 @@ class WorkerConfig(TypedDict, total=False):
893892
workflows: Sequence[type]
894893
activity_executor: concurrent.futures.Executor | None
895894
workflow_task_executor: concurrent.futures.ThreadPoolExecutor | None
896-
nexus_task_executor: concurrent.futures.Executor | None
895+
nexus_task_executor: concurrent.futures.ThreadPoolExecutor | None
897896
workflow_runner: WorkflowRunner
898897
unsandboxed_workflow_runner: WorkflowRunner
899898
plugins: Sequence[Plugin]

‎tests/helpers/metrics.py‎

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from collections.abc import Mapping
2+
3+
4+
class PromMetricMatcher:
5+
def __init__(self, prom_lines: list[str]) -> None:
6+
self._prom_lines = prom_lines
7+
8+
# Intentionally naive metric checker
9+
def matches_metric_line(
10+
self, line: str, name: str, at_least_labels: Mapping[str, str], value: int
11+
) -> bool:
12+
# Must have metric name
13+
if not line.startswith(name + "{"):
14+
return False
15+
# Must have labels (don't escape for this test)
16+
for k, v in at_least_labels.items():
17+
if f'{k}="{v}"' not in line:
18+
return False
19+
return line.endswith(f" {value}")
20+
21+
def assert_metric_exists(
22+
self, name: str, at_least_labels: Mapping[str, str], value: int
23+
) -> None:
24+
assert any(
25+
self.matches_metric_line(line, name, at_least_labels, value)
26+
for line in self._prom_lines
27+
)
28+
29+
def assert_description_exists(self, name: str, description: str) -> None:
30+
assert f"# HELP {name} {description}" in self._prom_lines

0 commit comments

Comments
 (0)