Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 41 additions & 19 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ struct SelfArgVisitor<'tcx> {
}

impl<'tcx> SelfArgVisitor<'tcx> {
fn new(tcx: TyCtxt<'tcx>, elem: ProjectionElem<Local, Ty<'tcx>>) -> Self {
Self { tcx, new_base: Place { local: SELF_ARG, projection: tcx.mk_place_elems(&[elem]) } }
fn new(tcx: TyCtxt<'tcx>, new_base: Place<'tcx>) -> Self {
Self { tcx, new_base }
}
}

Expand All @@ -144,16 +144,14 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
assert_ne!(*local, SELF_ARG);
}

fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _: Location) {
if place.local == SELF_ARG {
replace_base(place, self.new_base, self.tcx);
} else {
self.visit_local(&mut place.local, context, location);
}

for elem in place.projection.iter() {
if let PlaceElem::Index(local) = elem {
assert_ne!(local, SELF_ARG);
}
for elem in place.projection.iter() {
if let PlaceElem::Index(local) = elem {
assert_ne!(local, SELF_ARG);
}
}
}
Expand Down Expand Up @@ -484,31 +482,55 @@ fn make_aggregate_adt<'tcx>(
}

fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let coroutine_ty = body.local_decls.raw[1].ty;
let coroutine_ty = body.local_decls[SELF_ARG].ty;

let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);

// Replace the by value coroutine argument
body.local_decls.raw[1].ty = ref_coroutine_ty;
body.local_decls[SELF_ARG].ty = ref_coroutine_ty;

// Add a deref to accesses of the coroutine state
SelfArgVisitor::new(tcx, ProjectionElem::Deref).visit_body(body);
SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
}

fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let ref_coroutine_ty = body.local_decls.raw[1].ty;
let coroutine_ty = body.local_decls[SELF_ARG].ty;

let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);

let pin_did = tcx.require_lang_item(LangItem::Pin, body.span);
let pin_adt_ref = tcx.adt_def(pin_did);
let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);

// Replace the by ref coroutine argument
body.local_decls.raw[1].ty = pin_ref_coroutine_ty;
body.local_decls[SELF_ARG].ty = pin_ref_coroutine_ty;

let unpinned_local = body.local_decls.push(LocalDecl::new(ref_coroutine_ty, body.span));

// Add the Pin field access to accesses of the coroutine state
SelfArgVisitor::new(tcx, ProjectionElem::Field(FieldIdx::ZERO, ref_coroutine_ty))
.visit_body(body);
SelfArgVisitor::new(tcx, tcx.mk_place_deref(unpinned_local.into())).visit_body(body);

let source_info = SourceInfo::outermost(body.span);
let pin_field = tcx.mk_place_field(SELF_ARG.into(), FieldIdx::ZERO, ref_coroutine_ty);

let statements = &mut body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements;
// Miri requires retags to be the very first thing in the body.
// We insert this assignment just after.
let insert_point = statements
.iter()
.position(|stmt| !matches!(stmt.kind, StatementKind::Retag(..)))
.unwrap_or(statements.len());
statements.insert(
insert_point,
Statement::new(
source_info,
StatementKind::Assign(Box::new((
unpinned_local.into(),
Rvalue::Use(Operand::Copy(pin_field)),
))),
),
);
}

/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
Expand Down Expand Up @@ -1274,8 +1296,6 @@ fn create_coroutine_resume_function<'tcx>(
let default_block = insert_term_block(body, TerminatorKind::Unreachable);
insert_switch(body, cases, &transform, default_block);

make_coroutine_state_argument_indirect(tcx, body);

match transform.coroutine_kind {
CoroutineKind::Coroutine(_)
| CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
Expand All @@ -1284,7 +1304,9 @@ fn create_coroutine_resume_function<'tcx>(
}
// Iterator::next doesn't accept a pinned argument,
// unlike for all other coroutine kinds.
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
make_coroutine_state_argument_indirect(tcx, body);
}
}

// Make sure we remove dead blocks to remove
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_mir_transform/src/coroutine/drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,12 +676,13 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);

make_coroutine_state_argument_indirect(tcx, &mut body);

