Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions temporalio/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from ._operation_context import in_operation as in_operation
from ._operation_context import info as info
from ._operation_context import logger as logger
from ._operation_context import metric_meter as metric_meter
from ._token import WorkflowHandle as WorkflowHandle
70 changes: 49 additions & 21 deletions temporalio/nexus/_operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
Any,
Callable,
Generator,
Generic,
Optional,
TypeVar,
Union,
overload,
)

from nexusrpc.handler import CancelOperationContext, StartOperationContext
from nexusrpc.handler import (
CancelOperationContext,
OperationContext,
StartOperationContext,
)
from typing_extensions import Concatenate

import temporalio.api.common.v1
Expand Down Expand Up @@ -87,8 +93,13 @@ def client() -> temporalio.client.Client:
return _temporal_context().client


def metric_meter() -> temporalio.common.MetricMeter:
"""Get the metric meter for the current Nexus operation."""
return _temporal_context().metric_meter


def _temporal_context() -> (
Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]
_TemporalStartOperationContext | _TemporalCancelOperationContext
):
ctx = _try_temporal_context()
if ctx is None:
Expand All @@ -97,7 +108,7 @@ def _temporal_context() -> (


def _try_temporal_context() -> (
Optional[Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]]
_TemporalStartOperationContext | _TemporalCancelOperationContext | None
):
start_ctx = _temporal_start_operation_context.get(None)
cancel_ctx = _temporal_cancel_operation_context.get(None)
Expand All @@ -119,18 +130,39 @@ def _in_nexus_backing_workflow_start_context() -> bool:
return _temporal_nexus_backing_workflow_start_context.get(False)


@dataclass
class _TemporalStartOperationContext:
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""
_OperationCtxT = TypeVar("_OperationCtxT", bound=OperationContext)

nexus_context: StartOperationContext
"""Nexus-specific start operation context."""

@dataclass(kw_only=True)
class _TemporalOperationCtx(Generic[_OperationCtxT]):
client: temporalio.client.Client
"""The Temporal client in use by the worker handling the current Nexus operation."""

info: Callable[[], Info]
"""Temporal information about the running Nexus operation."""

client: temporalio.client.Client
"""The Temporal client in use by the worker handling this Nexus operation."""
nexus_context: _OperationCtxT
"""Nexus-specific start operation context."""

runtime_metric_meter: temporalio.common.MetricMeter
_metric_meter: temporalio.common.MetricMeter | None = None

@property
def metric_meter(self) -> temporalio.common.MetricMeter:
if not self._metric_meter:
self._metric_meter = self.runtime_metric_meter.with_additional_attributes(
{
"nexus_service": self.nexus_context.service,
"nexus_operation": self.nexus_context.operation,
"task_queue": self.info().task_queue,
}
)
return self._metric_meter


@dataclass
class _TemporalStartOperationContext(_TemporalOperationCtx[StartOperationContext]):
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""

@classmethod
def get(cls) -> _TemporalStartOperationContext:
Expand Down Expand Up @@ -218,6 +250,11 @@ def _from_start_operation_context(
**{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)},
)

@property
def metric_meter(self) -> temporalio.common.MetricMeter:
"""The metric meter"""
return self._temporal_context.metric_meter

# Overload for no-param workflow
@overload
async def start_workflow(
Expand Down Expand Up @@ -481,19 +518,10 @@ class NexusCallback:
"""Header to attach to callback request."""


@dataclass(frozen=True)
class _TemporalCancelOperationContext:
@dataclass
class _TemporalCancelOperationContext(_TemporalOperationCtx[CancelOperationContext]):
"""Context for a Nexus cancel operation being handled by a Temporal Nexus Worker."""

nexus_context: CancelOperationContext
"""Nexus-specific cancel operation context."""

info: Callable[[], Info]
"""Temporal information about the running Nexus cancel operation."""

client: temporalio.client.Client
"""The Temporal client in use by the worker handling the current Nexus operation."""

@classmethod
def get(cls) -> _TemporalCancelOperationContext:
ctx = _temporal_cancel_operation_context.get(None)
Expand Down
41 changes: 37 additions & 4 deletions temporalio/worker/_nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import concurrent.futures
import contextvars
import json
import threading
from dataclasses import dataclass
Expand All @@ -13,8 +14,10 @@
Mapping,
NoReturn,
Optional,
ParamSpec,
Sequence,
Type,
TypeVar,
Union,
)

