Skip to content

Commit 27096aa

Browse files
authored
Store initial configuration and provide it in config() by default (#1226)
* Store initial configuration and provide it in config() by default * Include replayer * Use active config in a few places * Kwarg
1 parent dbcbc08 commit 27096aa

File tree

4 files changed

+54
-25
lines changed

4 files changed

+54
-25
lines changed

temporalio/client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def __init__(
247247
default_workflow_query_reject_condition=default_workflow_query_reject_condition,
248248
header_codec_behavior=header_codec_behavior,
249249
)
250+
self._initial_config = config.copy()
250251

251252
for plugin in plugins:
252253
config = plugin.configure_client(config)
@@ -261,12 +262,16 @@ def _init_from_config(self, config: ClientConfig):
261262
for interceptor in reversed(list(self._config["interceptors"])):
262263
self._impl = interceptor.intercept_client(self._impl)
263264

264-
def config(self) -> ClientConfig:
265+
def config(self, *, active_config: bool = False) -> ClientConfig:
265266
"""Config, as a dictionary, used to create this client.
266267
268+
Args:
269+
active_config: If true, return the modified configuration in use rather than the initial one
270+
provided to the client.
271+
267272
This makes a shallow copy of the config each call.
268273
"""
269-
config = self._config.copy()
274+
config = self._config.copy() if active_config else self._initial_config.copy()
270275
config["interceptors"] = list(config["interceptors"])
271276
return config
272277

@@ -4290,7 +4295,8 @@ async def _to_proto(
42904295
await _apply_headers(
42914296
self.headers,
42924297
action.start_workflow.header.fields,
4293-
client.config()["header_codec_behavior"] == HeaderCodecBehavior.CODEC
4298+
client.config(active_config=True)["header_codec_behavior"]
4299+
== HeaderCodecBehavior.CODEC
42944300
and not self._from_raw,
42954301
client.data_converter.payload_codec,
42964302
)
@@ -6917,7 +6923,8 @@ async def _apply_headers(
69176923
await _apply_headers(
69186924
source,
69196925
dest,
6920-
self._client.config()["header_codec_behavior"] == HeaderCodecBehavior.CODEC,
6926+
self._client.config(active_config=True)["header_codec_behavior"]
6927+
== HeaderCodecBehavior.CODEC,
69216928
self._client.data_converter.payload_codec,
69226929
)
69236930

temporalio/worker/_replayer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
disable_safe_workflow_eviction=disable_safe_workflow_eviction,
8282
header_codec_behavior=header_codec_behavior,
8383
)
84+
self._initial_config = self._config.copy()
8485

8586
# Apply plugin configuration
8687
self.plugins = plugins
@@ -91,13 +92,17 @@ def __init__(
9192
if not self._config["workflows"]:
9293
raise ValueError("At least one workflow must be specified")
9394

94-
def config(self) -> ReplayerConfig:
95+
def config(self, *, active_config: bool = False) -> ReplayerConfig:
9596
"""Config, as a dictionary, used to create this replayer.
9697
98+
Args:
99+
active_config: If true, return the modified configuration in use rather than the initial one
100+
provided to the client.
101+
97102
Returns:
98103
Configuration, shallow-copied.
99104
"""
100-
config = self._config.copy()
105+
config = self._config.copy() if active_config else self._initial_config.copy()
101106
config["workflows"] = list(config["workflows"])
102107
return config
103108

temporalio/worker/_worker.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,9 @@ def __init__(
376376
f"The same plugin type {type(client_plugin)} is present from both client and worker. It may run twice and may not be the intended behavior."
377377
)
378378
plugins = plugins_from_client + list(plugins)
379+
self._initial_config = config.copy()
379380

380-
self.plugins = plugins
381+
self._plugins = plugins
381382
for plugin in plugins:
382383
config = plugin.configure_worker(config)
383384

@@ -388,7 +389,6 @@ def _init_from_config(self, client: temporalio.client.Client, config: WorkerConf
388389
Client is safe to take separately since it can't be modified by worker plugins.
389390
"""
390391
self._config = config
391-
392392
if not (
393393
config["activities"]
394394
or config["nexus_service_handlers"]
@@ -409,7 +409,7 @@ def _init_from_config(self, client: temporalio.client.Client, config: WorkerConf
409409
)
410410

411411
# Prepend applicable client interceptors to the given ones
412-
client_config = config["client"].config()
412+
client_config = config["client"].config(active_config=True)
413413
interceptors_from_client = cast(
414414
List[Interceptor],
415415
[i for i in client_config["interceptors"] if isinstance(i, Interceptor)],
@@ -555,7 +555,7 @@ def check_activity(activity):
555555
maximum=config["max_concurrent_activity_task_polls"]
556556
)
557557

558-
deduped_plugin_names = list(set([plugin.name() for plugin in self.plugins]))
558+
deduped_plugin_names = list(set([plugin.name() for plugin in self._plugins]))
559559

560560
# Create bridge worker last. We have empirically observed that if it is
561561
# created before an error is raised from the activity worker
@@ -623,13 +623,17 @@ def check_activity(activity):
623623
),
624624
)
625625

626-
def config(self) -> WorkerConfig:
626+
def config(self, *, active_config: bool = False) -> WorkerConfig:
627627
"""Config, as a dictionary, used to create this worker.
628628
629+
Args:
630+
active_config: If true, return the modified configuration in use rather than the initial one
631+
provided to the worker.
632+
629633
Returns:
630634
Configuration, shallow-copied.
631635
"""
632-
config = self._config.copy()
636+
config = self._config.copy() if active_config else self._initial_config.copy()
633637
config["activities"] = list(config.get("activities", []))
634638
config["workflows"] = list(config.get("workflows", []))
635639
return config
@@ -703,7 +707,7 @@ def make_lambda(plugin, next):
703707
return lambda w: plugin.run_worker(w, next)
704708

705709
next_function = lambda w: w._run()
706-
for plugin in reversed(self.plugins):
710+
for plugin in reversed(self._plugins):
707711
next_function = make_lambda(plugin, next_function)
708712

709713
await next_function(self)

tests/test_plugins.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ async def test_worker_plugin_basic_config(client: Client) -> None:
150150
activities=[never_run_activity],
151151
plugins=[MyWorkerPlugin()],
152152
)
153-
task_queue = worker.config().get("task_queue")
153+
task_queue = worker.config(active_config=True).get("task_queue")
154154
assert task_queue is not None and task_queue.startswith("replaced_queue")
155155

156156
# Test client plugin propagation to worker plugins
@@ -160,7 +160,7 @@ async def test_worker_plugin_basic_config(client: Client) -> None:
160160
worker = Worker(
161161
client, task_queue="queue" + str(uuid.uuid4()), activities=[never_run_activity]
162162
)
163-
task_queue = worker.config().get("task_queue")
163+
task_queue = worker.config(active_config=True).get("task_queue")
164164
assert task_queue is not None and task_queue.startswith("combined")
165165

166166
# Test both. Client propagated plugins are called first, so the worker plugin overrides in this case
@@ -170,7 +170,7 @@ async def test_worker_plugin_basic_config(client: Client) -> None:
170170
activities=[never_run_activity],
171171
plugins=[MyWorkerPlugin()],
172172
)
173-
task_queue = worker.config().get("task_queue")
173+
task_queue = worker.config(active_config=True).get("task_queue")
174174
assert task_queue is not None and task_queue.startswith("replaced_queue")
175175

176176

@@ -202,7 +202,8 @@ async def test_worker_sandbox_restrictions(client: Client) -> None:
202202
assert (
203203
"my_module"
204204
in cast(
205-
SandboxedWorkflowRunner, worker.config().get("workflow_runner")
205+
SandboxedWorkflowRunner,
206+
worker.config(active_config=True).get("workflow_runner"),
206207
).restrictions.passthrough_modules
207208
)
208209

@@ -276,8 +277,11 @@ async def test_replay(client: Client) -> None:
276277
)
277278
await handle.result()
278279
replayer = Replayer(workflows=[], plugins=[plugin])
279-
assert len(replayer.config().get("workflows") or []) == 1
280-
assert replayer.config().get("data_converter") == pydantic_data_converter
280+
assert len(replayer.config(active_config=True).get("workflows") or []) == 1
281+
assert (
282+
replayer.config(active_config=True).get("data_converter")
283+
== pydantic_data_converter
284+
)
281285

282286
await replayer.replay_workflow(await handle.fetch_history())
283287

@@ -303,19 +307,28 @@ async def test_simple_plugins(client: Client) -> None:
303307
plugins=[plugin],
304308
)
305309
# On a sequence, a value is appended
306-
assert worker.config().get("workflows") == [HelloWorkflow, HelloWorkflow2]
310+
assert worker.config(active_config=True).get("workflows") == [
311+
HelloWorkflow,
312+
HelloWorkflow2,
313+
]
307314

308315
# Test with plugin registered in client
309316
worker = Worker(
310317
new_client,
311318
task_queue="queue" + str(uuid.uuid4()),
312319
activities=[never_run_activity],
313320
)
314-
assert worker.config().get("workflows") == [HelloWorkflow2]
321+
assert worker.config(active_config=True).get("workflows") == [HelloWorkflow2]
315322

316323
replayer = Replayer(workflows=[HelloWorkflow], plugins=[plugin])
317-
assert replayer.config().get("data_converter") == pydantic_data_converter
318-
assert replayer.config().get("workflows") == [HelloWorkflow, HelloWorkflow2]
324+
assert (
325+
replayer.config(active_config=True).get("data_converter")
326+
== pydantic_data_converter
327+
)
328+
assert replayer.config(active_config=True).get("workflows") == [
329+
HelloWorkflow,
330+
HelloWorkflow2,
331+
]
319332

320333

321334
async def test_simple_plugins_callables(client: Client) -> None:
@@ -350,7 +363,7 @@ def converter(old: Optional[DataConverter]):
350363
activities=[never_run_activity],
351364
plugins=[plugin],
352365
)
353-
assert worker.config().get("workflows") == []
366+
assert worker.config(active_config=True).get("workflows") == []
354367

355368

356369
class MediumPlugin(SimplePlugin):
@@ -371,5 +384,5 @@ async def test_medium_plugin(client: Client) -> None:
371384
plugins=[plugin],
372385
workflows=[HelloWorkflow],
373386
)
374-
task_queue = worker.config().get("task_queue")
387+
task_queue = worker.config(active_config=True).get("task_queue")
375388
assert task_queue is not None and task_queue.startswith("override")

0 commit comments

Comments
 (0)