Skip to content

Commit f24c0ea

Browse files
committed
New mir-opt pass to simplify gotos with const values
Fixes #77355
1 parent 0dce3f6 commit f24c0ea

8 files changed

+260
-117
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
//! This pass optimizes the following sequence
2+
//! ```rust
3+
//! bb2: {
4+
//! _2 = const true;
5+
//! goto -> bb3;
6+
//! }
7+
//!
8+
//! bb3: {
9+
//! switchInt(_2) -> [false: bb4, otherwise: bb5];
10+
//! }
11+
//! ```
12+
//! into
13+
//! ```rust
14+
//! bb2: {
15+
//! _2 = const true;
16+
//! goto -> bb5;
17+
//! }
18+
//! ```
19+
20+
use crate::transform::MirPass;
21+
use rustc_middle::mir::*;
22+
use rustc_middle::ty::TyCtxt;
23+
use rustc_middle::{mir::visit::Visitor, ty::ParamEnv};
24+
25+
use super::simplify::{simplify_cfg, simplify_locals};
26+
27+
pub struct ConstGoto;
28+
29+
impl<'tcx> MirPass<'tcx> for ConstGoto {
30+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
31+
trace!("Running ConstGoto on {:?}", body.source);
32+
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
33+
let mut opt_finder =
34+
ConstGotoOptimizationFinder { tcx, body, optimizations: vec![], param_env };
35+
opt_finder.visit_body(body);
36+
let should_simplify = !opt_finder.optimizations.is_empty();
37+
for opt in opt_finder.optimizations {
38+
let terminator = body.basic_blocks_mut()[opt.bb_with_goto].terminator_mut();
39+
let new_goto = TerminatorKind::Goto { target: opt.target_to_use_in_goto };
40+
debug!("SUCCESS: replacing `{:?}` with `{:?}`", terminator.kind, new_goto);
41+
terminator.kind = new_goto;
42+
}
43+
44+
// if we applied optimizations, we potentially have some cfg to cleanup to
45+
// make it easier for further passes
46+
if should_simplify {
47+
simplify_cfg(body);
48+
simplify_locals(body, tcx);
49+
}
50+
}
51+
}
52+
53+
impl<'a, 'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'a, 'tcx> {
54+
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
55+
let mut bailer = || {
56+
match terminator.kind {
57+
TerminatorKind::Goto { target } => {
58+
// We only apply this optimization if the last statement is a const assignment
59+
let last_statement =
60+
self.body.basic_blocks()[location.block].statements.last()?;
61+
62+
match &last_statement.kind {
63+
StatementKind::Assign(box (place, Rvalue::Use(op))) => {
64+
let _const = op.constant()?;
65+
// We found a constant being assigned to `place`.
66+
// Now check that the target of this Goto switches on this place.
67+
let target_bb = &self.body.basic_blocks()[target];
68+
if !target_bb.statements.is_empty() {
69+
return None;
70+
}
71+
72+
let target_bb_terminator = target_bb.terminator();
73+
match &target_bb_terminator.kind {
74+
TerminatorKind::SwitchInt { discr, switch_ty, targets }
75+
if discr.place() == Some(*place) =>
76+
{
77+
// We now know that the Switch matches on the const place, and it is statementless
78+
// Now find which value in the Switch matches the const value.
79+
let const_value = _const.literal.eval_bits(
80+
self.tcx,
81+
self.param_env,
82+
switch_ty,
83+
);
84+
let found_value_idx_option = targets
85+
.iter()
86+
.enumerate()
87+
.find(|(_, x)| const_value == x.0)
88+
.map(|(idx, _)| idx);
89+
90+
let target_to_use_in_goto =
91+
if let Some(found_value_idx) = found_value_idx_option {
92+
targets.iter().nth(found_value_idx).unwrap().1
93+
} else {
94+
// If we did not find the const value in values, it must be the otherwise case
95+
targets.otherwise()
96+
};
97+
98+
self.optimizations.push(OptimizationToApply {
99+
bb_with_goto: location.block,
100+
target_to_use_in_goto,
101+
});
102+
}
103+
_ => {}
104+
}
105+
}
106+
_ => {}
107+
}
108+
}
109+
_ => {}
110+
}
111+
return Some(());
112+
};
113+
let _ = bailer();
114+
115+
self.super_terminator(terminator, location);
116+
}
117+
}
118+
119+
struct OptimizationToApply {
120+
bb_with_goto: BasicBlock,
121+
target_to_use_in_goto: BasicBlock,
122+
}
123+
124+
pub struct ConstGotoOptimizationFinder<'a, 'tcx> {
125+
tcx: TyCtxt<'tcx>,
126+
body: &'a Body<'tcx>,
127+
param_env: ParamEnv<'tcx>,
128+
optimizations: Vec<OptimizationToApply>,
129+
}

