Skip to content

Commit eab0865

Browse files
committed
Filter costly chains after simplification.
1 parent e5b514d commit eab0865

File tree

3 files changed

+96
-177
lines changed

3 files changed

+96
-177
lines changed

compiler/rustc_mir_transform/src/jump_threading.rs

Lines changed: 79 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@
5151
//!
5252
//! [libfirm]: <https://pp.ipd.kit.edu/uploads/publikationen/priesner17masterarbeit.pdf>
5353
54-
use std::cell::OnceCell;
55-
5654
use itertools::Itertools as _;
5755
use rustc_const_eval::const_eval::DummyMachine;
5856
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
@@ -100,7 +98,6 @@ impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
10098
map: Map::new(tcx, body, Some(MAX_PLACES)),
10199
maybe_loop_headers: maybe_loop_headers(body),
102100
entry_states: IndexVec::from_elem(ConditionSet::default(), &body.basic_blocks),
103-
costs: IndexVec::from_elem(OnceCell::new(), &body.basic_blocks),
104101
};
105102

106103
for (bb, bbdata) in traversal::postorder(body) {
@@ -136,6 +133,7 @@ impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
136133

137134
let mut entry_states = finder.entry_states;
138135
simplify_conditions(body, &mut entry_states);
136+
remove_costly_conditions(tcx, typing_env, body, &mut entry_states);
139137

140138
if let Some(opportunities) = OpportunitySet::new(body, entry_states) {
141139
opportunities.apply();
@@ -159,8 +157,6 @@ struct TOFinder<'a, 'tcx> {
159157
// Invariant: for each `bb`, each condition in `entry_states[bb]` has a `chain` that
160158
// starts with `bb`.
161159
entry_states: IndexVec<BasicBlock, ConditionSet>,
162-
/// Pre-computed cost of duplicating each block.
163-
costs: IndexVec<BasicBlock, OnceCell<usize>>,
164160
}
165161

166162
rustc_index::newtype_index! {
@@ -222,7 +218,6 @@ struct ConditionSet {
222218
active: Vec<(ConditionIndex, Condition)>,
223219
fulfilled: Vec<ConditionIndex>,
224220
targets: IndexVec<ConditionIndex, Vec<EdgeEffect>>,
225-
costs: IndexVec<ConditionIndex, u8>,
226221
}
227222

228223
impl ConditionSet {
@@ -233,7 +228,6 @@ impl ConditionSet {
233228
#[tracing::instrument(level = "trace", skip(self))]
234229
fn push_condition(&mut self, c: Condition, target: BasicBlock) {
235230
let index = self.targets.push(vec![EdgeEffect::Goto { target }]);
236-
self.costs.push(0);
237231
self.active.push((index, c));
238232
}
239233

@@ -293,21 +287,18 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
293287
active: Vec::with_capacity(state_len),
294288
targets: IndexVec::with_capacity(state_len),
295289
fulfilled: Vec::new(),
296-
costs: IndexVec::with_capacity(state_len),
297290
};
298291

299292
// Use an index-set to deduplicate conditions coming from different successor blocks.
300293
let mut known_conditions =
301294
FxIndexSet::with_capacity_and_hasher(state_len, Default::default());
302-
let mut insert = |condition, succ_block, succ_condition, cost| {
295+
let mut insert = |condition, succ_block, succ_condition| {
303296
let (index, new) = known_conditions.insert_full(condition);
304297
let index = ConditionIndex::from_usize(index);
305298
if new {
306299
state.active.push((index, condition));
307300
let _index = state.targets.push(Vec::new());
308301
debug_assert_eq!(_index, index);
309-
let _index = state.costs.push(u8::MAX);
310-
debug_assert_eq!(_index, index);
311302
}
312303
let target = EdgeEffect::Chain { succ_block, succ_condition };
313304
debug_assert!(
@@ -316,7 +307,6 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
316307
&state.targets[index],
317308
);
318309
state.targets[index].push(target);
319-
state.costs[index] = std::cmp::min(state.costs[index], cost);
320310
};
321311

322312
// A given block may have several times the same successor.
@@ -331,35 +321,19 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
331321
continue;
332322
}
333323

334-
let succ_cost = self.cost(succ);
335324
for &(succ_index, cond) in self.entry_states[succ].active.iter() {
336-
let cost = self.entry_states[succ].costs[succ_index];
337-
if let Ok(cost) = ((cost as usize) + succ_cost).try_into()
338-
&& cost < MAX_COST
339-
{
340-
insert(cond, succ, succ_index, cost);
341-
}
325+
insert(cond, succ, succ_index);
342326
}
343327
}
344328

345329
let num_conditions = known_conditions.len();
346330
debug_assert_eq!(num_conditions, state.active.len());
347331
debug_assert_eq!(num_conditions, state.targets.len());
348-
debug_assert_eq!(num_conditions, state.costs.len());
349332
state.fulfilled.reserve(num_conditions);
350333

351334
state
352335
}
353336

354-
fn cost(&self, bb: BasicBlock) -> usize {
355-
*self.costs[bb].get_or_init(|| {
356-
let bbdata = &self.body[bb];
357-
let mut cost = CostChecker::new(self.tcx, self.typing_env, None, self.body);
358-
cost.visit_basic_block_data(bb, bbdata);
359-
cost.cost()
360-
})
361-
}
362-
363337
/// Remove all conditions in the state that alias given place.
364338
fn flood_state(
365339
&self,
@@ -756,8 +730,6 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
756730
// Fulfilling `index` may thread conditions that we do not want,
757731
// so create a brand new index to immediately mark fulfilled.
758732
let index = state.targets.push(new_edges);
759-
let _index = state.costs.push(0);
760-
debug_assert_eq!(_index, index);
761733
state.fulfilled.push(index);
762734
}
763735
}
@@ -870,6 +842,82 @@ fn simplify_conditions(body: &Body<'_>, entry_states: &mut IndexVec<BasicBlock,
870842
}
871843
}
872844

