@@ -176,48 +176,47 @@ fn make_else_arm(
176
176
// ```
177
177
pub ( crate ) fn replace_match_with_if_let ( acc : & mut Assists , ctx : & AssistContext ) -> Option < ( ) > {
178
178
let match_expr: ast:: MatchExpr = ctx. find_node_at_offset ( ) ?;
179
+
179
180
let mut arms = match_expr. match_arm_list ( ) ?. arms ( ) ;
180
- let first_arm = arms. next ( ) ?;
181
- let second_arm = arms. next ( ) ?;
181
+ let ( first_arm, second_arm) = ( arms. next ( ) ?, arms. next ( ) ?) ;
182
182
if arms. next ( ) . is_some ( ) || first_arm. guard ( ) . is_some ( ) || second_arm. guard ( ) . is_some ( ) {
183
183
return None ;
184
184
}
185
- let condition_expr = match_expr . expr ( ) ? ;
186
- let ( if_let_pat, then_expr, else_expr) = if is_pat_wildcard_or_sad ( & ctx . sema , & first_arm . pat ( ) ? )
187
- {
188
- ( second_arm . pat ( ) ?, second_arm . expr ( ) ? , first_arm . expr ( ) ? )
189
- } else if is_pat_wildcard_or_sad ( & ctx . sema , & second_arm. pat ( ) ?) {
190
- ( first_arm. pat ( ) ? , first_arm . expr ( ) ?, second_arm . expr ( ) ? )
191
- } else {
192
- return None ;
193
- } ;
185
+
186
+ let ( if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order (
187
+ & ctx . sema ,
188
+ first_arm . pat ( ) ?,
189
+ second_arm. pat ( ) ?,
190
+ first_arm. expr ( ) ?,
191
+ second_arm . expr ( ) ? ,
192
+ ) ? ;
193
+ let scrutinee = match_expr . expr ( ) ? ;
194
194
195
195
let target = match_expr. syntax ( ) . text_range ( ) ;
196
196
acc. add (
197
197
AssistId ( "replace_match_with_if_let" , AssistKind :: RefactorRewrite ) ,
198
198
"Replace with if let" ,
199
199
target,
200
200
move |edit| {
201
- let condition = make:: condition ( condition_expr , Some ( if_let_pat) ) ;
201
+ let condition = make:: condition ( scrutinee , Some ( if_let_pat) ) ;
202
202
let then_block = match then_expr. reset_indent ( ) {
203
203
ast:: Expr :: BlockExpr ( block) => block,
204
204
expr => make:: block_expr ( iter:: empty ( ) , Some ( expr) ) ,
205
205
} ;
206
206
let else_expr = match else_expr {
207
- ast:: Expr :: BlockExpr ( block)
208
- if block. statements ( ) . count ( ) == 0 && block. tail_expr ( ) . is_none ( ) =>
209
- {
210
- None
211
- }
212
- ast:: Expr :: TupleExpr ( tuple) if tuple. fields ( ) . count ( ) == 0 => None ,
207
+ ast:: Expr :: BlockExpr ( block) if block. is_empty ( ) => None ,
208
+ ast:: Expr :: TupleExpr ( tuple) if tuple. fields ( ) . next ( ) . is_none ( ) => None ,
213
209
expr => Some ( expr) ,
214
210
} ;
215
211
let if_let_expr = make:: expr_if (
216
212
condition,
217
213
then_block,
218
- else_expr. map ( |else_expr| {
219
- ast:: ElseBranch :: Block ( make:: block_expr ( iter:: empty ( ) , Some ( else_expr) ) )
220
- } ) ,
214
+ else_expr
215
+ . map ( |expr| match expr {
216
+ ast:: Expr :: BlockExpr ( block) => block,
217
+ expr => ( make:: block_expr ( iter:: empty ( ) , Some ( expr) ) ) ,
218
+ } )
219
+ . map ( ast:: ElseBranch :: Block ) ,
221
220
)
222
221
. indent ( IndentLevel :: from_node ( match_expr. syntax ( ) ) ) ;
223
222
@@ -226,11 +225,50 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext)
226
225
)
227
226
}
228
227
229
- fn is_pat_wildcard_or_sad ( sema : & hir:: Semantics < RootDatabase > , pat : & ast:: Pat ) -> bool {
228
+ /// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order.
229
+ fn pick_pattern_and_expr_order (
230
+ sema : & hir:: Semantics < RootDatabase > ,
231
+ pat : ast:: Pat ,
232
+ pat2 : ast:: Pat ,
233
+ expr : ast:: Expr ,
234
+ expr2 : ast:: Expr ,
235
+ ) -> Option < ( ast:: Pat , ast:: Expr , ast:: Expr ) > {
236
+ let res = match ( pat, pat2) {
237
+ ( ast:: Pat :: WildcardPat ( _) , _) => return None ,
238
+ ( pat, sad_pat) if is_sad_pat ( sema, & sad_pat) => ( pat, expr, expr2) ,
239
+ ( sad_pat, pat) if is_sad_pat ( sema, & sad_pat) => ( pat, expr2, expr) ,
240
+ ( pat, pat2) => match ( binds_name ( & pat) , binds_name ( & pat2) ) {
241
+ ( true , true ) => return None ,
242
+ ( true , false ) => ( pat, expr, expr2) ,
243
+ ( false , true ) => ( pat2, expr2, expr) ,
244
+ ( false , false ) => ( pat, expr, expr2) ,
245
+ } ,
246
+ } ;
247
+ Some ( res)
248
+ }
249
+
250
+ fn binds_name ( pat : & ast:: Pat ) -> bool {
251
+ let binds_name_v = |pat| binds_name ( & pat) ;
252
+ match pat {
253
+ ast:: Pat :: IdentPat ( _) => true ,
254
+ ast:: Pat :: MacroPat ( _) => true ,
255
+ ast:: Pat :: OrPat ( pat) => pat. pats ( ) . any ( binds_name_v) ,
256
+ ast:: Pat :: SlicePat ( pat) => pat. pats ( ) . any ( binds_name_v) ,
257
+ ast:: Pat :: TuplePat ( it) => it. fields ( ) . any ( binds_name_v) ,
258
+ ast:: Pat :: TupleStructPat ( it) => it. fields ( ) . any ( binds_name_v) ,
259
+ ast:: Pat :: RecordPat ( it) => it
260
+ . record_pat_field_list ( )
261
+ . map_or ( false , |rpfl| rpfl. fields ( ) . flat_map ( |rpf| rpf. pat ( ) ) . any ( binds_name_v) ) ,
262
+ ast:: Pat :: RefPat ( pat) => pat. pat ( ) . map_or ( false , binds_name_v) ,
263
+ ast:: Pat :: BoxPat ( pat) => pat. pat ( ) . map_or ( false , binds_name_v) ,
264
+ ast:: Pat :: ParenPat ( pat) => pat. pat ( ) . map_or ( false , binds_name_v) ,
265
+ _ => false ,
266
+ }
267
+ }
268
+ fn is_sad_pat ( sema : & hir:: Semantics < RootDatabase > , pat : & ast:: Pat ) -> bool {
230
269
sema. type_of_pat ( pat)
231
270
. and_then ( |ty| TryEnum :: from_ty ( sema, & ty) )
232
- . map ( |it| it. sad_pattern ( ) . syntax ( ) . text ( ) == pat. syntax ( ) . text ( ) )
233
- . unwrap_or_else ( || matches ! ( pat, ast:: Pat :: WildcardPat ( _) ) )
271
+ . map_or ( false , |it| it. sad_pattern ( ) . syntax ( ) . text ( ) == pat. syntax ( ) . text ( ) )
234
272
}
235
273
236
274
#[ cfg( test) ]
@@ -662,4 +700,79 @@ fn main() {
662
700
"# ,
663
701
)
664
702
}
703
+
704
+ #[ test]
705
+ fn replace_match_with_if_let_exhaustive ( ) {
706
+ check_assist (
707
+ replace_match_with_if_let,
708
+ r#"
709
+ fn print_source(def_source: ModuleSource) {
710
+ match def_so$0urce {
711
+ ModuleSource::SourceFile(..) => { println!("source file"); }
712
+ ModuleSource::Module(..) => { println!("module"); }
713
+ }
714
+ }
715
+ "# ,
716
+ r#"
717
+ fn print_source(def_source: ModuleSource) {
718
+ if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); }
719
+ }
720
+ "# ,
721
+ )
722
+ }
723
+
724
+ #[ test]
725
+ fn replace_match_with_if_let_prefer_name_bind ( ) {
726
+ check_assist (
727
+ replace_match_with_if_let,
728
+ r#"
729
+ fn foo() {
730
+ match $0Foo(0) {
731
+ Foo(_) => (),
732
+ Bar(bar) => println!("bar {}", bar),
733
+ }
734
+ }
735
+ "# ,
736
+ r#"
737
+ fn foo() {
738
+ if let Bar(bar) = Foo(0) {
739
+ println!("bar {}", bar)
740
+ }
741
+ }
742
+ "# ,
743
+ ) ;
744
+ check_assist (
745
+ replace_match_with_if_let,
746
+ r#"
747
+ fn foo() {
748
+ match $0Foo(0) {
749
+ Bar(bar) => println!("bar {}", bar),
750
+ Foo(_) => (),
751
+ }
752
+ }
753
+ "# ,
754
+ r#"
755
+ fn foo() {
756
+ if let Bar(bar) = Foo(0) {
757
+ println!("bar {}", bar)
758
+ }
759
+ }
760
+ "# ,
761
+ ) ;
762
+ }
763
+
764
+ #[ test]
765
+ fn replace_match_with_if_let_rejects_double_name_bindings ( ) {
766
+ check_assist_not_applicable (
767
+ replace_match_with_if_let,
768
+ r#"
769
+ fn foo() {
770
+ match $0Foo(0) {
771
+ Foo(foo) => println!("bar {}", foo),
772
+ Bar(bar) => println!("bar {}", bar),
773
+ }
774
+ }
775
+ "# ,
776
+ ) ;
777
+ }
665
778
}
0 commit comments