Skip to content

Commit f18eb33

Browse files
committed
Renumber locals after state transform.
1 parent ebe2d0a commit f18eb33

File tree

4 files changed

+136
-103
lines changed

4 files changed

+136
-103
lines changed

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 93 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ use rustc_hir::lang_items::LangItem;
6868
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
6969
use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
7070
use rustc_index::{Idx, IndexVec};
71-
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
71+
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
7272
use rustc_middle::mir::*;
7373
use rustc_middle::ty::util::Discr;
7474
use rustc_middle::ty::{
@@ -111,6 +111,8 @@ impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
111111
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
112112
if *local == self.from {
113113
*local = self.to;
114+
} else if *local == self.to {
115+
*local = self.from;
114116
}
115117
}
116118

@@ -160,13 +162,15 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
160162
}
161163
}
162164

165+
#[tracing::instrument(level = "trace", skip(tcx))]
163166
fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
164167
place.local = new_base.local;
165168

166169
let mut new_projection = new_base.projection.to_vec();
167170
new_projection.append(&mut place.projection.to_vec());
168171

169172
place.projection = tcx.mk_place_elems(&new_projection);
173+
tracing::trace!(?place);
170174
}
171175

172176
const SELF_ARG: Local = Local::from_u32(1);
@@ -205,8 +209,8 @@ struct TransformVisitor<'tcx> {
205209
// The set of locals that have no `StorageLive`/`StorageDead` annotations.
206210
always_live_locals: DenseBitSet<Local>,
207211

208-
// The original RETURN_PLACE local
209-
old_ret_local: Local,
212+
// New local we just create to hold the `CoroutineState` value.
213+
new_ret_local: Local,
210214

211215
old_yield_ty: Ty<'tcx>,
212216

@@ -271,6 +275,7 @@ impl<'tcx> TransformVisitor<'tcx> {
271275
// `core::ops::CoroutineState` only has single element tuple variants,
272276
// so we can just write to the downcasted first field and then set the
273277
// discriminant to the appropriate variant.
278+
#[tracing::instrument(level = "trace", skip(self, statements))]
274279
fn make_state(
275280
&self,
276281
val: Operand<'tcx>,
@@ -344,11 +349,12 @@ impl<'tcx> TransformVisitor<'tcx> {
344349

345350
statements.push(Statement::new(
346351
source_info,
347-
StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
352+
StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
348353
));
349354
}
350355

351356
// Create a Place referencing a coroutine struct field
357+
#[tracing::instrument(level = "trace", skip(self), ret)]
352358
fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
353359
let self_place = Place::from(SELF_ARG);
354360
let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
@@ -359,6 +365,7 @@ impl<'tcx> TransformVisitor<'tcx> {
359365
}
360366

361367
// Create a statement which changes the discriminant
368+
#[tracing::instrument(level = "trace", skip(self))]
362369
fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
363370
let self_place = Place::from(SELF_ARG);
364371
Statement::new(
@@ -371,6 +378,7 @@ impl<'tcx> TransformVisitor<'tcx> {
371378
}
372379

373380
// Create a statement which reads the discriminant into a temporary
381+
#[tracing::instrument(level = "trace", skip(self, body))]
374382
fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
375383
let temp_decl = LocalDecl::new(self.discr_ty, body.span);
376384
let local_decls_len = body.local_decls.push(temp_decl);
@@ -383,29 +391,48 @@ impl<'tcx> TransformVisitor<'tcx> {
383391
);
384392
(assign, temp)
385393
}
394+
395+
/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
396+
///
397+
/// `local` will be changed to a new local decl with type `ty`.
398+
///
399+
/// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
400+
/// valid value to it before its first use.
401+
#[tracing::instrument(level = "trace", skip(self, body))]
402+
fn replace_local(&mut self, local: Local, new_local: Local, body: &mut Body<'tcx>) -> Local {
403+
body.local_decls.swap(local, new_local);
404+
405+
let mut visitor = RenameLocalVisitor { from: local, to: new_local, tcx: self.tcx };
406+
visitor.visit_body(body);
407+
for suspension in &mut self.suspension_points {
408+
let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
409+
let location = Location { block: START_BLOCK, statement_index: 0 };
410+
visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
411+
}
412+
413+
new_local
414+
}
386415
}
387416

388417
impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
389418
fn tcx(&self) -> TyCtxt<'tcx> {
390419
self.tcx
391420
}
392421

393-
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
422+
#[tracing::instrument(level = "trace", skip(self), ret)]
423+
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
394424
assert!(!self.remap.contains(*local));
395425
}
396426

397-
fn visit_place(
398-
&mut self,
399-
place: &mut Place<'tcx>,
400-
_context: PlaceContext,
401-
_location: Location,
402-
) {
427+
#[tracing::instrument(level = "trace", skip(self), ret)]
428+
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
403429
// Replace an Local in the remap with a coroutine struct access
404430
if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
405431
replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
406432
}
407433
}
408434

