Skip to content

Commit 0f19e49

Browse files
dandavisontconley1428
authored andcommitted
Provide client in activity context (#740)
* Expose client as activity.client()
1 parent daebf39 commit 0f19e49

File tree

7 files changed

+161
-25
lines changed

7 files changed

+161
-25
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,7 @@ calls in the `temporalio.activity` package make use of it. Specifically:
12571257

12581258
* `in_activity()` - Whether an activity context is present
12591259
* `info()` - Returns the immutable info of the currently running activity
1260+
* `client()` - Returns the Temporal client used by this worker. Only available in `async def` activities.
12601261
* `heartbeat(*details)` - Record a heartbeat
12611262
* `is_cancelled()` - Whether a cancellation has been requested on this activity
12621263
* `wait_for_cancelled()` - `async` call to wait for cancellation request

temporalio/activity.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from dataclasses import dataclass
2020
from datetime import datetime, timedelta
2121
from typing import (
22+
TYPE_CHECKING,
2223
Any,
2324
Callable,
2425
Iterator,
@@ -42,6 +43,9 @@
4243

4344
from .types import CallableType
4445

46+
if TYPE_CHECKING:
47+
from temporalio.client import Client
48+
4549

4650
@overload
4751
def defn(fn: CallableType) -> CallableType: ...
@@ -181,6 +185,7 @@ class _Context:
181185
temporalio.converter.PayloadConverter,
182186
]
183187
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
188+
client: Optional[Client]
184189
cancellation_details: _ActivityCancellationDetailsHolder
185190
_logger_details: Optional[Mapping[str, Any]] = None
186191
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
@@ -273,13 +278,37 @@ def wait_sync(self, timeout: Optional[float] = None) -> None:
273278
self.thread_event.wait(timeout)
274279

275280

281+
def client() -> Client:
282+
"""Return a Temporal Client for use in the current activity.
283+
284+
The client is only available in `async def` activities.
285+
286+
In tests it is not available automatically, but you can pass a client when creating a
287+
:py:class:`temporalio.testing.ActivityEnvironment`.
288+
289+
Returns:
290+
:py:class:`temporalio.client.Client` for use in the current activity.
291+
292+
Raises:
293+
RuntimeError: When the client is not available.
294+
"""
295+
client = _Context.current().client
296+
if not client:
297+
raise RuntimeError(
298+
"No client available. The client is only available in `async def` "
299+
"activities; not in `def` activities. In tests you can pass a "
300+
"client when creating ActivityEnvironment."
301+
)
302+
return client
303+
304+
276305
def in_activity() -> bool:
277306
"""Whether the current code is inside an activity.
278307
279308
Returns:
280309
True if in an activity, False otherwise.
281310
"""
282-
return not _current_context.get(None) is None
311+
return _current_context.get(None) is not None
283312

284313

285314
def info() -> Info:
@@ -576,8 +605,10 @@ def _apply_to_callable(
576605
fn=fn,
577606
# iscoroutinefunction does not return true for async __call__
578607
# TODO(cretz): Why can't MyPy handle this?
579-
is_async=inspect.iscoroutinefunction(fn)
580-
or inspect.iscoroutinefunction(fn.__call__), # type: ignore
608+
is_async=(
609+
inspect.iscoroutinefunction(fn)
610+
or inspect.iscoroutinefunction(fn.__call__) # type: ignore
611+
),
581612
no_thread_cancel_exception=no_thread_cancel_exception,
582613
),
583614
)

temporalio/testing/_activity.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import temporalio.converter
1717
import temporalio.exceptions
1818
import temporalio.worker._activity
19+
from temporalio.client import Client
1920

2021
_Params = ParamSpec("_Params")
2122
_Return = TypeVar("_Return")
@@ -63,7 +64,7 @@ class ActivityEnvironment:
6364
take effect. Default is noop.
6465
"""
6566

66-
def __init__(self) -> None:
67+
def __init__(self, client: Optional[Client] = None) -> None:
6768
"""Create an ActivityEnvironment for running activity code."""
6869
self.info = _default_info
6970
self.on_heartbeat: Callable[..., None] = lambda *args: None
@@ -74,6 +75,7 @@ def __init__(self) -> None:
7475
self._cancelled = False
7576
self._worker_shutdown = False
7677
self._activities: Set[_Activity] = set()
78+
self._client = client
7779
self._cancellation_details = (
7880
temporalio.activity._ActivityCancellationDetailsHolder()
7981
)
@@ -128,18 +130,21 @@ def run(
128130
The callable's result.
129131
"""
130132
# Create an activity and run it
131-
return _Activity(self, fn).run(*args, **kwargs)
133+
return _Activity(self, fn, self._client).run(*args, **kwargs)
132134

133135

