33import asyncio
44import uuid
55from datetime import timedelta
6+ from typing import Optional
67
78import nexusrpc .handler
89import pytest
910
1011from temporalio import workflow
1112from temporalio .client import Client
13+ from temporalio .worker import FixedSizeSlotSupplier , WorkerTuner
1214from tests .helpers import new_worker
1315from tests .helpers .nexus import create_nexus_endpoint , make_nexus_endpoint_name
1416
@@ -41,6 +43,77 @@ async def test_max_concurrent_nexus_tasks(
4143 num_nexus_operations : int ,
4244 expected_num_executed : int ,
4345):
46+ """Test max_concurrent_nexus_tasks parameter."""
47+ await _test_nexus_concurrency_helper (
48+ client = client ,
49+ num_nexus_operations = num_nexus_operations ,
50+ expected_num_executed = expected_num_executed ,
51+ max_concurrent_nexus_tasks = max_concurrent_nexus_tasks ,
52+ )
53+
54+
55+ @pytest .mark .parametrize (
56+ ["num_nexus_operations" , "nexus_slots" , "expected_num_executed" ],
57+ [(1 , 1 , 1 ), (2 , 1 , 1 ), (43 , 42 , 42 ), (43 , 44 , 43 )],
58+ )
59+ async def test_max_concurrent_nexus_tasks_with_tuner (
60+ client : Client ,
61+ nexus_slots : int ,
62+ num_nexus_operations : int ,
63+ expected_num_executed : int ,
64+ ):
65+ """Test nexus concurrency using a WorkerTuner."""
66+ tuner = WorkerTuner .create_fixed (
67+ workflow_slots = 10 ,
68+ activity_slots = 10 ,
69+ local_activity_slots = 10 ,
70+ nexus_slots = nexus_slots ,
71+ )
72+ await _test_nexus_concurrency_helper (
73+ client = client ,
74+ num_nexus_operations = num_nexus_operations ,
75+ expected_num_executed = expected_num_executed ,
76+ tuner = tuner ,
77+ )
78+
79+
80+ @pytest .mark .parametrize (
81+ ["num_nexus_operations" , "nexus_supplier" , "expected_num_executed" ],
82+ [
83+ (1 , FixedSizeSlotSupplier (1 ), 1 ),
84+ (2 , FixedSizeSlotSupplier (1 ), 1 ),
85+ (43 , FixedSizeSlotSupplier (42 ), 42 ),
86+ ],
87+ )
88+ async def test_max_concurrent_nexus_tasks_with_composite_tuner (
89+ client : Client ,
90+ nexus_supplier : FixedSizeSlotSupplier ,
91+ num_nexus_operations : int ,
92+ expected_num_executed : int ,
93+ ):
94+ """Test nexus concurrency using a composite WorkerTuner with nexus_supplier."""
95+ tuner = WorkerTuner .create_composite (
96+ workflow_supplier = FixedSizeSlotSupplier (10 ),
97+ activity_supplier = FixedSizeSlotSupplier (10 ),
98+ local_activity_supplier = FixedSizeSlotSupplier (10 ),
99+ nexus_supplier = nexus_supplier ,
100+ )
101+ await _test_nexus_concurrency_helper (
102+ client = client ,
103+ num_nexus_operations = num_nexus_operations ,
104+ expected_num_executed = expected_num_executed ,
105+ tuner = tuner ,
106+ )
107+
108+
109+ async def _test_nexus_concurrency_helper (
110+ client : Client ,
111+ num_nexus_operations : int ,
112+ expected_num_executed : int ,
113+ max_concurrent_nexus_tasks : Optional [int ] = None ,
114+ tuner : Optional [WorkerTuner ] = None ,
115+ ):
116+ assert (max_concurrent_nexus_tasks is None ) != (tuner is None )
44117 ids = []
45118 event = asyncio .Event ()
46119
@@ -58,6 +131,7 @@ async def op(
58131 NexusCallerWorkflow ,
59132 nexus_service_handlers = [MaxConcurrentTestService ()],
60133 max_concurrent_nexus_tasks = max_concurrent_nexus_tasks ,
134+ tuner = tuner ,
61135 ) as worker :
62136 await create_nexus_endpoint (worker .task_queue , client )
63137
0 commit comments