Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 104 additions & 48 deletions compiler/rustc_mir_build/src/builder/expr/into.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! See docs in build/expr/mod.rs

use rustc_abi::VariantIdx;
use rustc_ast::{AsmMacro, InlineAsmOptions};
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::stack::ensure_sufficient_stack;
Expand All @@ -8,6 +9,9 @@ use rustc_middle::mir::*;
use rustc_middle::span_bug;
use rustc_middle::thir::*;
use rustc_middle::ty::CanonicalUserTypeAnnotation;
use rustc_middle::ty::util::Discr;
use rustc_pattern_analysis::constructor::Constructor;
use rustc_pattern_analysis::rustc::RustcPatCtxt;
use rustc_span::source_map::Spanned;
use tracing::{debug, instrument};

Expand Down Expand Up @@ -244,6 +248,31 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
ExprKind::LoopMatch { state, region_scope, ref arms, .. } => {
// FIXME add diagram

let dropless_arena = rustc_arena::DroplessArena::default();
let typeck_results = this.tcx.typeck(this.def_id);

// FIXME use the lint level from `ExprKind::LoopMatch`
let lint_level = this.tcx.local_def_id_to_hir_id(this.def_id);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a commented out lint_level field in LoopMatch. That one should probably be uncommented and used here instead. The current code would use the lint level for the whole function. Alternatively adding a FIXME would be fine.


// the PatCtxt is normally used in pattern exhaustiveness checking, but reused here
// because it performs normalization and const evaluation.
let cx = RustcPatCtxt {
tcx: this.tcx,
typeck_results,
module: this.tcx.parent_module(lint_level).to_def_id(),
// FIXME(#132279): We're in a body, should handle opaques.
typing_env: rustc_middle::ty::TypingEnv::non_body_analysis(
this.tcx,
this.def_id,
),
dropless_arena: &dropless_arena,
match_lint_level: lint_level,
whole_match_span: Some(rustc_span::Span::default()),
scrut_span: rustc_span::Span::default(),
refutable: true,
known_valid_scrutinee: true,
};

let loop_block = this.cfg.start_new_block();

// Start the loop.
Expand All @@ -264,77 +293,99 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
this.diverge_from(loop_block);

let state_place = unpack!(body_block = this.as_place(body_block, state));
let state_ty = this.thir.exprs[state].ty;

// the type of the value that is switched on by the `SwitchInt`
let discr_ty = match state_ty {
ty if ty.is_enum() => ty.discriminant_ty(this.tcx),
ty if ty.is_integral() => ty,
_ => todo!(),
};

let rvalue = match state_ty {
ty if ty.is_enum() => Rvalue::Discriminant(state_place),
ty if ty.is_integral() => Rvalue::Use(Operand::Copy(state_place)),
_ => todo!(),
};

// block and arm of the wildcard pattern (if any)
let mut otherwise = None;

unpack!(
body_block = this.in_scope(
(region_scope, source_info),
LintLevel::Inherited,
move |this| {
let unreachable_block = this.cfg.start_new_block();
this.cfg.terminate(
unreachable_block,
source_info,
TerminatorKind::Unreachable,
);
let mut arm_blocks = Vec::with_capacity(arms.len());
for &arm in arms {
let pat = &this.thir[arm].pattern;
let deconstructed_pat = cx.lower_pat(pat);

match deconstructed_pat.ctor() {
Constructor::Variant(variant_index) => {
let PatKind::Variant { adt_def, .. } = pat.kind else {
unreachable!()
};

let discr = adt_def
.discriminant_for_variant(this.tcx, *variant_index);

let block = this.cfg.start_new_block();
arm_blocks.push((*variant_index, discr, block, arm))
}
Constructor::IntRange(int_range) => {
assert!(int_range.is_singleton());

let bits = state_ty.primitive_size(this.tcx).bits();
let value = int_range.lo.as_finite_int(bits).unwrap();

let mut arm_blocks = arms
.iter()
.map(|&arm| {
let block = this.cfg.start_new_block();
match &this.thir[arm].pattern.kind {
PatKind::Variant {
adt_def,
args: _,
variant_index,
subpatterns,
} => {
assert!(subpatterns.is_empty());

let discr = adt_def
.discriminants(this.tcx)
.find(|(var, _discr)| var == variant_index)
.unwrap()
.1;

(discr, block, arm)
}
_ => panic!(),
let discr =
Discr { val: value, ty: **deconstructed_pat.ty() };

let block = this.cfg.start_new_block();
arm_blocks.push((VariantIdx::ZERO, discr, block, arm))
}
Constructor::Wildcard => {
otherwise = Some((this.cfg.start_new_block(), arm));
}
})
.collect::<Vec<_>>();
arm_blocks.sort_by_cached_key(|&(discr, _, _)| {
match &this.thir[arms[0]].pattern.kind {
PatKind::Variant { adt_def, .. } => adt_def
.discriminants(this.tcx)
.position(|(_, i)| discr.val == i.val)
.unwrap(),
_ => panic!(),
other => todo!("{:?}", other),
}
});
}

// if we're matching on an enum, the discriminant order in the `SwitchInt`
// targets should match the order yielded by `AdtDef::discriminants`.
if state_ty.is_enum() {
arm_blocks.sort_by_key(|(variant_idx, ..)| *variant_idx);
}

let targets = SwitchTargets::new(
arm_blocks
.iter()
.map(|&(discr, block, _arm)| (discr.val, block)),
unreachable_block,
.map(|&(_, discr, block, _arm)| (discr.val, block)),
if let Some((block, _)) = otherwise {
block
} else {
let unreachable_block = this.cfg.start_new_block();
this.cfg.terminate(
unreachable_block,
source_info,
TerminatorKind::Unreachable,
);
unreachable_block
},
);

this.in_const_continuable_scope(
loop_block,
targets.clone(),
state_place,
|this| {
let discr_ty = match &this.thir[arms[0]].pattern.kind {
PatKind::Variant { adt_def, .. } => {
adt_def.discriminants(this.tcx).next().unwrap().1.ty
}
_ => panic!(),
};
let discr = this.temp(discr_ty, source_info.span);
this.cfg.push_assign(
body_block,
source_info,
discr,
Rvalue::Discriminant(state_place),
rvalue,
);
let discr = Operand::Copy(discr);
this.cfg.terminate(
Expand All @@ -343,7 +394,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
TerminatorKind::SwitchInt { discr, targets },
);

for (_discr, mut block, arm) in arm_blocks {
let it = arm_blocks
.into_iter()
.map(|(_, _, block, arm)| (block, arm))
.chain(otherwise);

for (mut block, arm) in it {
let empty_place = this.get_unit_temp();
unpack!(
block = this.expr_into_dest(
Expand Down
57 changes: 37 additions & 20 deletions compiler/rustc_mir_build/src/builder/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ that contains only loops and breakable blocks. It tracks where a `break`,

use std::mem;

use rustc_ast::LitKind;
use rustc_data_structures::fx::FxHashMap;
use rustc_hir::HirId;
use rustc_index::{IndexSlice, IndexVec};
Expand Down Expand Up @@ -731,43 +732,59 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.unwrap_or_else(|| {
span_bug!(span, "no enclosing const-continuable scope found")
});
let state_place = self.scopes.const_continuable_scopes[break_index].state_place;

let rustc_middle::thir::ExprKind::Scope { value, .. } =
self.thir[value.unwrap()].kind
else {
panic!();
};
let rustc_middle::thir::ExprKind::Adt(value_adt) = &self.thir[value].kind else {
panic!();

let scope = &self.scopes.const_continuable_scopes[break_index];

let state_ty = self.local_decls[scope.state_place.as_local().unwrap()].ty;
let discriminant_ty = match state_ty {
ty if ty.is_enum() => ty.discriminant_ty(self.tcx),
ty if ty.is_integral() => ty,
_ => todo!(),
};

//dbg!(&self.thir[value], value_adt);
let rvalue = match state_ty {
ty if ty.is_enum() => Rvalue::Discriminant(scope.state_place),
ty if ty.is_integral() => Rvalue::Use(Operand::Copy(scope.state_place)),
_ => todo!(),
};

let real_target = match &self.thir[value].kind {
rustc_middle::thir::ExprKind::Adt(value_adt) => scope
.match_arms
.target_for_value(u128::from(value_adt.variant_index.as_u32())),
rustc_middle::thir::ExprKind::Literal { lit, neg } => match lit.node {
LitKind::Int(n, _) => {
let n = if *neg {
(n.get() as i128).overflowing_neg().0 as u128
} else {
n.get()
};
let result = state_ty.primitive_size(self.tcx).truncate(n);
scope.match_arms.target_for_value(result)
}
_ => todo!(),
},
other => todo!("{other:?}"),
};

self.block_context.push(BlockFrame::SubExpr);
let state_place = scope.state_place;
block = self.expr_into_dest(state_place, block, value).into_block();
self.block_context.pop();

let discr_ty =
self.local_decls[state_place.as_local().unwrap()].ty.discriminant_ty(self.tcx);
let discr = self.temp(discr_ty, source_info.span);
let discr = self.temp(discriminant_ty, source_info.span);
let scope = &self.scopes.const_continuable_scopes[break_index];
self.cfg.push_assign(
block,
source_info,
discr,
Rvalue::Discriminant(scope.state_place),
);
self.cfg.push_assign(block, source_info, discr, rvalue);
self.cfg.terminate(
block,
source_info,
TerminatorKind::FalseEdge {
real_target: self.scopes.const_continuable_scopes[break_index]
.match_arms
.target_for_value(u128::from(value_adt.variant_index.as_u32())),
imaginary_target: self.scopes.const_continuable_scopes[break_index]
.loop_head,
},
TerminatorKind::FalseEdge { real_target, imaginary_target: scope.loop_head },
);

return self.cfg.start_new_block().unit();
Expand Down
30 changes: 30 additions & 0 deletions tests/ui/match/loop-match-integer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//@ run-pass

#![feature(loop_match)]

fn main() {
let mut state = 0;
#[loop_match]
'a: loop {
state = 'blk: {
match state {
-1 => {
if true {
#[const_continue]
break 'blk 2;
} else {
// No drops allowed at this point
#[const_continue]
break 'blk 0;
}
}
0 => {
#[const_continue]
break 'blk -1;
}
2 => break 'a,
_ => break 'a,
}
}
}
}
Loading