match transform.coroutine_kind {
// Iterator::next doesn't accept a pinned argument,
// unlike for all other coroutine kinds.
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
make_coroutine_state_argument_indirect(tcx, &mut body);
}

_ => {
make_coroutine_state_argument_pinned(tcx, &mut body);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>) -> Poll<()> {
debug _task_context => _2;
debug x => ((*(_1.0: &mut {async fn body of a<T>()})).0: T);
debug x => ((*_20).0: T);
let mut _0: std::task::Poll<()>;
let _3: T;
let mut _4: impl std::future::Future<Output = ()>;
Expand All @@ -21,12 +21,14 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
let mut _17: std::pin::Pin<&mut impl std::future::Future<Output = ()>>;
let mut _18: isize;
let mut _19: u32;
let mut _20: &mut {async fn body of a<T>()};
scope 1 {
debug x => (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).0: T);
debug x => (((*_20) as variant#4).0: T);
}

bb0: {
_19 = discriminant((*(_1.0: &mut {async fn body of a<T>()})));
_20 = copy (_1.0: &mut {async fn body of a<T>()});
_19 = discriminant((*_20));
switchInt(move _19) -> [0: bb9, 3: bb12, 4: bb13, otherwise: bb14];
}

Expand All @@ -43,13 +45,13 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)

bb3: {
_0 = Poll::<()>::Pending;
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 4;
discriminant((*_20)) = 4;
return;
}

bb4: {
StorageLive(_17);
_16 = &mut (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).1: impl std::future::Future<Output = ()>);
_16 = &mut (((*_20) as variant#4).1: impl std::future::Future<Output = ()>);
_17 = Pin::<&mut impl Future<Output = ()>>::new_unchecked(move _16) -> [return: bb7, unwind unreachable];
}

Expand Down Expand Up @@ -81,7 +83,7 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
}

bb11: {
drop(((*(_1.0: &mut {async fn body of a<T>()})).0: T)) -> [return: bb10, unwind unreachable];
drop(((*_20).0: T)) -> [return: bb10, unwind unreachable];
}

bb12: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>) -> Poll<()> {
debug _task_context => _2;
debug x => ((*(_1.0: &mut {async fn body of a<T>()})).0: T);
debug x => ((*_20).0: T);
let mut _0: std::task::Poll<()>;
let _3: T;
let mut _4: impl std::future::Future<Output = ()>;
Expand All @@ -21,12 +21,14 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
let mut _17: std::pin::Pin<&mut impl std::future::Future<Output = ()>>;
let mut _18: isize;
let mut _19: u32;
let mut _20: &mut {async fn body of a<T>()};
scope 1 {
debug x => (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).0: T);
debug x => (((*_20) as variant#4).0: T);
}

bb0: {
_19 = discriminant((*(_1.0: &mut {async fn body of a<T>()})));
_20 = copy (_1.0: &mut {async fn body of a<T>()});
_19 = discriminant((*_20));
switchInt(move _19) -> [0: bb12, 2: bb18, 3: bb16, 4: bb17, otherwise: bb19];
}

Expand Down Expand Up @@ -57,13 +59,13 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)

bb6: {
_0 = Poll::<()>::Pending;
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 4;
discriminant((*_20)) = 4;
return;
}

bb7: {
StorageLive(_17);
_16 = &mut (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).1: impl std::future::Future<Output = ()>);
_16 = &mut (((*_20) as variant#4).1: impl std::future::Future<Output = ()>);
_17 = Pin::<&mut impl Future<Output = ()>>::new_unchecked(move _16) -> [return: bb10, unwind: bb15];
}

Expand Down Expand Up @@ -95,11 +97,11 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
}

bb14: {
drop(((*(_1.0: &mut {async fn body of a<T>()})).0: T)) -> [return: bb13, unwind: bb4];
drop(((*_20).0: T)) -> [return: bb13, unwind: bb4];
}

bb15 (cleanup): {
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 2;
discriminant((*_20)) = 2;
resume;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) ->
let mut _0: std::task::Poll<()>;
let mut _3: ();
let mut _4: u32;
let mut _5: &mut {async fn body of a()};

bb0: {
_4 = discriminant((*(_1.0: &mut {async fn body of a()})));
_5 = copy (_1.0: &mut {async fn body of a()});
_4 = discriminant((*_5));
switchInt(move _4) -> [0: bb1, 1: bb4, otherwise: bb5];
}

Expand All @@ -27,7 +29,7 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) ->

