Skip to content

Commit 3a5c954

Browse files
committed
Filter costly chains after simplification.
1 parent 8b77331 commit 3a5c954

File tree

3 files changed

+96
-177
lines changed

3 files changed

+96
-177
lines changed

compiler/rustc_mir_transform/src/jump_threading.rs

Lines changed: 79 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@
5151
//!
5252
//! [libfirm]: <https://pp.ipd.kit.edu/uploads/publikationen/priesner17masterarbeit.pdf>
5353
54-
use std::cell::OnceCell;
55-
5654
use itertools::Itertools as _;
5755
use rustc_const_eval::const_eval::DummyMachine;
5856
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
@@ -100,7 +98,6 @@ impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
10098
map: Map::new(tcx, body, Some(MAX_PLACES)),
10199
loop_headers: loop_headers(body),
102100
entry_states: IndexVec::from_elem(ConditionSet::default(), &body.basic_blocks),
103-
costs: IndexVec::from_elem(OnceCell::new(), &body.basic_blocks),
104101
};
105102

106103
for (bb, bbdata) in traversal::postorder(body) {
@@ -136,6 +133,7 @@ impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
136133

137134
let mut entry_states = finder.entry_states;
138135
simplify_conditions(body, &mut entry_states);
136+
remove_costly_conditions(tcx, typing_env, body, &mut entry_states);
139137

140138
if let Some(opportunities) = OpportunitySet::new(body, entry_states) {
141139
opportunities.apply();
@@ -159,8 +157,6 @@ struct TOFinder<'a, 'tcx> {
159157
// Invariant: for each `bb`, each condition in `entry_states[bb]` has a `chain` that
160158
// starts with `bb`.
161159
entry_states: IndexVec<BasicBlock, ConditionSet>,
162-
/// Pre-computed cost of duplicating each block.
163-
costs: IndexVec<BasicBlock, OnceCell<usize>>,
164160
}
165161

166162
rustc_index::newtype_index! {
@@ -219,7 +215,6 @@ struct ConditionSet {
219215
active: Vec<(ConditionIndex, Condition)>,
220216
fulfilled: Vec<ConditionIndex>,
221217
targets: IndexVec<ConditionIndex, Vec<ConditionTarget>>,
222-
costs: IndexVec<ConditionIndex, u8>,
223218
}
224219

225220
impl ConditionSet {
@@ -230,7 +225,6 @@ impl ConditionSet {
230225
#[tracing::instrument(level = "trace", skip(self))]
231226
fn push_condition(&mut self, c: Condition, succ: BasicBlock) {
232227
let index = self.targets.push(vec![ConditionTarget::Goto(succ)]);
233-
self.costs.push(0);
234228
self.active.push((index, c));
235229
}
236230

@@ -290,21 +284,18 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
290284
active: Vec::with_capacity(state_len),
291285
targets: IndexVec::with_capacity(state_len),
292286
fulfilled: Vec::new(),
293-
costs: IndexVec::with_capacity(state_len),
294287
};
295288

296289
// Use an index-set to deduplicate conditions coming from different successor blocks.
297290
let mut known_conditions =
298291
FxIndexSet::with_capacity_and_hasher(state_len, Default::default());
299-
let mut insert = |condition, succ_block, succ_cond, cost| {
292+
let mut insert = |condition, succ_block, succ_cond| {
300293
let (index, new) = known_conditions.insert_full(condition);
301294
let index = ConditionIndex::from_usize(index);
302295
if new {
303296
state.active.push((index, condition));
304297
let _index = state.targets.push(Vec::new());
305298
debug_assert_eq!(_index, index);
306-
let _index = state.costs.push(u8::MAX);
307-
debug_assert_eq!(_index, index);
308299
}
309300
let target = ConditionTarget::Chain(succ_block, succ_cond);
310301
debug_assert!(
@@ -313,7 +304,6 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
313304
&state.targets[index],
314305
);
315306
state.targets[index].push(target);
316-
state.costs[index] = std::cmp::min(state.costs[index], cost);
317307
};
318308

319309
// A given block may have several times the same successor.
@@ -328,35 +318,19 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
328318
continue;
329319
}
330320

331-
let succ_cost = self.cost(succ);
332321
for &(succ_index, cond) in self.entry_states[succ].active.iter() {
333-
let cost = self.entry_states[succ].costs[succ_index];
334-
if let Ok(cost) = ((cost as usize) + succ_cost).try_into()
335-
&& cost < MAX_COST
336-
{
337-
insert(cond, succ, succ_index, cost);
338-
}
322+
insert(cond, succ, succ_index);
339323
}
340324
}
341325

342326
let num_conditions = known_conditions.len();
343327
debug_assert_eq!(num_conditions, state.active.len());
344328
debug_assert_eq!(num_conditions, state.targets.len());
345-
debug_assert_eq!(num_conditions, state.costs.len());
346329
state.fulfilled.reserve(num_conditions);
347330

348331
state
349332
}
350333

351-
fn cost(&self, bb: BasicBlock) -> usize {
352-
*self.costs[bb].get_or_init(|| {
353-
let bbdata = &self.body[bb];
354-
let mut cost = CostChecker::new(self.tcx, self.typing_env, None, self.body);
355-
cost.visit_basic_block_data(bb, bbdata);
356-
cost.cost()
357-
})
358-
}
359-
360334
/// Remove all conditions in the state that alias given place.
361335
fn flood_state(
362336
&self,
@@ -753,8 +727,6 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
753727
// Fulfilling `index` may thread conditions that we do not want,
754728
// so create a brand new index to immediately mark fulfilled.
755729
let index = state.targets.push(new_edges);
756-
let _index = state.costs.push(0);
757-
debug_assert_eq!(_index, index);
758730
state.fulfilled.push(index);
759731
}
760732
}
@@ -867,6 +839,82 @@ fn simplify_conditions(body: &Body<'_>, entry_states: &mut IndexVec<BasicBlock,
867839
}
868840
}
869841

