Skip to content

Commit 2a9cec3

Browse files
committed
Add static plugin constructor
1 parent a31886d commit 2a9cec3

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed

temporalio/plugin.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import abc
2+
import dataclasses
3+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
4+
from typing import Any, AsyncIterator, Callable, Optional, Sequence, Set, Type
5+
6+
import temporalio.client
7+
import temporalio.converter
8+
import temporalio.worker
9+
from temporalio.client import ClientConfig, WorkflowHistory
10+
from temporalio.worker import (
11+
Replayer,
12+
ReplayerConfig,
13+
Worker,
14+
WorkerConfig,
15+
WorkflowReplayResult,
16+
)
17+
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
18+
19+
20+
class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin, abc.ABC):
21+
pass
22+
23+
24+
def create_plugin(
25+
*,
26+
data_converter: Optional[temporalio.converter.DataConverter] = None,
27+
client_interceptors: Optional[Sequence[temporalio.client.Interceptor]] = None,
28+
activities: Optional[Sequence[Callable]] = None,
29+
nexus_service_handlers: Optional[Sequence[Any]] = None,
30+
workflows: Optional[Sequence[Type]] = None,
31+
passthrough_modules: Optional[Set[str]] = None,
32+
worker_interceptors: Optional[Sequence[temporalio.worker.Interceptor]] = None,
33+
workflow_failure_exception_types: Optional[Sequence[Type[BaseException]]] = None,
34+
run_context: Optional[AbstractAsyncContextManager[None]] = None,
35+
) -> Plugin:
36+
return _StaticPlugin(
37+
data_converter=data_converter,
38+
client_interceptors=client_interceptors,
39+
activities=activities,
40+
nexus_service_handlers=nexus_service_handlers,
41+
workflows=workflows,
42+
passthrough_modules=passthrough_modules,
43+
worker_interceptors=worker_interceptors,
44+
workflow_failure_exception_types=workflow_failure_exception_types,
45+
run_context=run_context,
46+
)
47+
48+
49+
class _StaticPlugin(Plugin):
50+
def __init__(
51+
self,
52+
*,
53+
data_converter: Optional[temporalio.converter.DataConverter] = None,
54+
client_interceptors: Optional[Sequence[temporalio.client.Interceptor]] = None,
55+
activities: Optional[Sequence[Callable]] = None,
56+
nexus_service_handlers: Optional[Sequence[Any]] = None,
57+
workflows: Optional[Sequence[Type]] = None,
58+
passthrough_modules: Optional[Set[str]] = None,
59+
worker_interceptors: Optional[Sequence[temporalio.worker.Interceptor]] = None,
60+
workflow_failure_exception_types: Optional[
61+
Sequence[Type[BaseException]]
62+
] = None,
63+
run_context: Optional[AbstractAsyncContextManager[None]] = None,
64+
) -> None:
65+
self.data_converter = data_converter
66+
self.client_interceptors = client_interceptors
67+
self.activities = activities
68+
self.nexus_service_handlers = nexus_service_handlers
69+
self.workflows = workflows
70+
self.passthrough_modules = passthrough_modules
71+
self.worker_interceptors = worker_interceptors
72+
self.workflow_failure_exception_types = workflow_failure_exception_types
73+
self.run_context = run_context
74+
75+
def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None:
76+
self.next_worker_plugin = next
77+
78+
def init_client_plugin(self, next: temporalio.client.Plugin) -> None:
79+
self.next_client_plugin = next
80+
81+
def configure_client(self, config: ClientConfig) -> ClientConfig:
82+
if self.data_converter:
83+
if not config["data_converter"] == temporalio.converter.default():
84+
raise ValueError(
85+
"Static Plugin was configured with a data converter, but the client was as well."
86+
)
87+
else:
88+
config["data_converter"] = self.data_converter
89+
90+
if self.client_interceptors:
91+
config["interceptors"] = list(config.get("interceptors", [])) + list(
92+
self.client_interceptors
93+
)
94+
95+
return self.next_client_plugin.configure_client(config)
96+
97+
async def connect_service_client(
98+
self, config: temporalio.service.ConnectConfig
99+
) -> temporalio.service.ServiceClient:
100+
return await self.next_client_plugin.connect_service_client(config)
101+
102+
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
103+
if self.activities:
104+
config["activities"] = list(config.get("activities", [])) + list(
105+
self.activities
106+
)
107+
108+
if self.nexus_service_handlers:
109+
config["nexus_service_handlers"] = list(
110+
config.get("nexus_service_handlers", [])
111+
) + list(self.nexus_service_handlers)
112+
113+
if self.workflows:
114+
config["workflows"] = list(config.get("workflows", [])) + list(
115+
self.workflows
116+
)
117+
118+
if self.passthrough_modules:
119+
runner = config.get("workflow_runner")
120+
if runner and isinstance(runner, SandboxedWorkflowRunner):
121+
config["workflow_runner"] = dataclasses.replace(
122+
runner,
123+
restrictions=runner.restrictions.with_passthrough_modules(
124+
*self.passthrough_modules
125+
),
126+
)
127+
128+
if self.worker_interceptors:
129+
config["interceptors"] = list(config.get("interceptors", [])) + list(
130+
self.worker_interceptors
131+
)
132+
133+
if self.workflow_failure_exception_types:
134+
config["workflow_failure_exception_types"] = list(
135+
config.get("workflow_failure_exception_types", [])
136+
) + list(self.workflow_failure_exception_types)
137+
138+
return config
139+
140+
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
141+
if self.data_converter:
142+
if not config["data_converter"] == temporalio.converter.default():
143+
raise ValueError(
144+
"Static Plugin was configured with a data converter, but the client was as well."
145+
)
146+
else:
147+
config["data_converter"] = self.data_converter
148+
149+
if self.workflows:
150+
config["workflows"] = list(config.get("workflows", [])) + list(
151+
self.workflows
152+
)
153+
154+
if self.passthrough_modules:
155+
runner = config.get("workflow_runner")
156+
if runner and isinstance(runner, SandboxedWorkflowRunner):
157+
config["workflow_runner"] = dataclasses.replace(
158+
runner,
159+
restrictions=runner.restrictions.with_passthrough_modules(
160+
*self.passthrough_modules
161+
),
162+
)
163+
164+
if self.worker_interceptors:
165+
config["interceptors"] = list(config.get("interceptors", [])) + list(
166+
self.worker_interceptors
167+
)
168+
169+
if self.workflow_failure_exception_types:
170+
config["workflow_failure_exception_types"] = list(
171+
config.get("workflow_failure_exception_types", [])
172+
) + list(self.workflow_failure_exception_types)
173+
174+
return config
175+
176+
async def run_worker(self, worker: Worker) -> None:
177+
if self.run_context:
178+
async with self.run_context:
179+
await self.next_worker_plugin.run_worker(worker)
180+
else:
181+
await self.next_worker_plugin.run_worker(worker)
182+
183+
@asynccontextmanager
184+
async def run_replayer(
185+
self,
186+
replayer: Replayer,
187+
histories: AsyncIterator[WorkflowHistory],
188+
) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]:
189+
if self.run_context:
190+
async with self.run_context:
191+
async with self.next_worker_plugin.run_replayer(
192+
replayer, histories
193+
) as results:
194+
yield results
195+
else:
196+
async with self.next_worker_plugin.run_replayer(
197+
replayer, histories
198+
) as results:
199+
yield results

