Skip to content

Commit 76cacf6

Browse files
committed
Fixing type checking
1 parent 3c92c40 commit 76cacf6

File tree

2 files changed

+65
-34
lines changed

2 files changed

+65
-34
lines changed

temporalio/worker/_replayer.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
from contextlib import asynccontextmanager
99
from dataclasses import dataclass
10-
from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, Union
10+
from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, Union, cast
1111

1212
from typing_extensions import TypedDict
1313

@@ -18,12 +18,11 @@
1818
import temporalio.converter
1919
import temporalio.runtime
2020
import temporalio.workflow
21-
21+
from temporalio.client import ClientConfig
2222

2323
from ..common import HeaderCodecBehavior
2424
from ._interceptor import Interceptor
25-
from ._worker import load_default_build_id, WorkerConfig
26-
from temporalio.client import ClientConfig
25+
from ._worker import WorkerConfig, load_default_build_id
2726
from ._workflow import _WorkflowWorker
2827
from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner
2928
from .workflow_sandbox import SandboxedWorkflowRunner
@@ -44,7 +43,9 @@ def __init__(
4443
namespace: str = "ReplayNamespace",
4544
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
4645
interceptors: Sequence[Interceptor] = [],
47-
plugins: Sequence[Union[temporalio.worker.Plugin, temporalio.client.Plugin]] = [],
46+
plugins: Sequence[
47+
Union[temporalio.worker.Plugin, temporalio.client.Plugin]
48+
] = [],
4849
build_id: Optional[str] = None,
4950
identity: Optional[str] = None,
5051
workflow_failure_exception_types: Sequence[Type[BaseException]] = [],
@@ -86,21 +87,59 @@ def __init__(
8687

8788
# Allow plugins to configure shared configurations with worker
8889
root_worker_plugin: temporalio.worker.Plugin = temporalio.worker._worker._RootPlugin()
89-
for plugin in reversed([plugin for plugin in plugins if isinstance(plugin, temporalio.worker.Plugin)]):
90+
for plugin in reversed(
91+
[
92+
plugin
93+
for plugin in plugins
94+
if isinstance(plugin, temporalio.worker.Plugin)
95+
]
96+
):
9097
root_worker_plugin = plugin.init_worker_plugin(root_worker_plugin)
9198

92-
worker_config = WorkerConfig(**{k: v for k, v in self._config.items() if k in WorkerConfig.__annotations__})
99+
worker_config = cast(
100+
WorkerConfig,
101+
{
102+
k: v
103+
for k, v in self._config.items()
104+
if k in WorkerConfig.__annotations__
105+
},
106+
)
107+
93108
worker_config = root_worker_plugin.configure_worker(worker_config)
94-
self._config.update({k: v for k, v in worker_config.items() if k in ReplayerConfig.__annotations__})
109+
self._config.update(
110+
cast(ReplayerConfig, {
111+
k: v
112+
for k, v in worker_config.items()
113+
if k in ReplayerConfig.__annotations__
114+
})
115+
)
95116

96117
# Allow plugins to configure shared configurations with client
97118
root_client_plugin: temporalio.client.Plugin = temporalio.client._RootPlugin()
98-
for plugin in reversed([plugin for plugin in plugins if isinstance(plugin, temporalio.client.Plugin)]):
99-
root_client_plugin = plugin.init_client_plugin(root_client_plugin)
100-
101-
client_config = ClientConfig(**{k: v for k, v in self._config.items() if k in ClientConfig.__annotations__})
119+
for client_plugin in reversed(
120+
[
121+
plugin
122+
for plugin in plugins
123+
if isinstance(plugin, temporalio.client.Plugin)
124+
]
125+
):
126+
root_client_plugin = client_plugin.init_client_plugin(root_client_plugin)
127+
128+
client_config = cast(ClientConfig,
129+
{
130+
k: v
131+
for k, v in self._config.items()
132+
if k in ClientConfig.__annotations__
133+
}
134+
)
102135
client_config = root_client_plugin.configure_client(client_config)
103-
self._config.update({k: v for k, v in client_config.items() if k in ReplayerConfig.__annotations__})
136+
self._config.update(
137+
cast(ReplayerConfig, {
138+
k: v
139+
for k, v in client_config.items()
140+
if k in ReplayerConfig.__annotations__
141+
})
142+
)
104143

105144
if not self._config["workflows"]:
106145
raise ValueError("At least one workflow must be specified")

tests/test_plugins.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
from temporalio.client import Client, ClientConfig, OutboundInterceptor
1212
from temporalio.contrib.pydantic import pydantic_data_converter
1313
from temporalio.testing import WorkflowEnvironment
14-
from temporalio.worker import Worker, WorkerConfig
14+
from temporalio.worker import Replayer, Worker, WorkerConfig
1515
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
16-
from tests.worker.test_worker import never_run_activity
17-
from temporalio.worker import Replayer
1816
from tests.helpers import new_worker
17+
from tests.worker.test_worker import never_run_activity
1918

2019

2120
class TestClientInterceptor(temporalio.client.Interceptor):
@@ -142,21 +141,24 @@ async def test_worker_sandbox_restrictions(client: Client) -> None:
142141
).restrictions.passthrough_modules
143142
)
144143

