Skip to content

Commit 6bc1534

Browse files
committed
-Zfused-futures
1 parent 8d72d3e commit 6bc1534

9 files changed

+303
-11
lines changed

compiler/rustc_abi/src/layout/ty.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ rustc_index::newtype_index! {
6565
const FIRST_VARIANT = 0;
6666
}
6767
}
68+
69+
impl VariantIdx {
70+
/// The second variant, at index 1.
71+
///
72+
/// For use alongside [`VariantIdx::ZERO`].
73+
pub const ONE: VariantIdx = VariantIdx::from_u32(1);
74+
}
75+
6876
#[derive(Copy, Clone, PartialEq, Eq, Hash, HashStable_Generic)]
6977
#[rustc_pass_by_value]
7078
pub struct Layout<'a>(pub Interned<'a, LayoutData<FieldIdx, VariantIdx>>);

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
//! For coroutines with state 1 (returned) and state 2 (poisoned) it panics.
4646
//! Otherwise it continues the execution from the last suspension point.
4747
//!
48+
//! If -Zfused-futures is given however, then `Future::poll` from the state 1 (returned)
49+
//! will not panic and will instead return `Poll::Pending`.
50+
//!
4851
//! The other function is the drop glue for the coroutine.
4952
//! For coroutines with state 0 (unresumed) it drops the upvars of the coroutine.
5053
//! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing.
@@ -218,10 +221,17 @@ impl<'tcx> TransformVisitor<'tcx> {
218221
let source_info = SourceInfo::outermost(body.span);
219222

220223
let none_value = match self.coroutine_kind {
224+
CoroutineKind::Coroutine(_) => span_bug!(body.span, "`Coroutine`s cannot be fused"),
225+
// Fused futures continue to return `Poll::Pending`.
221226
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
222-
span_bug!(body.span, "`Future`s are not fused inherently")
227+
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, body.span);
228+
make_aggregate_adt(
229+
poll_def_id,
230+
VariantIdx::ONE,
231+
self.tcx.mk_args(&[self.old_ret_ty.into()]),
232+
IndexVec::new(),
233+
)
223234
}
224-
CoroutineKind::Coroutine(_) => span_bug!(body.span, "`Coroutine`s cannot be fused"),
225235
// `gen` continues return `None`
226236
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
227237
let option_def_id = self.tcx.require_lang_item(LangItem::Option, body.span);
@@ -278,7 +288,7 @@ impl<'tcx> TransformVisitor<'tcx> {
278288
statements: &mut Vec<Statement<'tcx>>,
279289
) {
280290
const ZERO: VariantIdx = VariantIdx::ZERO;
281-
const ONE: VariantIdx = VariantIdx::from_usize(1);
291+
const ONE: VariantIdx = VariantIdx::ONE;
282292
let rvalue = match self.coroutine_kind {
283293
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
284294
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, source_info.span);
@@ -1099,7 +1109,7 @@ fn return_poll_ready_assign<'tcx>(tcx: TyCtxt<'tcx>, source_info: SourceInfo) ->
10991109
const_: Const::zero_sized(tcx.types.unit),
11001110
}));
11011111
let ready_val = Rvalue::Aggregate(
1102-
Box::new(AggregateKind::Adt(poll_def_id, VariantIdx::from_usize(0), args, None, None)),
1112+
Box::new(AggregateKind::Adt(poll_def_id, VariantIdx::ZERO, args, None, None)),
11031113
IndexVec::from_raw(vec![val]),
11041114
);
11051115
Statement::new(source_info, StatementKind::Assign(Box::new((Place::return_place(), ready_val))))
@@ -1253,17 +1263,23 @@ fn create_coroutine_resume_function<'tcx>(
12531263

12541264
if can_return {
12551265
let block = match transform.coroutine_kind {
1266+
CoroutineKind::Coroutine(_) => {
1267+
insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind))
1268+
}
12561269
CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
1257-
| CoroutineKind::Coroutine(_) => {
1270+
if tcx.is_async_drop_in_place_coroutine(body.source.def_id()) =>
1271+
{
12581272
// For `async_drop_in_place<T>::{closure}` we just keep return Poll::Ready,
12591273
// because async drop of such coroutine keeps polling original coroutine
1260-
if tcx.is_async_drop_in_place_coroutine(body.source.def_id()) {
1261-
insert_poll_ready_block(tcx, body)
1262-
} else {
1263-
insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind))
1264-
}
1274+
insert_poll_ready_block(tcx, body)
12651275
}
1266-
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
1276+
CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
1277+
if !tcx.sess.opts.unstable_opts.fused_futures =>
1278+
{
1279+
insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind))
1280+
}
1281+
CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
1282+
| CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
12671283
| CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
12681284
transform.insert_none_ret_block(body)
12691285
}