134136
class _Activity:
135137
def __init__(
136138
self,
137139
env: ActivityEnvironment,
138140
fn: Callable,
141+
client: Optional[Client],
139142
) -> None:
140143
self.env = env
141144
self.fn = fn
142-
self.is_async = inspect.iscoroutinefunction(fn)
145+
self.is_async = inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(
146+
fn.__call__ # type: ignore
147+
)
143148
self.cancel_thread_raiser: Optional[
144149
temporalio.worker._activity._ThreadExceptionRaiser
145150
] = None
@@ -163,11 +168,14 @@ def __init__(
163168
thread_event=threading.Event(),
164169
async_event=asyncio.Event() if self.is_async else None,
165170
),
166-
shield_thread_cancel_exception=None
167-
if not self.cancel_thread_raiser
168-
else self.cancel_thread_raiser.shielded,
171+
shield_thread_cancel_exception=(
172+
None
173+
if not self.cancel_thread_raiser
174+
else self.cancel_thread_raiser.shielded
175+
),
169176
payload_converter_class_or_instance=env.payload_converter,
170177
runtime_metric_meter=env.metric_meter,
178+
client=client if self.is_async else None,
171179
cancellation_details=env._cancellation_details,
172180
)
173181
self.task: Optional[asyncio.Task] = None

temporalio/worker/_activity.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
data_converter: temporalio.converter.DataConverter,
7070
interceptors: Sequence[Interceptor],
7171
metric_meter: temporalio.common.MetricMeter,
72+
client: temporalio.client.Client,
7273
encode_headers: bool,
7374
) -> None:
7475
self._bridge_worker = bridge_worker
@@ -86,6 +87,7 @@ def __init__(
8687
None
8788
)
8889
self._seen_sync_activity = False
90+
self._client = client
8991

9092
# Validate and build activity dict
9193
self._activities: Dict[str, temporalio.activity._Definition] = {}
@@ -571,11 +573,14 @@ async def _execute_activity(
571573
heartbeat=None,
572574
cancelled_event=running_activity.cancelled_event,
573575
worker_shutdown_event=self._worker_shutdown_event,
574-
shield_thread_cancel_exception=None
575-
if not running_activity.cancel_thread_raiser
576-
else running_activity.cancel_thread_raiser.shielded,
576+
shield_thread_cancel_exception=(
577+
None
578+
if not running_activity.cancel_thread_raiser
579+
else running_activity.cancel_thread_raiser.shielded
580+
),
577581
payload_converter_class_or_instance=self._data_converter.payload_converter,
578582
runtime_metric_meter=None if sync_non_threaded else self._metric_meter,
583+
client=self._client if not running_activity.sync else None,
579584
cancellation_details=running_activity.cancellation_details,
580585
)
581586
)
@@ -681,7 +686,7 @@ def _raise_in_thread_if_pending_unlocked(self) -> None:
681686

682687

683688
class _ActivityInboundImpl(ActivityInboundInterceptor):
684-
def __init__(
689+
def __init__( # type: ignore[reportMissingSuperCall]
685690
self, worker: _ActivityWorker, running_activity: _RunningActivity
686691
) -> None:
687692
# We are intentionally not calling the base class's __init__ here
@@ -788,7 +793,7 @@ async def heartbeat_with_context(*details: Any) -> None:
788793

789794

790795
class _ActivityOutboundImpl(ActivityOutboundInterceptor):
791-
def __init__(self, worker: _ActivityWorker, info: temporalio.activity.Info) -> None:
796+
def __init__(self, worker: _ActivityWorker, info: temporalio.activity.Info) -> None: # type: ignore[reportMissingSuperCall]
792797
# We are intentionally not calling the base class's __init__ here
793798
self._worker = worker
794799
self._info = info
@@ -840,11 +845,12 @@ def _execute_sync_activity(
840845
worker_shutdown_event=temporalio.activity._CompositeEvent(
841846
thread_event=worker_shutdown_event, async_event=None
842847
),
843-
shield_thread_cancel_exception=None
844-
if not cancel_thread_raiser
845-
else cancel_thread_raiser.shielded,
848+
shield_thread_cancel_exception=(
849+
None if not cancel_thread_raiser else cancel_thread_raiser.shielded
850+
),
846851
payload_converter_class_or_instance=payload_converter_class_or_instance,
847852
runtime_metric_meter=runtime_metric_meter,
853+
client=None,
848854
cancellation_details=cancellation_details,
849855
)
850856
)

temporalio/worker/_worker.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,10 @@ def __init__(
411411
data_converter=client_config["data_converter"],
412412
interceptors=interceptors,
413413
metric_meter=self._runtime.metric_meter,
414-
encode_headers=client_config["header_codec_behavior"]
415-
== HeaderCodecBehavior.CODEC,
414+
client=client,
415+
encode_headers=(
416+
client_config["header_codec_behavior"] == HeaderCodecBehavior.CODEC
417+
),
416418
)
417419
self._nexus_worker: Optional[_NexusWorker] = None
418420
if nexus_service_handlers:
@@ -577,12 +579,12 @@ def config(self) -> WorkerConfig:
577579
@property
578580
def task_queue(self) -> str:
579581
"""Task queue this worker is on."""
580-
return self._config["task_queue"]
582+
return self._config["task_queue"] # type: ignore[reportTypedDictNotRequiredAccess]
581583

582584
@property
583585
def client(self) -> temporalio.client.Client:
584586
"""Client currently set on the worker."""
585-
return self._config["client"]
587+
return self._config["client"] # type: ignore[reportTypedDictNotRequiredAccess]
586588

