Skip to content

Commit bce76c5

Browse files
authored
task: add try_join_next and try_join_next_with_id on JoinQueue (#7636)
1 parent b48586f commit bce76c5

File tree

2 files changed

+194
-3
lines changed

2 files changed

+194
-3
lines changed

tokio-util/src/task/join_queue.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,59 @@ impl<T> JoinQueue<T> {
183183
std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
184184
}
185185

186+
/// Tries to poll an `AbortOnDropHandle` without blocking or yielding.
187+
///
188+
/// Note that on success the handle will panic on subsequent polls
189+
/// since it becomes consumed.
190+
fn try_poll_handle(jh: &mut AbortOnDropHandle<T>) -> Option<Result<T, JoinError>> {
191+
let waker = futures_util::task::noop_waker();
192+
let mut cx = Context::from_waker(&waker);
193+
194+
// Since this function is not async and cannot be forced to yield, we should
195+
// disable budgeting when we want to check for the `JoinHandle` readiness.
196+
let jh = std::pin::pin!(tokio::task::coop::unconstrained(jh));
197+
if let Poll::Ready(res) = jh.poll(&mut cx) {
198+
Some(res)
199+
} else {
200+
None
201+
}
202+
}
203+
204+
/// Tries to join the next task in FIFO order if it has completed.
205+
///
206+
/// Returns `None` if the queue is empty or if the next task is not yet ready.
207+
pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
208+
let jh = self.0.front_mut()?;
209+
let res = Self::try_poll_handle(jh)?;
210+
// Use `detach` to avoid calling `abort` on a task that has already completed.
211+
// Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
212+
// we only need to drop the `JoinHandle` for cleanup.
213+
drop(self.0.pop_front().unwrap().detach());
214+
Some(res)
215+
}
216+
217+
/// Tries to join the next task in FIFO order if it has completed and return its output,
218+
/// along with its [task ID].
219+
///
220+
/// Returns `None` if the queue is empty or if the next task is not yet ready.
221+
///
222+
/// When this method returns an error, then the id of the task that failed can be accessed
223+
/// using the [`JoinError::id`] method.
224+
///
225+
/// [task ID]: tokio::task::Id
226+
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
227+
pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
228+
let jh = self.0.front_mut()?;
229+
let res = Self::try_poll_handle(jh)?;
230+
// Use `detach` to avoid calling `abort` on a task that has already completed.
231+
// Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
232+
// we only need to drop the `JoinHandle` for cleanup.
233+
let jh = self.0.pop_front().unwrap().detach();
234+
let id = jh.id();
235+
drop(jh);
236+
Some(res.map(|output| (id, output)))
237+
}
238+
186239
/// Aborts all tasks and waits for them to finish shutting down.
187240
///
188241
/// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in

tokio-util/tests/task_join_queue.rs

