1+ use std:: iter;
2+
3+ use either:: Either ;
14use hir:: { AsName , EnumVariant , Module , ModuleDef , Name } ;
25use ide_db:: { defs:: Definition , search:: Reference , RootDatabase } ;
36use rustc_hash:: { FxHashMap , FxHashSet } ;
@@ -31,48 +34,32 @@ pub(crate) fn extract_struct_from_enum_variant(
3134 ctx : & AssistContext ,
3235) -> Option < ( ) > {
3336 let variant = ctx. find_node_at_offset :: < ast:: Variant > ( ) ?;
34-
35- fn is_applicable_variant ( variant : & ast:: Variant ) -> bool {
36- 1 < match variant. kind ( ) {
37- ast:: StructKind :: Record ( field_list) => field_list. fields ( ) . count ( ) ,
38- ast:: StructKind :: Tuple ( field_list) => field_list. fields ( ) . count ( ) ,
39- ast:: StructKind :: Unit => 0 ,
40- }
41- }
42-
43- if !is_applicable_variant ( & variant) {
44- return None ;
45- }
46-
47- let field_list = match variant. kind ( ) {
48- ast:: StructKind :: Tuple ( field_list) => field_list,
49- _ => return None ,
50- } ;
37+ let field_list = extract_field_list_if_applicable ( & variant) ?;
5138
5239 let variant_name = variant. name ( ) ?;
5340 let variant_hir = ctx. sema . to_def ( & variant) ?;
5441 if existing_definition ( ctx. db ( ) , & variant_name, & variant_hir) {
5542 return None ;
5643 }
44+
5745 let enum_ast = variant. parent_enum ( ) ;
58- let visibility = enum_ast. visibility ( ) ;
5946 let enum_hir = ctx. sema . to_def ( & enum_ast) ?;
60- let variant_hir_name = variant_hir. name ( ctx. db ( ) ) ;
61- let enum_module_def = ModuleDef :: from ( enum_hir) ;
62- let current_module = enum_hir. module ( ctx. db ( ) ) ;
6347 let target = variant. syntax ( ) . text_range ( ) ;
6448 acc. add (
6549 AssistId ( "extract_struct_from_enum_variant" , AssistKind :: RefactorRewrite ) ,
6650 "Extract struct from enum variant" ,
6751 target,
6852 |builder| {
69- let definition = Definition :: ModuleDef ( ModuleDef :: EnumVariant ( variant_hir) ) ;
70- let res = definition. usages ( & ctx. sema ) . all ( ) ;
53+ let variant_hir_name = variant_hir. name ( ctx. db ( ) ) ;
54+ let enum_module_def = ModuleDef :: from ( enum_hir) ;
55+ let usages =
56+ Definition :: ModuleDef ( ModuleDef :: EnumVariant ( variant_hir) ) . usages ( & ctx. sema ) . all ( ) ;
7157
7258 let mut visited_modules_set = FxHashSet :: default ( ) ;
59+ let current_module = enum_hir. module ( ctx. db ( ) ) ;
7360 visited_modules_set. insert ( current_module) ;
7461 let mut rewriters = FxHashMap :: default ( ) ;
75- for reference in res {
62+ for reference in usages {
7663 let rewriter = rewriters
7764 . entry ( reference. file_range . file_id )
7865 . or_insert_with ( SyntaxRewriter :: default) ;
@@ -94,20 +81,34 @@ pub(crate) fn extract_struct_from_enum_variant(
9481 builder. rewrite ( rewriter) ;
9582 }
9683 builder. edit_file ( ctx. frange . file_id ) ;
97- update_variant ( & mut rewriter, & variant_name , & field_list ) ;
84+ update_variant ( & mut rewriter, & variant ) ;
9885 extract_struct_def (
9986 & mut rewriter,
10087 & enum_ast,
10188 variant_name. clone ( ) ,
10289 & field_list,
10390 & variant. parent_enum ( ) . syntax ( ) . clone ( ) . into ( ) ,
104- visibility,
91+ enum_ast . visibility ( ) ,
10592 ) ;
10693 builder. rewrite ( rewriter) ;
10794 } ,
10895 )
10996}
11097
98+ fn extract_field_list_if_applicable (
99+ variant : & ast:: Variant ,
100+ ) -> Option < Either < ast:: RecordFieldList , ast:: TupleFieldList > > {
101+ match variant. kind ( ) {
102+ ast:: StructKind :: Record ( field_list) if field_list. fields ( ) . next ( ) . is_some ( ) => {
103+ Some ( Either :: Left ( field_list) )
104+ }
105+ ast:: StructKind :: Tuple ( field_list) if field_list. fields ( ) . count ( ) > 1 => {
106+ Some ( Either :: Right ( field_list) )
107+ }
108+ _ => None ,
109+ }
110+ }
111+
111112fn existing_definition ( db : & RootDatabase , variant_name : & ast:: Name , variant : & EnumVariant ) -> bool {
112113 variant
113114 . parent_enum ( db)
@@ -150,19 +151,29 @@ fn extract_struct_def(
150151 rewriter : & mut SyntaxRewriter ,
151152 enum_ : & ast:: Enum ,
152153 variant_name : ast:: Name ,
153- variant_list : & ast:: TupleFieldList ,
154+ field_list : & Either < ast:: RecordFieldList , ast :: TupleFieldList > ,
154155 start_offset : & SyntaxElement ,
155156 visibility : Option < ast:: Visibility > ,
156157) -> Option < ( ) > {
157- let variant_list = make:: tuple_field_list (
158- variant_list
159- . fields ( )
160- . flat_map ( |field| Some ( make:: tuple_field ( Some ( make:: visibility_pub ( ) ) , field. ty ( ) ?) ) ) ,
161- ) ;
158+ let pub_vis = Some ( make:: visibility_pub ( ) ) ;
159+ let field_list = match field_list {
160+ Either :: Left ( field_list) => {
161+ make:: record_field_list ( field_list. fields ( ) . flat_map ( |field| {
162+ Some ( make:: record_field ( pub_vis. clone ( ) , field. name ( ) ?, field. ty ( ) ?) )
163+ } ) )
164+ . into ( )
165+ }
166+ Either :: Right ( field_list) => make:: tuple_field_list (
167+ field_list
168+ . fields ( )
169+ . flat_map ( |field| Some ( make:: tuple_field ( pub_vis. clone ( ) , field. ty ( ) ?) ) ) ,
170+ )
171+ . into ( ) ,
172+ } ;
162173
163174 rewriter. insert_before (
164175 start_offset,
165- make:: struct_ ( visibility, variant_name, None , variant_list . into ( ) ) . syntax ( ) ,
176+ make:: struct_ ( visibility, variant_name, None , field_list ) . syntax ( ) ,
166177 ) ;
167178 rewriter. insert_before ( start_offset, & make:: tokens:: blank_line ( ) ) ;
168179
@@ -173,15 +184,14 @@ fn extract_struct_def(
173184 Some ( ( ) )
174185}
175186
176- fn update_variant (
177- rewriter : & mut SyntaxRewriter ,
178- variant_name : & ast:: Name ,
179- field_list : & ast:: TupleFieldList ,
180- ) -> Option < ( ) > {
181- let ( l, r) : ( SyntaxElement , SyntaxElement ) =
182- ( field_list. l_paren_token ( ) ?. into ( ) , field_list. r_paren_token ( ) ?. into ( ) ) ;
183- let replacement = vec ! [ l, variant_name. syntax( ) . clone( ) . into( ) , r] ;
184- rewriter. replace_with_many ( field_list. syntax ( ) , replacement) ;
187+ fn update_variant ( rewriter : & mut SyntaxRewriter , variant : & ast:: Variant ) -> Option < ( ) > {
188+ let name = variant. name ( ) ?;
189+ let tuple_field = make:: tuple_field ( None , make:: ty ( name. text ( ) ) ) ;
190+ let replacement = make:: variant (
191+ name,
192+ Some ( ast:: FieldList :: TupleFieldList ( make:: tuple_field_list ( iter:: once ( tuple_field) ) ) ) ,
193+ ) ;
194+ rewriter. replace ( variant. syntax ( ) , replacement. syntax ( ) ) ;
185195 Some ( ( ) )
186196}
187197
@@ -243,10 +253,18 @@ enum A { One(One) }"#,
243253 check_assist (
244254 extract_struct_from_enum_variant,
245255 "enum A { <|>One { foo: u32, bar: u32 } }" ,
246- r#"struct One {
247- pub foo: u32,
248- pub bar: u32
249- }
256+ r#"struct One{ pub foo: u32, pub bar: u32 }
257+
258+ enum A { One(One) }"# ,
259+ ) ;
260+ }
261+
262+ #[ test]
263+ fn test_extract_struct_one_field_named ( ) {
264+ check_assist (
265+ extract_struct_from_enum_variant,
266+ "enum A { <|>One { foo: u32 } }" ,
267+ r#"struct One{ pub foo: u32 }
250268
251269enum A { One(One) }"# ,
252270 ) ;
@@ -350,4 +368,14 @@ fn another_fn() {
350368 fn test_extract_not_applicable_one_field ( ) {
351369 check_not_applicable ( r"enum A { <|>One(u32) }" ) ;
352370 }
371+
372+ #[ test]
373+ fn test_extract_not_applicable_no_field_tuple ( ) {
374+ check_not_applicable ( r"enum A { <|>None() }" ) ;
375+ }
376+
377+ #[ test]
378+ fn test_extract_not_applicable_no_field_named ( ) {
379+ check_not_applicable ( r"enum A { <|>None {} }" ) ;
380+ }
353381}
0 commit comments