Skip to content

Commit e6f6f91

Browse files
authored
Provide client in activity context (#740)
* Expose client as activity.client()
1 parent 1fec723 commit e6f6f91

File tree

7 files changed

+161
-25
lines changed

7 files changed

+161
-25
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,7 @@ calls in the `temporalio.activity` package make use of it. Specifically:
12571257

12581258
* `in_activity()` - Whether an activity context is present
12591259
* `info()` - Returns the immutable info of the currently running activity
1260+
* `client()` - Returns the Temporal client used by this worker. Only available in `async def` activities.
12601261
* `heartbeat(*details)` - Record a heartbeat
12611262
* `is_cancelled()` - Whether a cancellation has been requested on this activity
12621263
* `wait_for_cancelled()` - `async` call to wait for cancellation request

temporalio/activity.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from dataclasses import dataclass
2020
from datetime import datetime, timedelta
2121
from typing import (
22+
TYPE_CHECKING,
2223
Any,
2324
Callable,
2425
Iterator,
@@ -42,6 +43,9 @@
4243

4344
from .types import CallableType
4445

46+
if TYPE_CHECKING:
47+
from temporalio.client import Client
48+
4549

4650
@overload
4751
def defn(fn: CallableType) -> CallableType: ...
@@ -179,6 +183,7 @@ class _Context:
179183
temporalio.converter.PayloadConverter,
180184
]
181185
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
186+
client: Optional[Client]
182187
cancellation_details: _ActivityCancellationDetailsHolder
183188
_logger_details: Optional[Mapping[str, Any]] = None
184189
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
@@ -271,13 +276,37 @@ def wait_sync(self, timeout: Optional[float] = None) -> None:
271276
self.thread_event.wait(timeout)
272277

273278

279+
def client() -> Client:
280+
"""Return a Temporal Client for use in the current activity.
281+
282+
The client is only available in `async def` activities.
283+
284+
In tests it is not available automatically, but you can pass a client when creating a
285+
:py:class:`temporalio.testing.ActivityEnvironment`.
286+
287+
Returns:
288+
:py:class:`temporalio.client.Client` for use in the current activity.
289+
290+
Raises:
291+
RuntimeError: When the client is not available.
292+
"""
293+
client = _Context.current().client
294+
if not client:
295+
raise RuntimeError(
296+
"No client available. The client is only available in `async def` "
297+
"activities; not in `def` activities. In tests you can pass a "
298+
"client when creating ActivityEnvironment."
299+
)
300+
return client
301+
302+
274303
def in_activity() -> bool:
275304
"""Whether the current code is inside an activity.
276305
277306
Returns:
278307
True if in an activity, False otherwise.
279308
"""
280-
return not _current_context.get(None) is None
309+
return _current_context.get(None) is not None
281310

282311

283312
def info() -> Info:
@@ -574,8 +603,10 @@ def _apply_to_callable(
574603
fn=fn,
575604
# iscoroutinefunction does not return true for async __call__
576605
# TODO(cretz): Why can't MyPy handle this?
577-
is_async=inspect.iscoroutinefunction(fn)
578-
or inspect.iscoroutinefunction(fn.__call__), # type: ignore
606+
is_async=(
607+
inspect.iscoroutinefunction(fn)
608+
or inspect.iscoroutinefunction(fn.__call__) # type: ignore
609+
),
579610
no_thread_cancel_exception=no_thread_cancel_exception,
580611
),
581612
)

temporalio/testing/_activity.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import temporalio.converter
1717
import temporalio.exceptions
1818
import temporalio.worker._activity
19+
from temporalio.client import Client
1920

2021
_Params = ParamSpec("_Params")
2122
_Return = TypeVar("_Return")
@@ -63,7 +64,7 @@ class ActivityEnvironment:
6364
take effect. Default is noop.
6465
"""
6566

66-
def __init__(self) -> None:
67+
def __init__(self, client: Optional[Client] = None) -> None:
6768
"""Create an ActivityEnvironment for running activity code."""
6869
self.info = _default_info
6970
self.on_heartbeat: Callable[..., None] = lambda *args: None
@@ -74,6 +75,7 @@ def __init__(self) -> None:
7475
self._cancelled = False
7576
self._worker_shutdown = False
7677
self._activities: Set[_Activity] = set()
78+
self._client = client
7779
self._cancellation_details = (
7880
temporalio.activity._ActivityCancellationDetailsHolder()
7981
)
@@ -128,18 +130,21 @@ def run(
128130
The callable's result.
129131
"""
130132
# Create an activity and run it
131-
return _Activity(self, fn).run(*args, **kwargs)
133+
return _Activity(self, fn, self._client).run(*args, **kwargs)
132134

133135

