@@ -5,10 +5,9 @@ use hir::{AsName, EnumVariant, Module, ModuleDef, Name};
55use ide_db:: { defs:: Definition , search:: Reference , RootDatabase } ;
66use rustc_hash:: { FxHashMap , FxHashSet } ;
77use syntax:: {
8- algo:: find_node_at_offset,
9- algo:: SyntaxRewriter ,
10- ast:: { self , edit:: IndentLevel , make, ArgListOwner , AstNode , NameOwner , VisibilityOwner } ,
11- SourceFile , SyntaxElement ,
8+ algo:: { find_node_at_offset, SyntaxRewriter } ,
9+ ast:: { self , edit:: IndentLevel , make, AstNode , NameOwner , VisibilityOwner } ,
10+ SourceFile , SyntaxElement , SyntaxNode , T ,
1211} ;
1312
1413use crate :: {
@@ -130,17 +129,17 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &En
130129fn insert_import (
131130 ctx : & AssistContext ,
132131 rewriter : & mut SyntaxRewriter ,
133- path : & ast :: PathExpr ,
132+ scope_node : & SyntaxNode ,
134133 module : & Module ,
135134 enum_module_def : & ModuleDef ,
136135 variant_hir_name : & Name ,
137136) -> Option < ( ) > {
138137 let db = ctx. db ( ) ;
139138 let mod_path = module. find_use_path ( db, enum_module_def. clone ( ) ) ;
140- if let Some ( mut mod_path) = mod_path {
139+ if let Some ( mut mod_path) = mod_path. filter ( |path| path . len ( ) > 1 ) {
141140 mod_path. segments . pop ( ) ;
142141 mod_path. segments . push ( variant_hir_name. clone ( ) ) ;
143- let scope = ImportScope :: find_insert_use_container ( path . syntax ( ) , ctx) ?;
142+ let scope = ImportScope :: find_insert_use_container ( scope_node , ctx) ?;
144143
145144 * rewriter += insert_use ( & scope, mod_path_to_ast ( & mod_path) , ctx. config . insert_use . merge ) ;
146145 }
@@ -204,27 +203,31 @@ fn update_reference(
204203 variant_hir_name : & Name ,
205204 visited_modules_set : & mut FxHashSet < Module > ,
206205) -> Option < ( ) > {
207- let path_expr: ast:: PathExpr = find_node_at_offset :: < ast:: PathExpr > (
208- source_file. syntax ( ) ,
209- reference. file_range . range . start ( ) ,
210- ) ?;
211- let call = path_expr. syntax ( ) . parent ( ) . and_then ( ast:: CallExpr :: cast) ?;
212- let list = call. arg_list ( ) ?;
213- let segment = path_expr. path ( ) ?. segment ( ) ?;
214- let module = ctx. sema . scope ( & path_expr. syntax ( ) ) . module ( ) ?;
206+ let offset = reference. file_range . range . start ( ) ;
207+ let ( segment, expr) = if let Some ( path_expr) =
208+ find_node_at_offset :: < ast:: PathExpr > ( source_file. syntax ( ) , offset)
209+ {
210+ // tuple variant
211+ ( path_expr. path ( ) ?. segment ( ) ?, path_expr. syntax ( ) . parent ( ) ?. clone ( ) )
212+ } else if let Some ( record_expr) =
213+ find_node_at_offset :: < ast:: RecordExpr > ( source_file. syntax ( ) , offset)
214+ {
215+ // record variant
216+ ( record_expr. path ( ) ?. segment ( ) ?, record_expr. syntax ( ) . clone ( ) )
217+ } else {
218+ return None ;
219+ } ;
220+
221+ let module = ctx. sema . scope ( & expr) . module ( ) ?;
215222 if !visited_modules_set. contains ( & module) {
216- if insert_import ( ctx, rewriter, & path_expr, & module, enum_module_def, variant_hir_name)
217- . is_some ( )
223+ if insert_import ( ctx, rewriter, & expr, & module, enum_module_def, variant_hir_name) . is_some ( )
218224 {
219225 visited_modules_set. insert ( module) ;
220226 }
221227 }
222-
223- let lparen = syntax:: SyntaxElement :: from ( list. l_paren_token ( ) ?) ;
224- let rparen = syntax:: SyntaxElement :: from ( list. r_paren_token ( ) ?) ;
225- rewriter. insert_after ( & lparen, segment. syntax ( ) ) ;
226- rewriter. insert_after ( & lparen, & lparen) ;
227- rewriter. insert_before ( & rparen, & rparen) ;
228+ rewriter. insert_after ( segment. syntax ( ) , & make:: token ( T ! [ '(' ] ) ) ;
229+ rewriter. insert_after ( segment. syntax ( ) , segment. syntax ( ) ) ;
230+ rewriter. insert_after ( & expr, & make:: token ( T ! [ ')' ] ) ) ;
228231 Some ( ( ) )
229232}
230233
@@ -345,6 +348,33 @@ fn another_fn() {
345348 ) ;
346349 }
347350
351+ #[ test]
352+ fn extract_record_fix_references ( ) {
353+ check_assist (
354+ extract_struct_from_enum_variant,
355+ r#"
356+ enum E {
357+ <|>V { i: i32, j: i32 }
358+ }
359+
360+ fn f() {
361+ let e = E::V { i: 9, j: 2 };
362+ }
363+ "# ,
364+ r#"
365+ struct V{ pub i: i32, pub j: i32 }
366+
367+ enum E {
368+ V(V)
369+ }
370+
371+ fn f() {
372+ let e = E::V(V { i: 9, j: 2 });
373+ }
374+ "# ,
375+ )
376+ }
377+
348378 #[ test]
349379 fn test_several_files ( ) {
350380 check_assist (
@@ -372,8 +402,6 @@ enum E {
372402mod foo;
373403
374404//- /foo.rs
375- use V;
376-
377405use crate::E;
378406fn f() {
379407 let e = E::V(V(9, 2));
@@ -384,7 +412,6 @@ fn f() {
384412
385413 #[ test]
386414 fn test_several_files_record ( ) {
387- // FIXME: this should fix the usage as well!
388415 check_assist (
389416 extract_struct_from_enum_variant,
390417 r#"
@@ -401,13 +428,19 @@ fn f() {
401428}
402429"# ,
403430 r#"
431+ //- /main.rs
404432struct V{ pub i: i32, pub j: i32 }
405433
406434enum E {
407435 V(V)
408436}
409437mod foo;
410438
439+ //- /foo.rs
440+ use crate::E;
441+ fn f() {
442+ let e = E::V(V { i: 9, j: 2 });
443+ }
411444"# ,
412445 )
413446 }
0 commit comments