|
| 1 | +use crate::MirPass; |
| 2 | +use rustc_middle::mir::patch::MirPatch; |
| 3 | +use rustc_middle::mir::*; |
| 4 | +use rustc_middle::ty::TyCtxt; |
| 5 | + |
| 6 | +pub struct RefCmpSimplify; |
| 7 | + |
| 8 | +impl<'tcx> MirPass<'tcx> for RefCmpSimplify { |
| 9 | + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
| 10 | + self.simplify_ref_cmp(tcx, body) |
| 11 | + } |
| 12 | +} |
| 13 | + |
| 14 | +#[derive(Clone, Copy, Debug, PartialEq, Eq)] |
| 15 | +enum MatchState { |
| 16 | + Empty, |
| 17 | + Deref { src_statement_idx: usize, dst: Local, src: Local }, |
| 18 | + CopiedFrom { src_statement_idx: usize, dst: Local, real_src: Local }, |
| 19 | + Completed { src_statement_idx: usize, dst: Local, real_src: Local }, |
| 20 | +} |
| 21 | + |
| 22 | +impl RefCmpSimplify { |
| 23 | + fn simplify_ref_cmp<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
| 24 | + debug!("body: {:#?}", body); |
| 25 | + |
| 26 | + let n_bbs = body.basic_blocks.len() as u32; |
| 27 | + for bb in 0..n_bbs { |
| 28 | + let bb = BasicBlock::from_u32(bb); |
| 29 | + let mut max = Local::MAX; |
| 30 | + 'repeat: loop { |
| 31 | + let mut state = MatchState::Empty; |
| 32 | + let bb_data = &body.basic_blocks[bb]; |
| 33 | + for (i, stmt) in bb_data.statements.iter().enumerate().rev() { |
| 34 | + state = match (state, &stmt.kind) { |
| 35 | + ( |
| 36 | + MatchState::Empty, |
| 37 | + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Copy(rhs)))), |
| 38 | + ) if rhs.has_deref() && lhs.ty(body, tcx).ty.is_primitive() => { |
| 39 | + let Some(dst) = lhs.as_local() else { |
| 40 | + continue |
| 41 | + }; |
| 42 | + let Some(src) = rhs.local_or_deref_local() else { |
| 43 | + continue; |
| 44 | + }; |
| 45 | + if max <= dst { |
| 46 | + continue; |
| 47 | + } |
| 48 | + max = dst; |
| 49 | + MatchState::Deref { dst, src, src_statement_idx: i } |
| 50 | + } |
| 51 | + ( |
| 52 | + MatchState::Deref { src, dst, src_statement_idx }, |
| 53 | + StatementKind::Assign(box (lhs, Rvalue::CopyForDeref(rhs))), |
| 54 | + ) if lhs.as_local() == Some(src) && rhs.has_deref() => { |
| 55 | + let Some(real_src) = rhs.local_or_deref_local() else{ |
| 56 | + continue; |
| 57 | + }; |
| 58 | + MatchState::CopiedFrom { src_statement_idx, dst, real_src } |
| 59 | + } |
| 60 | + ( |
| 61 | + MatchState::CopiedFrom { src_statement_idx, dst, real_src }, |
| 62 | + StatementKind::Assign(box ( |
| 63 | + lhs, |
| 64 | + Rvalue::Ref(_, BorrowKind::Shared | BorrowKind::Shallow, rhs), |
| 65 | + )), |
| 66 | + ) if lhs.as_local() == Some(real_src) => { |
| 67 | + let Some(real_src) = rhs.as_local() else { |
| 68 | + continue; |
| 69 | + }; |
| 70 | + MatchState::Completed { dst, real_src, src_statement_idx } |
| 71 | + } |
| 72 | + _ => continue, |
| 73 | + }; |
| 74 | + if let MatchState::Completed { dst, real_src, src_statement_idx } = state { |
| 75 | + let mut patch = MirPatch::new(&body); |
| 76 | + let src = Place::from(real_src); |
| 77 | + let src = src.project_deeper(&[PlaceElem::Deref], tcx); |
| 78 | + let dst = Place::from(dst); |
| 79 | + let new_stmt = |
| 80 | + StatementKind::Assign(Box::new((dst, Rvalue::Use(Operand::Copy(src))))); |
| 81 | + patch.add_statement( |
| 82 | + Location { block: bb, statement_index: src_statement_idx + 1 }, |
| 83 | + new_stmt, |
| 84 | + ); |
| 85 | + patch.apply(body); |
| 86 | + continue 'repeat; |
| 87 | + } |
| 88 | + } |
| 89 | + break; |
| 90 | + } |
| 91 | + } |
| 92 | + } |
| 93 | +} |
0 commit comments