Skip to content

Commit 49cd4e7

Browse files
committed
Add metric meter support for nexus operations
1 parent dbcbc08 commit 49cd4e7

File tree

5 files changed

+212
-34
lines changed

5 files changed

+212
-34
lines changed

temporalio/nexus/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
from ._operation_context import in_operation as in_operation
1818
from ._operation_context import info as info
1919
from ._operation_context import logger as logger
20+
from ._operation_context import metric_meter as metric_meter
2021
from ._token import WorkflowHandle as WorkflowHandle

temporalio/nexus/_operation_context.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@
1212
Any,
1313
Callable,
1414
Generator,
15+
Generic,
1516
Optional,
17+
TypeVar,
1618
Union,
1719
overload,
1820
)
1921

20-
from nexusrpc.handler import CancelOperationContext, StartOperationContext
22+
from nexusrpc.handler import (
23+
CancelOperationContext,
24+
OperationContext,
25+
StartOperationContext,
26+
)
2127
from typing_extensions import Concatenate
2228

2329
import temporalio.api.common.v1
@@ -87,8 +93,13 @@ def client() -> temporalio.client.Client:
8793
return _temporal_context().client
8894

8995

96+
def metric_meter() -> temporalio.common.MetricMeter:
97+
"""Get the metric meter for the current Nexus operation."""
98+
return _temporal_context().metric_meter
99+
100+
90101
def _temporal_context() -> (
91-
Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]
102+
_TemporalStartOperationContext | _TemporalCancelOperationContext
92103
):
93104
ctx = _try_temporal_context()
94105
if ctx is None:
@@ -97,7 +108,7 @@ def _temporal_context() -> (
97108

98109

99110
def _try_temporal_context() -> (
100-
Optional[Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]]
111+
_TemporalStartOperationContext | _TemporalCancelOperationContext | None
101112
):
102113
start_ctx = _temporal_start_operation_context.get(None)
103114
cancel_ctx = _temporal_cancel_operation_context.get(None)
@@ -119,18 +130,39 @@ def _in_nexus_backing_workflow_start_context() -> bool:
119130
return _temporal_nexus_backing_workflow_start_context.get(False)
120131

121132

122-
@dataclass
123-
class _TemporalStartOperationContext:
124-
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""
133+
_OperationCtxT = TypeVar("_OperationCtxT", bound=OperationContext)
125134

126-
nexus_context: StartOperationContext
127-
"""Nexus-specific start operation context."""
135+
136+
@dataclass(kw_only=True)
137+
class _TemporalOperationCtx(Generic[_OperationCtxT]):
138+
client: temporalio.client.Client
139+
"""The Temporal client in use by the worker handling the current Nexus operation."""
128140

129141
info: Callable[[], Info]
130142
"""Temporal information about the running Nexus operation."""
131143

132-
client: temporalio.client.Client
133-
"""The Temporal client in use by the worker handling this Nexus operation."""
144+
nexus_context: _OperationCtxT
145+
"""Nexus-specific start operation context."""
146+
147+
runtime_metric_meter: temporalio.common.MetricMeter
148+
_metric_meter: temporalio.common.MetricMeter | None = None
149+
150+
@property
151+
def metric_meter(self) -> temporalio.common.MetricMeter:
152+
if not self._metric_meter:
153+
self._metric_meter = self.runtime_metric_meter.with_additional_attributes(
154+
{
155+
"nexus_service": self.nexus_context.service,
156+
"nexus_operation": self.nexus_context.operation,
157+
"task_queue": self.info().task_queue,
158+
}
159+
)
160+
return self._metric_meter
161+
162+
163+
@dataclass
164+
class _TemporalStartOperationContext(_TemporalOperationCtx[StartOperationContext]):
165+
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""
134166

135167
@classmethod
136168
def get(cls) -> _TemporalStartOperationContext:
@@ -217,6 +249,11 @@ def _from_start_operation_context(
217249
**{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)},
218250
)
219251

252+
@property
253+
def metric_meter(self) -> temporalio.common.MetricMeter:
254+
"""The metric meter"""
255+
return self._temporal_context.metric_meter
256+
220257
# Overload for no-param workflow
221258
@overload
222259
async def start_workflow(
@@ -480,19 +517,10 @@ class NexusCallback:
480517
"""Header to attach to callback request."""
481518

482519

483-
@dataclass(frozen=True)
484-
class _TemporalCancelOperationContext:
520+
@dataclass
521+
class _TemporalCancelOperationContext(_TemporalOperationCtx[CancelOperationContext]):
485522
"""Context for a Nexus cancel operation being handled by a Temporal Nexus Worker."""
486523

487-
nexus_context: CancelOperationContext
488-
"""Nexus-specific cancel operation context."""
489-
490-
info: Callable[[], Info]
491-
"""Temporal information about the running Nexus cancel operation."""
492-
493-
client: temporalio.client.Client
494-
"""The Temporal client in use by the worker handling the current Nexus operation."""
495-
496524
@classmethod
497525
def get(cls) -> _TemporalCancelOperationContext:
498526
ctx = _temporal_cancel_operation_context.get(None)

temporalio/worker/_nexus.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import concurrent.futures
7+
import contextvars
78
import json
89
import threading
910
from dataclasses import dataclass
@@ -13,15 +14,18 @@
1314
Mapping,
1415
NoReturn,
1516
Optional,
17+
ParamSpec,
1618
Sequence,
1719
Type,
20+
TypeVar,
1821
Union,
1922
)
2023

