1+ use std:: iter;
2+
3+ use either:: Either ;
14use hir:: { AsName , EnumVariant , Module , ModuleDef , Name } ;
25use ide_db:: { defs:: Definition , search:: Reference , RootDatabase } ;
36use rustc_hash:: { FxHashMap , FxHashSet } ;
@@ -31,40 +34,32 @@ pub(crate) fn extract_struct_from_enum_variant(
3134 ctx : & AssistContext ,
3235) -> Option < ( ) > {
3336 let variant = ctx. find_node_at_offset :: < ast:: Variant > ( ) ?;
34- let field_list = match variant. kind ( ) {
35- ast:: StructKind :: Tuple ( field_list) => field_list,
36- _ => return None ,
37- } ;
38-
39- // skip 1-tuple variants
40- if field_list. fields ( ) . count ( ) == 1 {
41- return None ;
42- }
37+ let field_list = extract_field_list_if_applicable ( & variant) ?;
4338
4439 let variant_name = variant. name ( ) ?;
4540 let variant_hir = ctx. sema . to_def ( & variant) ?;
46- if existing_struct_def ( ctx. db ( ) , & variant_name, & variant_hir) {
41+ if existing_definition ( ctx. db ( ) , & variant_name, & variant_hir) {
4742 return None ;
4843 }
44+
4945 let enum_ast = variant. parent_enum ( ) ;
50- let visibility = enum_ast. visibility ( ) ;
5146 let enum_hir = ctx. sema . to_def ( & enum_ast) ?;
52- let variant_hir_name = variant_hir. name ( ctx. db ( ) ) ;
53- let enum_module_def = ModuleDef :: from ( enum_hir) ;
54- let current_module = enum_hir. module ( ctx. db ( ) ) ;
5547 let target = variant. syntax ( ) . text_range ( ) ;
5648 acc. add (
5749 AssistId ( "extract_struct_from_enum_variant" , AssistKind :: RefactorRewrite ) ,
5850 "Extract struct from enum variant" ,
5951 target,
6052 |builder| {
61- let definition = Definition :: ModuleDef ( ModuleDef :: EnumVariant ( variant_hir) ) ;
62- let res = definition. usages ( & ctx. sema ) . all ( ) ;
53+ let variant_hir_name = variant_hir. name ( ctx. db ( ) ) ;
54+ let enum_module_def = ModuleDef :: from ( enum_hir) ;
55+ let usages =
56+ Definition :: ModuleDef ( ModuleDef :: EnumVariant ( variant_hir) ) . usages ( & ctx. sema ) . all ( ) ;
6357
6458 let mut visited_modules_set = FxHashSet :: default ( ) ;
59+ let current_module = enum_hir. module ( ctx. db ( ) ) ;
6560 visited_modules_set. insert ( current_module) ;
6661 let mut rewriters = FxHashMap :: default ( ) ;
67- for reference in res {
62+ for reference in usages {
6863 let rewriter = rewriters
6964 . entry ( reference. file_range . file_id )
7065 . or_insert_with ( SyntaxRewriter :: default) ;
@@ -86,26 +81,49 @@ pub(crate) fn extract_struct_from_enum_variant(
8681 builder. rewrite ( rewriter) ;
8782 }
8883 builder. edit_file ( ctx. frange . file_id ) ;
89- update_variant ( & mut rewriter, & variant_name , & field_list ) ;
84+ update_variant ( & mut rewriter, & variant ) ;
9085 extract_struct_def (
9186 & mut rewriter,
9287 & enum_ast,
9388 variant_name. clone ( ) ,
9489 & field_list,
9590 & variant. parent_enum ( ) . syntax ( ) . clone ( ) . into ( ) ,
96- visibility,
91+ enum_ast . visibility ( ) ,
9792 ) ;
9893 builder. rewrite ( rewriter) ;
9994 } ,
10095 )
10196}
10297
103- fn existing_struct_def ( db : & RootDatabase , variant_name : & ast:: Name , variant : & EnumVariant ) -> bool {
98+ fn extract_field_list_if_applicable (
99+ variant : & ast:: Variant ,
100+ ) -> Option < Either < ast:: RecordFieldList , ast:: TupleFieldList > > {
101+ match variant. kind ( ) {
102+ ast:: StructKind :: Record ( field_list) if field_list. fields ( ) . next ( ) . is_some ( ) => {
103+ Some ( Either :: Left ( field_list) )
104+ }
105+ ast:: StructKind :: Tuple ( field_list) if field_list. fields ( ) . count ( ) > 1 => {
106+ Some ( Either :: Right ( field_list) )
107+ }
108+ _ => None ,
109+ }
110+ }
111+
112+ fn existing_definition ( db : & RootDatabase , variant_name : & ast:: Name , variant : & EnumVariant ) -> bool {
104113 variant
105114 . parent_enum ( db)
106115 . module ( db)
107116 . scope ( db, None )
108117 . into_iter ( )
118+ . filter ( |( _, def) | match def {
119+ // only check type-namespace
120+ hir:: ScopeDef :: ModuleDef ( def) => matches ! ( def,
121+ ModuleDef :: Module ( _) | ModuleDef :: Adt ( _) |
122+ ModuleDef :: EnumVariant ( _) | ModuleDef :: Trait ( _) |
123+ ModuleDef :: TypeAlias ( _) | ModuleDef :: BuiltinType ( _)
124+ ) ,
125+ _ => false ,
126+ } )
109127 . any ( |( name, _) | name == variant_name. as_name ( ) )
110128}
111129
@@ -133,19 +151,29 @@ fn extract_struct_def(
133151 rewriter : & mut SyntaxRewriter ,
134152 enum_ : & ast:: Enum ,
135153 variant_name : ast:: Name ,
136- variant_list : & ast:: TupleFieldList ,
154+ field_list : & Either < ast:: RecordFieldList , ast :: TupleFieldList > ,
137155 start_offset : & SyntaxElement ,
138156 visibility : Option < ast:: Visibility > ,
139157) -> Option < ( ) > {
140- let variant_list = make:: tuple_field_list (
141- variant_list
142- . fields ( )
143- . flat_map ( |field| Some ( make:: tuple_field ( Some ( make:: visibility_pub ( ) ) , field. ty ( ) ?) ) ) ,
144- ) ;
158+ let pub_vis = Some ( make:: visibility_pub ( ) ) ;
159+ let field_list = match field_list {
160+ Either :: Left ( field_list) => {
161+ make:: record_field_list ( field_list. fields ( ) . flat_map ( |field| {
162+ Some ( make:: record_field ( pub_vis. clone ( ) , field. name ( ) ?, field. ty ( ) ?) )
163+ } ) )
164+ . into ( )
165+ }
166+ Either :: Right ( field_list) => make:: tuple_field_list (
167+ field_list
168+ . fields ( )
169+ . flat_map ( |field| Some ( make:: tuple_field ( pub_vis. clone ( ) , field. ty ( ) ?) ) ) ,
170+ )
171+ . into ( ) ,
172+ } ;
145173
146174 rewriter. insert_before (
147175 start_offset,
148- make:: struct_ ( visibility, variant_name, None , variant_list . into ( ) ) . syntax ( ) ,
176+ make:: struct_ ( visibility, variant_name, None , field_list ) . syntax ( ) ,
149177 ) ;
150178 rewriter. insert_before ( start_offset, & make:: tokens:: blank_line ( ) ) ;
151179
@@ -156,15 +184,14 @@ fn extract_struct_def(
156184 Some ( ( ) )
157185}
158186
159- fn update_variant (
160- rewriter : & mut SyntaxRewriter ,
161- variant_name : & ast:: Name ,
162- field_list : & ast:: TupleFieldList ,
163- ) -> Option < ( ) > {
164- let ( l, r) : ( SyntaxElement , SyntaxElement ) =
165- ( field_list. l_paren_token ( ) ?. into ( ) , field_list. r_paren_token ( ) ?. into ( ) ) ;
166- let replacement = vec ! [ l, variant_name. syntax( ) . clone( ) . into( ) , r] ;
167- rewriter. replace_with_many ( field_list. syntax ( ) , replacement) ;
187+ fn update_variant ( rewriter : & mut SyntaxRewriter , variant : & ast:: Variant ) -> Option < ( ) > {
188+ let name = variant. name ( ) ?;
189+ let tuple_field = make:: tuple_field ( None , make:: ty ( name. text ( ) ) ) ;
190+ let replacement = make:: variant (
191+ name,
192+ Some ( ast:: FieldList :: TupleFieldList ( make:: tuple_field_list ( iter:: once ( tuple_field) ) ) ) ,
193+ ) ;
194+ rewriter. replace ( variant. syntax ( ) , replacement. syntax ( ) ) ;
168195 Some ( ( ) )
169196}
170197
@@ -211,12 +238,47 @@ mod tests {
211238 use super :: * ;
212239
213240 #[ test]
214- fn test_extract_struct_several_fields ( ) {
241+ fn test_extract_struct_several_fields_tuple ( ) {
215242 check_assist (
216243 extract_struct_from_enum_variant,
217244 "enum A { <|>One(u32, u32) }" ,
218245 r#"struct One(pub u32, pub u32);
219246
247+ enum A { One(One) }"# ,
248+ ) ;
249+ }
250+
251+ #[ test]
252+ fn test_extract_struct_several_fields_named ( ) {
253+ check_assist (
254+ extract_struct_from_enum_variant,
255+ "enum A { <|>One { foo: u32, bar: u32 } }" ,
256+ r#"struct One{ pub foo: u32, pub bar: u32 }
257+
258+ enum A { One(One) }"# ,
259+ ) ;
260+ }
261+
262+ #[ test]
263+ fn test_extract_struct_one_field_named ( ) {
264+ check_assist (
265+ extract_struct_from_enum_variant,
266+ "enum A { <|>One { foo: u32 } }" ,
267+ r#"struct One{ pub foo: u32 }
268+
269+ enum A { One(One) }"# ,
270+ ) ;
271+ }
272+
273+ #[ test]
274+ fn test_extract_enum_variant_name_value_namespace ( ) {
275+ check_assist (
276+ extract_struct_from_enum_variant,
277+ r#"const One: () = ();
278+ enum A { <|>One(u32, u32) }"# ,
279+ r#"const One: () = ();
280+ struct One(pub u32, pub u32);
281+
220282enum A { One(One) }"# ,
221283 ) ;
222284 }
@@ -298,12 +360,22 @@ fn another_fn() {
298360 fn test_extract_enum_not_applicable_if_struct_exists ( ) {
299361 check_not_applicable (
300362 r#"struct One;
301- enum A { <|>One(u8) }"# ,
363+ enum A { <|>One(u8, u32 ) }"# ,
302364 ) ;
303365 }
304366
305367 #[ test]
306368 fn test_extract_not_applicable_one_field ( ) {
307369 check_not_applicable ( r"enum A { <|>One(u32) }" ) ;
308370 }
371+
372+ #[ test]
373+ fn test_extract_not_applicable_no_field_tuple ( ) {
374+ check_not_applicable ( r"enum A { <|>None() }" ) ;
375+ }
376+
377+ #[ test]
378+ fn test_extract_not_applicable_no_field_named ( ) {
379+ check_not_applicable ( r"enum A { <|>None {} }" ) ;
380+ }
309381}
0 commit comments