Skip to content

Commit f61b400

Browse files
committed
Move shared configuration into plugin definition
1 parent 76cacf6 commit f61b400

File tree

3 files changed

+108
-50
lines changed

3 files changed

+108
-50
lines changed

temporalio/worker/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ._replayer import (
2525
Replayer,
2626
ReplayerConfig,
27+
ReplayerPlugin,
2728
WorkflowReplayResult,
2829
WorkflowReplayResults,
2930
)
@@ -68,6 +69,7 @@
6869
"WorkerDeploymentVersion",
6970
"Replayer",
7071
"ReplayerConfig",
72+
"ReplayerPlugin",
7173
"WorkflowReplayResult",
7274
"WorkflowReplayResults",
7375
"PollerBehavior",

temporalio/worker/_replayer.py

Lines changed: 97 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
from contextlib import asynccontextmanager
99
from dataclasses import dataclass
10-
from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, Union, cast
10+
from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, cast
1111

1212
from typing_extensions import TypedDict
1313

@@ -18,7 +18,6 @@
1818
import temporalio.converter
1919
import temporalio.runtime
2020
import temporalio.workflow
21-
from temporalio.client import ClientConfig
2221

2322
from ..common import HeaderCodecBehavior
2423
from ._interceptor import Interceptor
@@ -30,6 +29,88 @@
3029
logger = logging.getLogger(__name__)
3130

3231

32+
class ReplayerPlugin:
33+
"""Base class for replayer plugins that can modify replayer configuration."""
34+
35+
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
36+
"""Configure the replayer.
37+
38+
Default implementation applies shared configuration from worker and client plugins.
39+
40+
Args:
41+
config: The replayer configuration to modify.
42+
43+
Returns:
44+
The modified replayer configuration.
45+
"""
46+
# If this plugin is also a worker plugin, apply shared worker config
47+
if isinstance(self, temporalio.worker.Plugin):
48+
# Create a minimal worker config with shared fields
49+
worker_config = cast(
50+
WorkerConfig,
51+
{
52+
"workflows": config["workflows"],
53+
"workflow_task_executor": config["workflow_task_executor"],
54+
"workflow_runner": config["workflow_runner"],
55+
"unsandboxed_workflow_runner": config[
56+
"unsandboxed_workflow_runner"
57+
],
58+
"interceptors": config["interceptors"],
59+
"build_id": config["build_id"],
60+
"identity": config["identity"],
61+
"workflow_failure_exception_types": config[
62+
"workflow_failure_exception_types"
63+
],
64+
"debug_mode": config["debug_mode"],
65+
"disable_safe_workflow_eviction": config[
66+
"disable_safe_workflow_eviction"
67+
],
68+
},
69+
)
70+
71+
modified_worker_config = self.configure_worker(worker_config)
72+
config["workflows"] = modified_worker_config["workflows"]
73+
config["workflow_task_executor"] = modified_worker_config[
74+
"workflow_task_executor"
75+
]
76+
config["workflow_runner"] = modified_worker_config["workflow_runner"]
77+
config["unsandboxed_workflow_runner"] = modified_worker_config[
78+
"unsandboxed_workflow_runner"
79+
]
80+
config["interceptors"] = modified_worker_config["interceptors"]
81+
config["build_id"] = modified_worker_config["build_id"]
82+
config["identity"] = modified_worker_config["identity"]
83+
config["workflow_failure_exception_types"] = modified_worker_config[
84+
"workflow_failure_exception_types"
85+
]
86+
config["debug_mode"] = modified_worker_config["debug_mode"]
87+
config["disable_safe_workflow_eviction"] = modified_worker_config[
88+
"disable_safe_workflow_eviction"
89+
]
90+
91+
# If this plugin is also a client plugin, apply shared client config
92+
if isinstance(self, temporalio.client.Plugin):
93+
# Only include fields that exist in both ReplayerConfig and ClientConfig
94+
# Note: interceptors are different types between client and worker, so excluded
95+
client_config = cast(
96+
temporalio.client.ClientConfig,
97+
{
98+
"namespace": config["namespace"],
99+
"data_converter": config["data_converter"],
100+
"header_codec_behavior": config["header_codec_behavior"],
101+
},
102+
)
103+
104+
modified_client_config = self.configure_client(client_config)
105+
config["namespace"] = modified_client_config["namespace"]
106+
config["data_converter"] = modified_client_config["data_converter"]
107+
config["header_codec_behavior"] = modified_client_config[
108+
"header_codec_behavior"
109+
]
110+
111+
return config
112+
113+
33114
class Replayer:
34115
"""Replayer to replay workflows from history."""
35116

