Skip to content

Commit 8a45031

Browse files
Auto merge of #147493 - cjgillot:single-pin, r=<try>
StateTransform: Only load pin field once.
2 parents 7a52736 + 3aa474a commit 8a45031

12 files changed

+198
-270
lines changed

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ struct SelfArgVisitor<'tcx> {
130130
}
131131

132132
impl<'tcx> SelfArgVisitor<'tcx> {
133-
fn new(tcx: TyCtxt<'tcx>, elem: ProjectionElem<Local, Ty<'tcx>>) -> Self {
134-
Self { tcx, new_base: Place { local: SELF_ARG, projection: tcx.mk_place_elems(&[elem]) } }
133+
fn new(tcx: TyCtxt<'tcx>, new_base: Place<'tcx>) -> Self {
134+
Self { tcx, new_base }
135135
}
136136
}
137137

@@ -144,16 +144,14 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
144144
assert_ne!(*local, SELF_ARG);
145145
}
146146

147-
fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
147+
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _: Location) {
148148
if place.local == SELF_ARG {
149149
replace_base(place, self.new_base, self.tcx);
150-
} else {
151-
self.visit_local(&mut place.local, context, location);
150+
}
152151

153-
for elem in place.projection.iter() {
154-
if let PlaceElem::Index(local) = elem {
155-
assert_ne!(local, SELF_ARG);
156-
}
152+
for elem in place.projection.iter() {
153+
if let PlaceElem::Index(local) = elem {
154+
assert_ne!(local, SELF_ARG);
157155
}
158156
}
159157
}
@@ -484,31 +482,50 @@ fn make_aggregate_adt<'tcx>(
484482
}
485483

486484
fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
487-
let coroutine_ty = body.local_decls.raw[1].ty;
485+
let coroutine_ty = body.local_decls[SELF_ARG].ty;
488486

489487
let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
490488

491489
// Replace the by value coroutine argument
492-
body.local_decls.raw[1].ty = ref_coroutine_ty;
490+
body.local_decls[SELF_ARG].ty = ref_coroutine_ty;
493491

494492
// Add a deref to accesses of the coroutine state
495-
SelfArgVisitor::new(tcx, ProjectionElem::Deref).visit_body(body);
493+
SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
496494
}
497495

498496
fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
499-
let ref_coroutine_ty = body.local_decls.raw[1].ty;
497+
let coroutine_ty = body.local_decls[SELF_ARG].ty;
498+
499+
let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
500500

501501
let pin_did = tcx.require_lang_item(LangItem::Pin, body.span);
502502
let pin_adt_ref = tcx.adt_def(pin_did);
503503
let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
504504
let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);
505505

506506
// Replace the by ref coroutine argument
507-
body.local_decls.raw[1].ty = pin_ref_coroutine_ty;
507+
body.local_decls[SELF_ARG].ty = pin_ref_coroutine_ty;
508+
509+
let unpinned_local = body.local_decls.push(LocalDecl::new(ref_coroutine_ty, body.span));
508510

509511
// Add the Pin field access to accesses of the coroutine state
510-
SelfArgVisitor::new(tcx, ProjectionElem::Field(FieldIdx::ZERO, ref_coroutine_ty))
511-
.visit_body(body);
512+
SelfArgVisitor::new(tcx, tcx.mk_place_deref(unpinned_local.into())).visit_body(body);
513+
514+
let source_info = SourceInfo::outermost(body.span);
515+
body.basic_blocks_mut()[START_BLOCK].statements.insert(
516+
0,
517+
Statement::new(
518+
source_info,
519+
StatementKind::Assign(Box::new((
520+
unpinned_local.into(),
521+
Rvalue::CopyForDeref(tcx.mk_place_field(
522+
SELF_ARG.into(),
523+
FieldIdx::ZERO,
524+
ref_coroutine_ty,
525+
)),
526+
))),
527+
),
528+
);
512529
}
513530

514531
/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
@@ -1274,8 +1291,6 @@ fn create_coroutine_resume_function<'tcx>(
12741291
let default_block = insert_term_block(body, TerminatorKind::Unreachable);
12751292
insert_switch(body, cases, &transform, default_block);
12761293

1277-
make_coroutine_state_argument_indirect(tcx, body);
1278-
12791294
match transform.coroutine_kind {
12801295
CoroutineKind::Coroutine(_)
12811296
| CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
@@ -1284,7 +1299,9 @@ fn create_coroutine_resume_function<'tcx>(
12841299
}
12851300
// Iterator::next doesn't accept a pinned argument,
12861301
// unlike for all other coroutine kinds.
1287-
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
1302+
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
1303+
make_coroutine_state_argument_indirect(tcx, body);
1304+
}
12881305
}
12891306

12901307
// Make sure we remove dead blocks to remove

compiler/rustc_mir_transform/src/coroutine/drop.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -676,12 +676,13 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
676676
let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
677677
body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
678678

