Skip to content

Commit d529deb

Browse files
committed
Extend to Tuners
1 parent 480bb2f commit d529deb

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

tests/nexus/test_worker.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import asyncio
66
import uuid
77
from datetime import timedelta
8+
from typing import Optional
89

910
import nexusrpc.handler
1011
import pytest
1112

1213
from temporalio import workflow
1314
from temporalio.client import Client
15+
from temporalio.worker import FixedSizeSlotSupplier, WorkerTuner
1416
from tests.helpers import new_worker
1517
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
1618

@@ -43,6 +45,77 @@ async def test_max_concurrent_nexus_tasks(
4345
num_nexus_operations: int,
4446
expected_num_executed: int,
4547
):
48+
"""Test max_concurrent_nexus_tasks parameter."""
49+
await _test_nexus_concurrency_helper(
50+
client=client,
51+
num_nexus_operations=num_nexus_operations,
52+
expected_num_executed=expected_num_executed,
53+
max_concurrent_nexus_tasks=max_concurrent_nexus_tasks,
54+
)
55+
56+
57+
@pytest.mark.parametrize(
58+
["num_nexus_operations", "nexus_slots", "expected_num_executed"],
59+
[(1, 1, 1), (2, 1, 1), (43, 42, 42), (43, 44, 43)],
60+
)
61+
async def test_max_concurrent_nexus_tasks_with_tuner(
62+
client: Client,
63+
nexus_slots: int,
64+
num_nexus_operations: int,
65+
expected_num_executed: int,
66+
):
67+
"""Test nexus concurrency using a WorkerTuner."""
68+
tuner = WorkerTuner.create_fixed(
69+
workflow_slots=10,
70+
activity_slots=10,
71+
local_activity_slots=10,
72+
nexus_slots=nexus_slots,
73+
)
74+
await _test_nexus_concurrency_helper(
75+
client=client,
76+
num_nexus_operations=num_nexus_operations,
77+
expected_num_executed=expected_num_executed,
78+
tuner=tuner,
79+
)
80+
81+
82+
@pytest.mark.parametrize(
83+
["num_nexus_operations", "nexus_supplier", "expected_num_executed"],
84+
[
85+
(1, FixedSizeSlotSupplier(1), 1),
86+
(2, FixedSizeSlotSupplier(1), 1),
87+
(43, FixedSizeSlotSupplier(42), 42),
88+
],
89+
)
90+
async def test_max_concurrent_nexus_tasks_with_composite_tuner(
91+
client: Client,
92+
nexus_supplier: FixedSizeSlotSupplier,
93+
num_nexus_operations: int,
94+
expected_num_executed: int,
95+
):
96+
"""Test nexus concurrency using a composite WorkerTuner with nexus_supplier."""
97+
tuner = WorkerTuner.create_composite(
98+
workflow_supplier=FixedSizeSlotSupplier(10),
99+
activity_supplier=FixedSizeSlotSupplier(10),
100+
local_activity_supplier=FixedSizeSlotSupplier(10),
101+
nexus_supplier=nexus_supplier,
102+
)
103+
await _test_nexus_concurrency_helper(
104+
client=client,
105+
num_nexus_operations=num_nexus_operations,
106+
expected_num_executed=expected_num_executed,
107+
tuner=tuner,
108+
)
109+
110+
111+
async def _test_nexus_concurrency_helper(
112+
client: Client,
113+
num_nexus_operations: int,
114+
expected_num_executed: int,
115+
max_concurrent_nexus_tasks: Optional[int] = None,
116+
tuner: Optional[WorkerTuner] = None,
117+
):
118+
assert (max_concurrent_nexus_tasks is None) != (tuner is None)
46119
ids = []
47120
event = asyncio.Event()
48121

@@ -60,6 +133,7 @@ async def op(
60133
NexusCallerWorkflow,
61134
nexus_service_handlers=[MaxConcurrentTestService()],
62135
max_concurrent_nexus_tasks=max_concurrent_nexus_tasks,
136+
tuner=tuner,
63137
) as worker:
64138
await create_nexus_endpoint(worker.task_queue, client)
65139

0 commit comments

Comments
 (0)