134136
class _Activity:
135137
def __init__(
136138
self,
137139
env: ActivityEnvironment,
138140
fn: Callable,
141+
client: Optional[Client],
139142
) -> None:
140143
self.env = env
141144
self.fn = fn
142-
self.is_async = inspect.iscoroutinefunction(fn)
145+
self.is_async = inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(
146+
fn.__call__ # type: ignore
147+
)
143148
self.cancel_thread_raiser: Optional[
144149
temporalio.worker._activity._ThreadExceptionRaiser
145150
] = None
@@ -163,11 +168,14 @@ def __init__(
163168
thread_event=threading.Event(),
164169
async_event=asyncio.Event() if self.is_async else None,
165170
),
166-
shield_thread_cancel_exception=None
167-
if not self.cancel_thread_raiser
168-
else self.cancel_thread_raiser.shielded,
171+
shield_thread_cancel_exception=(
172+
None
173+
if not self.cancel_thread_raiser
174+
else self.cancel_thread_raiser.shielded
175+
),
169176
payload_converter_class_or_instance=env.payload_converter,
170177
runtime_metric_meter=env.metric_meter,
178+
client=client if self.is_async else None,
171179
cancellation_details=env._cancellation_details,
172180
)
173181
self.task: Optional[asyncio.Task] = None

temporalio/worker/_activity.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
data_converter: temporalio.converter.DataConverter,
7070
interceptors: Sequence[Interceptor],
7171
metric_meter: temporalio.common.MetricMeter,
72+
client: temporalio.client.Client,
7273
encode_headers: bool,
7374
) -> None:
7475
self._bridge_worker = bridge_worker
@@ -86,6 +87,7 @@ def __init__(
8687
None
8788
)
8889
self._seen_sync_activity = False
90+
self._client = client
8991

9092
# Validate and build activity dict
9193
self._activities: Dict[str, temporalio.activity._Definition] = {}
@@ -569,11 +571,14 @@ async def _execute_activity(
569571
heartbeat=None,
570572
cancelled_event=running_activity.cancelled_event,
571573
worker_shutdown_event=self._worker_shutdown_event,
572-
shield_thread_cancel_exception=None
573-
if not running_activity.cancel_thread_raiser
574-
else running_activity.cancel_thread_raiser.shielded,
574+
shield_thread_cancel_exception=(
575+
None
576+
if not running_activity.cancel_thread_raiser
577+
else running_activity.cancel_thread_raiser.shielded
578+
),
575579
payload_converter_class_or_instance=self._data_converter.payload_converter,
576580
runtime_metric_meter=None if sync_non_threaded else self._metric_meter,
581+
client=self._client if not running_activity.sync else None,
577582
cancellation_details=running_activity.cancellation_details,
578583
)
579584
)
@@ -679,7 +684,7 @@ def _raise_in_thread_if_pending_unlocked(self) -> None:
679684

680685

681686
class _ActivityInboundImpl(ActivityInboundInterceptor):
682-
def __init__(
687+
def __init__( # type: ignore[reportMissingSuperCall]
683688
self, worker: _ActivityWorker, running_activity: _RunningActivity
684689
) -> None:
685690
# We are intentionally not calling the base class's __init__ here
@@ -786,7 +791,7 @@ async def heartbeat_with_context(*details: Any) -> None:
786791

787792

788793
class _ActivityOutboundImpl(ActivityOutboundInterceptor):
789-
def __init__(self, worker: _ActivityWorker, info: temporalio.activity.Info) -> None:
794+
def __init__(self, worker: _ActivityWorker, info: temporalio.activity.Info) -> None: # type: ignore[reportMissingSuperCall]
790795
# We are intentionally not calling the base class's __init__ here
791796
self._worker = worker
792797
self._info = info
@@ -838,11 +843,12 @@ def _execute_sync_activity(
838843
worker_shutdown_event=temporalio.activity._CompositeEvent(
839844
thread_event=worker_shutdown_event, async_event=None
840845
),
841-
shield_thread_cancel_exception=None
842-
if not cancel_thread_raiser
843-
else cancel_thread_raiser.shielded,
846+
shield_thread_cancel_exception=(
847+
None if not cancel_thread_raiser else cancel_thread_raiser.shielded
848+
),
844849
payload_converter_class_or_instance=payload_converter_class_or_instance,
845850
runtime_metric_meter=runtime_metric_meter,
851+
client=None,
846852
cancellation_details=cancellation_details,
847853
)
848854
)

temporalio/worker/_worker.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,10 @@ def __init__(
411411
data_converter=client_config["data_converter"],
412412
interceptors=interceptors,
413413
metric_meter=self._runtime.metric_meter,
414-
encode_headers=client_config["header_codec_behavior"]
415-
== HeaderCodecBehavior.CODEC,
414+
client=client,
415+
encode_headers=(
416+
client_config["header_codec_behavior"] == HeaderCodecBehavior.CODEC
417+
),
416418
)
417419
self._nexus_worker: Optional[_NexusWorker] = None
418420
if nexus_service_handlers:
@@ -577,12 +579,12 @@ def config(self) -> WorkerConfig:
577579
@property
578580
def task_queue(self) -> str:
579581
"""Task queue this worker is on."""
580-
return self._config["task_queue"]
582+
return self._config["task_queue"] # type: ignore[reportTypedDictNotRequiredAccess]
581583

582584
@property
583585
def client(self) -> temporalio.client.Client:
584586
"""Client currently set on the worker."""
585-
return self._config["client"]
587+
return self._config["client"] # type: ignore[reportTypedDictNotRequiredAccess]
586588

