Skip to content

Commit 37d92da

Browse files
committed
Renumber locals after state transform.
1 parent 66a0010 commit 37d92da

File tree

4 files changed

+136
-103
lines changed

4 files changed

+136
-103
lines changed

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 93 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ use rustc_hir::lang_items::LangItem;
6868
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
6969
use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
7070
use rustc_index::{Idx, IndexVec};
71-
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
71+
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
7272
use rustc_middle::mir::*;
7373
use rustc_middle::ty::util::Discr;
7474
use rustc_middle::ty::{
@@ -110,6 +110,8 @@ impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
110110
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
111111
if *local == self.from {
112112
*local = self.to;
113+
} else if *local == self.to {
114+
*local = self.from;
113115
}
114116
}
115117

@@ -159,13 +161,15 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
159161
}
160162
}
161163

164+
#[tracing::instrument(level = "trace", skip(tcx))]
162165
fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
163166
place.local = new_base.local;
164167

165168
let mut new_projection = new_base.projection.to_vec();
166169
new_projection.append(&mut place.projection.to_vec());
167170

168171
place.projection = tcx.mk_place_elems(&new_projection);
172+
tracing::trace!(?place);
169173
}
170174

171175
const SELF_ARG: Local = Local::from_u32(1);
@@ -204,8 +208,8 @@ struct TransformVisitor<'tcx> {
204208
// The set of locals that have no `StorageLive`/`StorageDead` annotations.
205209
always_live_locals: DenseBitSet<Local>,
206210

207-
// The original RETURN_PLACE local
208-
old_ret_local: Local,
211+
// New local we just create to hold the `CoroutineState` value.
212+
new_ret_local: Local,
209213

210214
old_yield_ty: Ty<'tcx>,
211215

@@ -270,6 +274,7 @@ impl<'tcx> TransformVisitor<'tcx> {
270274
// `core::ops::CoroutineState` only has single element tuple variants,
271275
// so we can just write to the downcasted first field and then set the
272276
// discriminant to the appropriate variant.
277+
#[tracing::instrument(level = "trace", skip(self, statements))]
273278
fn make_state(
274279
&self,
275280
val: Operand<'tcx>,
@@ -343,11 +348,12 @@ impl<'tcx> TransformVisitor<'tcx> {
343348

344349
statements.push(Statement::new(
345350
source_info,
346-
StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
351+
StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
347352
));
348353
}
349354

350355
// Create a Place referencing a coroutine struct field
356+
#[tracing::instrument(level = "trace", skip(self), ret)]
351357
fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
352358
let self_place = Place::from(SELF_ARG);
353359
let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
@@ -358,6 +364,7 @@ impl<'tcx> TransformVisitor<'tcx> {
358364
}
359365

360366
// Create a statement which changes the discriminant
367+
#[tracing::instrument(level = "trace", skip(self))]
361368
fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
362369
let self_place = Place::from(SELF_ARG);
363370
Statement::new(
@@ -370,6 +377,7 @@ impl<'tcx> TransformVisitor<'tcx> {
370377
}
371378

372379
// Create a statement which reads the discriminant into a temporary
380+
#[tracing::instrument(level = "trace", skip(self, body))]
373381
fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
374382
let temp_decl = LocalDecl::new(self.discr_ty, body.span);
375383
let local_decls_len = body.local_decls.push(temp_decl);
@@ -382,29 +390,48 @@ impl<'tcx> TransformVisitor<'tcx> {
382390
);
383391
(assign, temp)
384392
}
393+
394+
/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
395+
///
396+
/// `local` will be changed to a new local decl with type `ty`.
397+
///
398+
/// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
399+
/// valid value to it before its first use.
400+
#[tracing::instrument(level = "trace", skip(self, body))]
401+
fn replace_local(&mut self, local: Local, new_local: Local, body: &mut Body<'tcx>) -> Local {
402+
body.local_decls.swap(local, new_local);
403+
404+
let mut visitor = RenameLocalVisitor { from: local, to: new_local, tcx: self.tcx };
405+
visitor.visit_body(body);
406+
for suspension in &mut self.suspension_points {
407+
let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
408+
let location = Location { block: START_BLOCK, statement_index: 0 };
409+
visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
410+
}
411+
412+
new_local
413+
}
385414
}
386415

387416
impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
388417
fn tcx(&self) -> TyCtxt<'tcx> {
389418
self.tcx
390419
}
391420

392-
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
421+
#[tracing::instrument(level = "trace", skip(self), ret)]
422+
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
393423
assert!(!self.remap.contains(*local));
394424
}
395425

396-
fn visit_place(
397-
&mut self,
398-
place: &mut Place<'tcx>,
399-
_context: PlaceContext,
400-
_location: Location,
401-
) {
426+
#[tracing::instrument(level = "trace", skip(self), ret)]
427+
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
402428
// Replace an Local in the remap with a coroutine struct access
403429
if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
404430
replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
405431
}
406432
}
407433