tests/test_plugins.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from temporalio import workflow
1212
from temporalio.client import Client, ClientConfig, OutboundInterceptor, Plugin
1313
from temporalio.contrib.pydantic import pydantic_data_converter
14+
from temporalio.plugin import create_plugin
1415
from temporalio.testing import WorkflowEnvironment
1516
from temporalio.worker import (
1617
Replayer,
@@ -256,3 +257,36 @@ async def test_replay(client: Client) -> None:
256257
assert replayer.config().get("data_converter") == pydantic_data_converter
257258

258259
await replayer.replay_workflow(await handle.fetch_history())
260+
261+
async def test_static_plugins(client: Client) -> None:
262+
plugin = create_plugin(
263+
data_converter=pydantic_data_converter,
264+
workflows=[HelloWorkflow],
265+
)
266+
config = client.config()
267+
config["plugins"] = [plugin]
268+
new_client = Client(**config)
269+
270+
assert new_client.data_converter == pydantic_data_converter
271+
272+
# Test without plugin registered in client
273+
worker = Worker(
274+
client,
275+
task_queue="queue",
276+
activities=[never_run_activity],
277+
plugins=[plugin],
278+
)
279+
assert worker.config().get("workflows") == [HelloWorkflow]
280+
281+
# Test with plugin registered in client
282+
worker = Worker(
283+
new_client,
284+
task_queue="queue",
285+
activities=[never_run_activity],
286+
plugins=[plugin],
287+
)
288+
assert worker.config().get("workflows") == [HelloWorkflow]
289+
290+
replayer = Replayer(workflows=[], plugins=[plugin])
291+
assert replayer.config().get("data_converter") == pydantic_data_converter
292+
assert replayer.config().get("workflows") == [HelloWorkflow]

0 commit comments

Comments
 (0)