435+
#[tracing::instrument(level = "trace", skip(self, data), ret)]
409436
fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
410437
// Remove StorageLive and StorageDead statements for remapped locals
411438
for s in &mut data.statements {
@@ -416,29 +443,35 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
416443
}
417444
}
418445

419-
let ret_val = match data.terminator().kind {
446+
for (statement_index, statement) in data.statements.iter_mut().enumerate() {
447+
let location = Location { block, statement_index };
448+
self.visit_statement(statement, location);
449+
}
450+
451+
let location = Location { block, statement_index: data.statements.len() };
452+
let mut terminator = data.terminator.take().unwrap();
453+
let source_info = terminator.source_info;
454+
match terminator.kind {
420455
TerminatorKind::Return => {
421-
Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None))
422-
}
423-
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
424-
Some((false, Some((resume, resume_arg)), value.clone(), drop))
456+
let mut v = Operand::Move(Place::return_place());
457+
self.visit_operand(&mut v, location);
458+
// We must assign the value first in case it gets declared dead below
459+
self.make_state(v, source_info, true, &mut data.statements);
460+
// State for returned
461+
let state = VariantIdx::new(CoroutineArgs::RETURNED);
462+
data.statements.push(self.set_discr(state, source_info));
463+
terminator.kind = TerminatorKind::Return;
425464
}
426-
_ => None,
427-
};
428-
429-
if let Some((is_return, resume, v, drop)) = ret_val {
430-
let source_info = data.terminator().source_info;
431-
// We must assign the value first in case it gets declared dead below
432-
self.make_state(v, source_info, is_return, &mut data.statements);
433-
let state = if let Some((resume, mut resume_arg)) = resume {
434-
// Yield
435-
let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
436-
465+
TerminatorKind::Yield { mut value, resume, mut resume_arg, drop } => {
437466
// The resume arg target location might itself be remapped if its base local is
438467
// live across a yield.
439-
if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) {
440-
replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
441-
}
468+
self.visit_operand(&mut value, location);
469+
let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
470+
self.visit_place(&mut resume_arg, ctxt, location);
471+
// We must assign the value first in case it gets declared dead below
472+
self.make_state(value.clone(), source_info, false, &mut data.statements);
473+
// Yield
474+
let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
442475

443476
let storage_liveness: GrowableBitSet<Local> =
444477
self.storage_liveness[block].clone().unwrap().into();
@@ -453,7 +486,6 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
453486
.push(Statement::new(source_info, StatementKind::StorageDead(l)));
454487
}
455488
}
456-
457489
self.suspension_points.push(SuspensionPoint {
458490
state,
459491
resume,
@@ -462,16 +494,13 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
462494
storage_liveness,
463495
});
464496

465-
VariantIdx::new(state)
466-
} else {
467-
// Return
468-
VariantIdx::new(CoroutineArgs::RETURNED) // state for returned
469-
};
470-
data.statements.push(self.set_discr(state, source_info));
471-
data.terminator_mut().kind = TerminatorKind::Return;
472-
}
473-
474-
self.super_basic_block_data(block, data);
497+
let state = VariantIdx::new(state);
498+
data.statements.push(self.set_discr(state, source_info));
499+
terminator.kind = TerminatorKind::Return;
500+
}
501+
_ => self.visit_terminator(&mut terminator, location),
502+
};
503+
data.terminator = Some(terminator);
475504
}
476505
}
477506

@@ -484,6 +513,7 @@ fn make_aggregate_adt<'tcx>(
484513
Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
485514
}
486515

516+
#[tracing::instrument(level = "trace", skip(tcx, body))]
487517
fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
488518
let coroutine_ty = body.local_decls.raw[1].ty;
489519

@@ -496,6 +526,7 @@ fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Bo
496526
SelfArgVisitor::new(tcx, ProjectionElem::Deref).visit_body(body);
497527
}
498528

529+
#[tracing::instrument(level = "trace", skip(tcx, body))]
499530
fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
500531
let ref_coroutine_ty = body.local_decls.raw[1].ty;
501532

@@ -512,27 +543,6 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
512543
.visit_body(body);
513544
}
514545

515-
/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
516-
///
517-
/// `local` will be changed to a new local decl with type `ty`.
518-
///
519-
/// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
520-
/// valid value to it before its first use.
521-
fn replace_local<'tcx>(
522-
local: Local,
523-
ty: Ty<'tcx>,
524-
body: &mut Body<'tcx>,
525-
tcx: TyCtxt<'tcx>,
526-
) -> Local {
527-
let new_decl = LocalDecl::new(ty, body.span);
528-
let new_local = body.local_decls.push(new_decl);
529-
body.local_decls.swap(local, new_local);
530-
531-
RenameLocalVisitor { from: local, to: new_local, tcx }.visit_body(body);
532-
533-
new_local
534-
}
535-
536546
/// Transforms the `body` of the coroutine applying the following transforms:
537547
///
538548
/// - Eliminates all the `get_context` calls that async lowering created.
@@ -554,6 +564,7 @@ fn replace_local<'tcx>(
554564
/// The async lowering step and the type / lifetime inference / checking are
555565
/// still using the `ResumeTy` indirection for the time being, and that indirection
556566
/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
567+
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
557568
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> {
558569
let context_mut_ref = Ty::new_task_context(tcx);
559570

@@ -607,6 +618,7 @@ fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local
607618
}
608619