434+
#[tracing::instrument(level = "trace", skip(self, data), ret)]
408435
fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
409436
// Remove StorageLive and StorageDead statements for remapped locals
410437
for s in &mut data.statements {
@@ -415,29 +442,35 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
415442
}
416443
}
417444

418-
let ret_val = match data.terminator().kind {
445+
for (statement_index, statement) in data.statements.iter_mut().enumerate() {
446+
let location = Location { block, statement_index };
447+
self.visit_statement(statement, location);
448+
}
449+
450+
let location = Location { block, statement_index: data.statements.len() };
451+
let mut terminator = data.terminator.take().unwrap();
452+
let source_info = terminator.source_info;
453+
match terminator.kind {
419454
TerminatorKind::Return => {
420-
Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None))
421-
}
422-
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
423-
Some((false, Some((resume, resume_arg)), value.clone(), drop))
455+
let mut v = Operand::Move(Place::return_place());
456+
self.visit_operand(&mut v, location);
457+
// We must assign the value first in case it gets declared dead below
458+
self.make_state(v, source_info, true, &mut data.statements);
459+
// State for returned
460+
let state = VariantIdx::new(CoroutineArgs::RETURNED);
461+
data.statements.push(self.set_discr(state, source_info));
462+
terminator.kind = TerminatorKind::Return;
424463
}
425-
_ => None,
426-
};
427-
428-
if let Some((is_return, resume, v, drop)) = ret_val {
429-
let source_info = data.terminator().source_info;
430-
// We must assign the value first in case it gets declared dead below
431-
self.make_state(v, source_info, is_return, &mut data.statements);
432-
let state = if let Some((resume, mut resume_arg)) = resume {
433-
// Yield
434-
let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
435-
464+
TerminatorKind::Yield { mut value, resume, mut resume_arg, drop } => {
436465
// The resume arg target location might itself be remapped if its base local is
437466
// live across a yield.
438-
if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) {
439-
replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
440-
}
467+
self.visit_operand(&mut value, location);
468+
let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
469+
self.visit_place(&mut resume_arg, ctxt, location);
470+
// We must assign the value first in case it gets declared dead below
471+
self.make_state(value.clone(), source_info, false, &mut data.statements);
472+
// Yield
473+
let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
441474

442475
let storage_liveness: GrowableBitSet<Local> =
443476
self.storage_liveness[block].clone().unwrap().into();
@@ -452,7 +485,6 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
452485
.push(Statement::new(source_info, StatementKind::StorageDead(l)));
453486
}
454487
}
455-
456488
self.suspension_points.push(SuspensionPoint {
457489
state,
458490
resume,
@@ -461,16 +493,13 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
461493
storage_liveness,
462494
});
463495

464-
VariantIdx::new(state)
465-
} else {
466-
// Return
467-
VariantIdx::new(CoroutineArgs::RETURNED) // state for returned
468-
};
469-
data.statements.push(self.set_discr(state, source_info));
470-
data.terminator_mut().kind = TerminatorKind::Return;
471-
}
472-
473-
self.super_basic_block_data(block, data);
496+
let state = VariantIdx::new(state);
497+
data.statements.push(self.set_discr(state, source_info));
498+
terminator.kind = TerminatorKind::Return;
499+
}
500+
_ => self.visit_terminator(&mut terminator, location),
501+
};
502+
data.terminator = Some(terminator);
474503
}
475504
}
476505

@@ -483,6 +512,7 @@ fn make_aggregate_adt<'tcx>(
483512
Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
484513
}
485514

515+
#[tracing::instrument(level = "trace", skip(tcx, body))]
486516
fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
487517
let coroutine_ty = body.local_decls.raw[1].ty;
488518

@@ -495,6 +525,7 @@ fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Bo
495525
SelfArgVisitor::new(tcx, ProjectionElem::Deref).visit_body(body);
496526
}
497527

528+
#[tracing::instrument(level = "trace", skip(tcx, body))]
498529
fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
499530
let ref_coroutine_ty = body.local_decls.raw[1].ty;
500531

@@ -511,27 +542,6 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
511542
.visit_body(body);
512543
}
513544

514-
/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
515-
///
516-
/// `local` will be changed to a new local decl with type `ty`.
517-
///
518-
/// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
519-
/// valid value to it before its first use.
520-
fn replace_local<'tcx>(
521-
local: Local,
522-
ty: Ty<'tcx>,
523-
body: &mut Body<'tcx>,
524-
tcx: TyCtxt<'tcx>,
525-
) -> Local {
526-
let new_decl = LocalDecl::new(ty, body.span);
527-
let new_local = body.local_decls.push(new_decl);
528-
body.local_decls.swap(local, new_local);
529-
530-
RenameLocalVisitor { from: local, to: new_local, tcx }.visit_body(body);
531-
532-
new_local
533-
}
534-
535545
/// Transforms the `body` of the coroutine applying the following transforms:
536546
///
537547
/// - Eliminates all the `get_context` calls that async lowering created.
@@ -553,6 +563,7 @@ fn replace_local<'tcx>(
553563
/// The async lowering step and the type / lifetime inference / checking are
554564
/// still using the `ResumeTy` indirection for the time being, and that indirection
555565
/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
566+
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
556567
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> {
557568
let context_mut_ref = Ty::new_task_context(tcx);
558569

@@ -606,6 +617,7 @@ fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local
606617
}
607618

