55import asyncio
66import uuid
77from datetime import timedelta
8+ from typing import Optional
89
910import nexusrpc .handler
1011import pytest
1112
1213from temporalio import workflow
1314from temporalio .client import Client
15+ from temporalio .worker import FixedSizeSlotSupplier , WorkerTuner
1416from tests .helpers import new_worker
1517from tests .helpers .nexus import create_nexus_endpoint , make_nexus_endpoint_name
1618
@@ -43,6 +45,77 @@ async def test_max_concurrent_nexus_tasks(
4345 num_nexus_operations : int ,
4446 expected_num_executed : int ,
4547):
48+ """Test max_concurrent_nexus_tasks parameter."""
49+ await _test_nexus_concurrency_helper (
50+ client = client ,
51+ num_nexus_operations = num_nexus_operations ,
52+ expected_num_executed = expected_num_executed ,
53+ max_concurrent_nexus_tasks = max_concurrent_nexus_tasks ,
54+ )
55+
56+
57+ @pytest .mark .parametrize (
58+ ["num_nexus_operations" , "nexus_slots" , "expected_num_executed" ],
59+ [(1 , 1 , 1 ), (2 , 1 , 1 ), (43 , 42 , 42 ), (43 , 44 , 43 )],
60+ )
61+ async def test_max_concurrent_nexus_tasks_with_tuner (
62+ client : Client ,
63+ nexus_slots : int ,
64+ num_nexus_operations : int ,
65+ expected_num_executed : int ,
66+ ):
67+ """Test nexus concurrency using a WorkerTuner."""
68+ tuner = WorkerTuner .create_fixed (
69+ workflow_slots = 10 ,
70+ activity_slots = 10 ,
71+ local_activity_slots = 10 ,
72+ nexus_slots = nexus_slots ,
73+ )
74+ await _test_nexus_concurrency_helper (
75+ client = client ,
76+ num_nexus_operations = num_nexus_operations ,
77+ expected_num_executed = expected_num_executed ,
78+ tuner = tuner ,
79+ )
80+
81+
82+ @pytest .mark .parametrize (
83+ ["num_nexus_operations" , "nexus_supplier" , "expected_num_executed" ],
84+ [
85+ (1 , FixedSizeSlotSupplier (1 ), 1 ),
86+ (2 , FixedSizeSlotSupplier (1 ), 1 ),
87+ (43 , FixedSizeSlotSupplier (42 ), 42 ),
88+ ],
89+ )
90+ async def test_max_concurrent_nexus_tasks_with_composite_tuner (
91+ client : Client ,
92+ nexus_supplier : FixedSizeSlotSupplier ,
93+ num_nexus_operations : int ,
94+ expected_num_executed : int ,
95+ ):
96+ """Test nexus concurrency using a composite WorkerTuner with nexus_supplier."""
97+ tuner = WorkerTuner .create_composite (
98+ workflow_supplier = FixedSizeSlotSupplier (10 ),
99+ activity_supplier = FixedSizeSlotSupplier (10 ),
100+ local_activity_supplier = FixedSizeSlotSupplier (10 ),
101+ nexus_supplier = nexus_supplier ,
102+ )
103+ await _test_nexus_concurrency_helper (
104+ client = client ,
105+ num_nexus_operations = num_nexus_operations ,
106+ expected_num_executed = expected_num_executed ,
107+ tuner = tuner ,
108+ )
109+
110+
111+ async def _test_nexus_concurrency_helper (
112+ client : Client ,
113+ num_nexus_operations : int ,
114+ expected_num_executed : int ,
115+ max_concurrent_nexus_tasks : Optional [int ] = None ,
116+ tuner : Optional [WorkerTuner ] = None ,
117+ ):
118+ assert (max_concurrent_nexus_tasks is None ) != (tuner is None )
46119 ids = []
47120 event = asyncio .Event ()
48121
@@ -60,6 +133,7 @@ async def op(
60133 NexusCallerWorkflow ,
61134 nexus_service_handlers = [MaxConcurrentTestService ()],
62135 max_concurrent_nexus_tasks = max_concurrent_nexus_tasks ,
136+ tuner = tuner ,
63137 ) as worker :
64138 await create_nexus_endpoint (worker .task_queue , client )
65139
0 commit comments