@@ -192,12 +192,12 @@ async fn test_join_queue_join_next_with_id() {
192192
193193 let ( send, recv) = tokio:: sync:: watch:: channel ( ( ) ) ;
194194
195- let mut set = JoinQueue :: new ( ) ;
195+ let mut queue = JoinQueue :: new ( ) ;
196196 let mut spawned = Vec :: with_capacity ( TASK_NUM as usize ) ;
197197
198198 for _ in 0 ..TASK_NUM {
199199 let mut recv = recv. clone ( ) ;
200- let handle = set . spawn ( async move { recv. changed ( ) . await . unwrap ( ) } ) ;
200+ let handle = queue . spawn ( async move { recv. changed ( ) . await . unwrap ( ) } ) ;
201201
202202 spawned. push ( handle. id ( ) ) ;
203203 }
@@ -208,7 +208,7 @@ async fn test_join_queue_join_next_with_id() {
208208
209209 let mut count = 0 ;
210210 let mut joined = Vec :: with_capacity ( TASK_NUM as usize ) ;
211- while let Some ( res) = set . join_next_with_id ( ) . await {
211+ while let Some ( res) = queue . join_next_with_id ( ) . await {
212212 match res {
213213 Ok ( ( id, ( ) ) ) => {
214214 count += 1 ;
@@ -221,3 +221,141 @@ async fn test_join_queue_join_next_with_id() {
221221 assert_eq ! ( count, TASK_NUM ) ;
222222 assert_eq ! ( joined, spawned) ;
223223}
224+
225+ #[ tokio:: test]
226+ async fn test_join_queue_try_join_next ( ) {
227+ let mut queue = JoinQueue :: new ( ) ;
228+ let ( tx1, rx1) = oneshot:: channel :: < ( ) > ( ) ;
229+ queue. spawn ( async {
230+ let _ = rx1. await ;
231+ } ) ;
232+ let ( tx2, rx2) = oneshot:: channel :: < ( ) > ( ) ;
233+ queue. spawn ( async {
234+ let _ = rx2. await ;
235+ } ) ;
236+ let ( tx3, rx3) = oneshot:: channel :: < ( ) > ( ) ;
237+ queue. spawn ( async {
238+ let _ = rx3. await ;
239+ } ) ;
240+
241+ // This function also checks that calling `queue.try_join_next()` repeatedly when
242+ // no task is ready is idempotent, i.e. that it does not change the queue state.
243+ fn check_try_join_next_is_noop ( queue : & mut JoinQueue < ( ) > ) {
244+ let len = queue. len ( ) ;
245+ for _ in 0 ..5 {
246+ assert ! ( queue. try_join_next( ) . is_none( ) ) ;
247+ assert_eq ! ( queue. len( ) , len) ;
248+ }
249+ }
250+
251+ assert_eq ! ( queue. len( ) , 3 ) ;
252+ check_try_join_next_is_noop ( & mut queue) ;
253+
254+ tx1. send ( ( ) ) . unwrap ( ) ;
255+ tokio:: task:: yield_now ( ) . await ;
256+
257+ assert_eq ! ( queue. len( ) , 3 ) ;
258+ assert ! ( queue. try_join_next( ) . is_some( ) ) ;
259+ assert_eq ! ( queue. len( ) , 2 ) ;
260+ check_try_join_next_is_noop ( & mut queue) ;
261+
262+ tx3. send ( ( ) ) . unwrap ( ) ;
263+ tokio:: task:: yield_now ( ) . await ;
264+
265+ assert_eq ! ( queue. len( ) , 2 ) ;
266+ check_try_join_next_is_noop ( & mut queue) ;
267+
268+ tx2. send ( ( ) ) . unwrap ( ) ;
269+ tokio:: task:: yield_now ( ) . await ;
270+
271+ assert_eq ! ( queue. len( ) , 2 ) ;
272+ assert ! ( queue. try_join_next( ) . is_some( ) ) ;
273+ assert_eq ! ( queue. len( ) , 1 ) ;
274+ assert ! ( queue. try_join_next( ) . is_some( ) ) ;
275+ assert ! ( queue. is_empty( ) ) ;
276+ check_try_join_next_is_noop ( & mut queue) ;
277+ }
278+
279+ #[ tokio:: test]
280+ async fn test_join_queue_try_join_next_disabled_coop ( ) {
281+ // This number is large enough to trigger coop. Without using `tokio::task::coop::unconstrained`
282+ // inside `try_join_next` this test fails on `assert!(coop_count == 0)`.
283+ const TASK_NUM : u32 = 1000 ;
284+
285+ let sem: std:: sync:: Arc < tokio:: sync:: Semaphore > =
286+ std:: sync:: Arc :: new ( tokio:: sync:: Semaphore :: new ( 0 ) ) ;
287+
288+ let mut queue = JoinQueue :: new ( ) ;
289+
290+ for _ in 0 ..TASK_NUM {
291+ let sem = sem. clone ( ) ;
292+ queue. spawn ( async move {
293+ sem. add_permits ( 1 ) ;
294+ } ) ;
295+ }
296+
297+ let _ = sem. acquire_many ( TASK_NUM ) . await . unwrap ( ) ;
298+
299+ let mut count = 0 ;
300+ let mut coop_count = 0 ;
301+ while !queue. is_empty ( ) {
302+ match queue. try_join_next ( ) {
303+ Some ( Ok ( ( ) ) ) => count += 1 ,
304+ Some ( Err ( err) ) => panic ! ( "failed: {err}" ) ,
305+ None => {
306+ coop_count += 1 ;
307+ tokio:: task:: yield_now ( ) . await ;
308+ }
309+ }
310+ }
311+ assert_eq ! ( coop_count, 0 ) ;
312+ assert_eq ! ( count, TASK_NUM ) ;
313+ }
314+
315+ #[ tokio:: test]
316+ async fn test_join_queue_try_join_next_with_id_disabled_coop ( ) {
317+ // Note that this number is large enough to trigger coop as in
318+ // `test_join_queue_try_join_next_coop` test. Without using
319+ // `tokio::task::coop::unconstrained` inside `try_join_next_with_id`
320+ // this test fails on `assert_eq!(count, TASK_NUM)`.
321+ const TASK_NUM : u32 = 1000 ;
322+
323+ let ( send, recv) = tokio:: sync:: watch:: channel ( ( ) ) ;
324+
325+ let mut queue = JoinQueue :: new ( ) ;
326+ let mut spawned = Vec :: with_capacity ( TASK_NUM as usize ) ;
327+
328+ for _ in 0 ..TASK_NUM {
329+ let mut recv = recv. clone ( ) ;
330+ let handle = queue. spawn ( async move { recv. changed ( ) . await . unwrap ( ) } ) ;
331+
332+ spawned. push ( handle. id ( ) ) ;
333+ }
334+ drop ( recv) ;
335+
336+ assert ! ( queue. try_join_next_with_id( ) . is_none( ) ) ;
337+
338+ send. send_replace ( ( ) ) ;
339+ send. closed ( ) . await ;
340+
341+ let mut count = 0 ;
342+ let mut coop_count = 0 ;
343+ let mut joined = Vec :: with_capacity ( TASK_NUM as usize ) ;
344+ while !queue. is_empty ( ) {
345+ match queue. try_join_next_with_id ( ) {
346+ Some ( Ok ( ( id, ( ) ) ) ) => {
347+ count += 1 ;
348+ joined. push ( id) ;
349+ }
350+ Some ( Err ( err) ) => panic ! ( "failed: {err}" ) ,
351+ None => {
352+ coop_count += 1 ;
353+ tokio:: task:: yield_now ( ) . await ;
354+ }
355+ }
356+ }
357+
358+ assert_eq ! ( coop_count, 0 ) ;
359+ assert_eq ! ( count, TASK_NUM ) ;
360+ assert_eq ! ( joined, spawned) ;
361+ }
0 commit comments