Skip to content

Commit 633a905

Browse files
authored
Add stream_select macro (#2262)
1 parent ea07b4b commit 633a905

File tree

6 files changed

+219
-1
lines changed

6 files changed

+219
-1
lines changed

futures-macro/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use proc_macro::TokenStream;
1919
mod executor;
2020
mod join;
2121
mod select;
22+
mod stream_select;
2223

2324
/// The `join!` macro.
2425
#[cfg_attr(fn_like_proc_macro, proc_macro)]
@@ -54,3 +55,12 @@ pub fn select_biased_internal(input: TokenStream) -> TokenStream {
5455
pub fn test_internal(input: TokenStream, item: TokenStream) -> TokenStream {
5556
crate::executor::test(input, item)
5657
}
58+
59+
/// The `stream_select!` macro.
60+
#[cfg_attr(fn_like_proc_macro, proc_macro)]
61+
#[cfg_attr(not(fn_like_proc_macro), proc_macro_hack::proc_macro_hack)]
62+
pub fn stream_select_internal(input: TokenStream) -> TokenStream {
63+
crate::stream_select::stream_select(input.into())
64+
.unwrap_or_else(syn::Error::into_compile_error)
65+
.into()
66+
}

futures-macro/src/stream_select.rs

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
use proc_macro2::TokenStream;
2+
use quote::{format_ident, quote, ToTokens};
3+
use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token};
4+
5+
/// The `stream_select!` macro.
6+
pub(crate) fn stream_select(input: TokenStream) -> Result<TokenStream, syn::Error> {
7+
let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse2(input)?;
8+
if args.len() < 2 {
9+
return Ok(quote! {
10+
compile_error!("stream select macro needs at least two arguments.")
11+
});
12+
}
13+
let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::<Vec<_>>();
14+
let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::<Vec<_>>();
15+
let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}", i)).collect::<Vec<_>>();
16+
let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>();
17+
let args = args.iter().map(|e| e.to_token_stream());
18+
19+
Ok(quote! {
20+
{
21+
#[derive(Debug)]
22+
struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*);
23+
24+
enum StreamEnum<#(#generic_idents),*> {
25+
#(
26+
#generic_idents(#generic_idents)
27+
),*,
28+
None,
29+
}
30+
31+
impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*>
32+
where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
33+
{
34+
type Item = ITEM;
35+
36+
fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
37+
match self.get_mut() {
38+
#(
39+
Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx)
40+
),*,
41+
Self::None => panic!("StreamEnum::None should never be polled!"),
42+
}
43+
}
44+
}
45+
46+
impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*>
47+
where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
48+
{
49+
type Item = ITEM;
50+
51+
fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
52+
let Self(#(ref mut #field_idents),*) = self.get_mut();
53+
#(
54+
let mut #field_idents_2 = false;
55+
)*
56+
let mut any_pending = false;
57+
{
58+
let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*];
59+
__futures_crate::async_await::shuffle(&mut stream_array);
60+
61+
for mut s in stream_array {
62+
if let StreamEnum::None = s {
63+
continue;
64+
} else {
65+
match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) {
66+
r @ __futures_crate::task::Poll::Ready(Some(_)) => {
67+
return r;
68+
},
69+
__futures_crate::task::Poll::Pending => {
70+
any_pending = true;
71+
},
72+
__futures_crate::task::Poll::Ready(None) => {
73+
match s {
74+
#(
75+
StreamEnum::#generic_idents(_) => { #field_idents_2 = true; }
76+
),*,
77+
StreamEnum::None => panic!("StreamEnum::None should never be polled!"),
78+
}
79+
},
80+
}
81+
}
82+
}
83+
}
84+
#(
85+
if #field_idents_2 {
86+
*#field_idents = None;
87+
}
88+
)*
89+
if any_pending {
90+
__futures_crate::task::Poll::Pending
91+
} else {
92+
__futures_crate::task::Poll::Ready(None)
93+
}
94+
}
95+
96+
fn size_hint(&self) -> (usize, Option<usize>) {
97+
let mut s = (0, Some(0));
98+
#(
99+
if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) {
100+
s.0 += new_hint.0;
101+
// We can change this out for `.zip` when the MSRV is 1.46.0 or higher.
102+
s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b));
103+
}
104+
)*
105+
s
106+
}
107+
}
108+
109+
StreamSelect(#(Some(#args)),*)
110+
111+
}
112+
})
113+
}

futures-util/src/async_await/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ mod select_mod;
3030
#[cfg(feature = "async-await-macro")]
3131
pub use self::select_mod::*;
3232