608619
#[cfg_attr(not(debug_assertions), allow(unused))]
620+
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
609621
fn replace_resume_ty_local<'tcx>(
610622
tcx: TyCtxt<'tcx>,
611623
body: &mut Body<'tcx>,
@@ -616,7 +628,7 @@ fn replace_resume_ty_local<'tcx>(
616628
// We have to replace the `ResumeTy` that is used for type and borrow checking
617629
// with `&mut Context<'_>` in MIR.
618630
#[cfg(debug_assertions)]
619-
{
631+
if local_ty != context_mut_ref {
620632
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
621633
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
622634
assert_eq!(*resume_ty_adt, expected_adt);
@@ -670,6 +682,7 @@ struct LivenessInfo {
670682
/// case none exist, the local is considered to be always live.
671683
/// - a local has to be stored if it is either directly used after the
672684
/// the suspend point, or if it is live and has been previously borrowed.
685+
#[tracing::instrument(level = "trace", skip(tcx, body))]
673686
fn locals_live_across_suspend_points<'tcx>(
674687
tcx: TyCtxt<'tcx>,
675688
body: &Body<'tcx>,
@@ -945,6 +958,7 @@ impl StorageConflictVisitor<'_, '_> {
945958
}
946959
}
947960

961+
#[tracing::instrument(level = "trace", skip(liveness, body))]
948962
fn compute_layout<'tcx>(
949963
liveness: LivenessInfo,
950964
body: &Body<'tcx>,
@@ -1049,7 +1063,9 @@ fn compute_layout<'tcx>(
10491063
variant_source_info,
10501064
storage_conflicts,
10511065
};
1066+
debug!(?remap);
10521067
debug!(?layout);
1068+
debug!(?storage_liveness);
10531069

10541070
(remap, layout, storage_liveness)
10551071
}
@@ -1221,6 +1237,7 @@ fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
12211237
}
12221238
}
12231239

1240+
#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
12241241
fn create_coroutine_resume_function<'tcx>(
12251242
tcx: TyCtxt<'tcx>,
12261243
transform: TransformVisitor<'tcx>,
@@ -1299,7 +1316,7 @@ fn create_coroutine_resume_function<'tcx>(
12991316
}
13001317

13011318
/// An operation that can be performed on a coroutine.
1302-
#[derive(PartialEq, Copy, Clone)]
1319+
#[derive(PartialEq, Copy, Clone, Debug)]
13031320
enum Operation {
13041321
Resume,
13051322
Drop,
@@ -1314,6 +1331,7 @@ impl Operation {
13141331
}
13151332
}
13161333

1334+
#[tracing::instrument(level = "trace", skip(transform, body))]
13171335
fn create_cases<'tcx>(
13181336
body: &mut Body<'tcx>,
13191337
transform: &TransformVisitor<'tcx>,
@@ -1445,6 +1463,8 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
14451463
// This only applies to coroutines
14461464
return;
14471465
};
1466+
tracing::trace!(def_id = ?body.source.def_id());
1467+
14481468
let old_ret_ty = body.return_ty();
14491469

14501470
assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none());
@@ -1491,10 +1511,6 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
14911511
}
14921512
};
14931513

1494-
// We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
1495-
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1496-
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
1497-
14981514
// We need to insert clean drop for unresumed state and perform drop elaboration
14991515
// (finally in open_drop_for_tuple) before async drop expansion.
15001516
// Async drops, produced by this drop elaboration, will be expanded,
@@ -1519,6 +1535,11 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15191535
cleanup_async_drops(body);
15201536
}
15211537

1538+
// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1539+
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1540+
let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
1541+
tracing::trace!(?new_ret_local);
1542+
15221543
let always_live_locals = always_storage_live_locals(body);
15231544
let movable = coroutine_kind.movability() == hir::Movability::Movable;
15241545
let liveness_info =
@@ -1553,13 +1574,16 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15531574
storage_liveness,
15541575
always_live_locals,
15551576
suspension_points: Vec::new(),
1556-
old_ret_local,
15571577
discr_ty,
1578+
new_ret_local,
15581579
old_ret_ty,
15591580
old_yield_ty,
15601581
};
15611582
transform.visit_body(body);
15621583

1584+
// Swap the actual `RETURN_PLACE` and the provisional `new_ret_local`.
1585+
transform.replace_local(RETURN_PLACE, new_ret_local, body);
1586+
15631587
// MIR parameters are not explicitly assigned-to when entering the MIR body.
15641588
// If we want to save their values inside the coroutine state, we need to do so explicitly.
15651589
let source_info = SourceInfo::outermost(body.span);

0 commit comments

Comments
 (0)