Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions clippy_lints/src/ifs/branches_sharing_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use clippy_utils::{
use core::iter;
use core::ops::ControlFlow;
use rustc_errors::Applicability;
use rustc_hir::{Block, Expr, ExprKind, HirId, HirIdSet, LetStmt, Node, Stmt, StmtKind, intravisit};
use rustc_hir::{Block, Expr, ExprKind, HirId, HirIdSet, ItemKind, LetStmt, Node, Stmt, StmtKind, UseKind, intravisit};
use rustc_lint::LateContext;
use rustc_span::hygiene::walk_chain;
use rustc_span::source_map::SourceMap;
Expand Down Expand Up @@ -108,6 +108,7 @@ struct BlockEq {
/// The name and id of every local which can be moved at the beginning and the end.
moved_locals: Vec<(HirId, Symbol)>,
}

impl BlockEq {
fn start_span(&self, b: &Block<'_>, sm: &SourceMap) -> Option<Span> {
match &b.stmts[..self.start_end_eq] {
Expand All @@ -129,20 +130,32 @@ impl BlockEq {
}

/// If the statement is a local, checks if the bound names match the expected list of names.
fn eq_binding_names(s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool {
if let StmtKind::Let(l) = s.kind {
let mut i = 0usize;
let mut res = true;
l.pat.each_binding_or_first(&mut |_, _, _, name| {
if names.get(i).is_some_and(|&(_, n)| n == name.name) {
i += 1;
} else {
res = false;
}
});
res && i == names.len()
} else {
false
fn eq_binding_names(cx: &LateContext<'_>, s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool {
match s.kind {
StmtKind::Let(l) => {
let mut i = 0usize;
let mut res = true;
l.pat.each_binding_or_first(&mut |_, _, _, name| {
if names.get(i).is_some_and(|&(_, n)| n == name.name) {
i += 1;
} else {
res = false;
}
});
res && i == names.len()
},
StmtKind::Item(item_id)
if let item = cx.tcx.hir_item(item_id)
&& let ItemKind::Static(_, ident, ..)
| ItemKind::Const(ident, ..)
| ItemKind::Fn { ident, .. }
| ItemKind::TyAlias(ident, ..)
| ItemKind::Use(_, UseKind::Single(ident))
| ItemKind::Mod(ident, _) = item.kind =>
{
names.last().is_some_and(|&(_, n)| n == ident.name)
},
_ => false,
}
}

Expand All @@ -164,6 +177,7 @@ fn modifies_any_local<'tcx>(cx: &LateContext<'tcx>, s: &'tcx Stmt<'_>, locals: &
/// Checks if the given statement should be considered equal to the statement in the same
/// position for each block.
fn eq_stmts(
cx: &LateContext<'_>,
stmt: &Stmt<'_>,
blocks: &[&Block<'_>],
get_stmt: impl for<'a> Fn(&'a Block<'a>) -> Option<&'a Stmt<'a>>,
Expand All @@ -178,7 +192,7 @@ fn eq_stmts(
let new_bindings = &moved_bindings[old_count..];
blocks
.iter()
.all(|b| get_stmt(b).is_some_and(|s| eq_binding_names(s, new_bindings)))
.all(|b| get_stmt(b).is_some_and(|s| eq_binding_names(cx, s, new_bindings)))
} else {
true
}) && blocks.iter().all(|b| get_stmt(b).is_some_and(|s| eq.eq_stmt(s, stmt)))
Expand Down Expand Up @@ -218,7 +232,7 @@ fn scan_block_for_eq<'tcx>(
return true;
}
modifies_any_local(cx, stmt, &cond_locals)
|| !eq_stmts(stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals)
|| !eq_stmts(cx, stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals)
})
.map_or(block.stmts.len(), |(i, stmt)| {
adjust_by_closest_callsite(i, stmt, block.stmts[..i].iter().enumerate().rev())
Expand Down Expand Up @@ -279,6 +293,7 @@ fn scan_block_for_eq<'tcx>(
}))
.fold(end_search_start, |init, (stmt, offset)| {
if eq_stmts(
cx,
stmt,
blocks,
|b| b.stmts.get(b.stmts.len() - offset),
Expand All @@ -290,11 +305,26 @@ fn scan_block_for_eq<'tcx>(
// Clear out all locals seen at the end so far. None of them can be moved.
let stmts = &blocks[0].stmts;
for stmt in &stmts[stmts.len() - init..=stmts.len() - offset] {
if let StmtKind::Let(l) = stmt.kind {
l.pat.each_binding_or_first(&mut |_, id, _, _| {
// FIXME(rust/#120456) - is `swap_remove` correct?
eq.locals.swap_remove(&id);
});
match stmt.kind {
StmtKind::Let(l) => {
l.pat.each_binding_or_first(&mut |_, id, _, _| {
// FIXME(rust/#120456) - is `swap_remove` correct?
eq.locals.swap_remove(&id);
});
},
StmtKind::Item(item_id) => {
let item = cx.tcx.hir_item(item_id);
if let ItemKind::Static(..)
| ItemKind::Const(..)
| ItemKind::Fn { .. }
| ItemKind::TyAlias(..)
| ItemKind::Use(..)
| ItemKind::Mod(..) = item.kind
{
eq.local_items.swap_remove(&item.owner_id.to_def_id());
}
},
_ => {},
}
}
moved_locals.truncate(moved_locals_at_start);
Expand Down
161 changes: 157 additions & 4 deletions clippy_utils/src/hir_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ use crate::source::{SpanRange, SpanRangeExt, walk_span_to_context};
use crate::tokenize_with_text;
use rustc_ast::ast;
use rustc_ast::ast::InlineAsmTemplatePiece;
use rustc_data_structures::fx::FxHasher;
use rustc_data_structures::fx::{FxHasher, FxIndexMap};
use rustc_hir::MatchSource::TryDesugar;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::def_id::DefId;
use rustc_hir::{
AssocItemConstraint, BinOpKind, BindingMode, Block, BodyId, Closure, ConstArg, ConstArgKind, Expr, ExprField,
ExprKind, FnRetTy, GenericArg, GenericArgs, HirId, HirIdMap, InlineAsmOperand, LetExpr, Lifetime, LifetimeKind,
Node, Pat, PatExpr, PatExprKind, PatField, PatKind, Path, PathSegment, PrimTy, QPath, Stmt, StmtKind,
StructTailExpr, TraitBoundModifiers, Ty, TyKind, TyPat, TyPatKind,
ExprKind, FnDecl, FnRetTy, FnSig, GenericArg, GenericArgs, GenericParam, GenericParamKind, GenericParamSource,
Generics, HirId, HirIdMap, InlineAsmOperand, ItemId, ItemKind, LetExpr, Lifetime, LifetimeKind, LifetimeParamKind,
Node, ParamName, Pat, PatExpr, PatExprKind, PatField, PatKind, Path, PathSegment, PrimTy, QPath, Stmt, StmtKind,
StructTailExpr, TraitBoundModifiers, Ty, TyKind, TyPat, TyPatKind, UseKind,
};
use rustc_lexer::{FrontmatterAllowed, TokenKind, tokenize};
use rustc_lint::LateContext;
Expand Down Expand Up @@ -106,6 +108,7 @@ impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
left_ctxt: SyntaxContext::root(),
right_ctxt: SyntaxContext::root(),
locals: HirIdMap::default(),
local_items: FxIndexMap::default(),
}
}

Expand Down Expand Up @@ -144,6 +147,7 @@ pub struct HirEqInterExpr<'a, 'b, 'tcx> {
// right. For example, when comparing `{ let x = 1; x + 2 }` and `{ let y = 1; y + 2 }`,
// these blocks are considered equal since `x` is mapped to `y`.
pub locals: HirIdMap<HirId>,
pub local_items: FxIndexMap<DefId, DefId>,
}

impl HirEqInterExpr<'_, '_, '_> {
Expand All @@ -168,6 +172,144 @@ impl HirEqInterExpr<'_, '_, '_> {
&& self.eq_pat(l.pat, r.pat)
},
(StmtKind::Expr(l), StmtKind::Expr(r)) | (StmtKind::Semi(l), StmtKind::Semi(r)) => self.eq_expr(l, r),
(StmtKind::Item(l), StmtKind::Item(r)) => self.eq_item(*l, *r),
_ => false,
}
}

pub fn eq_item(&mut self, l: ItemId, r: ItemId) -> bool {
let left = self.inner.cx.tcx.hir_item(l);
let right = self.inner.cx.tcx.hir_item(r);
let eq = match (left.kind, right.kind) {
(
ItemKind::Const(l_ident, l_generics, l_ty, l_body),
ItemKind::Const(r_ident, r_generics, r_ty, r_body),
) => {
l_ident.name == r_ident.name
&& self.eq_generics(l_generics, r_generics)
&& self.eq_ty(l_ty, r_ty)
&& self.eq_body(l_body, r_body)
},
(ItemKind::Static(l_mut, l_ident, l_ty, l_body), ItemKind::Static(r_mut, r_ident, r_ty, r_body)) => {
l_mut == r_mut && l_ident.name == r_ident.name && self.eq_ty(l_ty, r_ty) && self.eq_body(l_body, r_body)
},
(
ItemKind::Fn {
sig: l_sig,
ident: l_ident,
generics: l_generics,
body: l_body,
has_body: l_has_body,
},
ItemKind::Fn {
sig: r_sig,
ident: r_ident,
generics: r_generics,
body: r_body,
has_body: r_has_body,
},
) => {
l_ident.name == r_ident.name
&& self.eq_fn_sig(&l_sig, &r_sig)
&& self.eq_generics(l_generics, r_generics)
&& (l_has_body == r_has_body)
&& self.eq_body(l_body, r_body)
},
(ItemKind::TyAlias(l_ident, l_generics, l_ty), ItemKind::TyAlias(r_ident, r_generics, r_ty)) => {
l_ident.name == r_ident.name && self.eq_generics(l_generics, r_generics) && self.eq_ty(l_ty, r_ty)
},
(ItemKind::Use(l_path, l_kind), ItemKind::Use(r_path, r_kind)) => {
self.eq_path_segments(l_path.segments, r_path.segments)
&& match (l_kind, r_kind) {
(UseKind::Single(l_ident), UseKind::Single(r_ident)) => l_ident.name == r_ident.name,
(UseKind::Glob, UseKind::Glob) | (UseKind::ListStem, UseKind::ListStem) => true,
_ => false,
}
},
(ItemKind::Mod(l_ident, l_mod), ItemKind::Mod(r_ident, r_mod)) => {
l_ident.name == r_ident.name && over(l_mod.item_ids, r_mod.item_ids, |l, r| self.eq_item(*l, *r))
},
_ => false,
};
if eq {
self.local_items.insert(l.owner_id.to_def_id(), r.owner_id.to_def_id());
}
eq
}

fn eq_fn_sig(&mut self, left: &FnSig<'_>, right: &FnSig<'_>) -> bool {
left.header.safety == right.header.safety
&& left.header.constness == right.header.constness
&& left.header.asyncness == right.header.asyncness
&& left.header.abi == right.header.abi
&& self.eq_fn_decl(left.decl, right.decl)
}

fn eq_fn_decl(&mut self, left: &FnDecl<'_>, right: &FnDecl<'_>) -> bool {
over(left.inputs, right.inputs, |l, r| self.eq_ty(l, r))
&& (match (left.output, right.output) {
(FnRetTy::DefaultReturn(_), FnRetTy::DefaultReturn(_)) => true,
(FnRetTy::Return(l_ty), FnRetTy::Return(r_ty)) => self.eq_ty(l_ty, r_ty),
_ => false,
})
&& left.c_variadic == right.c_variadic
&& left.implicit_self == right.implicit_self
&& left.lifetime_elision_allowed == right.lifetime_elision_allowed
}

fn eq_generics(&mut self, left: &Generics<'_>, right: &Generics<'_>) -> bool {
self.eq_generics_param(left.params, right.params)
}

fn eq_generics_param(&mut self, left: &[GenericParam<'_>], right: &[GenericParam<'_>]) -> bool {
over(left, right, |l, r| {
(match (l.name, r.name) {
(ParamName::Plain(l_ident), ParamName::Plain(r_ident))
| (ParamName::Error(l_ident), ParamName::Error(r_ident)) => l_ident.name == r_ident.name,
(ParamName::Fresh, ParamName::Fresh) => true,
_ => false,
}) && l.pure_wrt_drop == r.pure_wrt_drop
&& self.eq_generics_param_kind(&l.kind, &r.kind)
&& (matches!(
(l.source, r.source),
(GenericParamSource::Generics, GenericParamSource::Generics)
| (GenericParamSource::Binder, GenericParamSource::Binder)
))
})
}

fn eq_generics_param_kind(&mut self, left: &GenericParamKind<'_>, right: &GenericParamKind<'_>) -> bool {
match (left, right) {
(GenericParamKind::Lifetime { kind: l_kind }, GenericParamKind::Lifetime { kind: r_kind }) => {
match (l_kind, r_kind) {
(LifetimeParamKind::Explicit, LifetimeParamKind::Explicit)
| (LifetimeParamKind::Error, LifetimeParamKind::Error) => true,
(LifetimeParamKind::Elided(l_lifetime_kind), LifetimeParamKind::Elided(r_lifetime_kind)) => {
l_lifetime_kind == r_lifetime_kind
},
_ => false,
}
},
(
GenericParamKind::Type {
default: l_default,
synthetic: l_synthetic,
},
GenericParamKind::Type {
default: r_default,
synthetic: r_synthetic,
},
) => both(*l_default, *r_default, |l, r| self.eq_ty(l, r)) && l_synthetic == r_synthetic,
(
GenericParamKind::Const {
ty: l_ty,
default: l_default,
},
GenericParamKind::Const {
ty: r_ty,
default: r_default,
},
) => self.eq_ty(l_ty, r_ty) && both(*l_default, *r_default, |l, r| self.eq_const_arg(l, r)),
_ => false,
}
}
Expand Down Expand Up @@ -564,6 +706,17 @@ impl HirEqInterExpr<'_, '_, '_> {
match (left.res, right.res) {
(Res::Local(l), Res::Local(r)) => l == r || self.locals.get(&l) == Some(&r),
(Res::Local(_), _) | (_, Res::Local(_)) => false,
(Res::Def(l_kind, l), Res::Def(r_kind, r))
if l_kind == r_kind
&& let DefKind::Const
| DefKind::Static { .. }
| DefKind::Fn
| DefKind::TyAlias
| DefKind::Use
| DefKind::Mod = l_kind =>
{
(l == r || self.local_items.get(&l) == Some(&r)) && self.eq_path_segments(left.segments, right.segments)
},
_ => self.eq_path_segments(left.segments, right.segments),
}
}
Expand Down
52 changes: 52 additions & 0 deletions tests/ui/branches_sharing_code/shared_at_bottom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,55 @@ fn issue15004() {
//~^ branches_sharing_code
};
}

pub fn issue15347<T>() -> isize {
if false {
static A: isize = 4;
return A;
} else {
static A: isize = 5;
return A;
}

if false {
//~^ branches_sharing_code
type ISize = isize;
return ISize::MAX;
} else {
type ISize = isize;
return ISize::MAX;
}

if false {
//~^ branches_sharing_code
fn foo() -> isize {
4
}
return foo();
} else {
fn foo() -> isize {
4
}
return foo();
}

if false {
//~^ branches_sharing_code
use std::num::NonZeroIsize;
return NonZeroIsize::new(4).unwrap().get();
} else {
use std::num::NonZeroIsize;
return NonZeroIsize::new(4).unwrap().get();
}

if false {
//~^ branches_sharing_code
const B: isize = 5;
return B;
} else {
const B: isize = 5;
return B;
}

todo!()
}
Loading