Skip to content

Commit 3c2c1bc

Browse files
committed
Expose client as activity.client()
1 parent b147771 commit 3c2c1bc

File tree

6 files changed

+134
-17
lines changed

6 files changed

+134
-17
lines changed

temporalio/activity.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from dataclasses import dataclass
2020
from datetime import datetime, timedelta
2121
from typing import (
22+
TYPE_CHECKING,
2223
Any,
2324
Callable,
2425
Iterator,
@@ -42,6 +43,9 @@
4243

4344
from .types import CallableType
4445

46+
if TYPE_CHECKING:
47+
from temporalio.client import Client
48+
4549

4650
@overload
4751
def defn(fn: CallableType) -> CallableType: ...
@@ -179,6 +183,7 @@ class _Context:
179183
temporalio.converter.PayloadConverter,
180184
]
181185
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
186+
client: Optional[Client]
182187
cancellation_details: _ActivityCancellationDetailsHolder
183188
_logger_details: Optional[Mapping[str, Any]] = None
184189
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
@@ -271,13 +276,32 @@ def wait_sync(self, timeout: Optional[float] = None) -> None:
271276
self.thread_event.wait(timeout)
272277

273278

279+
def client() -> Client:
280+
"""Return a Temporal Client for use in the current activity.
281+
282+
Returns:
283+
:py:class:`temporalio.client.Client` for use in the current activity.
284+
285+
Raises:
286+
RuntimeError: When not in an activity.
287+
"""
288+
client = _Context.current().client
289+
if not client:
290+
raise RuntimeError(
291+
"No client available. The client is only available in async "
292+
"(i.e. `async def`) activities; not in sync (i.e. `def`) activities. "
293+
"In tests you can pass a client when creating ActivityEnvironment."
294+
)
295+
return client
296+
297+
274298
def in_activity() -> bool:
275299
"""Whether the current code is inside an activity.
276300
277301
Returns:
278302
True if in an activity, False otherwise.
279303
"""
280-
return not _current_context.get(None) is None
304+
return _current_context.get(None) is not None
281305

282306

283307
def info() -> Info:
@@ -574,8 +598,10 @@ def _apply_to_callable(
574598
fn=fn,
575599
# iscoroutinefunction does not return true for async __call__
576600
# TODO(cretz): Why can't MyPy handle this?
577-
is_async=inspect.iscoroutinefunction(fn)
578-
or inspect.iscoroutinefunction(fn.__call__), # type: ignore
601+
is_async=(
602+
inspect.iscoroutinefunction(fn)
603+
or inspect.iscoroutinefunction(fn.__call__) # type: ignore
604+
),
579605
no_thread_cancel_exception=no_thread_cancel_exception,
580606
),
581607
)

temporalio/testing/_activity.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import temporalio.converter
1717
import temporalio.exceptions
1818
import temporalio.worker._activity
19+
from temporalio.client import Client
1920

2021
_Params = ParamSpec("_Params")
2122
_Return = TypeVar("_Return")
@@ -63,7 +64,7 @@ class ActivityEnvironment:
6364
take effect. Default is noop.
6465
"""
6566

66-
def __init__(self) -> None:
67+
def __init__(self, client: Optional[Client] = None) -> None:
6768
"""Create an ActivityEnvironment for running activity code."""
6869
self.info = _default_info
6970
self.on_heartbeat: Callable[..., None] = lambda *args: None
@@ -74,6 +75,7 @@ def __init__(self) -> None:
7475
self._cancelled = False
7576
self._worker_shutdown = False
7677
self._activities: Set[_Activity] = set()
78+
self._client = client
7779
self._cancellation_details = (
7880
temporalio.activity._ActivityCancellationDetailsHolder()
7981
)
@@ -128,18 +130,21 @@ def run(
128130
The callable's result.
129131
"""
130132
# Create an activity and run it
131-
return _Activity(self, fn).run(*args, **kwargs)
133+
return _Activity(self, fn, self._client).run(*args, **kwargs)
132134

133135

134136
class _Activity:
135137
def __init__(
136138
self,
137139
env: ActivityEnvironment,
138140
fn: Callable,
141+
client: Optional[Client],
139142
) -> None:
140143
self.env = env
141144
self.fn = fn
142-
self.is_async = inspect.iscoroutinefunction(fn)
145+
self.is_async = inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(
146+
fn.__call__ # type: ignore
147+
)
143148
self.cancel_thread_raiser: Optional[
144149
temporalio.worker._activity._ThreadExceptionRaiser
145150
] = None
@@ -163,11 +168,14 @@ def __init__(
163168
thread_event=threading.Event(),
164169
async_event=asyncio.Event() if self.is_async else None,
165170
),
166-
shield_thread_cancel_exception=None
167-
if not self.cancel_thread_raiser
168-
else self.cancel_thread_raiser.shielded,
171+
shield_thread_cancel_exception=(
172+
None
173+
if not self.cancel_thread_raiser
174+
else self.cancel_thread_raiser.shielded
175+
),
169176
payload_converter_class_or_instance=env.payload_converter,
170177
runtime_metric_meter=env.metric_meter,
178+
client=client if self.is_async else None,
171179
cancellation_details=env._cancellation_details,
172180
)
173181
self.task: Optional[asyncio.Task] = None

