Skip to content

Commit b5db959

Browse files
committed
Some PR updates, refactoring tests, added name
1 parent 6bcd2fd commit b5db959

File tree

5 files changed

+128
-72
lines changed

5 files changed

+128
-72
lines changed

temporalio/client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,14 @@ def __init__(
228228
for plugin in reversed(list(plugins)):
229229
root_plugin = plugin.init_client_plugin(root_plugin)
230230

231-
self._config = root_plugin.on_create_client(config)
231+
self._init_from_config(root_plugin.on_create_client(config))
232+
233+
def _init_from_config(self, config: ClientConfig):
234+
self._config = config
232235

233236
# Iterate over interceptors in reverse building the impl
234237
self._impl: OutboundInterceptor = _ClientImpl(self)
235-
for interceptor in reversed(list(interceptors)):
238+
for interceptor in reversed(list(self._config["interceptors"])):
236239
self._impl = interceptor.intercept_client(self._impl)
237240

238241
def config(self) -> ClientConfig:
@@ -7388,6 +7391,9 @@ async def _decode_user_metadata(
73887391

73897392

73907393
class Plugin:
7394+
def name(self) -> str:
7395+
return type(self).__module__ + "." + type(self).__qualname__
7396+
73917397
def init_client_plugin(self, next: Plugin) -> Plugin:
73927398
self.next_client_plugin = next
73937399
return self

temporalio/worker/_worker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def _to_bridge(self) -> temporalio.bridge.worker.PollerBehavior:
9797

9898

9999
class Plugin:
100+
def name(self) -> str:
101+
return type(self).__module__ + "." + type(self).__qualname__
102+
100103
def init_worker_plugin(self, next: Plugin) -> Plugin:
101104
self.next_worker_plugin = next
102105
return self
@@ -388,6 +391,11 @@ def __init__(
388391
List[Plugin],
389392
[p for p in client.config()["plugins"] if isinstance(p, Plugin)],
390393
)
394+
for client_plugin in plugins_from_client:
395+
if type(client_plugin) in [type(p) for p in plugins]:
396+
warnings.warn(
397+
f"The same plugin type {type(client_plugin)} is present from both client and worker. It may run twice and may not be the intended behavior."
398+
)
391399
plugins = plugins_from_client + list(plugins)
392400

393401
root_plugin: Plugin = _RootPlugin()

tests/test_client.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,30 +1501,3 @@ async def test_cloud_client_simple():
15011501
GetNamespaceRequest(namespace=os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"])
15021502
)
15031503
assert os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"] == result.namespace.namespace
1504-
1505-
1506-
class MyPlugin(Plugin):
1507-
def on_create_client(self, config: ClientConfig) -> ClientConfig:
1508-
config["namespace"] = "replaced_namespace"
1509-
return super().on_create_client(config)
1510-
1511-
async def connect_service_client(
1512-
self, config: temporalio.service.ConnectConfig
1513-
) -> temporalio.service.ServiceClient:
1514-
config.api_key = "replaced key"
1515-
return await super().connect_service_client(config)
1516-
1517-
1518-
async def test_client_plugin(client: Client, env: WorkflowEnvironment):
1519-
if env.supports_time_skipping:
1520-
pytest.skip("Client connect is only designed for local")
1521-
1522-
config = client.config()
1523-
config["plugins"] = [MyPlugin()]
1524-
new_client = Client(**config)
1525-
assert new_client.namespace == "replaced_namespace"
1526-
1527-
new_client = await Client.connect(
1528-
client.service_client.config.target_host, plugins=[MyPlugin()]
1529-
)
1530-
assert new_client.service_client.config.api_key == "replaced key"

tests/test_plugins.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import warnings
2+
3+
import pytest
4+
5+
import temporalio.client
6+
import temporalio.worker
7+
from temporalio.client import Client, ClientConfig, OutboundInterceptor, Plugin
8+
from temporalio.testing import WorkflowEnvironment
9+
from temporalio.worker import Worker, WorkerConfig
10+
from tests.worker.test_worker import never_run_activity
11+
12+
13+
class TestClientInterceptor(temporalio.client.Interceptor):
14+
intercepted = False
15+
16+
def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor:
17+
self.intercepted = True
18+
return super().intercept_client(next)
19+
20+
21+
class MyClientPlugin(Plugin):
22+
def __init__(self):
23+
self.interceptor = TestClientInterceptor()
24+
25+
def on_create_client(self, config: ClientConfig) -> ClientConfig:
26+
config["namespace"] = "replaced_namespace"
27+
config["interceptors"] = list(config.get("interceptors") or []) + [
28+
self.interceptor
29+
]
30+
return super().on_create_client(config)
31+
32+
async def connect_service_client(
33+
self, config: temporalio.service.ConnectConfig
34+
) -> temporalio.service.ServiceClient:
35+
config.api_key = "replaced key"
36+
return await super().connect_service_client(config)
37+
38+
39+
async def test_client_plugin(client: Client, env: WorkflowEnvironment):
40+
if env.supports_time_skipping:
41+
pytest.skip("Client connect is only designed for local")
42+
43+
plugin = MyClientPlugin()
44+
config = client.config()
45+
config["plugins"] = [plugin]
46+
new_client = Client(**config)
47+
assert new_client.namespace == "replaced_namespace"
48+
assert plugin.interceptor.intercepted
49+
assert plugin.name() == "tests.test_plugins.MyClientPlugin"
50+
51+
new_client = await Client.connect(
52+
client.service_client.config.target_host, plugins=[MyClientPlugin()]
53+
)
54+
assert new_client.service_client.config.api_key == "replaced key"
55+
56+
57+
class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
58+
def on_create_worker(self, config: WorkerConfig) -> WorkerConfig:
59+
config["task_queue"] = "combined"
60+
return super().on_create_worker(config)
61+
62+
63+
class MyWorkerPlugin(temporalio.worker.Plugin):
64+
def on_create_worker(self, config: WorkerConfig) -> WorkerConfig:
65+
config["task_queue"] = "replaced_queue"
66+
return super().on_create_worker(config)
67+
68+
async def run_worker(self, worker: Worker) -> None:
69+
await super().run_worker(worker)
70+
71+
72+
async def test_worker_plugin_basic_config(client: Client) -> None:
73+
worker = Worker(
74+
client,
75+
task_queue="queue",
76+
activities=[never_run_activity],
77+
plugins=[MyWorkerPlugin()],
78+
)
79+
assert worker.config().get("task_queue") == "replaced_queue"
80+
81+
# Test client plugin propagation to worker plugins
82+
new_config = client.config()
83+
new_config["plugins"] = [MyCombinedPlugin()]
84+
client = Client(**new_config)
85+
worker = Worker(client, task_queue="queue", activities=[never_run_activity])
86+
assert worker.config().get("task_queue") == "combined"
87+
88+
# Test both. Client propagated plugins are called first, so the worker plugin overrides in this case
89+
worker = Worker(
90+
client,
91+
task_queue="queue",
92+
activities=[never_run_activity],
93+
plugins=[MyWorkerPlugin()],
94+
)
95+
assert worker.config().get("task_queue") == "replaced_queue"
96+
97+
98+
async def test_worker_duplicated_plugin(client: Client) -> None:
99+
new_config = client.config()
100+
new_config["plugins"] = [MyCombinedPlugin()]
101+
client = Client(**new_config)
102+
103+
with warnings.catch_warnings(record=True) as warning_list:
104+
worker = Worker(
105+
client,
106+
task_queue="queue",
107+
activities=[never_run_activity],
108+
plugins=[MyCombinedPlugin()],
109+
)
110+
111+
assert len(warning_list) == 1
112+
assert "The same plugin type" in str(warning_list[0].message)

tests/worker/test_worker.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,46 +1189,3 @@ def shutdown(self) -> None:
11891189
if self.next_exception_task:
11901190
self.next_exception_task.cancel()
11911191
setattr(self.worker._bridge_worker, self.attr, self.orig_poll_call)
1192-
1193-
1194-
class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
1195-
def on_create_worker(self, config: WorkerConfig) -> WorkerConfig:
1196-
print("Create worker combined plugin")
1197-
config["task_queue"] = "combined"
1198-
return super().on_create_worker(config)
1199-
1200-
1201-
class MyWorkerPlugin(temporalio.worker.Plugin):
1202-
def on_create_worker(self, config: WorkerConfig) -> WorkerConfig:
1203-
print("Create worker worker plugin")
1204-
config["task_queue"] = "replaced_queue"
1205-
return super().on_create_worker(config)
1206-
1207-
async def run_worker(self, worker: Worker) -> None:
1208-
await super().run_worker(worker)
1209-
1210-
1211-
async def test_worker_plugin(client: Client) -> None:
1212-
worker = Worker(
1213-
client,
1214-
task_queue="queue",
1215-
activities=[never_run_activity],
1216-
plugins=[MyWorkerPlugin()],
1217-
)
1218-
assert worker.config().get("task_queue") == "replaced_queue"
1219-
1220-
# Test client plugin propagation to worker plugins
1221-
new_config = client.config()
1222-
new_config["plugins"] = [MyCombinedPlugin()]
1223-
client = Client(**new_config)
1224-
worker = Worker(client, task_queue="queue", activities=[never_run_activity])
1225-
assert worker.config().get("task_queue") == "combined"
1226-
1227-
# Test both. Client propagated plugins are called first, so the worker plugin overrides in this case
1228-
worker = Worker(
1229-
client,
1230-
task_queue="queue",
1231-
activities=[never_run_activity],
1232-
plugins=[MyWorkerPlugin()],
1233-
)
1234-
assert worker.config().get("task_queue") == "replaced_queue"

0 commit comments

Comments
 (0)