Lines changed: 141 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,12 @@ async fn test_join_queue_join_next_with_id() {
192192

193193
let (send, recv) = tokio::sync::watch::channel(());
194194

195-
let mut set = JoinQueue::new();
195+
let mut queue = JoinQueue::new();
196196
let mut spawned = Vec::with_capacity(TASK_NUM as usize);
197197

198198
for _ in 0..TASK_NUM {
199199
let mut recv = recv.clone();
200-
let handle = set.spawn(async move { recv.changed().await.unwrap() });
200+
let handle = queue.spawn(async move { recv.changed().await.unwrap() });
201201

202202
spawned.push(handle.id());
203203
}
@@ -208,7 +208,7 @@ async fn test_join_queue_join_next_with_id() {
208208

209209
let mut count = 0;
210210
let mut joined = Vec::with_capacity(TASK_NUM as usize);
211-
while let Some(res) = set.join_next_with_id().await {
211+
while let Some(res) = queue.join_next_with_id().await {
212212
match res {
213213
Ok((id, ())) => {
214214
count += 1;
@@ -221,3 +221,141 @@ async fn test_join_queue_join_next_with_id() {
221221
assert_eq!(count, TASK_NUM);
222222
assert_eq!(joined, spawned);
223223
}
224+
225+
#[tokio::test]
226+
async fn test_join_queue_try_join_next() {
227+
let mut queue = JoinQueue::new();
228+
let (tx1, rx1) = oneshot::channel::<()>();
229+
queue.spawn(async {
230+
let _ = rx1.await;
231+
});
232+
let (tx2, rx2) = oneshot::channel::<()>();
233+
queue.spawn(async {
234+
let _ = rx2.await;
235+
});
236+
let (tx3, rx3) = oneshot::channel::<()>();
237+
queue.spawn(async {
238+
let _ = rx3.await;
239+
});
240+
241+
// This function also checks that calling `queue.try_join_next()` repeatedly when
242+
// no task is ready is idempotent, i.e. that it does not change the queue state.
243+
fn check_try_join_next_is_noop(queue: &mut JoinQueue<()>) {
244+
let len = queue.len();
245+
for _ in 0..5 {
246+
assert!(queue.try_join_next().is_none());
247+
assert_eq!(queue.len(), len);
248+
}
249+
}
250+
251+
assert_eq!(queue.len(), 3);
252+
check_try_join_next_is_noop(&mut queue);
253+
254+
tx1.send(()).unwrap();
255+
tokio::task::yield_now().await;
256+
257+
assert_eq!(queue.len(), 3);
258+
assert!(queue.try_join_next().is_some());
259+
assert_eq!(queue.len(), 2);
260+
check_try_join_next_is_noop(&mut queue);
261+
262+
tx3.send(()).unwrap();
263+
tokio::task::yield_now().await;
264+
265+
assert_eq!(queue.len(), 2);
266+
check_try_join_next_is_noop(&mut queue);
267+
268+
tx2.send(()).unwrap();
269+
tokio::task::yield_now().await;
270+
271+
assert_eq!(queue.len(), 2);
272+
assert!(queue.try_join_next().is_some());
273+
assert_eq!(queue.len(), 1);
274+
assert!(queue.try_join_next().is_some());
275+
assert!(queue.is_empty());
276+
check_try_join_next_is_noop(&mut queue);
277+
}
278+
279+
#[tokio::test]
280+
async fn test_join_queue_try_join_next_disabled_coop() {
281+
// This number is large enough to trigger coop. Without using `tokio::task::coop::unconstrained`
282+
// inside `try_join_next` this test fails on `assert!(coop_count == 0)`.
283+
const TASK_NUM: u32 = 1000;
284+
285+
let sem: std::sync::Arc<tokio::sync::Semaphore> =
286+
std::sync::Arc::new(tokio::sync::Semaphore::new(0));
287+
288+
let mut queue = JoinQueue::new();
289+
290+
for _ in 0..TASK_NUM {
291+
let sem = sem.clone();
292+
queue.spawn(async move {
293+
sem.add_permits(1);
294+
});
295+
}
296+
297+
let _ = sem.acquire_many(TASK_NUM).await.unwrap();
298+
299+
let mut count = 0;
300+
let mut coop_count = 0;
301+
while !queue.is_empty() {
302+
match queue.try_join_next() {
303+
Some(Ok(())) => count += 1,
304+
Some(Err(err)) => panic!("failed: {err}"),
305+
None => {
306+
coop_count += 1;
307+
tokio::task::yield_now().await;
308+
}
309+
}
310+
}
311+
assert_eq!(coop_count, 0);
312+
assert_eq!(count, TASK_NUM);
313+
}
314+
315+
#[tokio::test]
316+
async fn test_join_queue_try_join_next_with_id_disabled_coop() {
317+
// Note that this number is large enough to trigger coop as in
318+
// `test_join_queue_try_join_next_coop` test. Without using
319+
// `tokio::task::coop::unconstrained` inside `try_join_next_with_id`
320+
// this test fails on `assert_eq!(count, TASK_NUM)`.
321+
const TASK_NUM: u32 = 1000;
322+
323+
let (send, recv) = tokio::sync::watch::channel(());
324+
325+
let mut queue = JoinQueue::new();
326+
let mut spawned = Vec::with_capacity(TASK_NUM as usize);
327+
328+
for _ in 0..TASK_NUM {
329+
let mut recv = recv.clone();
330+
let handle = queue.spawn(async move { recv.changed().await.unwrap() });
331+
332+
spawned.push(handle.id());
333+
}
334+
drop(recv);
335+
336+
assert!(queue.try_join_next_with_id().is_none());
337+
338+
send.send_replace(());
339+
send.closed().await;
340+
341+
let mut count = 0;
342+
let mut coop_count = 0;
343+
let mut joined = Vec::with_capacity(TASK_NUM as usize);
344+
while !queue.is_empty() {
345+
match queue.try_join_next_with_id() {
346+
Some(Ok((id, ()))) => {
347+
count += 1;
348+
joined.push(id);
349+
}
350+
Some(Err(err)) => panic!("failed: {err}"),
351+
None => {
352+
coop_count += 1;
353+
tokio::task::yield_now().await;
354+
}
355+
}
356+
}
357+
358+
assert_eq!(coop_count, 0);
359+
assert_eq!(count, TASK_NUM);
360+
assert_eq!(joined, spawned);
361+
}

0 commit comments

Comments
 (0)