2124
import google.protobuf.json_format
2225
import nexusrpc.handler
2326
from nexusrpc import LazyValue
2427
from nexusrpc.handler import CancelOperationContext, Handler, StartOperationContext
28+
from typing_extensions import Self
2529

2630
import temporalio.api.common.v1
2731
import temporalio.api.enums.v1
@@ -66,19 +70,25 @@ def __init__(
6670
data_converter: temporalio.converter.DataConverter,
6771
interceptors: Sequence[Interceptor],
6872
metric_meter: temporalio.common.MetricMeter,
69-
executor: Optional[concurrent.futures.Executor],
73+
executor: concurrent.futures.ThreadPoolExecutor | None,
7074
) -> None:
7175
# TODO: make it possible to query task queue of bridge worker instead of passing
7276
# unused task_queue into _NexusWorker, _ActivityWorker, etc?
7377
self._bridge_worker = bridge_worker
7478
self._client = client
7579
self._task_queue = task_queue
76-
self._handler = Handler(service_handlers, executor)
80+
81+
self._metric_meter = metric_meter
82+
83+
# If an executor is provided, we wrap the executor with one that will
84+
# copy the contextvars.Context to the thread on submit
85+
handler_executor = _ContextPropagatingExecutor(executor) if executor else None
86+
87+
self._handler = Handler(service_handlers, handler_executor)
7788
self._data_converter = data_converter
7889
# TODO(nexus-preview): interceptors
7990
self._interceptors = interceptors
80-
# TODO(nexus-preview): metric_meter
81-
self._metric_meter = metric_meter
91+
8292
self._running_tasks: dict[bytes, _RunningNexusTask] = {}
8393
self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue()
8494

@@ -206,6 +216,7 @@ async def _handle_cancel_operation_task(
206216
info=lambda: Info(task_queue=self._task_queue),
207217
nexus_context=ctx,
208218
client=self._client,
219+
runtime_metric_meter=self._metric_meter,
209220
).set()
210221
try:
211222
try:
@@ -323,6 +334,7 @@ async def _start_operation(
323334
nexus_context=ctx,
324335
client=self._client,
325336
info=lambda: Info(task_queue=self._task_queue),
337+
runtime_metric_meter=self._metric_meter,
326338
).set()
327339
input = LazyValue(
328340
serializer=_DummyPayloadSerializer(
@@ -597,3 +609,25 @@ def cancel(self, reason: str) -> bool:
597609
self._thread_evt.set()
598610
self._async_evt.set()
599611
return True
612+
613+
614+
_P = ParamSpec("_P")
615+
_T = TypeVar("_T")
616+
617+
618+
class _ContextPropagatingExecutor(concurrent.futures.Executor):
619+
def __init__(self, executor: concurrent.futures.ThreadPoolExecutor) -> None:
620+
self._executor = executor
621+
622+
def submit(
623+
self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs
624+
) -> concurrent.futures.Future[_T]:
625+
ctx = contextvars.copy_context()
626+
627+
def wrapped(*a: _P.args, **k: _P.kwargs) -> _T:
628+
return ctx.run(fn, *a, **k)
629+
630+
return self._executor.submit(wrapped, *args, **kwargs)
631+
632+
def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
633+
return self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)

temporalio/worker/_worker.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
workflows: Sequence[Type] = [],
109109
activity_executor: Optional[concurrent.futures.Executor] = None,
110110
workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None,
111-
nexus_task_executor: Optional[concurrent.futures.Executor] = None,
111+
nexus_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None,
112112
workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(),
113113
unsandboxed_workflow_runner: WorkflowRunner = UnsandboxedWorkflowRunner(),
114114
plugins: Sequence[Plugin] = [],
@@ -187,8 +187,7 @@ def __init__(
187187
the worker is shut down.
188188
nexus_task_executor: Executor to use for non-async
189189
Nexus operations. This is required if any operation start methods
190-
are non-`async def`. :py:class:`concurrent.futures.ThreadPoolExecutor`
191-
is recommended.
190+
are non-`async def`.
192191
193192
.. warning::
194193
This parameter is experimental and unstable.
@@ -880,7 +879,7 @@ class WorkerConfig(TypedDict, total=False):
880879
workflows: Sequence[Type]
881880
activity_executor: Optional[concurrent.futures.Executor]
882881
workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor]
883-
nexus_task_executor: Optional[concurrent.futures.Executor]
882+
nexus_task_executor: Optional[concurrent.futures.ThreadPoolExecutor]
884883
workflow_runner: WorkflowRunner
885884
unsandboxed_workflow_runner: WorkflowRunner
886885
plugins: Sequence[Plugin]

0 commit comments

Comments
 (0)