33import asyncio
44import uuid
55from datetime import timedelta
6- from typing import Any
6+ from typing import Any , Sequence
77
88import nexusrpc .handler
99import pytest
@@ -37,17 +37,19 @@ async def run(self, n: int) -> None:
3737
3838
3939@pytest .mark .parametrize (
40- ["num_nexus_operations" , "max_concurrent_nexus_tasks" ],
40+ ["num_nexus_operations" , "max_concurrent_nexus_tasks" , "expect_timeout" ],
4141 [
42- (1 , 1 ),
43- (3 , 3 ),
44- (4 , 3 ),
42+ (1 , 1 , False ),
43+ (1 , 3 , False ),
44+ (3 , 3 , False ),
45+ (4 , 3 , True ),
4546 ],
4647)
4748async def test_max_concurrent_nexus_tasks (
4849 env : WorkflowEnvironment ,
4950 max_concurrent_nexus_tasks : int ,
5051 num_nexus_operations : int ,
52+ expect_timeout : bool ,
5153):
5254 if env .supports_time_skipping :
5355 pytest .skip ("Nexus tests don't work with Javas test server" )
@@ -57,8 +59,12 @@ def __init__(self, size: int) -> None:
5759 self .size = size
5860 self .event = asyncio .Event ()
5961
62+ @property
63+ def waiters (self ) -> Sequence [Any ]:
64+ return getattr (self .event , "_waiters" )
65+
6066 async def wait (self ) -> None :
61- if len (self .event . _waiters ) >= self .size - 1 :
67+ if len (self .waiters ) >= self .size - 1 :
6268 self .event .set ()
6369 else :
6470 await self .event .wait ()
@@ -87,9 +93,7 @@ async def op(
8793 id = str (uuid .uuid4 ()),
8894 task_queue = worker .task_queue ,
8995 )
90- if num_nexus_operations <= max_concurrent_nexus_tasks :
91- await execute_operations_concurrently
92- else :
96+ if expect_timeout :
9397 try :
9498 await asyncio .wait_for (execute_operations_concurrently , timeout = 10 )
9599 except TimeoutError :
@@ -100,3 +104,5 @@ async def op(
100104 f"max_concurrent_nexus_tasks={ max_concurrent_nexus_tasks } , "
101105 f"num_nexus_operations={ num_nexus_operations } "
102106 )
107+ else :
108+ await execute_operations_concurrently
0 commit comments