587589
@client.setter
588590
def client(self, value: temporalio.client.Client) -> None:
@@ -679,9 +681,9 @@ async def raise_on_shutdown():
679681
)
680682
if exception:
681683
logger.error("Worker failed, shutting down", exc_info=exception)
682-
if self._config["on_fatal_error"]:
684+
if self._config["on_fatal_error"]: # type: ignore[reportTypedDictNotRequiredAccess]
683685
try:
684-
await self._config["on_fatal_error"](exception)
686+
await self._config["on_fatal_error"](exception) # type: ignore[reportTypedDictNotRequiredAccess]
685687
except:
686688
logger.warning("Fatal error handler failed")
687689

@@ -692,7 +694,7 @@ async def raise_on_shutdown():
692694

693695
# Cancel the shutdown task (safe if already done)
694696
tasks[None].cancel()
695-
graceful_timeout = self._config["graceful_shutdown_timeout"]
697+
graceful_timeout = self._config["graceful_shutdown_timeout"] # type: ignore[reportTypedDictNotRequiredAccess]
696698
logger.info(
697699
f"Beginning worker shutdown, will wait {graceful_timeout} before cancelling activities"
698700
)

tests/testing/test_activity.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
import threading
33
import time
44
from contextvars import copy_context
5+
from unittest.mock import Mock
6+
7+
import pytest
58

69
from temporalio import activity
10+
from temporalio.client import Client
711
from temporalio.exceptions import CancelledError
812
from temporalio.testing import ActivityEnvironment
913

@@ -122,3 +126,44 @@ async def assert_equals(a: str, b: str) -> None:
122126

123127
assert type(expected_err) == type(actual_err)
124128
assert str(expected_err) == str(actual_err)
129+
130+
131+
async def test_error_on_access_client_in_activity_environment_without_client():
132+
saw_error: bool = False
133+
134+
async def my_activity() -> None:
135+
with pytest.raises(RuntimeError, match="No client available"):
136+
activity.client()
137+
nonlocal saw_error
138+
saw_error = True
139+
140+
env = ActivityEnvironment()
141+
await env.run(my_activity)
142+
assert saw_error
143+
144+
145+
async def test_access_client_in_activity_environment_with_client():
146+
got_client: bool = False
147+
148+
async def my_activity() -> None:
149+
nonlocal got_client
150+
if activity.client():
151+
got_client = True
152+
153+
env = ActivityEnvironment(client=Mock(spec=Client))
154+
await env.run(my_activity)
155+
assert got_client
156+
157+
158+
async def test_error_on_access_client_in_sync_activity_in_environment_with_client():
159+
saw_error: bool = False
160+
161+
def my_activity() -> None:
162+
with pytest.raises(RuntimeError, match="No client available"):
163+
activity.client()
164+
nonlocal saw_error
165+
saw_error = True
166+
167+
env = ActivityEnvironment(client=Mock(spec=Client))
168+
env.run(my_activity)
169+
assert saw_error

tests/worker/test_activity.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,49 @@ async def get_name(name: str) -> str:
9696
assert result.result == "Name: my custom activity name!"
9797

9898

99+
async def test_client_available_in_async_activities(
100+
client: Client, worker: ExternalWorker
101+
):
102+
with pytest.raises(RuntimeError, match="Not in activity context"):
103+
activity.client()
104+
105+
captured_client: Optional[Client] = None
106+
107+
@activity.defn
108+
async def capture_client() -> None:
109+
nonlocal captured_client
110+
captured_client = activity.client()
111+
112+
await _execute_workflow_with_activity(client, worker, capture_client)
113+
assert captured_client is client
114+
115+
116+
async def test_client_not_available_in_sync_activities(
117+
client: Client, worker: ExternalWorker
118+
):
119+
saw_error = False
120+
121+
@activity.defn
122+
def some_activity() -> None:
123+
with pytest.raises(
124+
RuntimeError, match="The client is only available in `async def`"
125+
):
126+
activity.client()
127+
nonlocal saw_error
128+
saw_error = True
129+
130+
await _execute_workflow_with_activity(
131+
client,
132+
worker,
133+
some_activity,
134+
worker_config={
135+
"activity_executor": concurrent.futures.ThreadPoolExecutor(1),
136+
"max_concurrent_activities": 1,
137+
},
138+
)
139+
assert saw_error
140+
141+
99142
async def test_activity_info(
100143
client: Client, worker: ExternalWorker, env: WorkflowEnvironment
101144
):
@@ -614,7 +657,7 @@ async def some_activity(param1: SomeClass2, param2: str) -> str:
614657
result.result
615658
== "param1: <class 'tests.worker.test_activity.SomeClass2'>, param2: <class 'str'>"
616659
)
617-
assert activity_param1 == SomeClass2(foo="str1", bar=SomeClass1(foo=123))
660+
assert activity_param1 == SomeClass2(foo="str1", bar=SomeClass1(foo=123)) # type: ignore[reportUnboundVariable] # noqa
618661

619662

620663
async def test_activity_heartbeat_details(

0 commit comments

Comments
 (0)