Skip to content

Commit 251f0c6

Browse files
committed
replace_match_with_if_let works on more binary matches
1 parent f181952 commit 251f0c6

File tree

2 files changed

+141
-24
lines changed

2 files changed

+141
-24
lines changed

crates/ide_assists/src/handlers/replace_if_let_with_match.rs

Lines changed: 137 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -176,48 +176,47 @@ fn make_else_arm(
176176
// ```
177177
pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
178178
let match_expr: ast::MatchExpr = ctx.find_node_at_offset()?;
179+
179180
let mut arms = match_expr.match_arm_list()?.arms();
180-
let first_arm = arms.next()?;
181-
let second_arm = arms.next()?;
181+
let (first_arm, second_arm) = (arms.next()?, arms.next()?);
182182
if arms.next().is_some() || first_arm.guard().is_some() || second_arm.guard().is_some() {
183183
return None;
184184
}
185-
let condition_expr = match_expr.expr()?;
186-
let (if_let_pat, then_expr, else_expr) = if is_pat_wildcard_or_sad(&ctx.sema, &first_arm.pat()?)
187-
{
188-
(second_arm.pat()?, second_arm.expr()?, first_arm.expr()?)
189-
} else if is_pat_wildcard_or_sad(&ctx.sema, &second_arm.pat()?) {
190-
(first_arm.pat()?, first_arm.expr()?, second_arm.expr()?)
191-
} else {
192-
return None;
193-
};
185+
186+
let (if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order(
187+
&ctx.sema,
188+
first_arm.pat()?,
189+
second_arm.pat()?,
190+
first_arm.expr()?,
191+
second_arm.expr()?,
192+
)?;
193+
let scrutinee = match_expr.expr()?;
194194

195195
let target = match_expr.syntax().text_range();
196196
acc.add(
197197
AssistId("replace_match_with_if_let", AssistKind::RefactorRewrite),
198198
"Replace with if let",
199199
target,
200200
move |edit| {
201-
let condition = make::condition(condition_expr, Some(if_let_pat));
201+
let condition = make::condition(scrutinee, Some(if_let_pat));
202202
let then_block = match then_expr.reset_indent() {
203203
ast::Expr::BlockExpr(block) => block,
204204
expr => make::block_expr(iter::empty(), Some(expr)),
205205
};
206206
let else_expr = match else_expr {
207-
ast::Expr::BlockExpr(block)
208-
if block.statements().count() == 0 && block.tail_expr().is_none() =>
209-
{
210-
None
211-
}
212-
ast::Expr::TupleExpr(tuple) if tuple.fields().count() == 0 => None,
207+
ast::Expr::BlockExpr(block) if block.is_empty() => None,
208+
ast::Expr::TupleExpr(tuple) if tuple.fields().next().is_none() => None,
213209
expr => Some(expr),
214210
};
215211
let if_let_expr = make::expr_if(
216212
condition,
217213
then_block,
218-
else_expr.map(|else_expr| {
219-
ast::ElseBranch::Block(make::block_expr(iter::empty(), Some(else_expr)))
220-
}),
214+
else_expr
215+
.map(|expr| match expr {
216+
ast::Expr::BlockExpr(block) => block,
217+
expr => (make::block_expr(iter::empty(), Some(expr))),
218+
})
219+
.map(ast::ElseBranch::Block),
221220
)
222221
.indent(IndentLevel::from_node(match_expr.syntax()));
223222

@@ -226,11 +225,50 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext)
226225
)
227226
}
228227

229-
fn is_pat_wildcard_or_sad(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
228+
/// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order.
229+
fn pick_pattern_and_expr_order(
230+
sema: &hir::Semantics<RootDatabase>,
231+
pat: ast::Pat,
232+
pat2: ast::Pat,
233+
expr: ast::Expr,
234+
expr2: ast::Expr,
235+
) -> Option<(ast::Pat, ast::Expr, ast::Expr)> {
236+
let res = match (pat, pat2) {
237+
(ast::Pat::WildcardPat(_), _) => return None,
238+
(pat, sad_pat) if is_sad_pat(sema, &sad_pat) => (pat, expr, expr2),
239+
(sad_pat, pat) if is_sad_pat(sema, &sad_pat) => (pat, expr2, expr),
240+
(pat, pat2) => match (binds_name(&pat), binds_name(&pat2)) {
241+
(true, true) => return None,
242+
(true, false) => (pat, expr, expr2),
243+
(false, true) => (pat2, expr2, expr),
244+
(false, false) => (pat, expr, expr2),
245+
},
246+
};
247+
Some(res)
248+
}
249+
250+
fn binds_name(pat: &ast::Pat) -> bool {
251+
let binds_name_v = |pat| binds_name(&pat);
252+
match pat {
253+
ast::Pat::IdentPat(_) => true,
254+
ast::Pat::MacroPat(_) => true,
255+
ast::Pat::OrPat(pat) => pat.pats().any(binds_name_v),
256+
ast::Pat::SlicePat(pat) => pat.pats().any(binds_name_v),
257+
ast::Pat::TuplePat(it) => it.fields().any(binds_name_v),
258+
ast::Pat::TupleStructPat(it) => it.fields().any(binds_name_v),
259+
ast::Pat::RecordPat(it) => it
260+
.record_pat_field_list()
261+
.map_or(false, |rpfl| rpfl.fields().flat_map(|rpf| rpf.pat()).any(binds_name_v)),
262+
ast::Pat::RefPat(pat) => pat.pat().map_or(false, binds_name_v),
263+
ast::Pat::BoxPat(pat) => pat.pat().map_or(false, binds_name_v),
264+
ast::Pat::ParenPat(pat) => pat.pat().map_or(false, binds_name_v),
265+
_ => false,
266+
}
267+
}
268+
fn is_sad_pat(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
230269
sema.type_of_pat(pat)
231270
.and_then(|ty| TryEnum::from_ty(sema, &ty))
232-
.map(|it| it.sad_pattern().syntax().text() == pat.syntax().text())
233-
.unwrap_or_else(|| matches!(pat, ast::Pat::WildcardPat(_)))
271+
.map_or(false, |it| it.sad_pattern().syntax().text() == pat.syntax().text())
234272
}
235273

236274
#[cfg(test)]
@@ -662,4 +700,79 @@ fn main() {
662700
"#,
663701
)
664702
}
703+
704+
#[test]
705+
fn replace_match_with_if_let_exhaustive() {
706+
check_assist(
707+
replace_match_with_if_let,
708+
r#"
709+
fn print_source(def_source: ModuleSource) {
710+
match def_so$0urce {
711+
ModuleSource::SourceFile(..) => { println!("source file"); }
712+
ModuleSource::Module(..) => { println!("module"); }
713+
}
714+
}
715+
"#,
716+
r#"
717+
fn print_source(def_source: ModuleSource) {
718+
if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); }
719+
}
720+
"#,
721+
)
722+
}
723+
724+
#[test]
725+
fn replace_match_with_if_let_prefer_name_bind() {
726+
check_assist(
727+
replace_match_with_if_let,
728+
r#"
729+
fn foo() {
730+
match $0Foo(0) {
731+
Foo(_) => (),
732+
Bar(bar) => println!("bar {}", bar),
733+
}
734+
}
735+
"#,
736+
r#"
737+
fn foo() {
738+
if let Bar(bar) = Foo(0) {
739+
println!("bar {}", bar)
740+
}
741+
}
742+
"#,
743+
);
744+
check_assist(
745+
replace_match_with_if_let,
746+
r#"
747+
fn foo() {
748+
match $0Foo(0) {
749+
Bar(bar) => println!("bar {}", bar),
750+
Foo(_) => (),
751+
}
752+
}
753+
"#,
754+
r#"
755+
fn foo() {
756+
if let Bar(bar) = Foo(0) {
757+
println!("bar {}", bar)
758+
}
759+
}
760+
"#,
761+
);
762+
}
763+
764+
#[test]
765+
fn replace_match_with_if_let_rejects_double_name_bindings() {
766+
check_assist_not_applicable(
767+
replace_match_with_if_let,
768+
r#"
769+
fn foo() {
770+
match $0Foo(0) {
771+
Foo(foo) => println!("bar {}", foo),
772+
Bar(bar) => println!("bar {}", bar),
773+
}
774+
}
775+
"#,
776+
);
777+
}
665778
}

crates/syntax/src/ast/node_ext.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ impl ast::BlockExpr {
4949
pub fn items(&self) -> AstChildren<ast::Item> {
5050
support::children(self.syntax())
5151
}
52+
53+
pub fn is_empty(&self) -> bool {
54+
self.statements().next().is_none() && self.tail_expr().is_none()
55+
}
5256
}
5357

5458
impl ast::Expr {

0 commit comments

Comments
 (0)