679-
make_coroutine_state_argument_indirect(tcx, &mut body);
680-
681679
match transform.coroutine_kind {
682680
// Iterator::next doesn't accept a pinned argument,
683681
// unlike for all other coroutine kinds.
684-
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
682+
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
683+
make_coroutine_state_argument_indirect(tcx, &mut body);
684+
}
685+
685686
_ => {
686687
make_coroutine_state_argument_pinned(tcx, &mut body);
687688
}

tests/mir-opt/async_drop_live_dead.a-{closure#0}.coroutine_drop_async.0.panic-abort.mir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>) -> Poll<()> {
44
debug _task_context => _2;
5-
debug x => ((*(_1.0: &mut {async fn body of a<T>()})).0: T);
5+
debug x => ((*_20).0: T);
66
let mut _0: std::task::Poll<()>;
77
let _3: T;
88
let mut _4: impl std::future::Future<Output = ()>;
@@ -21,12 +21,14 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
2121
let mut _17: std::pin::Pin<&mut impl std::future::Future<Output = ()>>;
2222
let mut _18: isize;
2323
let mut _19: u32;
24+
let mut _20: &mut {async fn body of a<T>()};
2425
scope 1 {
25-
debug x => (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).0: T);
26+
debug x => (((*_20) as variant#4).0: T);
2627
}
2728

2829
bb0: {
29-
_19 = discriminant((*(_1.0: &mut {async fn body of a<T>()})));
30+
_20 = deref_copy (_1.0: &mut {async fn body of a<T>()});
31+
_19 = discriminant((*_20));
3032
switchInt(move _19) -> [0: bb9, 3: bb12, 4: bb13, otherwise: bb14];
3133
}
3234

@@ -43,13 +45,13 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
4345

4446
bb3: {
4547
_0 = Poll::<()>::Pending;
46-
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 4;
48+
discriminant((*_20)) = 4;
4749
return;
4850
}
4951

5052
bb4: {
5153
StorageLive(_17);
52-
_16 = &mut (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).1: impl std::future::Future<Output = ()>);
54+
_16 = &mut (((*_20) as variant#4).1: impl std::future::Future<Output = ()>);
5355
_17 = Pin::<&mut impl Future<Output = ()>>::new_unchecked(move _16) -> [return: bb7, unwind unreachable];
5456
}
5557

@@ -81,7 +83,7 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
8183
}
8284

8385
bb11: {
84-
drop(((*(_1.0: &mut {async fn body of a<T>()})).0: T)) -> [return: bb10, unwind unreachable];
86+
drop(((*_20).0: T)) -> [return: bb10, unwind unreachable];
8587
}
8688

8789
bb12: {

tests/mir-opt/async_drop_live_dead.a-{closure#0}.coroutine_drop_async.0.panic-unwind.mir

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>) -> Poll<()> {
44
debug _task_context => _2;
5-
debug x => ((*(_1.0: &mut {async fn body of a<T>()})).0: T);
5+
debug x => ((*_20).0: T);
66
let mut _0: std::task::Poll<()>;
77
let _3: T;
88
let mut _4: impl std::future::Future<Output = ()>;
@@ -21,12 +21,14 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
2121
let mut _17: std::pin::Pin<&mut impl std::future::Future<Output = ()>>;
2222
let mut _18: isize;
2323
let mut _19: u32;
24+
let mut _20: &mut {async fn body of a<T>()};
2425
scope 1 {
25-
debug x => (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).0: T);
26+
debug x => (((*_20) as variant#4).0: T);
2627
}
2728

2829
bb0: {
29-
_19 = discriminant((*(_1.0: &mut {async fn body of a<T>()})));
30+
_20 = deref_copy (_1.0: &mut {async fn body of a<T>()});
31+
_19 = discriminant((*_20));
3032
switchInt(move _19) -> [0: bb12, 2: bb18, 3: bb16, 4: bb17, otherwise: bb19];
3133
}
3234

@@ -57,13 +59,13 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
5759

5860
bb6: {
5961
_0 = Poll::<()>::Pending;
60-
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 4;
62+
discriminant((*_20)) = 4;
6163
return;
6264
}
6365

6466
bb7: {
6567
StorageLive(_17);
66-
_16 = &mut (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).1: impl std::future::Future<Output = ()>);
68+
_16 = &mut (((*_20) as variant#4).1: impl std::future::Future<Output = ()>);
6769
_17 = Pin::<&mut impl Future<Output = ()>>::new_unchecked(move _16) -> [return: bb10, unwind: bb15];
6870
}
6971

@@ -95,11 +97,11 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
9597
}
9698

9799
bb14: {
98-
drop(((*(_1.0: &mut {async fn body of a<T>()})).0: T)) -> [return: bb13, unwind: bb4];
100+
drop(((*_20).0: T)) -> [return: bb13, unwind: bb4];
99101
}
100102

101103
bb15 (cleanup): {
102-
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 2;
104+
discriminant((*_20)) = 2;
103105
resume;
104106
}
105107

tests/mir-opt/building/async_await.a-{closure#0}.coroutine_resume.0.mir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) ->
1414
let mut _0: std::task::Poll<()>;
1515
let mut _3: ();
1616
let mut _4: u32;
17+
let mut _5: &mut {async fn body of a()};
1718

1819
bb0: {
19-
_4 = discriminant((*(_1.0: &mut {async fn body of a()})));
20+
_5 = deref_copy (_1.0: &mut {async fn body of a()});
21+
_4 = discriminant((*_5));
2022
switchInt(move _4) -> [0: bb1, 1: bb4, otherwise: bb5];
2123
}
2224

@@ -27,7 +29,7 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) ->
2729

2830
bb2: {
2931
_0 = Poll::<()>::Ready(move _3);
30-
discriminant((*(_1.0: &mut {async fn body of a()}))) = 1;
32+
discriminant((*_5)) = 1;
3133
return;
3234
}
3335

tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,23 +86,25 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
8686
let mut _36: ();
8787
let mut _37: ();
8888
let mut _38: u32;
89+
let mut _39: &mut {async fn body of b()};
8990
scope 1 {
90-
debug __awaitee => (((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()});
91+
debug __awaitee => (((*_39) as variant#3).0: {async fn body of a()});
9192
let _17: ();
9293
scope 2 {
9394
debug result => _17;
9495
}
9596
}
9697
scope 3 {
97-
debug __awaitee => (((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()});
98+
debug __awaitee => (((*_39) as variant#4).0: {async fn body of a()});
9899
let _33: ();
99100
scope 4 {
100101
debug result => _33;
101102
}
102103
}
103104

104105
bb0: {
105-
_38 = discriminant((*(_1.0: &mut {async fn body of b()})));
106+
_39 = deref_copy (_1.0: &mut {async fn body of b()});
107+
_38 = discriminant((*_39));
106108
switchInt(move _38) -> [0: bb1, 1: bb29, 3: bb27, 4: bb28, otherwise: bb8];
107109
}
108110

@@ -121,7 +123,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
121123
StorageDead(_5);
122124
PlaceMention(_4);
123125
nop;
124-
(((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()}) = move _4;
126+
(((*_39) as variant#3).0: {async fn body of a()}) = move _4;
125127
goto -> bb4;
126128
}
127129

@@ -131,7 +133,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
131133
StorageLive(_10);
132134
StorageLive(_11);
133135
StorageLive(_12);
134-
_12 = &mut (((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()});
136+
_12 = &mut (((*_39) as variant#3).0: {async fn body of a()});
135137
_11 = &mut (*_12);
136138
_10 = Pin::<&mut {async fn body of a()}>::new_unchecked(move _11) -> [return: bb5, unwind unreachable];
137139
}
@@ -178,7 +180,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
178180
StorageDead(_4);
179181
StorageDead(_19);
180182
StorageDead(_20);
181-
discriminant((*(_1.0: &mut {async fn body of b()}))) = 3;
183+
discriminant((*_39)) = 3;
182184
return;
183185
}
184186

@@ -191,7 +193,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
191193
StorageDead(_12);
192194
StorageDead(_9);
193195
StorageDead(_8);
194-
drop((((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()})) -> [return: bb12, unwind unreachable];
196+
drop((((*_39) as variant#3).0: {async fn body of a()})) -> [return: bb12, unwind unreachable];
195197
}
196198

197199
bb11: {
@@ -223,7 +225,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
223225
StorageDead(_22);
224226
PlaceMention(_21);
225227
nop;
226-
(((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()}) = move _21;
228+
(((*_39) as variant#4).0: {async fn body of a()}) = move _21;
227229
goto -> bb16;
228230
}
229231

@@ -233,7 +235,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
233235
StorageLive(_26);
234236
StorageLive(_27);
235237
StorageLive(_28);
236-
_28 = &mut (((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()});
238+
_28 = &mut (((*_39) as variant#4).0: {async fn body of a()});
237239
_27 = &mut (*_28);
238240
_26 = Pin::<&mut {async fn body of a()}>::new_unchecked(move _27) -> [return: bb17, unwind unreachable];
239241
}
@@ -275,7 +277,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
275277
StorageDead(_21);
276278
StorageDead(_35);
277279
StorageDead(_36);
278-
discriminant((*(_1.0: &mut {async fn body of b()}))) = 4;
280+
discriminant((*_39)) = 4;
279281
return;
280282
}
281283

@@ -288,7 +290,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
288290
StorageDead(_28);
289291
StorageDead(_25);
290292
StorageDead(_24);
291-
drop((((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()})) -> [return: bb23, unwind unreachable];
293+
drop((((*_39) as variant#4).0: {async fn body of a()})) -> [return: bb23, unwind unreachable];
292294
}
293295

294296
bb22: {
@@ -311,7 +313,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
311313

312314
bb25: {
313315
_0 = Poll::<()>::Ready(move _37);
314-
discriminant((*(_1.0: &mut {async fn body of b()}))) = 1;
316+
discriminant((*_39)) = 1;
315317
return;
316318
}
317319

0 commit comments

Comments
 (0)