144+
145145
class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
146146
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
147-
config["workflows"] = list(config["workflows"]) + [HelloWorkflow]
147+
config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow]
148148
return super().configure_worker(config)
149149

150150
def configure_client(self, config: ClientConfig) -> ClientConfig:
151151
config["data_converter"] = pydantic_data_converter
152152
return super().configure_client(config)
153153

154+
154155
@workflow.defn
155156
class HelloWorkflow:
156157
@workflow.run
157158
async def run(self, name: str) -> str:
158159
return f"Hello, {name}!"
159160

161+
160162
async def test_replay(client: Client) -> None:
161163
plugin = ReplayCheckPlugin()
162164
new_config = client.config()
@@ -171,28 +173,18 @@ async def test_replay(client: Client) -> None:
171173
task_queue=worker.task_queue,
172174
)
173175
await handle.result()
174-
replayer = Replayer(
175-
workflows=[],
176-
plugins=[plugin]
177-
)
178-
assert len(replayer.config()["workflows"])==1
179-
assert replayer.config()["data_converter"] == pydantic_data_converter
176+
replayer = Replayer(workflows=[], plugins=[plugin])
177+
assert len(replayer.config().get("workflows") or []) == 1
178+
assert replayer.config().get("data_converter") == pydantic_data_converter
180179

181180
await replayer.replay_workflow(await handle.fetch_history())
182181

182+
replayer = Replayer(workflows=[HelloWorkflow], plugins=[MyClientPlugin()])
183+
replayer = Replayer(workflows=[HelloWorkflow], plugins=[MyWorkerPlugin()])
183184
replayer = Replayer(
184-
workflows=[HelloWorkflow],
185-
plugins=[MyClientPlugin()]
186-
)
187-
replayer = Replayer(
188-
workflows=[HelloWorkflow],
189-
plugins=[MyWorkerPlugin()]
185+
workflows=[HelloWorkflow], plugins=[MyClientPlugin(), MyWorkerPlugin()]
190186
)
191187
replayer = Replayer(
192188
workflows=[HelloWorkflow],
193-
plugins=[MyClientPlugin(), MyWorkerPlugin()]
189+
plugins=[MyWorkerPlugin(), MyClientPlugin(), MyCombinedPlugin()],
194190
)
195-
replayer = Replayer(
196-
workflows=[HelloWorkflow],
197-
plugins=[MyWorkerPlugin(), MyClientPlugin(), MyCombinedPlugin()]
198-
)

0 commit comments

Comments
 (0)