temporalio/worker/_activity.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
data_converter: temporalio.converter.DataConverter,
7070
interceptors: Sequence[Interceptor],
7171
metric_meter: temporalio.common.MetricMeter,
72+
client: temporalio.client.Client,
7273
encode_headers: bool,
7374
) -> None:
7475
self._bridge_worker = bridge_worker
@@ -86,6 +87,7 @@ def __init__(
8687
None
8788
)
8889
self._seen_sync_activity = False
90+
self._client = client
8991

9092
# Validate and build activity dict
9193
self._activities: Dict[str, temporalio.activity._Definition] = {}
@@ -569,11 +571,14 @@ async def _execute_activity(
569571
heartbeat=None,
570572
cancelled_event=running_activity.cancelled_event,
571573
worker_shutdown_event=self._worker_shutdown_event,
572-
shield_thread_cancel_exception=None
573-
if not running_activity.cancel_thread_raiser
574-
else running_activity.cancel_thread_raiser.shielded,
574+
shield_thread_cancel_exception=(
575+
None
576+
if not running_activity.cancel_thread_raiser
577+
else running_activity.cancel_thread_raiser.shielded
578+
),
575579
payload_converter_class_or_instance=self._data_converter.payload_converter,
576580
runtime_metric_meter=None if sync_non_threaded else self._metric_meter,
581+
client=self._client if not running_activity.sync else None,
577582
cancellation_details=running_activity.cancellation_details,
578583
)
579584
)
@@ -838,11 +843,12 @@ def _execute_sync_activity(
838843
worker_shutdown_event=temporalio.activity._CompositeEvent(
839844
thread_event=worker_shutdown_event, async_event=None
840845
),
841-
shield_thread_cancel_exception=None
842-
if not cancel_thread_raiser
843-
else cancel_thread_raiser.shielded,
846+
shield_thread_cancel_exception=(
847+
None if not cancel_thread_raiser else cancel_thread_raiser.shielded
848+
),
844849
payload_converter_class_or_instance=payload_converter_class_or_instance,
845850
runtime_metric_meter=runtime_metric_meter,
851+
client=None,
846852
cancellation_details=cancellation_details,
847853
)
848854
)

temporalio/worker/_worker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,10 @@ def __init__(
411411
data_converter=client_config["data_converter"],
412412
interceptors=interceptors,
413413
metric_meter=self._runtime.metric_meter,
414-
encode_headers=client_config["header_codec_behavior"]
415-
== HeaderCodecBehavior.CODEC,
414+
client=client,
415+
encode_headers=(
416+
client_config["header_codec_behavior"] == HeaderCodecBehavior.CODEC
417+
),
416418
)
417419
self._nexus_worker: Optional[_NexusWorker] = None
418420
if nexus_service_handlers:

tests/testing/test_activity.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
import threading
33
import time
44
from contextvars import copy_context
5+
from unittest.mock import Mock
6+
7+
import pytest
58

69
from temporalio import activity
10+
from temporalio.client import Client
711
from temporalio.exceptions import CancelledError
812
from temporalio.testing import ActivityEnvironment
913

@@ -122,3 +126,30 @@ async def assert_equals(a: str, b: str) -> None:
122126

123127
assert type(expected_err) == type(actual_err)
124128
assert str(expected_err) == str(actual_err)
129+
130+
131+
async def test_error_on_access_client_in_activity_environment_without_client():
132+
saw_error: bool = False
133+
134+
def my_activity() -> None:
135+
with pytest.raises(RuntimeError, match="No client available"):
136+
activity.client()
137+
nonlocal saw_error
138+
saw_error = True
139+
140+
env = ActivityEnvironment()
141+
env.run(my_activity)
142+
assert saw_error
143+
144+
145+
async def test_access_client_in_activity_environment_with_client():
146+
got_client: bool = False
147+
148+
def my_activity() -> None:
149+
nonlocal got_client
150+
if activity.client():
151+
got_client = True
152+
153+
env = ActivityEnvironment(client=Mock(spec=Client))
154+
env.run(my_activity)
155+
assert got_client

tests/worker/test_activity.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,50 @@ async def get_name(name: str) -> str:
9494
assert result.result == "Name: my custom activity name!"
9595

9696

97+
async def test_client_available_in_async_activities(
98+
client: Client, worker: ExternalWorker
99+
):
100+
with pytest.raises(RuntimeError) as err:
101+
activity.client()
102+
assert str(err.value) == "Not in activity context"
103+
104+
captured_client: Optional[Client] = None
105+
106+
@activity.defn
107+
async def capture_client() -> None:
108+
nonlocal captured_client
109+
captured_client = activity.client()
110+
111+
await _execute_workflow_with_activity(client, worker, capture_client)
112+
assert captured_client is client
113+
114+
115+
async def test_client_not_available_in_sync_activities(
116+
client: Client, worker: ExternalWorker
117+
):
118+
saw_error = False
119+
120+
@activity.defn
121+
def some_activity() -> None:
122+
with pytest.raises(
123+
RuntimeError, match="The client is only available in async"
124+
) as err:
125+
activity.client()
126+
nonlocal saw_error
127+
saw_error = True
128+
129+
await _execute_workflow_with_activity(
130+
client,
131+
worker,
132+
some_activity,
133+
worker_config={
134+
"activity_executor": concurrent.futures.ThreadPoolExecutor(1),
135+
"max_concurrent_activities": 1,
136+
},
137+
)
138+
assert saw_error
139+
140+
97141
async def test_activity_info(
98142
client: Client, worker: ExternalWorker, env: WorkflowEnvironment
99143
):

0 commit comments

Comments
 (0)