diff --git a/compiler/rustc_mir_build/src/builder/expr/into.rs b/compiler/rustc_mir_build/src/builder/expr/into.rs index 971acf56ad671..40a881a82d7d4 100644 --- a/compiler/rustc_mir_build/src/builder/expr/into.rs +++ b/compiler/rustc_mir_build/src/builder/expr/into.rs @@ -1,5 +1,6 @@ //! See docs in build/expr/mod.rs +use rustc_abi::VariantIdx; use rustc_ast::{AsmMacro, InlineAsmOptions}; use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::stack::ensure_sufficient_stack; @@ -8,6 +9,9 @@ use rustc_middle::mir::*; use rustc_middle::span_bug; use rustc_middle::thir::*; use rustc_middle::ty::CanonicalUserTypeAnnotation; +use rustc_middle::ty::util::Discr; +use rustc_pattern_analysis::constructor::Constructor; +use rustc_pattern_analysis::rustc::RustcPatCtxt; use rustc_span::source_map::Spanned; use tracing::{debug, instrument}; @@ -244,6 +248,31 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ExprKind::LoopMatch { state, region_scope, ref arms, .. } => { // FIXME add diagram + let dropless_arena = rustc_arena::DroplessArena::default(); + let typeck_results = this.tcx.typeck(this.def_id); + + // FIXME use the lint level from `ExprKind::LoopMatch` + let lint_level = this.tcx.local_def_id_to_hir_id(this.def_id); + + // the PatCtxt is normally used in pattern exhaustiveness checking, but reused here + // because it performs normalization and const evaluation. + let cx = RustcPatCtxt { + tcx: this.tcx, + typeck_results, + module: this.tcx.parent_module(lint_level).to_def_id(), + // FIXME(#132279): We're in a body, should handle opaques. + typing_env: rustc_middle::ty::TypingEnv::non_body_analysis( + this.tcx, + this.def_id, + ), + dropless_arena: &dropless_arena, + match_lint_level: lint_level, + whole_match_span: Some(rustc_span::Span::default()), + scrut_span: rustc_span::Span::default(), + refutable: true, + known_valid_scrutinee: true, + }; + let loop_block = this.cfg.start_new_block(); // Start the loop. @@ -264,77 +293,99 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { this.diverge_from(loop_block); let state_place = unpack!(body_block = this.as_place(body_block, state)); + let state_ty = this.thir.exprs[state].ty; + + // the type of the value that is switched on by the `SwitchInt` + let discr_ty = match state_ty { + ty if ty.is_enum() => ty.discriminant_ty(this.tcx), + ty if ty.is_integral() => ty, + _ => todo!(), + }; + + let rvalue = match state_ty { + ty if ty.is_enum() => Rvalue::Discriminant(state_place), + ty if ty.is_integral() => Rvalue::Use(Operand::Copy(state_place)), + _ => todo!(), + }; + + // block and arm of the wildcard pattern (if any) + let mut otherwise = None; unpack!( body_block = this.in_scope( (region_scope, source_info), LintLevel::Inherited, move |this| { - let unreachable_block = this.cfg.start_new_block(); - this.cfg.terminate( - unreachable_block, - source_info, - TerminatorKind::Unreachable, - ); + let mut arm_blocks = Vec::with_capacity(arms.len()); + for &arm in arms { + let pat = &this.thir[arm].pattern; + let deconstructed_pat = cx.lower_pat(pat); + + match deconstructed_pat.ctor() { + Constructor::Variant(variant_index) => { + let PatKind::Variant { adt_def, .. } = pat.kind else { + unreachable!() + }; + + let discr = adt_def + .discriminant_for_variant(this.tcx, *variant_index); + + let block = this.cfg.start_new_block(); + arm_blocks.push((*variant_index, discr, block, arm)) + } + Constructor::IntRange(int_range) => { + assert!(int_range.is_singleton()); + + let bits = state_ty.primitive_size(this.tcx).bits(); + let value = int_range.lo.as_finite_int(bits).unwrap(); - let mut arm_blocks = arms - .iter() - .map(|&arm| { - let block = this.cfg.start_new_block(); - match &this.thir[arm].pattern.kind { - PatKind::Variant { - adt_def, - args: _, - variant_index, - subpatterns, - } => { - assert!(subpatterns.is_empty()); - - let discr = adt_def - .discriminants(this.tcx) - .find(|(var, _discr)| var == variant_index) - .unwrap() - .1; - - (discr, block, arm) - } - _ => panic!(), + let discr = + Discr { val: value, ty: **deconstructed_pat.ty() }; + + let block = this.cfg.start_new_block(); + arm_blocks.push((VariantIdx::ZERO, discr, block, arm)) + } + Constructor::Wildcard => { + otherwise = Some((this.cfg.start_new_block(), arm)); } - }) - .collect::>(); - arm_blocks.sort_by_cached_key(|&(discr, _, _)| { - match &this.thir[arms[0]].pattern.kind { - PatKind::Variant { adt_def, .. } => adt_def - .discriminants(this.tcx) - .position(|(_, i)| discr.val == i.val) - .unwrap(), - _ => panic!(), + other => todo!("{:?}", other), } - }); + } + + // if we're matching on an enum, the discriminant order in the `SwitchInt` + // targets should match the order yielded by `AdtDef::discriminants`. + if state_ty.is_enum() { + arm_blocks.sort_by_key(|(variant_idx, ..)| *variant_idx); + } let targets = SwitchTargets::new( arm_blocks .iter() - .map(|&(discr, block, _arm)| (discr.val, block)), - unreachable_block, + .map(|&(_, discr, block, _arm)| (discr.val, block)), + if let Some((block, _)) = otherwise { + block + } else { + let unreachable_block = this.cfg.start_new_block(); + this.cfg.terminate( + unreachable_block, + source_info, + TerminatorKind::Unreachable, + ); + unreachable_block + }, ); + this.in_const_continuable_scope( loop_block, targets.clone(), state_place, |this| { - let discr_ty = match &this.thir[arms[0]].pattern.kind { - PatKind::Variant { adt_def, .. } => { - adt_def.discriminants(this.tcx).next().unwrap().1.ty - } - _ => panic!(), - }; let discr = this.temp(discr_ty, source_info.span); this.cfg.push_assign( body_block, source_info, discr, - Rvalue::Discriminant(state_place), + rvalue, ); let discr = Operand::Copy(discr); this.cfg.terminate( @@ -343,7 +394,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { TerminatorKind::SwitchInt { discr, targets }, ); - for (_discr, mut block, arm) in arm_blocks { + let it = arm_blocks + .into_iter() + .map(|(_, _, block, arm)| (block, arm)) + .chain(otherwise); + + for (mut block, arm) in it { let empty_place = this.get_unit_temp(); unpack!( block = this.expr_into_dest( diff --git a/compiler/rustc_mir_build/src/builder/scope.rs b/compiler/rustc_mir_build/src/builder/scope.rs index f031a3f514f41..6bbc3aad322f1 100644 --- a/compiler/rustc_mir_build/src/builder/scope.rs +++ b/compiler/rustc_mir_build/src/builder/scope.rs @@ -83,6 +83,7 @@ that contains only loops and breakable blocks. It tracks where a `break`, use std::mem; +use rustc_ast::LitKind; use rustc_data_structures::fx::FxHashMap; use rustc_hir::HirId; use rustc_index::{IndexSlice, IndexVec}; @@ -731,43 +732,59 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .unwrap_or_else(|| { span_bug!(span, "no enclosing const-continuable scope found") }); - let state_place = self.scopes.const_continuable_scopes[break_index].state_place; let rustc_middle::thir::ExprKind::Scope { value, .. } = self.thir[value.unwrap()].kind else { panic!(); }; - let rustc_middle::thir::ExprKind::Adt(value_adt) = &self.thir[value].kind else { - panic!(); + + let scope = &self.scopes.const_continuable_scopes[break_index]; + + let state_ty = self.local_decls[scope.state_place.as_local().unwrap()].ty; + let discriminant_ty = match state_ty { + ty if ty.is_enum() => ty.discriminant_ty(self.tcx), + ty if ty.is_integral() => ty, + _ => todo!(), }; - //dbg!(&self.thir[value], value_adt); + let rvalue = match state_ty { + ty if ty.is_enum() => Rvalue::Discriminant(scope.state_place), + ty if ty.is_integral() => Rvalue::Use(Operand::Copy(scope.state_place)), + _ => todo!(), + }; + + let real_target = match &self.thir[value].kind { + rustc_middle::thir::ExprKind::Adt(value_adt) => scope + .match_arms + .target_for_value(u128::from(value_adt.variant_index.as_u32())), + rustc_middle::thir::ExprKind::Literal { lit, neg } => match lit.node { + LitKind::Int(n, _) => { + let n = if *neg { + (n.get() as i128).overflowing_neg().0 as u128 + } else { + n.get() + }; + let result = state_ty.primitive_size(self.tcx).truncate(n); + scope.match_arms.target_for_value(result) + } + _ => todo!(), + }, + other => todo!("{other:?}"), + }; self.block_context.push(BlockFrame::SubExpr); + let state_place = scope.state_place; block = self.expr_into_dest(state_place, block, value).into_block(); self.block_context.pop(); - let discr_ty = - self.local_decls[state_place.as_local().unwrap()].ty.discriminant_ty(self.tcx); - let discr = self.temp(discr_ty, source_info.span); + let discr = self.temp(discriminant_ty, source_info.span); let scope = &self.scopes.const_continuable_scopes[break_index]; - self.cfg.push_assign( - block, - source_info, - discr, - Rvalue::Discriminant(scope.state_place), - ); + self.cfg.push_assign(block, source_info, discr, rvalue); self.cfg.terminate( block, source_info, - TerminatorKind::FalseEdge { - real_target: self.scopes.const_continuable_scopes[break_index] - .match_arms - .target_for_value(u128::from(value_adt.variant_index.as_u32())), - imaginary_target: self.scopes.const_continuable_scopes[break_index] - .loop_head, - }, + TerminatorKind::FalseEdge { real_target, imaginary_target: scope.loop_head }, ); return self.cfg.start_new_block().unit(); diff --git a/tests/ui/match/loop-match-integer.rs b/tests/ui/match/loop-match-integer.rs new file mode 100644 index 0000000000000..3a5b9523ad020 --- /dev/null +++ b/tests/ui/match/loop-match-integer.rs @@ -0,0 +1,30 @@ +//@ run-pass + +#![feature(loop_match)] + +fn main() { + let mut state = 0; + #[loop_match] + 'a: loop { + state = 'blk: { + match state { + -1 => { + if true { + #[const_continue] + break 'blk 2; + } else { + // No drops allowed at this point + #[const_continue] + break 'blk 0; + } + } + 0 => { + #[const_continue] + break 'blk -1; + } + 2 => break 'a, + _ => break 'a, + } + } + } +}