diff --git a/clippy_lints/src/ifs/branches_sharing_code.rs b/clippy_lints/src/ifs/branches_sharing_code.rs index eb1025f71498..866073208522 100644 --- a/clippy_lints/src/ifs/branches_sharing_code.rs +++ b/clippy_lints/src/ifs/branches_sharing_code.rs @@ -9,7 +9,7 @@ use clippy_utils::{ use core::iter; use core::ops::ControlFlow; use rustc_errors::Applicability; -use rustc_hir::{Block, Expr, ExprKind, HirId, HirIdSet, LetStmt, Node, Stmt, StmtKind, intravisit}; +use rustc_hir::{Block, Expr, ExprKind, HirId, HirIdSet, ItemKind, LetStmt, Node, Stmt, StmtKind, UseKind, intravisit}; use rustc_lint::LateContext; use rustc_span::hygiene::walk_chain; use rustc_span::source_map::SourceMap; @@ -108,6 +108,7 @@ struct BlockEq { /// The name and id of every local which can be moved at the beginning and the end. moved_locals: Vec<(HirId, Symbol)>, } + impl BlockEq { fn start_span(&self, b: &Block<'_>, sm: &SourceMap) -> Option { match &b.stmts[..self.start_end_eq] { @@ -129,20 +130,32 @@ impl BlockEq { } /// If the statement is a local, checks if the bound names match the expected list of names. -fn eq_binding_names(s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool { - if let StmtKind::Let(l) = s.kind { - let mut i = 0usize; - let mut res = true; - l.pat.each_binding_or_first(&mut |_, _, _, name| { - if names.get(i).is_some_and(|&(_, n)| n == name.name) { - i += 1; - } else { - res = false; - } - }); - res && i == names.len() - } else { - false +fn eq_binding_names(cx: &LateContext<'_>, s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool { + match s.kind { + StmtKind::Let(l) => { + let mut i = 0usize; + let mut res = true; + l.pat.each_binding_or_first(&mut |_, _, _, name| { + if names.get(i).is_some_and(|&(_, n)| n == name.name) { + i += 1; + } else { + res = false; + } + }); + res && i == names.len() + }, + StmtKind::Item(item_id) + if let item = cx.tcx.hir_item(item_id) + && let ItemKind::Static(_, ident, ..) + | ItemKind::Const(ident, ..) + | ItemKind::Fn { ident, .. } + | ItemKind::TyAlias(ident, ..) + | ItemKind::Use(_, UseKind::Single(ident)) + | ItemKind::Mod(ident, _) = item.kind => + { + names.last().is_some_and(|&(_, n)| n == ident.name) + }, + _ => false, } } @@ -164,6 +177,7 @@ fn modifies_any_local<'tcx>(cx: &LateContext<'tcx>, s: &'tcx Stmt<'_>, locals: & /// Checks if the given statement should be considered equal to the statement in the same /// position for each block. fn eq_stmts( + cx: &LateContext<'_>, stmt: &Stmt<'_>, blocks: &[&Block<'_>], get_stmt: impl for<'a> Fn(&'a Block<'a>) -> Option<&'a Stmt<'a>>, @@ -178,7 +192,7 @@ fn eq_stmts( let new_bindings = &moved_bindings[old_count..]; blocks .iter() - .all(|b| get_stmt(b).is_some_and(|s| eq_binding_names(s, new_bindings))) + .all(|b| get_stmt(b).is_some_and(|s| eq_binding_names(cx, s, new_bindings))) } else { true }) && blocks.iter().all(|b| get_stmt(b).is_some_and(|s| eq.eq_stmt(s, stmt))) @@ -218,7 +232,7 @@ fn scan_block_for_eq<'tcx>( return true; } modifies_any_local(cx, stmt, &cond_locals) - || !eq_stmts(stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals) + || !eq_stmts(cx, stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals) }) .map_or(block.stmts.len(), |(i, stmt)| { adjust_by_closest_callsite(i, stmt, block.stmts[..i].iter().enumerate().rev()) @@ -279,6 +293,7 @@ fn scan_block_for_eq<'tcx>( })) .fold(end_search_start, |init, (stmt, offset)| { if eq_stmts( + cx, stmt, blocks, |b| b.stmts.get(b.stmts.len() - offset), @@ -290,11 +305,26 @@ fn scan_block_for_eq<'tcx>( // Clear out all locals seen at the end so far. None of them can be moved. let stmts = &blocks[0].stmts; for stmt in &stmts[stmts.len() - init..=stmts.len() - offset] { - if let StmtKind::Let(l) = stmt.kind { - l.pat.each_binding_or_first(&mut |_, id, _, _| { - // FIXME(rust/#120456) - is `swap_remove` correct? - eq.locals.swap_remove(&id); - }); + match stmt.kind { + StmtKind::Let(l) => { + l.pat.each_binding_or_first(&mut |_, id, _, _| { + // FIXME(rust/#120456) - is `swap_remove` correct? + eq.locals.swap_remove(&id); + }); + }, + StmtKind::Item(item_id) => { + let item = cx.tcx.hir_item(item_id); + if let ItemKind::Static(..) + | ItemKind::Const(..) + | ItemKind::Fn { .. } + | ItemKind::TyAlias(..) + | ItemKind::Use(..) + | ItemKind::Mod(..) = item.kind + { + eq.local_items.swap_remove(&item.owner_id.to_def_id()); + } + }, + _ => {}, } } moved_locals.truncate(moved_locals_at_start); diff --git a/clippy_utils/src/hir_utils.rs b/clippy_utils/src/hir_utils.rs index a6d821c99c1b..5b309a5565fe 100644 --- a/clippy_utils/src/hir_utils.rs +++ b/clippy_utils/src/hir_utils.rs @@ -4,14 +4,16 @@ use crate::source::{SpanRange, SpanRangeExt, walk_span_to_context}; use crate::tokenize_with_text; use rustc_ast::ast; use rustc_ast::ast::InlineAsmTemplatePiece; -use rustc_data_structures::fx::FxHasher; +use rustc_data_structures::fx::{FxHasher, FxIndexMap}; use rustc_hir::MatchSource::TryDesugar; use rustc_hir::def::{DefKind, Res}; +use rustc_hir::def_id::DefId; use rustc_hir::{ AssocItemConstraint, BinOpKind, BindingMode, Block, BodyId, Closure, ConstArg, ConstArgKind, Expr, ExprField, - ExprKind, FnRetTy, GenericArg, GenericArgs, HirId, HirIdMap, InlineAsmOperand, LetExpr, Lifetime, LifetimeKind, - Node, Pat, PatExpr, PatExprKind, PatField, PatKind, Path, PathSegment, PrimTy, QPath, Stmt, StmtKind, - StructTailExpr, TraitBoundModifiers, Ty, TyKind, TyPat, TyPatKind, + ExprKind, FnDecl, FnRetTy, FnSig, GenericArg, GenericArgs, GenericParam, GenericParamKind, GenericParamSource, + Generics, HirId, HirIdMap, InlineAsmOperand, ItemId, ItemKind, LetExpr, Lifetime, LifetimeKind, LifetimeParamKind, + Node, ParamName, Pat, PatExpr, PatExprKind, PatField, PatKind, Path, PathSegment, PrimTy, QPath, Stmt, StmtKind, + StructTailExpr, TraitBoundModifiers, Ty, TyKind, TyPat, TyPatKind, UseKind, }; use rustc_lexer::{FrontmatterAllowed, TokenKind, tokenize}; use rustc_lint::LateContext; @@ -106,6 +108,7 @@ impl<'a, 'tcx> SpanlessEq<'a, 'tcx> { left_ctxt: SyntaxContext::root(), right_ctxt: SyntaxContext::root(), locals: HirIdMap::default(), + local_items: FxIndexMap::default(), } } @@ -144,6 +147,7 @@ pub struct HirEqInterExpr<'a, 'b, 'tcx> { // right. For example, when comparing `{ let x = 1; x + 2 }` and `{ let y = 1; y + 2 }`, // these blocks are considered equal since `x` is mapped to `y`. pub locals: HirIdMap, + pub local_items: FxIndexMap, } impl HirEqInterExpr<'_, '_, '_> { @@ -168,6 +172,144 @@ impl HirEqInterExpr<'_, '_, '_> { && self.eq_pat(l.pat, r.pat) }, (StmtKind::Expr(l), StmtKind::Expr(r)) | (StmtKind::Semi(l), StmtKind::Semi(r)) => self.eq_expr(l, r), + (StmtKind::Item(l), StmtKind::Item(r)) => self.eq_item(*l, *r), + _ => false, + } + } + + pub fn eq_item(&mut self, l: ItemId, r: ItemId) -> bool { + let left = self.inner.cx.tcx.hir_item(l); + let right = self.inner.cx.tcx.hir_item(r); + let eq = match (left.kind, right.kind) { + ( + ItemKind::Const(l_ident, l_generics, l_ty, l_body), + ItemKind::Const(r_ident, r_generics, r_ty, r_body), + ) => { + l_ident.name == r_ident.name + && self.eq_generics(l_generics, r_generics) + && self.eq_ty(l_ty, r_ty) + && self.eq_body(l_body, r_body) + }, + (ItemKind::Static(l_mut, l_ident, l_ty, l_body), ItemKind::Static(r_mut, r_ident, r_ty, r_body)) => { + l_mut == r_mut && l_ident.name == r_ident.name && self.eq_ty(l_ty, r_ty) && self.eq_body(l_body, r_body) + }, + ( + ItemKind::Fn { + sig: l_sig, + ident: l_ident, + generics: l_generics, + body: l_body, + has_body: l_has_body, + }, + ItemKind::Fn { + sig: r_sig, + ident: r_ident, + generics: r_generics, + body: r_body, + has_body: r_has_body, + }, + ) => { + l_ident.name == r_ident.name + && self.eq_fn_sig(&l_sig, &r_sig) + && self.eq_generics(l_generics, r_generics) + && (l_has_body == r_has_body) + && self.eq_body(l_body, r_body) + }, + (ItemKind::TyAlias(l_ident, l_generics, l_ty), ItemKind::TyAlias(r_ident, r_generics, r_ty)) => { + l_ident.name == r_ident.name && self.eq_generics(l_generics, r_generics) && self.eq_ty(l_ty, r_ty) + }, + (ItemKind::Use(l_path, l_kind), ItemKind::Use(r_path, r_kind)) => { + self.eq_path_segments(l_path.segments, r_path.segments) + && match (l_kind, r_kind) { + (UseKind::Single(l_ident), UseKind::Single(r_ident)) => l_ident.name == r_ident.name, + (UseKind::Glob, UseKind::Glob) | (UseKind::ListStem, UseKind::ListStem) => true, + _ => false, + } + }, + (ItemKind::Mod(l_ident, l_mod), ItemKind::Mod(r_ident, r_mod)) => { + l_ident.name == r_ident.name && over(l_mod.item_ids, r_mod.item_ids, |l, r| self.eq_item(*l, *r)) + }, + _ => false, + }; + if eq { + self.local_items.insert(l.owner_id.to_def_id(), r.owner_id.to_def_id()); + } + eq + } + + fn eq_fn_sig(&mut self, left: &FnSig<'_>, right: &FnSig<'_>) -> bool { + left.header.safety == right.header.safety + && left.header.constness == right.header.constness + && left.header.asyncness == right.header.asyncness + && left.header.abi == right.header.abi + && self.eq_fn_decl(left.decl, right.decl) + } + + fn eq_fn_decl(&mut self, left: &FnDecl<'_>, right: &FnDecl<'_>) -> bool { + over(left.inputs, right.inputs, |l, r| self.eq_ty(l, r)) + && (match (left.output, right.output) { + (FnRetTy::DefaultReturn(_), FnRetTy::DefaultReturn(_)) => true, + (FnRetTy::Return(l_ty), FnRetTy::Return(r_ty)) => self.eq_ty(l_ty, r_ty), + _ => false, + }) + && left.c_variadic == right.c_variadic + && left.implicit_self == right.implicit_self + && left.lifetime_elision_allowed == right.lifetime_elision_allowed + } + + fn eq_generics(&mut self, left: &Generics<'_>, right: &Generics<'_>) -> bool { + self.eq_generics_param(left.params, right.params) + } + + fn eq_generics_param(&mut self, left: &[GenericParam<'_>], right: &[GenericParam<'_>]) -> bool { + over(left, right, |l, r| { + (match (l.name, r.name) { + (ParamName::Plain(l_ident), ParamName::Plain(r_ident)) + | (ParamName::Error(l_ident), ParamName::Error(r_ident)) => l_ident.name == r_ident.name, + (ParamName::Fresh, ParamName::Fresh) => true, + _ => false, + }) && l.pure_wrt_drop == r.pure_wrt_drop + && self.eq_generics_param_kind(&l.kind, &r.kind) + && (matches!( + (l.source, r.source), + (GenericParamSource::Generics, GenericParamSource::Generics) + | (GenericParamSource::Binder, GenericParamSource::Binder) + )) + }) + } + + fn eq_generics_param_kind(&mut self, left: &GenericParamKind<'_>, right: &GenericParamKind<'_>) -> bool { + match (left, right) { + (GenericParamKind::Lifetime { kind: l_kind }, GenericParamKind::Lifetime { kind: r_kind }) => { + match (l_kind, r_kind) { + (LifetimeParamKind::Explicit, LifetimeParamKind::Explicit) + | (LifetimeParamKind::Error, LifetimeParamKind::Error) => true, + (LifetimeParamKind::Elided(l_lifetime_kind), LifetimeParamKind::Elided(r_lifetime_kind)) => { + l_lifetime_kind == r_lifetime_kind + }, + _ => false, + } + }, + ( + GenericParamKind::Type { + default: l_default, + synthetic: l_synthetic, + }, + GenericParamKind::Type { + default: r_default, + synthetic: r_synthetic, + }, + ) => both(*l_default, *r_default, |l, r| self.eq_ty(l, r)) && l_synthetic == r_synthetic, + ( + GenericParamKind::Const { + ty: l_ty, + default: l_default, + }, + GenericParamKind::Const { + ty: r_ty, + default: r_default, + }, + ) => self.eq_ty(l_ty, r_ty) && both(*l_default, *r_default, |l, r| self.eq_const_arg(l, r)), _ => false, } } @@ -564,6 +706,17 @@ impl HirEqInterExpr<'_, '_, '_> { match (left.res, right.res) { (Res::Local(l), Res::Local(r)) => l == r || self.locals.get(&l) == Some(&r), (Res::Local(_), _) | (_, Res::Local(_)) => false, + (Res::Def(l_kind, l), Res::Def(r_kind, r)) + if l_kind == r_kind + && let DefKind::Const + | DefKind::Static { .. } + | DefKind::Fn + | DefKind::TyAlias + | DefKind::Use + | DefKind::Mod = l_kind => + { + (l == r || self.local_items.get(&l) == Some(&r)) && self.eq_path_segments(left.segments, right.segments) + }, _ => self.eq_path_segments(left.segments, right.segments), } } diff --git a/tests/ui/branches_sharing_code/shared_at_bottom.rs b/tests/ui/branches_sharing_code/shared_at_bottom.rs index fa322dc28a78..3e1369890f69 100644 --- a/tests/ui/branches_sharing_code/shared_at_bottom.rs +++ b/tests/ui/branches_sharing_code/shared_at_bottom.rs @@ -300,3 +300,55 @@ fn issue15004() { //~^ branches_sharing_code }; } + +pub fn issue15347() -> isize { + if false { + static A: isize = 4; + return A; + } else { + static A: isize = 5; + return A; + } + + if false { + //~^ branches_sharing_code + type ISize = isize; + return ISize::MAX; + } else { + type ISize = isize; + return ISize::MAX; + } + + if false { + //~^ branches_sharing_code + fn foo() -> isize { + 4 + } + return foo(); + } else { + fn foo() -> isize { + 4 + } + return foo(); + } + + if false { + //~^ branches_sharing_code + use std::num::NonZeroIsize; + return NonZeroIsize::new(4).unwrap().get(); + } else { + use std::num::NonZeroIsize; + return NonZeroIsize::new(4).unwrap().get(); + } + + if false { + //~^ branches_sharing_code + const B: isize = 5; + return B; + } else { + const B: isize = 5; + return B; + } + + todo!() +} diff --git a/tests/ui/branches_sharing_code/shared_at_bottom.stderr b/tests/ui/branches_sharing_code/shared_at_bottom.stderr index 1c470fb0da5e..4ff3990232a5 100644 --- a/tests/ui/branches_sharing_code/shared_at_bottom.stderr +++ b/tests/ui/branches_sharing_code/shared_at_bottom.stderr @@ -202,5 +202,73 @@ LL ~ } LL ~ 1; | -error: aborting due to 12 previous errors +error: all if blocks contain the same code at the start + --> tests/ui/branches_sharing_code/shared_at_bottom.rs:313:5 + | +LL | / if false { +LL | | +LL | | type ISize = isize; +LL | | return ISize::MAX; + | |__________________________^ + | +help: consider moving these statements before the if + | +LL ~ type ISize = isize; +LL + return ISize::MAX; +LL + if false { + | + +error: all if blocks contain the same code at the start + --> tests/ui/branches_sharing_code/shared_at_bottom.rs:322:5 + | +LL | / if false { +LL | | +LL | | fn foo() -> isize { +LL | | 4 +LL | | } +LL | | return foo(); + | |_____________________^ + | +help: consider moving these statements before the if + | +LL ~ fn foo() -> isize { +LL + 4 +LL + } +LL + return foo(); +LL + if false { + | + +error: all if blocks contain the same code at the start + --> tests/ui/branches_sharing_code/shared_at_bottom.rs:335:5 + | +LL | / if false { +LL | | +LL | | use std::num::NonZeroIsize; +LL | | return NonZeroIsize::new(4).unwrap().get(); + | |___________________________________________________^ + | +help: consider moving these statements before the if + | +LL ~ use std::num::NonZeroIsize; +LL + return NonZeroIsize::new(4).unwrap().get(); +LL + if false { + | + +error: all if blocks contain the same code at the start + --> tests/ui/branches_sharing_code/shared_at_bottom.rs:344:5 + | +LL | / if false { +LL | | +LL | | const B: isize = 5; +LL | | return B; + | |_________________^ + | +help: consider moving these statements before the if + | +LL ~ const B: isize = 5; +LL + return B; +LL + if false { + | + +error: aborting due to 16 previous errors