Skip to content

Commit 6145234

Browse files
committed
Support struct variants in extract_struct_from_enum_variant
1 parent dc9842b commit 6145234

File tree

3 files changed

+101
-49
lines changed

3 files changed

+101
-49
lines changed

crates/assists/src/handlers/extract_struct_from_enum_variant.rs

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use std::iter;
2+
3+
use either::Either;
14
use hir::{AsName, EnumVariant, Module, ModuleDef, Name};
25
use ide_db::{defs::Definition, search::Reference, RootDatabase};
36
use rustc_hash::{FxHashMap, FxHashSet};
@@ -31,48 +34,32 @@ pub(crate) fn extract_struct_from_enum_variant(
3134
ctx: &AssistContext,
3235
) -> Option<()> {
3336
let variant = ctx.find_node_at_offset::<ast::Variant>()?;
34-
35-
fn is_applicable_variant(variant: &ast::Variant) -> bool {
36-
1 < match variant.kind() {
37-
ast::StructKind::Record(field_list) => field_list.fields().count(),
38-
ast::StructKind::Tuple(field_list) => field_list.fields().count(),
39-
ast::StructKind::Unit => 0,
40-
}
41-
}
42-
43-
if !is_applicable_variant(&variant) {
44-
return None;
45-
}
46-
47-
let field_list = match variant.kind() {
48-
ast::StructKind::Tuple(field_list) => field_list,
49-
_ => return None,
50-
};
37+
let field_list = extract_field_list_if_applicable(&variant)?;
5138

5239
let variant_name = variant.name()?;
5340
let variant_hir = ctx.sema.to_def(&variant)?;
5441
if existing_definition(ctx.db(), &variant_name, &variant_hir) {
5542
return None;
5643
}
44+
5745
let enum_ast = variant.parent_enum();
58-
let visibility = enum_ast.visibility();
5946
let enum_hir = ctx.sema.to_def(&enum_ast)?;
60-
let variant_hir_name = variant_hir.name(ctx.db());
61-
let enum_module_def = ModuleDef::from(enum_hir);
62-
let current_module = enum_hir.module(ctx.db());
6347
let target = variant.syntax().text_range();
6448
acc.add(
6549
AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite),
6650
"Extract struct from enum variant",
6751
target,
6852
|builder| {
69-
let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir));
70-
let res = definition.usages(&ctx.sema).all();
53+
let variant_hir_name = variant_hir.name(ctx.db());
54+
let enum_module_def = ModuleDef::from(enum_hir);
55+
let usages =
56+
Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)).usages(&ctx.sema).all();
7157

7258
let mut visited_modules_set = FxHashSet::default();
59+
let current_module = enum_hir.module(ctx.db());
7360
visited_modules_set.insert(current_module);
7461
let mut rewriters = FxHashMap::default();
75-
for reference in res {
62+
for reference in usages {
7663
let rewriter = rewriters
7764
.entry(reference.file_range.file_id)
7865
.or_insert_with(SyntaxRewriter::default);
@@ -94,20 +81,34 @@ pub(crate) fn extract_struct_from_enum_variant(
9481
builder.rewrite(rewriter);
9582
}
9683
builder.edit_file(ctx.frange.file_id);
97-
update_variant(&mut rewriter, &variant_name, &field_list);
84+
update_variant(&mut rewriter, &variant);
9885
extract_struct_def(
9986
&mut rewriter,
10087
&enum_ast,
10188
variant_name.clone(),
10289
&field_list,
10390
&variant.parent_enum().syntax().clone().into(),
104-
visibility,
91+
enum_ast.visibility(),
10592
);
10693
builder.rewrite(rewriter);
10794
},
10895
)
10996
}
11097

98+
fn extract_field_list_if_applicable(
99+
variant: &ast::Variant,
100+
) -> Option<Either<ast::RecordFieldList, ast::TupleFieldList>> {
101+
match variant.kind() {
102+
ast::StructKind::Record(field_list) if field_list.fields().next().is_some() => {
103+
Some(Either::Left(field_list))
104+
}
105+
ast::StructKind::Tuple(field_list) if field_list.fields().count() > 1 => {
106+
Some(Either::Right(field_list))
107+
}
108+
_ => None,
109+
}
110+
}
111+
111112
fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool {
112113
variant
113114
.parent_enum(db)
@@ -150,19 +151,29 @@ fn extract_struct_def(
150151
rewriter: &mut SyntaxRewriter,
151152
enum_: &ast::Enum,
152153
variant_name: ast::Name,
153-
variant_list: &ast::TupleFieldList,
154+
field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
154155
start_offset: &SyntaxElement,
155156
visibility: Option<ast::Visibility>,
156157
) -> Option<()> {
157-
let variant_list = make::tuple_field_list(
158-
variant_list
159-
.fields()
160-
.flat_map(|field| Some(make::tuple_field(Some(make::visibility_pub()), field.ty()?))),
161-
);
158+
let pub_vis = Some(make::visibility_pub());
159+
let field_list = match field_list {
160+
Either::Left(field_list) => {
161+
make::record_field_list(field_list.fields().flat_map(|field| {
162+
Some(make::record_field(pub_vis.clone(), field.name()?, field.ty()?))
163+
}))
164+
.into()
165+
}
166+
Either::Right(field_list) => make::tuple_field_list(
167+
field_list
168+
.fields()
169+
.flat_map(|field| Some(make::tuple_field(pub_vis.clone(), field.ty()?))),
170+
)
171+
.into(),
172+
};
162173

163174
rewriter.insert_before(
164175
start_offset,
165-
make::struct_(visibility, variant_name, None, variant_list.into()).syntax(),
176+
make::struct_(visibility, variant_name, None, field_list).syntax(),
166177
);
167178
rewriter.insert_before(start_offset, &make::tokens::blank_line());
168179

@@ -173,15 +184,14 @@ fn extract_struct_def(
173184
Some(())
174185
}
175186

176-
fn update_variant(
177-
rewriter: &mut SyntaxRewriter,
178-
variant_name: &ast::Name,
179-
field_list: &ast::TupleFieldList,
180-
) -> Option<()> {
181-
let (l, r): (SyntaxElement, SyntaxElement) =
182-
(field_list.l_paren_token()?.into(), field_list.r_paren_token()?.into());
183-
let replacement = vec![l, variant_name.syntax().clone().into(), r];
184-
rewriter.replace_with_many(field_list.syntax(), replacement);
187+
fn update_variant(rewriter: &mut SyntaxRewriter, variant: &ast::Variant) -> Option<()> {
188+
let name = variant.name()?;
189+
let tuple_field = make::tuple_field(None, make::ty(name.text()));
190+
let replacement = make::variant(
191+
name,
192+
Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))),
193+
);
194+
rewriter.replace(variant.syntax(), replacement.syntax());
185195
Some(())
186196
}
187197

