@@ -33,7 +33,7 @@ async def run(self, id: int) -> None:
3333
3434@pytest .mark .parametrize (
3535 ["num_nexus_operations" , "max_concurrent_nexus_tasks" , "expected_num_executed" ],
36- [(1 , 1 , 1 ), (2 , 1 , 1 ), (43 , 42 , 42 ), (43 , 44 , 43 )],
36+ [(1 , 1 , 1 ), (2 , 1 , 1 ), (18 , 17 , 17 ), (18 , 19 , 18 )],
3737)
3838async def test_max_concurrent_nexus_tasks (
3939 client : Client ,
@@ -61,19 +61,29 @@ async def op(
6161 ) as worker :
6262 await create_nexus_endpoint (worker .task_queue , client )
6363
64- coros = [
65- client .execute_workflow (
66- NexusCallerWorkflow .run ,
67- i ,
68- id = str (uuid .uuid4 ()),
69- task_queue = worker .task_queue ,
64+ tasks = [
65+ asyncio .create_task (
66+ client .execute_workflow (
67+ NexusCallerWorkflow .run ,
68+ i ,
69+ id = str (uuid .uuid4 ()),
70+ task_queue = worker .task_queue ,
71+ )
7072 )
7173 for i in range (num_nexus_operations )
7274 ]
73- try :
74- await asyncio .wait_for (asyncio .gather (* coros ), timeout = 5 )
75- except asyncio .TimeoutError :
76- pass
77- event .set ()
78- assert len (set (ids )) == len (ids )
75+
76+ for _ in range (50 ): # 5 seconds max
77+ if len (ids ) >= expected_num_executed :
78+ break
79+ await asyncio .sleep (0.1 )
80+
81+ await asyncio .sleep (0.1 )
7982 assert len (ids ) == expected_num_executed
83+ assert len (set (ids )) == len (ids )
84+
85+ event .set ()
86+ for task in tasks :
87+ if not task .done ():
88+ task .cancel ()
89+ await asyncio .gather (* tasks , return_exceptions = True )
0 commit comments