compiler/rustc_session/src/options.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,6 +2336,8 @@ options! {
23362336
"replace returns with jumps to `__x86_return_thunk` (default: `keep`)"),
23372337
function_sections: Option<bool> = (None, parse_opt_bool, [TRACKED],
23382338
"whether each function should go in its own section"),
2339+
fused_futures: bool = (false, parse_bool, [TRACKED],
2340+
"make compiler-generated futures return `Poll::Pending` and not `panic!` when polled after completion"),
23392341
future_incompat_test: bool = (false, parse_bool, [UNTRACKED],
23402342
"forces all lints to be future incompatible, used for internal testing (default: no)"),
23412343
graphviz_dark_mode: bool = (false, parse_bool, [UNTRACKED],
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// MIR for `future::{closure#0}` 0 coroutine_resume
2+
/* coroutine_layout = CoroutineLayout {
3+
field_tys: {},
4+
variant_fields: {
5+
Unresumed(0): [],
6+
Returned (1): [],
7+
Panicked (2): [],
8+
},
9+
storage_conflicts: BitMatrix(0x0) {},
10+
} */
11+
12+
fn future::{closure#0}(_1: Pin<&mut {async fn body of future()}>, _2: &mut Context<'_>) -> Poll<u32> {
13+
debug _task_context => _2;
14+
let mut _0: std::task::Poll<u32>;
15+
let mut _3: u32;
16+
let mut _4: u32;
17+
18+
bb0: {
19+
_4 = discriminant((*(_1.0: &mut {async fn body of future()})));
20+
switchInt(move _4) -> [0: bb1, 1: bb4, otherwise: bb5];
21+
}
22+
23+
bb1: {
24+
_3 = const 42_u32;
25+
goto -> bb3;
26+
}
27+
28+
bb2: {
29+
_0 = Poll::<u32>::Ready(move _3);
30+
discriminant((*(_1.0: &mut {async fn body of future()}))) = 1;
31+
return;
32+
}
33+
34+
bb3: {
35+
goto -> bb2;
36+
}
37+
38+
bb4: {
39+
_0 = Poll::<u32>::Pending;
40+
return;
41+
}
42+
43+
bb5: {
44+
unreachable;
45+
}
46+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// MIR for `future::{closure#0}` 0 coroutine_resume
2+
/* coroutine_layout = CoroutineLayout {
3+
field_tys: {},
4+
variant_fields: {
5+
Unresumed(0): [],
6+
Returned (1): [],
7+
Panicked (2): [],
8+
},
9+
storage_conflicts: BitMatrix(0x0) {},
10+
} */
11+
12+
fn future::{closure#0}(_1: Pin<&mut {async fn body of future()}>, _2: &mut Context<'_>) -> Poll<u32> {
13+
debug _task_context => _2;
14+
let mut _0: std::task::Poll<u32>;
15+
let mut _3: u32;
16+
let mut _4: u32;
17+
18+
bb0: {
19+
_4 = discriminant((*(_1.0: &mut {async fn body of future()})));
20+
switchInt(move _4) -> [0: bb1, 1: bb4, otherwise: bb5];
21+
}
22+
23+
bb1: {
24+
_3 = const 42_u32;
25+
goto -> bb3;
26+
}
27+
28+
bb2: {
29+
_0 = Poll::<u32>::Ready(move _3);
30+
discriminant((*(_1.0: &mut {async fn body of future()}))) = 1;
31+
return;
32+
}
33+
34+
bb3: {
35+
goto -> bb2;
36+
}
37+
38+
bb4: {
39+
_0 = Poll::<u32>::Pending;
40+
return;
41+
}
42+
43+
bb5: {
44+
unreachable;
45+
}
46+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// MIR for `main::{closure#0}` 0 coroutine_resume
2+
/* coroutine_layout = CoroutineLayout {
3+
field_tys: {},
4+
variant_fields: {
5+
Unresumed(0): [],
6+
Returned (1): [],
7+
Panicked (2): [],
8+
Suspend0 (3): [],
9+
},
10+
storage_conflicts: BitMatrix(0x0) {},
11+
} */
12+
13+
fn main::{closure#0}(_1: Pin<&mut {coroutine@$DIR/fused_futures.rs:17:5: 17:7}>, _2: ()) -> CoroutineState<i32, &str> {
14+
let mut _0: std::ops::CoroutineState<i32, &str>;
15+
let mut _3: !;
16+
let _4: ();
17+
let mut _5: &str;
18+
let mut _6: u32;
19+
20+
bb0: {
21+
_6 = discriminant((*(_1.0: &mut {coroutine@$DIR/fused_futures.rs:17:5: 17:7})));
22+
switchInt(move _6) -> [0: bb1, 1: bb6, 3: bb5, otherwise: bb7];
23+
}
24+
25+
bb1: {
26+
StorageLive(_4);
27+
_0 = CoroutineState::<i32, &str>::Yielded(const 1_i32);
28+
StorageDead(_4);
29+
discriminant((*(_1.0: &mut {coroutine@$DIR/fused_futures.rs:17:5: 17:7}))) = 3;
30+
return;
31+
}
32+
33+
bb2: {
34+
StorageDead(_4);
35+
_5 = const "foo";
36+
goto -> bb4;
37+
}
38+
39+
bb3: {
40+
_0 = CoroutineState::<i32, &str>::Complete(move _5);
41+
discriminant((*(_1.0: &mut {coroutine@$DIR/fused_futures.rs:17:5: 17:7}))) = 1;
42+
return;
43+
}
44+
45+
bb4: {
46+
goto -> bb3;
47+
}
48+
49+
bb5: {
50+
StorageLive(_4);
51+
_4 = move _2;
52+
goto -> bb2;
53+
}
54+
55+
bb6: {
56+
assert(const false, "coroutine resumed after completion") -> [success: bb6, unwind unreachable];
57+
}
58+
59+
bb7: {
60+
unreachable;
61+
}
62+
}
63+
64+
ALLOC0 (size: 3, align: 1) {
65+
66 6f 6f │ foo
66+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// MIR for `main::{closure#0}` 0 coroutine_resume
2+
/* coroutine_layout = CoroutineLayout {
3+
field_tys: {},
4+
variant_fields: {
5+
Unresumed(0): [],
6+
Returned (1): [],
7+
Panicked (2): [],
8+
Suspend0 (3): [],
9+
},
10+
storage_conflicts: BitMatrix(0x0) {},
11+
} */
12+
13+
fn main::{closure#0}(_1: Pin<&mut {coroutine@$DIR/fused_futures.rs:17:5: 17:7}>, _2: ()) -> CoroutineState<i32, &str> {
14+
let mut _0: std::ops::CoroutineState<i32, &str>;
15+
let mut _3: !;
16+
let _4: ();
17+
let mut _5: &str;
18+
let mut _6: u32;
19+
20+
bb0: {
21+
_6 = discriminant((*(_1.0: &mut {coroutine@$DIR/fused_futures.rs:17:5: 17:7})));
22+
switchInt(move _6) -> [0: bb1, 1: bb6, 3: bb5, otherwise: bb7];
23+
}
24+
25+
bb1: {
26+
StorageLive(_4);
27+
_0 = CoroutineState::<i32, &str>::Yielded(const 1_i32);
28+
StorageDead(_4);
29+
discriminant((*(_1.0: &mut {coroutine@$DIR/fused_futures.rs:17:5: 17:7}))) = 3;
30+
return;
31+
}
32+
33+
bb2: {
34+
StorageDead(_4);
35+
_5 = const "foo";
36+
goto -> bb4;
37+
}
38+
39+
bb3: {
40+
_0 = CoroutineState::<i32, &str>::Complete(move _5);
41+
discriminant((*(_1.0: &mut {coroutine@$DIR/fused_futures.rs:17:5: 17:7}))) = 1;
42+
return;
43+
}
44+
45+
bb4: {
46+
goto -> bb3;
47+
}
48+
49+
bb5: {
50+
StorageLive(_4);
51+
_4 = move _2;
52+
goto -> bb2;
53+
}
54+
55+
bb6: {
56+
assert(const false, "coroutine resumed after completion") -> [success: bb6, unwind continue];
57+
}
58+
59+
bb7: {
60+
unreachable;
61+
}
62+
}
63+
64+
ALLOC0 (size: 3, align: 1) {
65+
66 6f 6f │ foo
66+
}

tests/mir-opt/fused_futures.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//@ edition:2024
2+
//@ compile-flags: -Zfused-futures
3+
// skip-filecheck
4+
// EMIT_MIR_FOR_EACH_PANIC_STRATEGY
5+
6+
#![feature(coroutines, stmt_expr_attributes)]
7+
#![allow(unused)]
8+
9+
// EMIT_MIR fused_futures.future-{closure#0}.coroutine_resume.0.mir
10+
pub async fn future() -> u32 {
11+
42
12+
}
13+
14+
// EMIT_MIR fused_futures.main-{closure#0}.coroutine_resume.0.mir
15+
fn main() {
16+
let mut coroutine = #[coroutine]
17+
|| {
18+
yield 1;
19+
return "foo";
20+
};
21+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//@ run-pass
2+
//@ edition: 2024
3+
//@ compile-flags: -Zfused-futures
4+
5+
use std::pin::pin;
6+
use std::task::{Context, Poll, Waker};
7+
8+
async fn foo() -> u8 {
9+
12
10+
}
11+
12+
const N: usize = 10;
13+
14+
fn main() {
15+
let cx = &mut Context::from_waker(Waker::noop());
16+
let mut x = pin!(foo());
17+
assert_eq!(x.as_mut().poll(cx), Poll::Ready(12));
18+
for _ in 0..N {
19+
assert_eq!(x.as_mut().poll(cx), Poll::Pending);
20+
}
21+
}

0 commit comments

Comments
 (0)