@@ -243,10 +253,18 @@ enum A { One(One) }"#,
243253
check_assist(
244254
extract_struct_from_enum_variant,
245255
"enum A { <|>One { foo: u32, bar: u32 } }",
246-
r#"struct One {
247-
pub foo: u32,
248-
pub bar: u32
249-
}
256+
r#"struct One{ pub foo: u32, pub bar: u32 }
257+
258+
enum A { One(One) }"#,
259+
);
260+
}
261+
262+
#[test]
263+
fn test_extract_struct_one_field_named() {
264+
check_assist(
265+
extract_struct_from_enum_variant,
266+
"enum A { <|>One { foo: u32 } }",
267+
r#"struct One{ pub foo: u32 }
250268
251269
enum A { One(One) }"#,
252270
);
@@ -350,4 +368,14 @@ fn another_fn() {
350368
fn test_extract_not_applicable_one_field() {
351369
check_not_applicable(r"enum A { <|>One(u32) }");
352370
}
371+
372+
#[test]
373+
fn test_extract_not_applicable_no_field_tuple() {
374+
check_not_applicable(r"enum A { <|>None() }");
375+
}
376+
377+
#[test]
378+
fn test_extract_not_applicable_no_field_named() {
379+
check_not_applicable(r"enum A { <|>None {} }");
380+
}
353381
}

crates/ide/src/diagnostics/fixes.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ fn missing_record_expr_field_fix(
157157
return None;
158158
}
159159
let new_field = make::record_field(
160-
record_expr_field.field_name()?,
160+
None,
161+
make::name(record_expr_field.field_name()?.text()),
161162
make::ty(&new_field_type.display_source_code(sema.db, module.into()).ok()?),
162163
);
163164

crates/syntax/src/ast/make.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,16 @@ pub fn record_expr_field(name: ast::NameRef, expr: Option<ast::Expr>) -> ast::Re
110110
}
111111
}
112112

113-
pub fn record_field(name: ast::NameRef, ty: ast::Type) -> ast::RecordField {
114-
ast_from_text(&format!("struct S {{ {}: {}, }}", name, ty))
113+
pub fn record_field(
114+
visibility: Option<ast::Visibility>,
115+
name: ast::Name,
116+
ty: ast::Type,
117+
) -> ast::RecordField {
118+
let visibility = match visibility {
119+
None => String::new(),
120+
Some(it) => format!("{} ", it),
121+
};
122+
ast_from_text(&format!("struct S {{ {}{}: {}, }}", visibility, name, ty))
115123
}
116124

117125
pub fn block_expr(
@@ -360,6 +368,13 @@ pub fn tuple_field_list(fields: impl IntoIterator<Item = ast::TupleField>) -> as
360368
ast_from_text(&format!("struct f({});", fields))
361369
}
362370

371+
pub fn record_field_list(
372+
fields: impl IntoIterator<Item = ast::RecordField>,
373+
) -> ast::RecordFieldList {
374+
let fields = fields.into_iter().join(", ");
375+
ast_from_text(&format!("struct f {{ {} }}", fields))
376+
}
377+
363378
pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::TupleField {
364379
let visibility = match visibility {
365380
None => String::new(),
@@ -368,6 +383,14 @@ pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::T
368383
ast_from_text(&format!("struct f({}{});", visibility, ty))
369384
}
370385

386+
pub fn variant(name: ast::Name, field_list: Option<ast::FieldList>) -> ast::Variant {
387+
let field_list = match field_list {
388+
None => String::new(),
389+
Some(it) => format!("{}", it),
390+
};
391+
ast_from_text(&format!("enum f {{ {}{} }}", name, field_list))
392+
}
393+
371394
pub fn fn_(
372395
visibility: Option<ast::Visibility>,
373396
fn_name: ast::Name,

0 commit comments

Comments
 (0)