609620
#[cfg_attr(not(debug_assertions), allow(unused))]
621+
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
610622
fn replace_resume_ty_local<'tcx>(
611623
tcx: TyCtxt<'tcx>,
612624
body: &mut Body<'tcx>,
@@ -617,7 +629,7 @@ fn replace_resume_ty_local<'tcx>(
617629
// We have to replace the `ResumeTy` that is used for type and borrow checking
618630
// with `&mut Context<'_>` in MIR.
619631
#[cfg(debug_assertions)]
620-
{
632+
if local_ty != context_mut_ref {
621633
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
622634
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
623635
assert_eq!(*resume_ty_adt, expected_adt);
@@ -671,6 +683,7 @@ struct LivenessInfo {
671683
/// case none exist, the local is considered to be always live.
672684
/// - a local has to be stored if it is either directly used after the
673685
/// the suspend point, or if it is live and has been previously borrowed.
686+
#[tracing::instrument(level = "trace", skip(tcx, body))]
674687
fn locals_live_across_suspend_points<'tcx>(
675688
tcx: TyCtxt<'tcx>,
676689
body: &Body<'tcx>,
@@ -946,6 +959,7 @@ impl StorageConflictVisitor<'_, '_> {
946959
}
947960
}
948961

962+
#[tracing::instrument(level = "trace", skip(liveness, body))]
949963
fn compute_layout<'tcx>(
950964
liveness: LivenessInfo,
951965
body: &Body<'tcx>,
@@ -1050,7 +1064,9 @@ fn compute_layout<'tcx>(
10501064
variant_source_info,
10511065
storage_conflicts,
10521066
};
1067+
debug!(?remap);
10531068
debug!(?layout);
1069+
debug!(?storage_liveness);
10541070

10551071
(remap, layout, storage_liveness)
10561072
}
@@ -1222,6 +1238,7 @@ fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
12221238
}
12231239
}
12241240

1241+
#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
12251242
fn create_coroutine_resume_function<'tcx>(
12261243
tcx: TyCtxt<'tcx>,
12271244
transform: TransformVisitor<'tcx>,
@@ -1300,7 +1317,7 @@ fn create_coroutine_resume_function<'tcx>(
13001317
}
13011318

13021319
/// An operation that can be performed on a coroutine.
1303-
#[derive(PartialEq, Copy, Clone)]
1320+
#[derive(PartialEq, Copy, Clone, Debug)]
13041321
enum Operation {
13051322
Resume,
13061323
Drop,
@@ -1315,6 +1332,7 @@ impl Operation {
13151332
}
13161333
}
13171334

1335+
#[tracing::instrument(level = "trace", skip(transform, body))]
13181336
fn create_cases<'tcx>(
13191337
body: &mut Body<'tcx>,
13201338
transform: &TransformVisitor<'tcx>,
@@ -1446,6 +1464,8 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
14461464
// This only applies to coroutines
14471465
return;
14481466
};
1467+
tracing::trace!(def_id = ?body.source.def_id());
1468+
14491469
let old_ret_ty = body.return_ty();
14501470

14511471
assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none());
@@ -1492,10 +1512,6 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
14921512
}
14931513
};
14941514

1495-
// We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
1496-
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1497-
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
1498-
14991515
// We need to insert clean drop for unresumed state and perform drop elaboration
15001516
// (finally in open_drop_for_tuple) before async drop expansion.
15011517
// Async drops, produced by this drop elaboration, will be expanded,
@@ -1520,6 +1536,11 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15201536
cleanup_async_drops(body);
15211537
}
15221538

1539+
// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1540+
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1541+
let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
1542+
tracing::trace!(?new_ret_local);
1543+
15231544
let always_live_locals = always_storage_live_locals(body);
15241545
let movable = coroutine_kind.movability() == hir::Movability::Movable;
15251546
let liveness_info =
@@ -1554,13 +1575,16 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15541575
storage_liveness,
15551576
always_live_locals,
15561577
suspension_points: Vec::new(),
1557-
old_ret_local,
15581578
discr_ty,
1579+
new_ret_local,
15591580
old_ret_ty,
15601581
old_yield_ty,
15611582
};
15621583
transform.visit_body(body);
15631584

1585+
// Swap the actual `RETURN_PLACE` and the provisional `new_ret_local`.
1586+
transform.replace_local(RETURN_PLACE, new_ret_local, body);
1587+
15641588
// MIR parameters are not explicitly assigned-to when entering the MIR body.
15651589
// If we want to save their values inside the coroutine state, we need to do so explicitly.
15661590
let source_info = SourceInfo::outermost(body.span);

0 commit comments

Comments
 (0)