Skip to content

Commit e24cc77

Browse files
committed
Fix extract_struct_from_enum_variant not updating record references
1 parent d5775b3 commit e24cc77

File tree

1 file changed

+59
-26
lines changed

1 file changed

+59
-26
lines changed

crates/assists/src/handlers/extract_struct_from_enum_variant.rs

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@ use hir::{AsName, EnumVariant, Module, ModuleDef, Name};
55
use ide_db::{defs::Definition, search::Reference, RootDatabase};
66
use rustc_hash::{FxHashMap, FxHashSet};
77
use syntax::{
8-
algo::find_node_at_offset,
9-
algo::SyntaxRewriter,
10-
ast::{self, edit::IndentLevel, make, ArgListOwner, AstNode, NameOwner, VisibilityOwner},
11-
SourceFile, SyntaxElement,
8+
algo::{find_node_at_offset, SyntaxRewriter},
9+
ast::{self, edit::IndentLevel, make, AstNode, NameOwner, VisibilityOwner},
10+
SourceFile, SyntaxElement, SyntaxNode, T,
1211
};
1312

1413
use crate::{
@@ -130,17 +129,17 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &En
130129
fn insert_import(
131130
ctx: &AssistContext,
132131
rewriter: &mut SyntaxRewriter,
133-
path: &ast::PathExpr,
132+
scope_node: &SyntaxNode,
134133
module: &Module,
135134
enum_module_def: &ModuleDef,
136135
variant_hir_name: &Name,
137136
) -> Option<()> {
138137
let db = ctx.db();
139138
let mod_path = module.find_use_path(db, enum_module_def.clone());
140-
if let Some(mut mod_path) = mod_path {
139+
if let Some(mut mod_path) = mod_path.filter(|path| path.len() > 1) {
141140
mod_path.segments.pop();
142141
mod_path.segments.push(variant_hir_name.clone());
143-
let scope = ImportScope::find_insert_use_container(path.syntax(), ctx)?;
142+
let scope = ImportScope::find_insert_use_container(scope_node, ctx)?;
144143

145144
*rewriter += insert_use(&scope, mod_path_to_ast(&mod_path), ctx.config.insert_use.merge);
146145
}
@@ -204,27 +203,31 @@ fn update_reference(
204203
variant_hir_name: &Name,
205204
visited_modules_set: &mut FxHashSet<Module>,
206205
) -> Option<()> {
207-
let path_expr: ast::PathExpr = find_node_at_offset::<ast::PathExpr>(
208-
source_file.syntax(),
209-
reference.file_range.range.start(),
210-
)?;
211-
let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?;
212-
let list = call.arg_list()?;
213-
let segment = path_expr.path()?.segment()?;
214-
let module = ctx.sema.scope(&path_expr.syntax()).module()?;
206+
let offset = reference.file_range.range.start();
207+
let (segment, expr) = if let Some(path_expr) =
208+
find_node_at_offset::<ast::PathExpr>(source_file.syntax(), offset)
209+
{
210+
// tuple variant
211+
(path_expr.path()?.segment()?, path_expr.syntax().parent()?.clone())
212+
} else if let Some(record_expr) =
213+
find_node_at_offset::<ast::RecordExpr>(source_file.syntax(), offset)
214+
{
215+
// record variant
216+
(record_expr.path()?.segment()?, record_expr.syntax().clone())
217+
} else {
218+
return None;
219+
};
220+
221+
let module = ctx.sema.scope(&expr).module()?;
215222
if !visited_modules_set.contains(&module) {
216-
if insert_import(ctx, rewriter, &path_expr, &module, enum_module_def, variant_hir_name)
217-
.is_some()
223+
if insert_import(ctx, rewriter, &expr, &module, enum_module_def, variant_hir_name).is_some()
218224
{
219225
visited_modules_set.insert(module);
220226
}
221227
}
222-
223-
let lparen = syntax::SyntaxElement::from(list.l_paren_token()?);
224-
let rparen = syntax::SyntaxElement::from(list.r_paren_token()?);
225-
rewriter.insert_after(&lparen, segment.syntax());
226-
rewriter.insert_after(&lparen, &lparen);
227-
rewriter.insert_before(&rparen, &rparen);
228+
rewriter.insert_after(segment.syntax(), &make::token(T!['(']));
229+
rewriter.insert_after(segment.syntax(), segment.syntax());
230+
rewriter.insert_after(&expr, &make::token(T![')']));
228231
Some(())
229232
}
230233

@@ -345,6 +348,33 @@ fn another_fn() {
345348
);
346349
}
347350

351+
#[test]
352+
fn extract_record_fix_references() {
353+
check_assist(
354+
extract_struct_from_enum_variant,
355+
r#"
356+
enum E {
357+
<|>V { i: i32, j: i32 }
358+
}
359+
360+
fn f() {
361+
let e = E::V { i: 9, j: 2 };
362+
}
363+
"#,
364+
r#"
365+
struct V{ pub i: i32, pub j: i32 }
366+
367+
enum E {
368+
V(V)
369+
}
370+
371+
fn f() {
372+
let e = E::V(V { i: 9, j: 2 });
373+
}
374+
"#,
375+
)
376+
}
377+
348378
#[test]
349379
fn test_several_files() {
350380
check_assist(
@@ -372,8 +402,6 @@ enum E {
372402
mod foo;
373403
374404
//- /foo.rs
375-
use V;
376-
377405
use crate::E;
378406
fn f() {
379407
let e = E::V(V(9, 2));
@@ -384,7 +412,6 @@ fn f() {
384412

385413
#[test]
386414
fn test_several_files_record() {
387-
// FIXME: this should fix the usage as well!
388415
check_assist(
389416
extract_struct_from_enum_variant,
390417
r#"
@@ -401,13 +428,19 @@ fn f() {
401428
}
402429
"#,
403430
r#"
431+
//- /main.rs
404432
struct V{ pub i: i32, pub j: i32 }
405433
406434
enum E {
407435
V(V)
408436
}
409437
mod foo;
410438
439+
//- /foo.rs
440+
use crate::E;
441+
fn f() {
442+
let e = E::V(V { i: 9, j: 2 });
443+
}
411444
"#,
412445
)
413446
}

0 commit comments

Comments
 (0)