@@ -43,9 +124,7 @@ def __init__(
43124
namespace: str = "ReplayNamespace",
44125
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
45126
interceptors: Sequence[Interceptor] = [],
46-
plugins: Sequence[
47-
Union[temporalio.worker.Plugin, temporalio.client.Plugin]
48-
] = [],
127+
plugins: Sequence[ReplayerPlugin] = [],
49128
build_id: Optional[str] = None,
50129
identity: Optional[str] = None,
51130
workflow_failure_exception_types: Sequence[Type[BaseException]] = [],
@@ -85,62 +164,35 @@ def __init__(
85164
header_codec_behavior=header_codec_behavior,
86165
)
87166

88-
# Allow plugins to configure shared configurations with worker
89-
root_worker_plugin: temporalio.worker.Plugin = temporalio.worker._worker._RootPlugin()
90-
for plugin in reversed(
167+
# Initialize all worker plugins
168+
root_worker_plugin: temporalio.worker.Plugin = (
169+
temporalio.worker._worker._RootPlugin()
170+
)
171+
for worker_plugin in reversed(
91172
[
92-
plugin
173+
cast(temporalio.worker.Plugin, plugin)
93174
for plugin in plugins
94175
if isinstance(plugin, temporalio.worker.Plugin)
95176
]
96177
):
97-
root_worker_plugin = plugin.init_worker_plugin(root_worker_plugin)
98-
99-
worker_config = cast(
100-
WorkerConfig,
101-
{
102-
k: v
103-
for k, v in self._config.items()
104-
if k in WorkerConfig.__annotations__
105-
},
106-
)
107-
108-
worker_config = root_worker_plugin.configure_worker(worker_config)
109-
self._config.update(
110-
cast(ReplayerConfig, {
111-
k: v
112-
for k, v in worker_config.items()
113-
if k in ReplayerConfig.__annotations__
114-
})
115-
)
178+
root_worker_plugin = worker_plugin.init_worker_plugin(root_worker_plugin)
116179

117-
# Allow plugins to configure shared configurations with client
180+
# Initialize all client plugins
118181
root_client_plugin: temporalio.client.Plugin = temporalio.client._RootPlugin()
119182
for client_plugin in reversed(
120183
[
121-
plugin
184+
cast(temporalio.client.Plugin, plugin)
122185
for plugin in plugins
123186
if isinstance(plugin, temporalio.client.Plugin)
124187
]
125188
):
126189
root_client_plugin = client_plugin.init_client_plugin(root_client_plugin)
127190

128-
client_config = cast(ClientConfig,
129-
{
130-
k: v
131-
for k, v in self._config.items()
132-
if k in ClientConfig.__annotations__
133-
}
134-
)
135-
client_config = root_client_plugin.configure_client(client_config)
136-
self._config.update(
137-
cast(ReplayerConfig, {
138-
k: v
139-
for k, v in client_config.items()
140-
if k in ReplayerConfig.__annotations__
141-
})
142-
)
191+
# Apply plugin configuration
192+
for plugin in plugins:
193+
self._config = plugin.configure_replayer(self._config)
143194

195+
# Validate workflows after plugin configuration
144196
if not self._config["workflows"]:
145197
raise ValueError("At least one workflow must be specified")
146198

tests/test_plugins.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from temporalio.client import Client, ClientConfig, OutboundInterceptor
1212
from temporalio.contrib.pydantic import pydantic_data_converter
1313
from temporalio.testing import WorkflowEnvironment
14-
from temporalio.worker import Replayer, Worker, WorkerConfig
14+
from temporalio.worker import Replayer, ReplayerConfig, Worker, WorkerConfig
1515
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
1616
from tests.helpers import new_worker
1717
from tests.worker.test_worker import never_run_activity
@@ -26,7 +26,7 @@ def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor:
2626
return super().intercept_client(next)
2727

2828

29-
class MyClientPlugin(temporalio.client.Plugin):
29+
class MyClientPlugin(temporalio.worker.ReplayerPlugin, temporalio.client.Plugin):
3030
def __init__(self):
3131
self.interceptor = TestClientInterceptor()
3232

@@ -62,13 +62,15 @@ async def test_client_plugin(client: Client, env: WorkflowEnvironment):
6262
assert new_client.service_client.config.api_key == "replaced key"
6363

6464

65-
class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
65+
class MyCombinedPlugin(
66+
temporalio.worker.ReplayerPlugin, temporalio.client.Plugin, temporalio.worker.Plugin
67+
):
6668
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
6769
config["task_queue"] = "combined"
6870
return super().configure_worker(config)
6971

7072

71-
class MyWorkerPlugin(temporalio.worker.Plugin):
73+
class MyWorkerPlugin(temporalio.worker.ReplayerPlugin, temporalio.worker.Plugin):
7274
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
7375
config["task_queue"] = "replaced_queue"
7476
runner = config.get("workflow_runner")
@@ -142,7 +144,9 @@ async def test_worker_sandbox_restrictions(client: Client) -> None:
142144
)
143145

144146

145-
class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
147+
class ReplayCheckPlugin(
148+
temporalio.worker.ReplayerPlugin, temporalio.client.Plugin, temporalio.worker.Plugin
149+
):
146150
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
147151
config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow]
148152
return super().configure_worker(config)

0 commit comments

Comments
 (0)