842+
#[instrument(level = "debug", skip(tcx, typing_env, body, entry_states))]
843+
fn remove_costly_conditions<'tcx>(
844+
tcx: TyCtxt<'tcx>,
845+
typing_env: ty::TypingEnv<'tcx>,
846+
body: &Body<'tcx>,
847+
entry_states: &mut IndexVec<BasicBlock, ConditionSet>,
848+
) {
849+
let basic_blocks = &body.basic_blocks;
850+
851+
let mut costs = IndexVec::from_elem(None, basic_blocks);
852+
let mut cost = |bb: BasicBlock| -> u8 {
853+
let c = *costs[bb].get_or_insert_with(|| {
854+
let bbdata = &basic_blocks[bb];
855+
let mut cost = CostChecker::new(tcx, typing_env, None, body);
856+
cost.visit_basic_block_data(bb, bbdata);
857+
cost.cost().try_into().unwrap_or(MAX_COST)
858+
});
859+
trace!("cost[{bb:?}] = {c}");
860+
c
861+
};
862+
863+
// Initialize costs with `MAX_COST`: if we have a cycle, the cyclic `bb` has infinite costs.
864+
let mut condition_cost = IndexVec::from_fn_n(
865+
|bb: BasicBlock| IndexVec::from_elem_n(MAX_COST, entry_states[bb].targets.len()),
866+
entry_states.len(),
867+
);
868+
869+
let reverse_postorder = basic_blocks.reverse_postorder();
870+
871+
for &bb in reverse_postorder.iter().rev() {
872+
let state = &entry_states[bb];
873+
trace!(?bb, ?state);
874+
875+
let mut current_costs = IndexVec::from_elem(0u8, &state.targets);
876+
877+
for (condition, targets) in state.targets.iter_enumerated() {
878+
for &target in targets {
879+
match target {
880+
// A `Goto` has cost 0.
881+
ConditionTarget::Goto(_) => {}
882+
// Chaining into an already-fulfilled condition is nop.
883+
ConditionTarget::Chain(target, target_condition)
884+
if entry_states[target].fulfilled.contains(&target_condition) => {}
885+
// When chaining, use `cost[target][target_condition] + cost(target)`.
886+
ConditionTarget::Chain(target, target_condition) => {
887+
// Cost associated with duplicating `target`.
888+
let duplication_cost = cost(target);
889+
// Cost associated with the rest of the chain.
890+
let target_cost =
891+
*condition_cost[target].get(target_condition).unwrap_or(&MAX_COST);
892+
let cost = current_costs[condition]
893+
.saturating_add(duplication_cost)
894+
.saturating_add(target_cost);
895+
trace!(?condition, ?target, ?duplication_cost, ?target_cost);
896+
current_costs[condition] = cost;
897+
}
898+
}
899+
}
900+
}
901+
902+
trace!("condition_cost[{bb:?}] = {:?}", current_costs);
903+
condition_cost[bb] = current_costs;
904+
}
905+
906+
trace!(?condition_cost);
907+
908+
for &bb in reverse_postorder {
909+
for (index, targets) in entry_states[bb].targets.iter_enumerated_mut() {
910+
if condition_cost[bb][index] >= MAX_COST {
911+
trace!(?bb, ?index, ?targets, c = ?condition_cost[bb][index], "remove");
912+
targets.clear()
913+
}
914+
}
915+
}
916+
}
917+
870918
struct OpportunitySet<'a, 'tcx> {
871919
basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
872920
entry_states: IndexVec<BasicBlock, ConditionSet>,
@@ -889,7 +937,6 @@ impl<'a, 'tcx> OpportunitySet<'a, 'tcx> {
889937
// Free some memory, because we will need to clone condition sets.
890938
for state in entry_states.iter_mut() {
891939
state.active = Default::default();
892-
state.costs = Default::default();
893940
}
894941
let duplicates = Default::default();
895942
let basic_blocks = body.basic_blocks.as_mut();

tests/mir-opt/jump_threading.chained_conditions.JumpThreading.panic-abort.diff

Lines changed: 9 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
bb3: {
7272
_1 = chained_conditions::BacktraceStyle::Off;
7373
- goto -> bb18;
74-
+ goto -> bb35;
74+
+ goto -> bb23;
7575
}
7676

7777
bb4: {
@@ -131,7 +131,7 @@
131131
StorageDead(_5);
132132
StorageDead(_6);
133133
- goto -> bb18;
134-
+ goto -> bb22;
134+
+ goto -> bb21;
135135
}
136136

137137
bb8: {
@@ -154,14 +154,14 @@
154154
StorageDead(_13);
155155
_1 = chained_conditions::BacktraceStyle::Short;
156156
- goto -> bb18;
157-
+ goto -> bb31;
157+
+ goto -> bb23;
158158
}
159159

160160
bb10: {
161161
StorageDead(_12);
162162
StorageDead(_13);
163163
- goto -> bb18;
164-
+ goto -> bb27;
164+
+ goto -> bb21;
165165
}
166166

167167
bb11: {
@@ -219,88 +219,20 @@
219219
+
220220
+ bb21: {
221221
+ _24 = discriminant(_2);
222-
+ switchInt(move _24) -> [1: bb16, otherwise: bb15];
222+
+ switchInt(move _24) -> [1: bb22, otherwise: bb15];
223223
+ }
224224
+
225225
+ bb22: {
226-
+ _24 = discriminant(_2);
227-
+ switchInt(move _24) -> [1: bb25, otherwise: bb23];
226+
+ goto -> bb15;
228227
+ }
229228
+
230229
+ bb23: {
231-
+ _22 = const false;
232-
+ _23 = const false;
233-
+ StorageDead(_2);
234-
+ _19 = discriminant(_1);
235-
+ goto -> bb11;
236-
+ }
237-
+
238-
+ bb24: {
239-
+ switchInt(copy _23) -> [0: bb15, otherwise: bb17];
240-
+ }
241-
+
242-
+ bb25: {
243-
+ goto -> bb23;
244-
+ }
245-
+
246-
+ bb26: {
247-
+ _24 = discriminant(_2);
248-
+ switchInt(move _24) -> [1: bb16, otherwise: bb15];
249-
+ }
250-
+
251-
+ bb27: {
252230
+ _24 = discriminant(_2);
253-
+ switchInt(move _24) -> [1: bb30, otherwise: bb28];
254-
+ }
255-
+
256-
+ bb28: {
257-
+ _22 = const false;
258-
+ _23 = const false;
259-
+ StorageDead(_2);
260-
+ _19 = discriminant(_1);
261-
+ goto -> bb13;
262-
+ }
263-
+
264-
+ bb29: {
265-
+ switchInt(copy _23) -> [0: bb15, otherwise: bb17];
231+
+ switchInt(move _24) -> [1: bb24, otherwise: bb15];
266232
+ }
267233
+
268-
+ bb30: {
269-
+ goto -> bb28;
270-
+ }
271-
+
272-
+ bb31: {
273-
+ _24 = discriminant(_2);
274-
+ switchInt(move _24) -> [1: bb33, otherwise: bb32];
275-
+ }
276-
+
277-
+ bb32: {
278-
+ _22 = const false;
279-
+ _23 = const false;
280-
+ StorageDead(_2);
281-
+ _19 = discriminant(_1);
282-
+ goto -> bb12;
283-
+ }
284-
+
285-
+ bb33: {
286-
+ switchInt(copy _23) -> [0: bb32, otherwise: bb34];
287-
+ }
288-
+
289-
+ bb34: {
290-
+ drop(((_2 as Some).0: std::string::String)) -> [return: bb32, unwind unreachable];
291-
+ }
292-
+
293-
+ bb35: {
294-
+ _24 = discriminant(_2);
295-
+ switchInt(move _24) -> [1: bb36, otherwise: bb28];
296-
+ }
297-
+
298-
+ bb36: {
299-
+ goto -> bb37;
300-
+ }
301-
+
302-
+ bb37: {
303-
+ drop(((_2 as Some).0: std::string::String)) -> [return: bb28, unwind unreachable];
234+
+ bb24: {
235+
+ goto -> bb17;
304236
}
305237
}
306238

0 commit comments

Comments
 (0)