11//! FIXME: write short doc here
22
3- use std:: iter;
3+ use std:: { collections :: LinkedList , iter} ;
44
55use hir:: { Adt , HasSource , Semantics } ;
66use ra_ide_db:: RootDatabase ;
7- use ra_syntax:: ast:: { self , edit:: IndentLevel , make, AstNode , NameOwner } ;
87
98use crate :: { Assist , AssistCtx , AssistId } ;
9+ use ra_syntax:: {
10+ ast:: { self , edit:: IndentLevel , make, AstNode , NameOwner } ,
11+ SyntaxKind , SyntaxNode ,
12+ } ;
13+
14+ use ast:: { MatchArm , MatchGuard , Pat } ;
1015
1116// Assist: fill_match_arms
1217//
@@ -36,16 +41,6 @@ pub(crate) fn fill_match_arms(ctx: AssistCtx) -> Option<Assist> {
3641 let match_expr = ctx. find_node_at_offset :: < ast:: MatchExpr > ( ) ?;
3742 let match_arm_list = match_expr. match_arm_list ( ) ?;
3843
39- // We already have some match arms, so we don't provide any assists.
40- // Unless if there is only one trivial match arm possibly created
41- // by match postfix complete. Trivial match arm is the catch all arm.
42- let mut existing_arms = match_arm_list. arms ( ) ;
43- if let Some ( arm) = existing_arms. next ( ) {
44- if !is_trivial ( & arm) || existing_arms. next ( ) . is_some ( ) {
45- return None ;
46- }
47- } ;
48-
4944 let expr = match_expr. expr ( ) ?;
5045 let enum_def = resolve_enum_def ( & ctx. sema , & expr) ?;
5146 let module = ctx. sema . scope ( expr. syntax ( ) ) . module ( ) ?;
@@ -56,29 +51,113 @@ pub(crate) fn fill_match_arms(ctx: AssistCtx) -> Option<Assist> {
5651 }
5752
5853 let db = ctx. db ;
59-
6054 ctx. add_assist ( AssistId ( "fill_match_arms" ) , "Fill match arms" , |edit| {
61- let indent_level = IndentLevel :: from_node ( match_arm_list. syntax ( ) ) ;
55+ let mut arms: Vec < MatchArm > = match_arm_list. arms ( ) . collect ( ) ;
56+ if arms. len ( ) == 1 {
57+ if let Some ( Pat :: PlaceholderPat ( ..) ) = arms[ 0 ] . pat ( ) {
58+ arms. clear ( ) ;
59+ }
60+ }
6261
63- let new_arm_list = {
64- let arms = variants
65- . into_iter ( )
66- . filter_map ( |variant| build_pat ( db, module, variant) )
67- . map ( |pat| make:: match_arm ( iter:: once ( pat) , make:: expr_unit ( ) ) ) ;
68- indent_level. increase_indent ( make:: match_arm_list ( arms) )
69- } ;
62+ let mut has_partial_match = false ;
63+ let variants: Vec < MatchArm > = variants
64+ . into_iter ( )
65+ . filter_map ( |variant| build_pat ( db, module, variant) )
66+ . filter ( |variant_pat| {
67+ !arms. iter ( ) . filter_map ( |arm| arm. pat ( ) . map ( |_| arm) ) . any ( |arm| {
68+ let pat = arm. pat ( ) . unwrap ( ) ;
69+
70+ // Special casee OrPat as separate top-level pats
71+ let pats: Vec < Pat > = match Pat :: from ( pat. clone ( ) ) {
72+ Pat :: OrPat ( pats) => pats. pats ( ) . collect :: < Vec < _ > > ( ) ,
73+ _ => vec ! [ pat] ,
74+ } ;
75+
76+ pats. iter ( ) . any ( |pat| {
77+ match does_arm_pat_match_variant ( pat, arm. guard ( ) , variant_pat) {
78+ ArmMatch :: Yes => true ,
79+ ArmMatch :: No => false ,
80+ ArmMatch :: Partial => {
81+ has_partial_match = true ;
82+ true
83+ }
84+ }
85+ } )
86+ } )
87+ } )
88+ . map ( |pat| make:: match_arm ( iter:: once ( pat) , make:: expr_unit ( ) ) )
89+ . collect ( ) ;
90+
91+ arms. extend ( variants) ;
92+ if has_partial_match {
93+ arms. push ( make:: match_arm (
94+ iter:: once ( make:: placeholder_pat ( ) . into ( ) ) ,
95+ make:: expr_unit ( ) ,
96+ ) ) ;
97+ }
98+
99+ let indent_level = IndentLevel :: from_node ( match_arm_list. syntax ( ) ) ;
100+ let new_arm_list = indent_level. increase_indent ( make:: match_arm_list ( arms) ) ;
70101
71102 edit. target ( match_expr. syntax ( ) . text_range ( ) ) ;
72103 edit. set_cursor ( expr. syntax ( ) . text_range ( ) . start ( ) ) ;
73104 edit. replace_ast ( match_arm_list, new_arm_list) ;
74105 } )
75106}
76107
77- fn is_trivial ( arm : & ast:: MatchArm ) -> bool {
78- match arm. pat ( ) {
79- Some ( ast:: Pat :: PlaceholderPat ( ..) ) => true ,
80- _ => false ,
108+ enum ArmMatch {
109+ Yes ,
110+ No ,
111+ Partial ,
112+ }
113+
114+ fn does_arm_pat_match_variant ( arm : & Pat , arm_guard : Option < MatchGuard > , var : & Pat ) -> ArmMatch {
115+ let arm = flatten_pats ( arm. clone ( ) ) ;
116+ let var = flatten_pats ( var. clone ( ) ) ;
117+ let mut arm = arm. iter ( ) ;
118+ let mut var = var. iter ( ) ;
119+
120+ // If the first part of the Pat don't match, there's no match
121+ match ( arm. next ( ) , var. next ( ) ) {
122+ ( Some ( arm) , Some ( var) ) if arm. text ( ) == var. text ( ) => { }
123+ _ => return ArmMatch :: No ,
124+ }
125+
126+ // If we have a guard we automatically know we have a partial match
127+ if arm_guard. is_some ( ) {
128+ return ArmMatch :: Partial ;
129+ }
130+
131+ if arm. clone ( ) . count ( ) != var. clone ( ) . count ( ) {
132+ return ArmMatch :: Partial ;
133+ }
134+
135+ let direct_match = arm. zip ( var) . all ( |( arm, var) | {
136+ if arm. text ( ) == var. text ( ) {
137+ return true ;
138+ }
139+ match ( arm. kind ( ) , var. kind ( ) ) {
140+ ( SyntaxKind :: PLACEHOLDER_PAT , SyntaxKind :: PLACEHOLDER_PAT ) => true ,
141+ ( SyntaxKind :: DOT_DOT_PAT , SyntaxKind :: PLACEHOLDER_PAT ) => true ,
142+ ( SyntaxKind :: BIND_PAT , SyntaxKind :: PLACEHOLDER_PAT ) => true ,
143+ _ => false ,
144+ }
145+ } ) ;
146+
147+ match direct_match {
148+ true => ArmMatch :: Yes ,
149+ false => ArmMatch :: Partial ,
150+ }
151+ }
152+
153+ fn flatten_pats ( pat : Pat ) -> Vec < SyntaxNode > {
154+ let mut pats: LinkedList < SyntaxNode > = pat. syntax ( ) . children ( ) . collect ( ) ;
155+ let mut out: Vec < SyntaxNode > = vec ! [ ] ;
156+ while let Some ( p) = pats. pop_front ( ) {
157+ pats. extend ( p. children ( ) ) ;
158+ out. push ( p) ;
81159 }
160+ out
82161}
83162
84163fn resolve_enum_def ( sema : & Semantics < RootDatabase > , expr : & ast:: Expr ) -> Option < hir:: Enum > {
@@ -114,6 +193,183 @@ mod tests {
114193
115194 use super :: fill_match_arms;
116195
196+ #[ test]
197+ fn partial_fill_multi ( ) {
198+ check_assist (
199+ fill_match_arms,
200+ r#"
201+ enum A {
202+ As,
203+ Bs(i32, Option<i32>)
204+ }
205+ fn main() {
206+ match A::As<|> {
207+ A::Bs(_, Some(_)) => (),
208+ }
209+ }
210+ "# ,
211+ r#"
212+ enum A {
213+ As,
214+ Bs(i32, Option<i32>)
215+ }
216+ fn main() {
217+ match <|>A::As {
218+ A::Bs(_, Some(_)) => (),
219+ A::As => (),
220+ _ => (),
221+ }
222+ }
223+ "# ,
224+ ) ;
225+ }
226+
227+ #[ test]
228+ fn partial_fill_record ( ) {
229+ check_assist (
230+ fill_match_arms,
231+ r#"
232+ enum A {
233+ As,
234+ Bs{x:i32, y:Option<i32>},
235+ }
236+ fn main() {
237+ match A::As<|> {
238+ A::Bs{x,y:Some(_)} => (),
239+ }
240+ }
241+ "# ,
242+ r#"
243+ enum A {
244+ As,
245+ Bs{x:i32, y:Option<i32>},
246+ }
247+ fn main() {
248+ match <|>A::As {
249+ A::Bs{x,y:Some(_)} => (),
250+ A::As => (),
251+ _ => (),
252+ }
253+ }
254+ "# ,
255+ ) ;
256+ }
257+
258+ #[ test]
259+ fn partial_fill_or_pat ( ) {
260+ check_assist (
261+ fill_match_arms,
262+ r#"
263+ enum A {
264+ As,
265+ Bs,
266+ Cs(Option<i32>),
267+ }
268+ fn main() {
269+ match A::As<|> {
270+ A::Cs(_) | A::Bs => (),
271+ }
272+ }
273+ "# ,
274+ r#"
275+ enum A {
276+ As,
277+ Bs,
278+ Cs(Option<i32>),
279+ }
280+ fn main() {
281+ match <|>A::As {
282+ A::Cs(_) | A::Bs => (),
283+ A::As => (),
284+ }
285+ }
286+ "# ,
287+ ) ;
288+ }
289+
290+ #[ test]
291+ fn partial_fill_or_pat2 ( ) {
292+ check_assist (
293+ fill_match_arms,
294+ r#"
295+ enum A {
296+ As,
297+ Bs,
298+ Cs(Option<i32>),
299+ }
300+ fn main() {
301+ match A::As<|> {
302+ A::Cs(Some(_)) | A::Bs => (),
303+ }
304+ }
305+ "# ,
306+ r#"
307+ enum A {
308+ As,
309+ Bs,
310+ Cs(Option<i32>),
311+ }
312+ fn main() {
313+ match <|>A::As {
314+ A::Cs(Some(_)) | A::Bs => (),
315+ A::As => (),
316+ _ => (),
317+ }
318+ }
319+ "# ,
320+ ) ;
321+ }
322+
323+ #[ test]
324+ fn partial_fill ( ) {
325+ check_assist (
326+ fill_match_arms,
327+ r#"
328+ enum A {
329+ As,
330+ Bs,
331+ Cs,
332+ Ds(String),
333+ Es(B),
334+ }
335+ enum B {
336+ Xs,
337+ Ys,
338+ }
339+ fn main() {
340+ match A::As<|> {
341+ A::Bs if 0 < 1 => (),
342+ A::Ds(_value) => (),
343+ A::Es(B::Xs) => (),
344+ }
345+ }
346+ "# ,
347+ r#"
348+ enum A {
349+ As,
350+ Bs,
351+ Cs,
352+ Ds(String),
353+ Es(B),
354+ }
355+ enum B {
356+ Xs,
357+ Ys,
358+ }
359+ fn main() {
360+ match <|>A::As {
361+ A::Bs if 0 < 1 => (),
362+ A::Ds(_value) => (),
363+ A::Es(B::Xs) => (),
364+ A::As => (),
365+ A::Cs => (),
366+ _ => (),
367+ }
368+ }
369+ "# ,
370+ ) ;
371+ }
372+
117373 #[ test]
118374 fn fill_match_arms_empty_body ( ) {
119375 check_assist (
0 commit comments