33import asyncio
44import uuid
55from datetime import timedelta
6- from typing import Any , Sequence
6+ from typing import Any
77
88import nexusrpc .handler
99import pytest
@@ -37,19 +37,17 @@ async def run(self, n: int) -> None:
3737
3838
3939@pytest .mark .parametrize (
40- ["num_nexus_operations" , "max_concurrent_nexus_tasks" , "expect_timeout" ],
40+ ["num_nexus_operations" , "max_concurrent_nexus_tasks" ],
4141 [
42- (1 , 1 , False ),
43- (1 , 3 , False ),
44- (3 , 3 , False ),
45- (4 , 3 , True ),
42+ (1 , 1 ),
43+ (3 , 3 ),
44+ (4 , 3 ),
4645 ],
4746)
4847async def test_max_concurrent_nexus_tasks (
4948 env : WorkflowEnvironment ,
5049 max_concurrent_nexus_tasks : int ,
5150 num_nexus_operations : int ,
52- expect_timeout : bool ,
5351):
5452 if env .supports_time_skipping :
5553 pytest .skip ("Nexus tests don't work with Javas test server" )
@@ -59,12 +57,8 @@ def __init__(self, size: int) -> None:
5957 self .size = size
6058 self .event = asyncio .Event ()
6159
62- @property
63- def waiters (self ) -> Sequence [Any ]:
64- return getattr (self .event , "_waiters" )
65-
6660 async def wait (self ) -> None :
67- if len (self .waiters ) >= self .size - 1 :
61+ if len (self .event . _waiters ) >= self .size - 1 :
6862 self .event .set ()
6963 else :
7064 await self .event .wait ()
@@ -93,7 +87,9 @@ async def op(
9387 id = str (uuid .uuid4 ()),
9488 task_queue = worker .task_queue ,
9589 )
96- if expect_timeout :
90+ if num_nexus_operations <= max_concurrent_nexus_tasks :
91+ await execute_operations_concurrently
92+ else :
9793 try :
9894 await asyncio .wait_for (execute_operations_concurrently , timeout = 10 )
9995 except TimeoutError :
@@ -104,5 +100,3 @@ async def op(
104100 f"max_concurrent_nexus_tasks={ max_concurrent_nexus_tasks } , "
105101 f"num_nexus_operations={ num_nexus_operations } "
106102 )
107- else :
108- await execute_operations_concurrently
0 commit comments