845+
#[instrument(level = "debug", skip(tcx, typing_env, body, entry_states))]
846+
fn remove_costly_conditions<'tcx>(
847+
tcx: TyCtxt<'tcx>,
848+
typing_env: ty::TypingEnv<'tcx>,
849+
body: &Body<'tcx>,
850+
entry_states: &mut IndexVec<BasicBlock, ConditionSet>,
851+
) {
852+
let basic_blocks = &body.basic_blocks;
853+
854+
let mut costs = IndexVec::from_elem(None, basic_blocks);
855+
let mut cost = |bb: BasicBlock| -> u8 {
856+
let c = *costs[bb].get_or_insert_with(|| {
857+
let bbdata = &basic_blocks[bb];
858+
let mut cost = CostChecker::new(tcx, typing_env, None, body);
859+
cost.visit_basic_block_data(bb, bbdata);
860+
cost.cost().try_into().unwrap_or(MAX_COST)
861+
});
862+
trace!("cost[{bb:?}] = {c}");
863+
c
864+
};
865+
866+
// Initialize costs with `MAX_COST`: if we have a cycle, the cyclic `bb` has infinite costs.
867+
let mut condition_cost = IndexVec::from_fn_n(
868+
|bb: BasicBlock| IndexVec::from_elem_n(MAX_COST, entry_states[bb].targets.len()),
869+
entry_states.len(),
870+
);
871+
872+
let reverse_postorder = basic_blocks.reverse_postorder();
873+
874+
for &bb in reverse_postorder.iter().rev() {
875+
let state = &entry_states[bb];
876+
trace!(?bb, ?state);
877+
878+
let mut current_costs = IndexVec::from_elem(0u8, &state.targets);
879+
880+
for (condition, targets) in state.targets.iter_enumerated() {
881+
for &target in targets {
882+
match target {
883+
// A `Goto` has cost 0.
884+
EdgeEffect::Goto { .. } => {}
885+
// Chaining into an already-fulfilled condition is nop.
886+
EdgeEffect::Chain { succ_block, succ_condition }
887+
if entry_states[succ_block].fulfilled.contains(&succ_condition) => {}
888+
// When chaining, use `cost[succ_block][succ_condition] + cost(succ_block)`.
889+
EdgeEffect::Chain { succ_block, succ_condition } => {
890+
// Cost associated with duplicating `succ_block`.
891+
let duplication_cost = cost(succ_block);
892+
// Cost associated with the rest of the chain.
893+
let target_cost =
894+
*condition_cost[succ_block].get(succ_condition).unwrap_or(&MAX_COST);
895+
let cost = current_costs[condition]
896+
.saturating_add(duplication_cost)
897+
.saturating_add(target_cost);
898+
trace!(?condition, ?succ_block, ?duplication_cost, ?target_cost);
899+
current_costs[condition] = cost;
900+
}
901+
}
902+
}
903+
}
904+
905+
trace!("condition_cost[{bb:?}] = {:?}", current_costs);
906+
condition_cost[bb] = current_costs;
907+
}
908+
909+
trace!(?condition_cost);
910+
911+
for &bb in reverse_postorder {
912+
for (index, targets) in entry_states[bb].targets.iter_enumerated_mut() {
913+
if condition_cost[bb][index] >= MAX_COST {
914+
trace!(?bb, ?index, ?targets, c = ?condition_cost[bb][index], "remove");
915+
targets.clear()
916+
}
917+
}
918+
}
919+
}
920+
873921
struct OpportunitySet<'a, 'tcx> {
874922
basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
875923
entry_states: IndexVec<BasicBlock, ConditionSet>,
@@ -892,7 +940,6 @@ impl<'a, 'tcx> OpportunitySet<'a, 'tcx> {
892940
// Free some memory, because we will need to clone condition sets.
893941
for state in entry_states.iter_mut() {
894942
state.active = Default::default();
895-
state.costs = Default::default();
896943
}
897944
let duplicates = Default::default();
898945
let basic_blocks = body.basic_blocks.as_mut();

tests/mir-opt/jump_threading.chained_conditions.JumpThreading.panic-abort.diff

Lines changed: 9 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
bb3: {
7272
_1 = chained_conditions::BacktraceStyle::Off;
7373
- goto -> bb18;
74-
+ goto -> bb35;
74+
+ goto -> bb23;
7575
}
7676

7777
bb4: {
@@ -131,7 +131,7 @@
131131
StorageDead(_5);
132132
StorageDead(_6);
133133
- goto -> bb18;
134-
+ goto -> bb22;
134+
+ goto -> bb21;
135135
}
136136

137137
bb8: {
@@ -154,14 +154,14 @@
154154
StorageDead(_13);
155155
_1 = chained_conditions::BacktraceStyle::Short;
156156
- goto -> bb18;
157-
+ goto -> bb31;
157+
+ goto -> bb23;
158158
}
159159

160160
bb10: {
161161
StorageDead(_12);
162162
StorageDead(_13);
163163
- goto -> bb18;
164-
+ goto -> bb27;
164+
+ goto -> bb21;
165165
}
166166

167167
bb11: {
@@ -219,88 +219,20 @@
219219
+
220220
+ bb21: {
221221
+ _24 = discriminant(_2);
222-
+ switchInt(move _24) -> [1: bb16, otherwise: bb15];
222+
+ switchInt(move _24) -> [1: bb22, otherwise: bb15];
223223
+ }
224224
+
225225
+ bb22: {
226-
+ _24 = discriminant(_2);
227-
+ switchInt(move _24) -> [1: bb25, otherwise: bb23];
226+
+ goto -> bb15;
228227
+ }
229228
+
230229
+ bb23: {
231-
+ _22 = const false;
232-
+ _23 = const false;
233-
+ StorageDead(_2);
234-
+ _19 = discriminant(_1);
235-
+ goto -> bb11;
236-
+ }
237-
+
238-
+ bb24: {
239-
+ switchInt(copy _23) -> [0: bb15, otherwise: bb17];
240-
+ }
241-
+
242-
+ bb25: {
243-
+ goto -> bb23;
244-
+ }
245-
+
246-
+ bb26: {
247-
+ _24 = discriminant(_2);
248-
+ switchInt(move _24) -> [1: bb16, otherwise: bb15];
249-
+ }
250-
+
251-
+ bb27: {
252230
+ _24 = discriminant(_2);
253-
+ switchInt(move _24) -> [1: bb30, otherwise: bb28];
254-
+ }
255-
+
256-
+ bb28: {
257-
+ _22 = const false;
258-
+ _23 = const false;
259-
+ StorageDead(_2);
260-
+ _19 = discriminant(_1);
261-
+ goto -> bb13;
262-
+ }
263-
+
264-
+ bb29: {
265-
+ switchInt(copy _23) -> [0: bb15, otherwise: bb17];
231+
+ switchInt(move _24) -> [1: bb24, otherwise: bb15];
266232
+ }
267233
+
268-
+ bb30: {
269-
+ goto -> bb28;
270-
+ }
271-
+
272-
+ bb31: {
273-
+ _24 = discriminant(_2);
274-
+ switchInt(move _24) -> [1: bb33, otherwise: bb32];
275-
+ }
276-
+
277-
+ bb32: {
278-
+ _22 = const false;
279-
+ _23 = const false;
280-
+ StorageDead(_2);
281-
+ _19 = discriminant(_1);
282-
+ goto -> bb12;
283-
+ }
284-
+
285-
+ bb33: {
286-
+ switchInt(copy _23) -> [0: bb32, otherwise: bb34];
287-
+ }
288-
+
289-
+ bb34: {
290-
+ drop(((_2 as Some).0: std::string::String)) -> [return: bb32, unwind unreachable];
291-
+ }
292-
+
293-
+ bb35: {
294-
+ _24 = discriminant(_2);
295-
+ switchInt(move _24) -> [1: bb36, otherwise: bb28];
296-
+ }
297-
+
298-
+ bb36: {
299-
+ goto -> bb37;
300-
+ }
301-
+
302-
+ bb37: {
303-
+ drop(((_2 as Some).0: std::string::String)) -> [return: bb28, unwind unreachable];
234+
+ bb24: {
235+
+ goto -> bb17;
304236
}
305237
}
306238

0 commit comments

Comments
 (0)