Expand Down Expand Up @@ -66,19 +69,25 @@ def __init__(
data_converter: temporalio.converter.DataConverter,
interceptors: Sequence[Interceptor],
metric_meter: temporalio.common.MetricMeter,
executor: Optional[concurrent.futures.Executor],
executor: concurrent.futures.ThreadPoolExecutor | None,
) -> None:
# TODO: make it possible to query task queue of bridge worker instead of passing
# unused task_queue into _NexusWorker, _ActivityWorker, etc?
self._bridge_worker = bridge_worker
self._client = client
self._task_queue = task_queue
self._handler = Handler(service_handlers, executor)

self._metric_meter = metric_meter

# If an executor is provided, we wrap the executor with one that will
# copy the contextvars.Context to the thread on submit
handler_executor = _ContextPropagatingExecutor(executor) if executor else None

self._handler = Handler(service_handlers, handler_executor)
self._data_converter = data_converter
# TODO(nexus-preview): interceptors
self._interceptors = interceptors
# TODO(nexus-preview): metric_meter
self._metric_meter = metric_meter

self._running_tasks: dict[bytes, _RunningNexusTask] = {}
self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue()

Expand Down Expand Up @@ -206,6 +215,7 @@ async def _handle_cancel_operation_task(
info=lambda: Info(task_queue=self._task_queue),
nexus_context=ctx,
client=self._client,
runtime_metric_meter=self._metric_meter,
).set()
try:
try:
Expand Down Expand Up @@ -323,6 +333,7 @@ async def _start_operation(
nexus_context=ctx,
client=self._client,
info=lambda: Info(task_queue=self._task_queue),
runtime_metric_meter=self._metric_meter,
).set()
input = LazyValue(
serializer=_DummyPayloadSerializer(
Expand Down Expand Up @@ -597,3 +608,25 @@ def cancel(self, reason: str) -> bool:
self._thread_evt.set()
self._async_evt.set()
return True


_P = ParamSpec("_P")
_T = TypeVar("_T")


class _ContextPropagatingExecutor(concurrent.futures.Executor):
def __init__(self, executor: concurrent.futures.ThreadPoolExecutor) -> None:
self._executor = executor

def submit(
self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs
) -> concurrent.futures.Future[_T]:
ctx = contextvars.copy_context()

def wrapped(*a: _P.args, **k: _P.kwargs) -> _T:
return ctx.run(fn, *a, **k)

return self._executor.submit(wrapped, *args, **kwargs)

def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
return self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
7 changes: 3 additions & 4 deletions temporalio/worker/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
workflows: Sequence[Type] = [],
activity_executor: Optional[concurrent.futures.Executor] = None,
workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None,
nexus_task_executor: Optional[concurrent.futures.Executor] = None,
nexus_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None,
workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(),
unsandboxed_workflow_runner: WorkflowRunner = UnsandboxedWorkflowRunner(),
plugins: Sequence[Plugin] = [],
Expand Down Expand Up @@ -187,8 +187,7 @@ def __init__(
the worker is shut down.
nexus_task_executor: Executor to use for non-async
Nexus operations. This is required if any operation start methods
are non-`async def`. :py:class:`concurrent.futures.ThreadPoolExecutor`
is recommended.
are non-`async def`.

.. warning::
This parameter is experimental and unstable.
Expand Down Expand Up @@ -884,7 +883,7 @@ class WorkerConfig(TypedDict, total=False):
workflows: Sequence[Type]
activity_executor: Optional[concurrent.futures.Executor]
workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor]
nexus_task_executor: Optional[concurrent.futures.Executor]
nexus_task_executor: Optional[concurrent.futures.ThreadPoolExecutor]
workflow_runner: WorkflowRunner
unsandboxed_workflow_runner: WorkflowRunner
plugins: Sequence[Plugin]
Expand Down
30 changes: 30 additions & 0 deletions tests/helpers/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from collections.abc import Mapping


class PromMetricMatcher:
def __init__(self, prom_lines: list[str]) -> None:
self._prom_lines = prom_lines

# Intentionally naive metric checker
def matches_metric_line(
self, line: str, name: str, at_least_labels: Mapping[str, str], value: int
) -> bool:
# Must have metric name
if not line.startswith(name + "{"):
return False
# Must have labels (don't escape for this test)
for k, v in at_least_labels.items():
if f'{k}="{v}"' not in line:
return False
return line.endswith(f" {value}")

def assert_metric_exists(
self, name: str, at_least_labels: Mapping[str, str], value: int
) -> None:
assert any(
self.matches_metric_line(line, name, at_least_labels, value)
for line in self._prom_lines
)

def assert_description_exists(self, name: str, description: str) -> None:
assert f"# HELP {name} {description}" in self._prom_lines
Loading