compiler/rustc_mir/src/transform/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub mod check_consts;
2121
pub mod check_packed_ref;
2222
pub mod check_unsafety;
2323
pub mod cleanup_post_borrowck;
24+
pub mod const_goto;
2425
pub mod const_prop;
2526
pub mod deaggregator;
2627
pub mod dest_prop;
@@ -388,6 +389,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
388389

389390
// The main optimizations that we do on MIR.
390391
let optimizations: &[&dyn MirPass<'tcx>] = &[
392+
&const_goto::ConstGoto,
391393
&remove_unneeded_drops::RemoveUnneededDrops,
392394
&match_branches::MatchBranchSimplification,
393395
// inst combine is after MatchBranchSimplification to clean up Ne(_1, false)

compiler/rustc_mir/src/transform/simplify.rs

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -315,48 +315,51 @@ pub fn remove_dead_blocks(body: &mut Body<'_>) {
315315
}
316316
}
317317

318-
pub struct SimplifyLocals;
319-
320-
impl<'tcx> MirPass<'tcx> for SimplifyLocals {
321-
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
322-
trace!("running SimplifyLocals on {:?}", body.source);
323-
324-
// First, we're going to get a count of *actual* uses for every `Local`.
325-
// Take a look at `DeclMarker::visit_local()` to see exactly what is ignored.
326-
let mut used_locals = {
327-
let mut marker = DeclMarker::new(body);
328-
marker.visit_body(&body);
329-
330-
marker.local_counts
331-
};
332-
333-
let arg_count = body.arg_count;
318+
pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) {
319+
// First, we're going to get a count of *actual* uses for every `Local`.
320+
// Take a look at `DeclMarker::visit_local()` to see exactly what is ignored.
321+
let mut used_locals = {
322+
let mut marker = DeclMarker::new(body);
323+
marker.visit_body(&body);
324+
325+
marker.local_counts
326+
};
327+
328+
let arg_count = body.arg_count;
329+
330+
// Next, we're going to remove any `Local` with zero actual uses. When we remove those
331+
// `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
332+
// count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
333+
// `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
334+
// fixedpoint where there are no more unused locals.
335+
loop {
336+
let mut remove_statements = RemoveStatements::new(&mut used_locals, arg_count, tcx);
337+
remove_statements.visit_body(body);
338+
339+
if !remove_statements.modified {
340+
break;
341+
}
342+
}
334343

335-
// Next, we're going to remove any `Local` with zero actual uses. When we remove those
336-
// `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
337-
// count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
338-
// `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
339-
// fixedpoint where there are no more unused locals.
340-
loop {
341-
let mut remove_statements = RemoveStatements::new(&mut used_locals, arg_count, tcx);
342-
remove_statements.visit_body(body);
344+
// Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s.
345+
let map = make_local_map(&mut body.local_decls, used_locals, arg_count);
343346

344-
if !remove_statements.modified {
345-
break;
346-
}
347-
}
347+
// Only bother running the `LocalUpdater` if we actually found locals to remove.
348+
if map.iter().any(Option::is_none) {
349+
// Update references to all vars and tmps now
350+
let mut updater = LocalUpdater { map, tcx };
351+
updater.visit_body(body);
348352

349-
// Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s.
350-
let map = make_local_map(&mut body.local_decls, used_locals, arg_count);
353+
body.local_decls.shrink_to_fit();
354+
}
355+
}
351356