bb2: {
_0 = Poll::<()>::Ready(move _3);
discriminant((*(_1.0: &mut {async fn body of a()}))) = 1;
discriminant((*_5)) = 1;
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,25 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
let mut _36: ();
let mut _37: ();
let mut _38: u32;
let mut _39: &mut {async fn body of b()};
scope 1 {
debug __awaitee => (((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()});
debug __awaitee => (((*_39) as variant#3).0: {async fn body of a()});
let _17: ();
scope 2 {
debug result => _17;
}
}
scope 3 {
debug __awaitee => (((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()});
debug __awaitee => (((*_39) as variant#4).0: {async fn body of a()});
let _33: ();
scope 4 {
debug result => _33;
}
}

bb0: {
_38 = discriminant((*(_1.0: &mut {async fn body of b()})));
_39 = copy (_1.0: &mut {async fn body of b()});
_38 = discriminant((*_39));
switchInt(move _38) -> [0: bb1, 1: bb29, 3: bb27, 4: bb28, otherwise: bb8];
}

Expand All @@ -121,7 +123,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
StorageDead(_5);
PlaceMention(_4);
nop;
(((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()}) = move _4;
(((*_39) as variant#3).0: {async fn body of a()}) = move _4;
goto -> bb4;
}

Expand All @@ -131,7 +133,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
StorageLive(_10);
StorageLive(_11);
StorageLive(_12);
_12 = &mut (((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()});
_12 = &mut (((*_39) as variant#3).0: {async fn body of a()});
_11 = &mut (*_12);
_10 = Pin::<&mut {async fn body of a()}>::new_unchecked(move _11) -> [return: bb5, unwind unreachable];
}
Expand Down Expand Up @@ -178,7 +180,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
StorageDead(_4);
StorageDead(_19);
StorageDead(_20);
discriminant((*(_1.0: &mut {async fn body of b()}))) = 3;
discriminant((*_39)) = 3;
return;
}

Expand All @@ -191,7 +193,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
StorageDead(_12);
StorageDead(_9);
StorageDead(_8);
drop((((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()})) -> [return: bb12, unwind unreachable];
drop((((*_39) as variant#3).0: {async fn body of a()})) -> [return: bb12, unwind unreachable];
}

bb11: {
Expand Down Expand Up @@ -223,7 +225,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
StorageDead(_22);
PlaceMention(_21);
nop;
(((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()}) = move _21;
(((*_39) as variant#4).0: {async fn body of a()}) = move _21;
goto -> bb16;
}

Expand All @@ -233,7 +235,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
StorageLive(_26);
StorageLive(_27);
StorageLive(_28);
_28 = &mut (((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()});
_28 = &mut (((*_39) as variant#4).0: {async fn body of a()});
_27 = &mut (*_28);
_26 = Pin::<&mut {async fn body of a()}>::new_unchecked(move _27) -> [return: bb17, unwind unreachable];
}
Expand Down Expand Up @@ -275,7 +277,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
StorageDead(_21);
StorageDead(_35);
StorageDead(_36);
discriminant((*(_1.0: &mut {async fn body of b()}))) = 4;
discriminant((*_39)) = 4;
return;
}

Expand All @@ -288,7 +290,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
StorageDead(_28);
StorageDead(_25);
StorageDead(_24);
drop((((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()})) -> [return: bb23, unwind unreachable];
drop((((*_39) as variant#4).0: {async fn body of a()})) -> [return: bb23, unwind unreachable];
}

bb22: {
Expand All @@ -311,7 +313,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->

bb25: {
_0 = Poll::<()>::Ready(move _37);
discriminant((*(_1.0: &mut {async fn body of b()}))) = 1;
discriminant((*_39)) = 1;
return;
}

Expand Down
Loading
Loading