Skip to content

Commit 1741a6e

Browse files
committed
POC for replayer configuration from existing plugins
1 parent 62604ca commit 1741a6e

File tree

2 files changed

+66
-4
lines changed

2 files changed

+66
-4
lines changed

temporalio/worker/_replayer.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
from contextlib import asynccontextmanager
99
from dataclasses import dataclass
10-
from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type
10+
from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, Union
1111

1212
from typing_extensions import TypedDict
1313

@@ -19,9 +19,11 @@
1919
import temporalio.runtime
2020
import temporalio.workflow
2121

22+
2223
from ..common import HeaderCodecBehavior
2324
from ._interceptor import Interceptor
24-
from ._worker import load_default_build_id
25+
from ._worker import load_default_build_id, WorkerConfig
26+
from temporalio.client import ClientConfig
2527
from ._workflow import _WorkflowWorker
2628
from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner
2729
from .workflow_sandbox import SandboxedWorkflowRunner
@@ -42,6 +44,7 @@ def __init__(
4244
namespace: str = "ReplayNamespace",
4345
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
4446
interceptors: Sequence[Interceptor] = [],
47+
plugins: Sequence[Union[temporalio.worker.Plugin, temporalio.client.Plugin]] = [],
4548
build_id: Optional[str] = None,
4649
identity: Optional[str] = None,
4750
workflow_failure_exception_types: Sequence[Type[BaseException]] = [],
@@ -62,8 +65,6 @@ def __init__(
6265
will be shared across all replay calls and never explicitly shut down.
6366
Users are encouraged to provide their own if needing more control.
6467
"""
65-
if not workflows:
66-
raise ValueError("At least one workflow must be specified")
6768
self._config = ReplayerConfig(
6869
workflows=list(workflows),
6970
workflow_task_executor=(
@@ -82,6 +83,24 @@ def __init__(
8283
disable_safe_workflow_eviction=disable_safe_workflow_eviction,
8384
header_codec_behavior=header_codec_behavior,
8485
)
86+
root_worker_plugin: temporalio.worker.Plugin = temporalio.worker._worker._RootPlugin()
87+
root_client_plugin: temporalio.client.Plugin = temporalio.client._RootPlugin()
88+
for plugin in reversed(plugins):
89+
root_worker_plugin = plugin.init_worker_plugin(root_worker_plugin)
90+
root_client_plugin = plugin.init_client_plugin(root_client_plugin)
91+
92+
# Allow plugins to configure shared configurations with worker
93+
worker_config = WorkerConfig(**{k: v for k, v in self._config.items() if k in WorkerConfig.__annotations__})
94+
worker_config = root_worker_plugin.configure_worker(worker_config)
95+
self._config.update({k: v for k, v in worker_config.items() if k in ReplayerConfig.__annotations__})
96+
97+
# Allow plugins to configure shared configurations with client
98+
client_config = ClientConfig(**{k: v for k, v in self._config.items() if k in ClientConfig.__annotations__})
99+
client_config = root_client_plugin.configure_client(client_config)
100+
self._config.update({k: v for k, v in client_config.items() if k in ReplayerConfig.__annotations__})
101+
102+
if not self._config["workflows"]:
103+
raise ValueError("At least one workflow must be specified")
85104

86105
def config(self) -> ReplayerConfig:
87106
"""Config, as a dictionary, used to create this replayer.

tests/test_plugins.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
import dataclasses
2+
import uuid
23
import warnings
34
from typing import cast
45

56
import pytest
67

78
import temporalio.client
89
import temporalio.worker
10+
from temporalio import workflow
911
from temporalio.client import Client, ClientConfig, OutboundInterceptor
12+
from temporalio.contrib.pydantic import pydantic_data_converter
1013
from temporalio.testing import WorkflowEnvironment
1114
from temporalio.worker import Worker, WorkerConfig
1215
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
1316
from tests.worker.test_worker import never_run_activity
17+
from temporalio.worker import Replayer
18+
from tests.helpers import new_worker
1419

1520

1621
class TestClientInterceptor(temporalio.client.Interceptor):
@@ -136,3 +141,41 @@ async def test_worker_sandbox_restrictions(client: Client) -> None:
136141
SandboxedWorkflowRunner, worker.config().get("workflow_runner")
137142
).restrictions.passthrough_modules
138143
)
144+
145+
class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
146+
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
147+
config["workflows"] = list(config["workflows"]) + [HelloWorkflow]
148+
return super().configure_worker(config)
149+
150+
def configure_client(self, config: ClientConfig) -> ClientConfig:
151+
config["data_converter"] = pydantic_data_converter
152+
return super().configure_client(config)
153+
154+
@workflow.defn
155+
class HelloWorkflow:
156+
@workflow.run
157+
async def run(self, name: str) -> str:
158+
return f"Hello, {name}!"
159+
160+
async def test_replay(client: Client) -> None:
161+
plugin = ReplayCheckPlugin()
162+
new_config = client.config()
163+
new_config["plugins"] = [plugin]
164+
client = Client(**new_config)
165+
166+
async with new_worker(client) as worker:
167+
handle = await client.start_workflow(
168+
HelloWorkflow.run,
169+
"Tim",
170+
id=f"workflow-{uuid.uuid4()}",
171+
task_queue=worker.task_queue,
172+
)
173+
await handle.result()
174+
replayer = Replayer(
175+
workflows=[],
176+
plugins=[plugin]
177+
)
178+
assert len(replayer.config()["workflows"])==1
179+
assert replayer.config()["data_converter"] == pydantic_data_converter
180+
181+
await replayer.replay_workflow(await handle.fetch_history())

0 commit comments

Comments
 (0)