@@ -86,53 +86,21 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
86
86
target,
87
87
move |edit| {
88
88
let match_expr = {
89
- let else_arm = {
90
- match else_block {
91
- Some ( else_block) => {
92
- let pattern = match & * cond_bodies {
93
- [ ( Either :: Left ( pat) , _) ] => ctx
94
- . sema
95
- . type_of_pat ( & pat)
96
- . and_then ( |ty| TryEnum :: from_ty ( & ctx. sema , & ty) )
97
- . map ( |it| {
98
- if does_pat_match_variant ( & pat, & it. sad_pattern ( ) ) {
99
- it. happy_pattern ( )
100
- } else {
101
- it. sad_pattern ( )
102
- }
103
- } ) ,
104
- _ => None ,
105
- }
106
- . unwrap_or_else ( || make:: wildcard_pat ( ) . into ( ) ) ;
107
- make:: match_arm (
108
- iter:: once ( pattern) ,
109
- None ,
110
- unwrap_trivial_block ( else_block) ,
111
- )
89
+ let else_arm = make_else_arm ( else_block, & cond_bodies, ctx) ;
90
+ let make_match_arm = |( pat, body) : ( _ , ast:: BlockExpr ) | {
91
+ let body = body. reset_indent ( ) . indent ( IndentLevel ( 1 ) ) ;
92
+ match pat {
93
+ Either :: Left ( pat) => {
94
+ make:: match_arm ( iter:: once ( pat) , None , unwrap_trivial_block ( body) )
112
95
}
113
- None => make:: match_arm (
96
+ Either :: Right ( expr ) => make:: match_arm (
114
97
iter:: once ( make:: wildcard_pat ( ) . into ( ) ) ,
115
- None ,
116
- make :: expr_unit ( ) . into ( ) ,
98
+ Some ( expr ) ,
99
+ unwrap_trivial_block ( body ) ,
117
100
) ,
118
101
}
119
102
} ;
120
- let arms = cond_bodies
121
- . into_iter ( )
122
- . map ( |( pat, body) | {
123
- let body = body. reset_indent ( ) . indent ( IndentLevel ( 1 ) ) ;
124
- match pat {
125
- Either :: Left ( pat) => {
126
- make:: match_arm ( iter:: once ( pat) , None , unwrap_trivial_block ( body) )
127
- }
128
- Either :: Right ( expr) => make:: match_arm (
129
- iter:: once ( make:: wildcard_pat ( ) . into ( ) ) ,
130
- Some ( expr) ,
131
- unwrap_trivial_block ( body) ,
132
- ) ,
133
- }
134
- } )
135
- . chain ( iter:: once ( else_arm) ) ;
103
+ let arms = cond_bodies. into_iter ( ) . map ( make_match_arm) . chain ( iter:: once ( else_arm) ) ;
136
104
let match_expr = make:: expr_match ( scrutinee_to_be_expr, make:: match_arm_list ( arms) ) ;
137
105
match_expr. indent ( IndentLevel :: from_node ( if_expr. syntax ( ) ) )
138
106
} ;
@@ -150,6 +118,36 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
150
118
)
151
119
}
152
120
121
+ fn make_else_arm (
122
+ else_block : Option < ast:: BlockExpr > ,
123
+ cond_bodies : & Vec < ( Either < ast:: Pat , ast:: Expr > , ast:: BlockExpr ) > ,
124
+ ctx : & AssistContext ,
125
+ ) -> ast:: MatchArm {
126
+ if let Some ( else_block) = else_block {
127
+ let pattern = if let [ ( Either :: Left ( pat) , _) ] = & * * cond_bodies {
128
+ ctx. sema
129
+ . type_of_pat ( & pat)
130
+ . and_then ( |ty| TryEnum :: from_ty ( & ctx. sema , & ty) )
131
+ . zip ( Some ( pat) )
132
+ } else {
133
+ None
134
+ } ;
135
+ let pattern = match pattern {
136
+ Some ( ( it, pat) ) => {
137
+ if does_pat_match_variant ( & pat, & it. sad_pattern ( ) ) {
138
+ it. happy_pattern ( )
139
+ } else {
140
+ it. sad_pattern ( )
141
+ }
142
+ }
143
+ None => make:: wildcard_pat ( ) . into ( ) ,
144
+ } ;
145
+ make:: match_arm ( iter:: once ( pattern) , None , unwrap_trivial_block ( else_block) )
146
+ } else {
147
+ make:: match_arm ( iter:: once ( make:: wildcard_pat ( ) . into ( ) ) , None , make:: expr_unit ( ) . into ( ) )
148
+ }
149
+ }
150
+
153
151
// Assist: replace_match_with_if_let
154
152
//
155
153
// Replaces a binary `match` with a wildcard pattern and no guards with an `if let` expression.
@@ -178,48 +176,47 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
178
176
// ```
179
177
pub ( crate ) fn replace_match_with_if_let ( acc : & mut Assists , ctx : & AssistContext ) -> Option < ( ) > {
180
178
let match_expr: ast:: MatchExpr = ctx. find_node_at_offset ( ) ?;
179
+
181
180
let mut arms = match_expr. match_arm_list ( ) ?. arms ( ) ;
182
- let first_arm = arms. next ( ) ?;
183
- let second_arm = arms. next ( ) ?;
181
+ let ( first_arm, second_arm) = ( arms. next ( ) ?, arms. next ( ) ?) ;
184
182
if arms. next ( ) . is_some ( ) || first_arm. guard ( ) . is_some ( ) || second_arm. guard ( ) . is_some ( ) {
185
183
return None ;
186
184
}
187
- let condition_expr = match_expr . expr ( ) ? ;
188
- let ( if_let_pat, then_expr, else_expr) = if is_pat_wildcard_or_sad ( & ctx . sema , & first_arm . pat ( ) ? )
189
- {
190
- ( second_arm . pat ( ) ?, second_arm . expr ( ) ? , first_arm . expr ( ) ? )
191
- } else if is_pat_wildcard_or_sad ( & ctx . sema , & second_arm. pat ( ) ?) {
192
- ( first_arm. pat ( ) ? , first_arm . expr ( ) ?, second_arm . expr ( ) ? )
193
- } else {
194
- return None ;
195
- } ;
185
+
186
+ let ( if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order (
187
+ & ctx . sema ,
188
+ first_arm . pat ( ) ?,
189
+ second_arm. pat ( ) ?,
190
+ first_arm. expr ( ) ?,
191
+ second_arm . expr ( ) ? ,
192
+ ) ? ;
193
+ let scrutinee = match_expr . expr ( ) ? ;
196
194
197
195
let target = match_expr. syntax ( ) . text_range ( ) ;
198
196
acc. add (
199
197
AssistId ( "replace_match_with_if_let" , AssistKind :: RefactorRewrite ) ,
200
198
"Replace with if let" ,
201
199
target,
202
200
move |edit| {
203
- let condition = make:: condition ( condition_expr , Some ( if_let_pat) ) ;
201
+ let condition = make:: condition ( scrutinee , Some ( if_let_pat) ) ;
204
202
let then_block = match then_expr. reset_indent ( ) {
205
203
ast:: Expr :: BlockExpr ( block) => block,
206
204
expr => make:: block_expr ( iter:: empty ( ) , Some ( expr) ) ,
207
205
} ;
208
206
let else_expr = match else_expr {
209
- ast:: Expr :: BlockExpr ( block)
210
- if block. statements ( ) . count ( ) == 0 && block. tail_expr ( ) . is_none ( ) =>
211
- {
212
- None
213
- }
214
- ast:: Expr :: TupleExpr ( tuple) if tuple. fields ( ) . count ( ) == 0 => None ,
207
+ ast:: Expr :: BlockExpr ( block) if block. is_empty ( ) => None ,
208
+ ast:: Expr :: TupleExpr ( tuple) if tuple. fields ( ) . next ( ) . is_none ( ) => None ,
215
209
expr => Some ( expr) ,
216
210
} ;
217
211
let if_let_expr = make:: expr_if (
218
212
condition,
219
213
then_block,
220
- else_expr. map ( |else_expr| {
221
- ast:: ElseBranch :: Block ( make:: block_expr ( iter:: empty ( ) , Some ( else_expr) ) )
222
- } ) ,
214
+ else_expr
215
+ . map ( |expr| match expr {
216
+ ast:: Expr :: BlockExpr ( block) => block,
217
+ expr => ( make:: block_expr ( iter:: empty ( ) , Some ( expr) ) ) ,
218
+ } )
219
+ . map ( ast:: ElseBranch :: Block ) ,
223
220
)
224
221
. indent ( IndentLevel :: from_node ( match_expr. syntax ( ) ) ) ;
225
222
@@ -228,11 +225,51 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext)
228
225
)
229
226
}
230
227
231
- fn is_pat_wildcard_or_sad ( sema : & hir:: Semantics < RootDatabase > , pat : & ast:: Pat ) -> bool {
228
+ /// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order.
229
+ fn pick_pattern_and_expr_order (
230
+ sema : & hir:: Semantics < RootDatabase > ,
231
+ pat : ast:: Pat ,
232
+ pat2 : ast:: Pat ,
233
+ expr : ast:: Expr ,
234
+ expr2 : ast:: Expr ,
235
+ ) -> Option < ( ast:: Pat , ast:: Expr , ast:: Expr ) > {
236
+ let res = match ( pat, pat2) {
237
+ ( ast:: Pat :: WildcardPat ( _) , _) => return None ,
238
+ ( pat, sad_pat) if is_sad_pat ( sema, & sad_pat) => ( pat, expr, expr2) ,
239
+ ( sad_pat, pat) if is_sad_pat ( sema, & sad_pat) => ( pat, expr2, expr) ,
240
+ ( pat, pat2) => match ( binds_name ( & pat) , binds_name ( & pat2) ) {
241
+ ( true , true ) => return None ,
242
+ ( true , false ) => ( pat, expr, expr2) ,
243
+ ( false , true ) => ( pat2, expr2, expr) ,
244
+ ( false , false ) => ( pat, expr, expr2) ,
245
+ } ,
246
+ } ;
247
+ Some ( res)
248
+ }
249
+
250
+ fn binds_name ( pat : & ast:: Pat ) -> bool {
251
+ let binds_name_v = |pat| binds_name ( & pat) ;
252
+ match pat {
253
+ ast:: Pat :: IdentPat ( _) => true ,
254
+ ast:: Pat :: MacroPat ( _) => true ,
255
+ ast:: Pat :: OrPat ( pat) => pat. pats ( ) . any ( binds_name_v) ,
256
+ ast:: Pat :: SlicePat ( pat) => pat. pats ( ) . any ( binds_name_v) ,
257
+ ast:: Pat :: TuplePat ( it) => it. fields ( ) . any ( binds_name_v) ,
258
+ ast:: Pat :: TupleStructPat ( it) => it. fields ( ) . any ( binds_name_v) ,
259
+ ast:: Pat :: RecordPat ( it) => it
260
+ . record_pat_field_list ( )
261
+ . map_or ( false , |rpfl| rpfl. fields ( ) . flat_map ( |rpf| rpf. pat ( ) ) . any ( binds_name_v) ) ,
262
+ ast:: Pat :: RefPat ( pat) => pat. pat ( ) . map_or ( false , binds_name_v) ,
263
+ ast:: Pat :: BoxPat ( pat) => pat. pat ( ) . map_or ( false , binds_name_v) ,
264
+ ast:: Pat :: ParenPat ( pat) => pat. pat ( ) . map_or ( false , binds_name_v) ,
265
+ _ => false ,
266
+ }
267
+ }
268
+
269
+ fn is_sad_pat ( sema : & hir:: Semantics < RootDatabase > , pat : & ast:: Pat ) -> bool {
232
270
sema. type_of_pat ( pat)
233
271
. and_then ( |ty| TryEnum :: from_ty ( sema, & ty) )
234
- . map ( |it| it. sad_pattern ( ) . syntax ( ) . text ( ) == pat. syntax ( ) . text ( ) )
235
- . unwrap_or_else ( || matches ! ( pat, ast:: Pat :: WildcardPat ( _) ) )
272
+ . map_or ( false , |it| does_pat_match_variant ( pat, & it. sad_pattern ( ) ) )
236
273
}
237
274
238
275
#[ cfg( test) ]
@@ -664,4 +701,79 @@ fn main() {
664
701
"# ,
665
702
)
666
703
}
704
+
705
+ #[ test]
706
+ fn replace_match_with_if_let_exhaustive ( ) {
707
+ check_assist (
708
+ replace_match_with_if_let,
709
+ r#"
710
+ fn print_source(def_source: ModuleSource) {
711
+ match def_so$0urce {
712
+ ModuleSource::SourceFile(..) => { println!("source file"); }
713
+ ModuleSource::Module(..) => { println!("module"); }
714
+ }
715
+ }
716
+ "# ,
717
+ r#"
718
+ fn print_source(def_source: ModuleSource) {
719
+ if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); }
720
+ }
721
+ "# ,
722
+ )
723
+ }
724
+
725
+ #[ test]
726
+ fn replace_match_with_if_let_prefer_name_bind ( ) {
727
+ check_assist (
728
+ replace_match_with_if_let,
729
+ r#"
730
+ fn foo() {
731
+ match $0Foo(0) {
732
+ Foo(_) => (),
733
+ Bar(bar) => println!("bar {}", bar),
734
+ }
735
+ }
736
+ "# ,
737
+ r#"
738
+ fn foo() {
739
+ if let Bar(bar) = Foo(0) {
740
+ println!("bar {}", bar)
741
+ }
742
+ }
743
+ "# ,
744
+ ) ;
745
+ check_assist (
746
+ replace_match_with_if_let,
747
+ r#"
748
+ fn foo() {
749
+ match $0Foo(0) {
750
+ Bar(bar) => println!("bar {}", bar),
751
+ Foo(_) => (),
752
+ }
753
+ }
754
+ "# ,
755
+ r#"
756
+ fn foo() {
757
+ if let Bar(bar) = Foo(0) {
758
+ println!("bar {}", bar)
759
+ }
760
+ }
761
+ "# ,
762
+ ) ;
763
+ }
764
+
765
+ #[ test]
766
+ fn replace_match_with_if_let_rejects_double_name_bindings ( ) {
767
+ check_assist_not_applicable (
768
+ replace_match_with_if_let,
769
+ r#"
770
+ fn foo() {
771
+ match $0Foo(0) {
772
+ Foo(foo) => println!("bar {}", foo),
773
+ Bar(bar) => println!("bar {}", bar),
774
+ }
775
+ }
776
+ "# ,
777
+ ) ;
778
+ }
667
779
}
0 commit comments