587589
@client.setter
588590
def client(self, value: temporalio.client.Client) -> None:
@@ -679,9 +681,9 @@ async def raise_on_shutdown():
679681
)
680682
if exception:
681683
logger.error("Worker failed, shutting down", exc_info=exception)
682-
if self._config["on_fatal_error"]:
684+
if self._config["on_fatal_error"]: # type: ignore[reportTypedDictNotRequiredAccess]
683685
try:
684-
await self._config["on_fatal_error"](exception)
686+
await self._config["on_fatal_error"](exception) # type: ignore[reportTypedDictNotRequiredAccess]
685687
except:
686688
logger.warning("Fatal error handler failed")
687689

@@ -692,7 +694,7 @@ async def raise_on_shutdown():
692694

693695
# Cancel the shutdown task (safe if already done)
694696
tasks[None].cancel()
695-
graceful_timeout = self._config["graceful_shutdown_timeout"]
697+
graceful_timeout = self._config["graceful_shutdown_timeout"] # type: ignore[reportTypedDictNotRequiredAccess]
696698
logger.info(
697699
f"Beginning worker shutdown, will wait {graceful_timeout} before cancelling activities"
698700
)

tests/testing/test_activity.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
import threading
33
import time
44
from contextvars import copy_context
5+
from unittest.mock import Mock
6+
7+
import pytest
58

69
from temporalio import activity
10+
from temporalio.client import Client
711
from temporalio.exceptions import CancelledError
812
from temporalio.testing import ActivityEnvironment
913

@@ -122,3 +126,44 @@ async def assert_equals(a: str, b: str) -> None:
122126

123127
assert type(expected_err) == type(actual_err)
124128
assert str(expected_err) == str(actual_err)
129+
130+
131+
async def test_error_on_access_client_in_activity_environment_without_client():
132+
saw_error: bool = False
133+
134+
async def my_activity() -> None:
135+
with pytest.raises(RuntimeError, match="No client available"):
136+
activity.client()
137+
nonlocal saw_error
138+
saw_error = True
139+
140+
env = ActivityEnvironment()
141+
await env.run(my_activity)
142+
assert saw_error
143+
144+
145+
async def test_access_client_in_activity_environment_with_client():
146+
got_client: bool = False
147+
148+
async def my_activity() -> None:
149+
nonlocal got_client
150+
if activity.client():
151+
got_client = True
152+
153+
env = ActivityEnvironment(client=Mock(spec=Client))
154+
await env.run(my_activity)
155+
assert got_client
156+
157+
158+
async def test_error_on_access_client_in_sync_activity_in_environment_with_client():
159+
saw_error: bool = False
160+
161+
def my_activity() -> None:
162+
with pytest.raises(RuntimeError, match="No client available"):
163+
activity.client()
164+
nonlocal saw_error
165+
saw_error = True
166+
167+
env = ActivityEnvironment(client=Mock(spec=Client))
168+
env.run(my_activity)
169+
assert saw_error

tests/worker/test_activity.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,49 @@ async def get_name(name: str) -> str:
9494
assert result.result == "Name: my custom activity name!"
9595

9696

97+
async def test_client_available_in_async_activities(
98+
client: Client, worker: ExternalWorker
99+
):
100+
with pytest.raises(RuntimeError, match="Not in activity context"):
101+
activity.client()
102+
103+
captured_client: Optional[Client] = None
104+
105+
@activity.defn
106+
async def capture_client() -> None:
107+
nonlocal captured_client
108+
captured_client = activity.client()
109+
110+
await _execute_workflow_with_activity(client, worker, capture_client)
111+
assert captured_client is client
112+
113+
114+
async def test_client_not_available_in_sync_activities(
115+
client: Client, worker: ExternalWorker
116+
):
117+
saw_error = False
118+
119+
@activity.defn
120+
def some_activity() -> None:
121+
with pytest.raises(
122+
RuntimeError, match="The client is only available in `async def`"
123+
):
124+
activity.client()
125+
nonlocal saw_error
126+
saw_error = True
127+
128+
await _execute_workflow_with_activity(
129+
client,
130+
worker,
131+
some_activity,
132+
worker_config={
133+
"activity_executor": concurrent.futures.ThreadPoolExecutor(1),
134+
"max_concurrent_activities": 1,
135+
},
136+
)
137+
assert saw_error
138+
139+
97140
async def test_activity_info(
98141
client: Client, worker: ExternalWorker, env: WorkflowEnvironment
99142
):
@@ -612,7 +655,7 @@ async def some_activity(param1: SomeClass2, param2: str) -> str:
612655
result.result
613656
== "param1: <class 'tests.worker.test_activity.SomeClass2'>, param2: <class 'str'>"
614657
)
615-
assert activity_param1 == SomeClass2(foo="str1", bar=SomeClass1(foo=123))
658+
assert activity_param1 == SomeClass2(foo="str1", bar=SomeClass1(foo=123)) # type: ignore[reportUnboundVariable] # noqa
616659

617660

618661
async def test_activity_heartbeat_details(

0 commit comments

Comments
 (0)