352-
// Only bother running the `LocalUpdater` if we actually found locals to remove.
353-
if map.iter().any(Option::is_none) {
354-
// Update references to all vars and tmps now
355-
let mut updater = LocalUpdater { map, tcx };
356-
updater.visit_body(body);
357+
pub struct SimplifyLocals;
357358

358-
body.local_decls.shrink_to_fit();
359-
}
359+
impl<'tcx> MirPass<'tcx> for SimplifyLocals {
360+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
361+
trace!("running SimplifyLocals on {:?}", body.source);
362+
simplify_locals(body, tcx);
360363
}
361364
}
362365

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
- // MIR for `issue_77355_opt` before ConstGoto
2+
+ // MIR for `issue_77355_opt` after ConstGoto
3+
4+
fn issue_77355_opt(_1: Foo) -> u64 {
5+
debug num => _1; // in scope 0 at $DIR/const_goto.rs:11:20: 11:23
6+
let mut _0: u64; // return place in scope 0 at $DIR/const_goto.rs:11:33: 11:36
7+
- let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
8+
- let mut _3: isize; // in scope 0 at $DIR/const_goto.rs:12:22: 12:28
9+
+ let mut _2: isize; // in scope 0 at $DIR/const_goto.rs:12:22: 12:28
10+
11+
bb0: {
12+
- StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
13+
- _3 = discriminant(_1); // scope 0 at $DIR/const_goto.rs:12:22: 12:28
14+
- switchInt(move _3) -> [1_isize: bb2, 2_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto.rs:12:22: 12:28
15+
+ _2 = discriminant(_1); // scope 0 at $DIR/const_goto.rs:12:22: 12:28
16+
+ switchInt(move _2) -> [1_isize: bb2, 2_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto.rs:12:22: 12:28
17+
}
18+
19+
bb1: {
20+
- _2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
21+
- goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
22+
- }
23+
-
24+
- bb2: {
25+
- _2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
26+
- goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
27+
- }
28+
-
29+
- bb3: {
30+
- switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
31+
- }
32+
-
33+
- bb4: {
34+
_0 = const 42_u64; // scope 0 at $DIR/const_goto.rs:12:53: 12:55
35+
- goto -> bb6; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
36+
+ goto -> bb3; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
37+
}
38+
39+
- bb5: {
40+
+ bb2: {
41+
_0 = const 23_u64; // scope 0 at $DIR/const_goto.rs:12:41: 12:43
42+
- goto -> bb6; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
43+
+ goto -> bb3; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
44+
}
45+
46+
- bb6: {
47+
- StorageDead(_2); // scope 0 at $DIR/const_goto.rs:13:1: 13:2
48+
+ bb3: {
49+
return; // scope 0 at $DIR/const_goto.rs:13:2: 13:2
50+
}
51+
}
52+

src/test/mir-opt/const_goto.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
pub enum Foo {
2+
A,
3+
B,
4+
C,
5+
D,
6+
E,
7+
F,
8+
}
9+
10+
// EMIT_MIR const_goto.issue_77355_opt.ConstGoto.diff
11+
fn issue_77355_opt(num: Foo) -> u64 {
12+
if matches!(num, Foo::B | Foo::C) { 23 } else { 42 }
13+
}
14+
fn main() {
15+
issue_77355_opt(Foo::A);
16+
}

src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.32bit.diff

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,23 @@
44
fn foo(_1: Option<()>) -> () {
55
debug bar => _1; // in scope 0 at $DIR/matches_reduce_branches.rs:6:8: 6:11
66
let mut _0: (); // return place in scope 0 at $DIR/matches_reduce_branches.rs:6:25: 6:25
7-
let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
8-
let mut _3: isize; // in scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
9-
+ let mut _4: isize; // in scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
7+
let mut _2: isize; // in scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
108

119
bb0: {
12-
StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
13-
_3 = discriminant(_1); // scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
14-
- switchInt(move _3) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
15-
+ StorageLive(_4); // scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
16-
+ _4 = move _3; // scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
17-
+ _2 = Eq(_4, const 0_isize); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
18-
+ StorageDead(_4); // scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
19-
+ goto -> bb3; // scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
10+
_2 = discriminant(_1); // scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
11+
switchInt(move _2) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/matches_reduce_branches.rs:7:22: 7:26
2012
}
2113

2214
bb1: {
23-
_2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
24-
goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
15+
_0 = const (); // scope 0 at $DIR/matches_reduce_branches.rs:7:5: 9:6
16+
goto -> bb3; // scope 0 at $DIR/matches_reduce_branches.rs:7:5: 9:6
2517
}
2618

2719
bb2: {
28-
_2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
2920
goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
3021
}
3122

3223
bb3: {
33-
switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/matches_reduce_branches.rs:7:5: 9:6
34-
}
35-
36-
bb4: {
37-
_0 = const (); // scope 0 at $DIR/matches_reduce_branches.rs:7:5: 9:6
38-
goto -> bb5; // scope 0 at $DIR/matches_reduce_branches.rs:7:5: 9:6
39-
}
40-
41-
bb5: {
42-
StorageDead(_2); // scope 0 at $DIR/matches_reduce_branches.rs:10:1: 10:2
4324
return; // scope 0 at $DIR/matches_reduce_branches.rs:10:2: 10:2
4425
}
4526
}

0 commit comments

Comments
 (0)