33+
// Primary export is a macro
34+
#[cfg(feature = "async-await-macro")]
35+
mod stream_select_mod;
36+
#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/64762
37+
#[cfg(feature = "async-await-macro")]
38+
pub use self::stream_select_mod::*;
39+
3340
#[cfg(feature = "std")]
3441
#[cfg(feature = "async-await-macro")]
3542
mod random;
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//! The `stream_select` macro.
2+
3+
#[cfg(feature = "std")]
4+
#[allow(unreachable_pub)]
5+
#[doc(hidden)]
6+
#[cfg_attr(not(fn_like_proc_macro), proc_macro_hack::proc_macro_hack(support_nested))]
7+
pub use futures_macro::stream_select_internal;
8+
9+
/// Combines several streams, all producing the same `Item` type, into one stream.
10+
/// This is similar to `select_all` but does not require the streams to all be the same type.
11+
/// It also keeps the streams inline, and does not require `Box<dyn Stream>`s to be allocated.
12+
/// Streams passed to this macro must be `Unpin`.
13+
///
14+
/// If multiple streams are ready, one will be pseudo randomly selected at runtime.
15+
///
16+
/// This macro is gated behind the `async-await` feature of this library, which is activated by default.
17+
/// Note that `stream_select!` relies on `proc-macro-hack`, and may require to set the compiler's recursion
18+
/// limit very high, e.g. `#![recursion_limit="1024"]`.
19+
///
20+
/// # Examples
21+
///
22+
/// ```
23+
/// # futures::executor::block_on(async {
24+
/// use futures::{stream, StreamExt, stream_select};
25+
/// let endless_ints = |i| stream::iter(vec![i].into_iter().cycle()).fuse();
26+
///
27+
/// let mut endless_numbers = stream_select!(endless_ints(1i32), endless_ints(2), endless_ints(3));
28+
/// match endless_numbers.next().await {
29+
/// Some(1) => println!("Got a 1"),
30+
/// Some(2) => println!("Got a 2"),
31+
/// Some(3) => println!("Got a 3"),
32+
/// _ => unreachable!(),
33+
/// }
34+
/// # });
35+
/// ```
36+
#[cfg(feature = "std")]
37+
#[macro_export]
38+
macro_rules! stream_select {
39+
($($tokens:tt)*) => {{
40+
use $crate::__private as __futures_crate;
41+
$crate::stream_select_internal! {
42+
$( $tokens )*
43+
}
44+
}}
45+
}

futures/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ pub use futures_util::{join, pending, poll, select_biased, try_join}; // Async-a
137137
#[doc(inline)]
138138
pub use futures_util::{future, sink, stream, task};
139139

140+
#[cfg(feature = "std")]
141+
#[cfg(feature = "async-await")]
142+
pub use futures_util::stream_select;
143+
144+
#[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))]
140145
#[cfg(feature = "alloc")]
141146
#[doc(inline)]
142147
pub use futures_channel as channel;

futures/tests/async_await_macros.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ use futures::future::{self, poll_fn, FutureExt};
44
use futures::sink::SinkExt;
55
use futures::stream::StreamExt;
66
use futures::task::{Context, Poll};
7-
use futures::{join, pending, pin_mut, poll, select, select_biased, try_join};
7+
use futures::{
8+
join, pending, pin_mut, poll, select, select_biased, stream, stream_select, try_join,
9+
};
810
use std::mem;
911

1012
#[test]
@@ -308,6 +310,42 @@ fn select_on_mutable_borrowing_future_with_same_borrow_in_block_and_default() {
308310
});
309311
}
310312

313+
#[test]
314+
#[allow(unused_assignments)]
315+
fn stream_select() {
316+
// stream_select! macro
317+
block_on(async {
318+
let endless_ints = |i| stream::iter(vec![i].into_iter().cycle());
319+
320+
let mut endless_ones = stream_select!(endless_ints(1i32), stream::pending());
321+
assert_eq!(endless_ones.next().await, Some(1));
322+
assert_eq!(endless_ones.next().await, Some(1));
323+
324+
let mut finite_list =
325+
stream_select!(stream::iter(vec![1].into_iter()), stream::iter(vec![1].into_iter()));
326+
assert_eq!(finite_list.next().await, Some(1));
327+
assert_eq!(finite_list.next().await, Some(1));
328+
assert_eq!(finite_list.next().await, None);
329+
330+
let endless_mixed = stream_select!(endless_ints(1i32), endless_ints(2), endless_ints(3));
331+
// Take 1000, and assert a somewhat even distribution of values.
332+
// The fairness is randomized, but over 1000 samples we should be pretty close to even.
333+
// This test may be a bit flaky. Feel free to adjust the margins as you see fit.
334+
let mut count = 0;
335+
let results = endless_mixed
336+
.take_while(move |_| {
337+
count += 1;
338+
let ret = count < 1000;
339+
async move { ret }
340+
})
341+
.collect::<Vec<_>>()
342+
.await;
343+
assert!(results.iter().filter(|x| **x == 1).count() >= 299);
344+
assert!(results.iter().filter(|x| **x == 2).count() >= 299);
345+
assert!(results.iter().filter(|x| **x == 3).count() >= 299);
346+
});
347+
}
348+
311349
#[test]
312350
fn join_size() {
313351
let fut = async {

0 commit comments

Comments
 (0)