Skip to content

Commit bd6eeff

Browse files
bors[bot]Veykril
andauthored
Merge #6456
6456: Support record variants in extract_struct_from_enum_variant r=matklad a=Veykril As requested :) This also prevents the assist from being disabled if a definition in the value namespace exists with the same name as our new struct since that won't cause a collision #4468 Co-authored-by: Lukas Wirth <[email protected]>
2 parents 99a8e59 + 6145234 commit bd6eeff

File tree

3 files changed

+137
-41
lines changed

3 files changed

+137
-41
lines changed

crates/assists/src/handlers/extract_struct_from_enum_variant.rs

Lines changed: 110 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use std::iter;
2+
3+
use either::Either;
14
use hir::{AsName, EnumVariant, Module, ModuleDef, Name};
25
use ide_db::{defs::Definition, search::Reference, RootDatabase};
36
use rustc_hash::{FxHashMap, FxHashSet};
@@ -31,40 +34,32 @@ pub(crate) fn extract_struct_from_enum_variant(
3134
ctx: &AssistContext,
3235
) -> Option<()> {
3336
let variant = ctx.find_node_at_offset::<ast::Variant>()?;
34-
let field_list = match variant.kind() {
35-
ast::StructKind::Tuple(field_list) => field_list,
36-
_ => return None,
37-
};
38-
39-
// skip 1-tuple variants
40-
if field_list.fields().count() == 1 {
41-
return None;
42-
}
37+
let field_list = extract_field_list_if_applicable(&variant)?;
4338

4439
let variant_name = variant.name()?;
4540
let variant_hir = ctx.sema.to_def(&variant)?;
46-
if existing_struct_def(ctx.db(), &variant_name, &variant_hir) {
41+
if existing_definition(ctx.db(), &variant_name, &variant_hir) {
4742
return None;
4843
}
44+
4945
let enum_ast = variant.parent_enum();
50-
let visibility = enum_ast.visibility();
5146
let enum_hir = ctx.sema.to_def(&enum_ast)?;
52-
let variant_hir_name = variant_hir.name(ctx.db());
53-
let enum_module_def = ModuleDef::from(enum_hir);
54-
let current_module = enum_hir.module(ctx.db());
5547
let target = variant.syntax().text_range();
5648
acc.add(
5749
AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite),
5850
"Extract struct from enum variant",
5951
target,
6052
|builder| {
61-
let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir));
62-
let res = definition.usages(&ctx.sema).all();
53+
let variant_hir_name = variant_hir.name(ctx.db());
54+
let enum_module_def = ModuleDef::from(enum_hir);
55+
let usages =
56+
Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)).usages(&ctx.sema).all();
6357

6458
let mut visited_modules_set = FxHashSet::default();
59+
let current_module = enum_hir.module(ctx.db());
6560
visited_modules_set.insert(current_module);
6661
let mut rewriters = FxHashMap::default();
67-
for reference in res {
62+
for reference in usages {
6863
let rewriter = rewriters
6964
.entry(reference.file_range.file_id)
7065
.or_insert_with(SyntaxRewriter::default);
@@ -86,26 +81,49 @@ pub(crate) fn extract_struct_from_enum_variant(
8681
builder.rewrite(rewriter);
8782
}
8883
builder.edit_file(ctx.frange.file_id);
89-
update_variant(&mut rewriter, &variant_name, &field_list);
84+
update_variant(&mut rewriter, &variant);
9085
extract_struct_def(
9186
&mut rewriter,
9287
&enum_ast,
9388
variant_name.clone(),
9489
&field_list,
9590
&variant.parent_enum().syntax().clone().into(),
96-
visibility,
91+
enum_ast.visibility(),
9792
);
9893
builder.rewrite(rewriter);
9994
},
10095
)
10196
}
10297

103-
fn existing_struct_def(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool {
98+
fn extract_field_list_if_applicable(
99+
variant: &ast::Variant,
100+
) -> Option<Either<ast::RecordFieldList, ast::TupleFieldList>> {
101+
match variant.kind() {
102+
ast::StructKind::Record(field_list) if field_list.fields().next().is_some() => {
103+
Some(Either::Left(field_list))
104+
}
105+
ast::StructKind::Tuple(field_list) if field_list.fields().count() > 1 => {
106+
Some(Either::Right(field_list))
107+
}
108+
_ => None,
109+
}
110+
}
111+
112+
fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool {
104113
variant
105114
.parent_enum(db)
106115
.module(db)
107116
.scope(db, None)
108117
.into_iter()
118+
.filter(|(_, def)| match def {
119+
// only check type-namespace
120+
hir::ScopeDef::ModuleDef(def) => matches!(def,
121+
ModuleDef::Module(_) | ModuleDef::Adt(_) |
122+
ModuleDef::EnumVariant(_) | ModuleDef::Trait(_) |
123+
ModuleDef::TypeAlias(_) | ModuleDef::BuiltinType(_)
124+
),
125+
_ => false,
126+
})
109127
.any(|(name, _)| name == variant_name.as_name())
110128
}
111129

@@ -133,19 +151,29 @@ fn extract_struct_def(
133151
rewriter: &mut SyntaxRewriter,
134152
enum_: &ast::Enum,
135153
variant_name: ast::Name,
136-
variant_list: &ast::TupleFieldList,
154+
field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
137155
start_offset: &SyntaxElement,
138156
visibility: Option<ast::Visibility>,
139157
) -> Option<()> {
140-
let variant_list = make::tuple_field_list(
141-
variant_list
142-
.fields()
143-
.flat_map(|field| Some(make::tuple_field(Some(make::visibility_pub()), field.ty()?))),
144-
);
158+
let pub_vis = Some(make::visibility_pub());
159+
let field_list = match field_list {
160+
Either::Left(field_list) => {
161+
make::record_field_list(field_list.fields().flat_map(|field| {
162+
Some(make::record_field(pub_vis.clone(), field.name()?, field.ty()?))
163+
}))
164+
.into()
165+
}
166+
Either::Right(field_list) => make::tuple_field_list(
167+
field_list
168+
.fields()
169+
.flat_map(|field| Some(make::tuple_field(pub_vis.clone(), field.ty()?))),
170+
)
171+
.into(),
172+
};
145173

146174
rewriter.insert_before(
147175
start_offset,
148-
make::struct_(visibility, variant_name, None, variant_list.into()).syntax(),
176+
make::struct_(visibility, variant_name, None, field_list).syntax(),
149177
);
150178
rewriter.insert_before(start_offset, &make::tokens::blank_line());
151179

@@ -156,15 +184,14 @@ fn extract_struct_def(
156184
Some(())
157185
}
158186

159-
fn update_variant(
160-
rewriter: &mut SyntaxRewriter,
161-
variant_name: &ast::Name,
162-
field_list: &ast::TupleFieldList,
163-
) -> Option<()> {
164-
let (l, r): (SyntaxElement, SyntaxElement) =
165-
(field_list.l_paren_token()?.into(), field_list.r_paren_token()?.into());
166-
let replacement = vec![l, variant_name.syntax().clone().into(), r];
167-
rewriter.replace_with_many(field_list.syntax(), replacement);
187+
fn update_variant(rewriter: &mut SyntaxRewriter, variant: &ast::Variant) -> Option<()> {
188+
let name = variant.name()?;
189+
let tuple_field = make::tuple_field(None, make::ty(name.text()));
190+
let replacement = make::variant(
191+
name,
192+
Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))),
193+
);
194+
rewriter.replace(variant.syntax(), replacement.syntax());
168195
Some(())
169196
}
170197

@@ -211,12 +238,47 @@ mod tests {
211238
use super::*;
212239

213240
#[test]
214-
fn test_extract_struct_several_fields() {
241+
fn test_extract_struct_several_fields_tuple() {
215242
check_assist(
216243
extract_struct_from_enum_variant,
217244
"enum A { <|>One(u32, u32) }",
218245
r#"struct One(pub u32, pub u32);
219246
247+
enum A { One(One) }"#,
248+
);
249+
}
250+
251+
#[test]
252+
fn test_extract_struct_several_fields_named() {
253+
check_assist(
254+
extract_struct_from_enum_variant,
255+
"enum A { <|>One { foo: u32, bar: u32 } }",
256+
r#"struct One{ pub foo: u32, pub bar: u32 }
257+
258+
enum A { One(One) }"#,
259+
);
260+
}
261+
262+
#[test]
263+
fn test_extract_struct_one_field_named() {
264+
check_assist(
265+
extract_struct_from_enum_variant,
266+
"enum A { <|>One { foo: u32 } }",
267+
r#"struct One{ pub foo: u32 }
268+
269+
enum A { One(One) }"#,
270+
);
271+
}
272+
273+
#[test]
274+
fn test_extract_enum_variant_name_value_namespace() {
275+
check_assist(
276+
extract_struct_from_enum_variant,
277+
r#"const One: () = ();
278+
enum A { <|>One(u32, u32) }"#,
279+
r#"const One: () = ();
280+
struct One(pub u32, pub u32);
281+
220282
enum A { One(One) }"#,
221283
);
222284
}
@@ -298,12 +360,22 @@ fn another_fn() {
298360
fn test_extract_enum_not_applicable_if_struct_exists() {
299361
check_not_applicable(
300362
r#"struct One;
301-
enum A { <|>One(u8) }"#,
363+
enum A { <|>One(u8, u32) }"#,
302364
);
303365
}
304366

305367
#[test]
306368
fn test_extract_not_applicable_one_field() {
307369
check_not_applicable(r"enum A { <|>One(u32) }");
308370
}
371+
372+
#[test]
373+
fn test_extract_not_applicable_no_field_tuple() {
374+
check_not_applicable(r"enum A { <|>None() }");
375+
}
376+
377+
#[test]
378+
fn test_extract_not_applicable_no_field_named() {
379+
check_not_applicable(r"enum A { <|>None {} }");
380+
}
309381
}

crates/ide/src/diagnostics/fixes.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ fn missing_record_expr_field_fix(
157157
return None;
158158
}
159159
let new_field = make::record_field(
160-
record_expr_field.field_name()?,
160+
None,
161+
make::name(record_expr_field.field_name()?.text()),
161162
make::ty(&new_field_type.display_source_code(sema.db, module.into()).ok()?),
162163
);
163164

crates/syntax/src/ast/make.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,16 @@ pub fn record_expr_field(name: ast::NameRef, expr: Option<ast::Expr>) -> ast::Re
110110
}
111111
}
112112

113-
pub fn record_field(name: ast::NameRef, ty: ast::Type) -> ast::RecordField {
114-
ast_from_text(&format!("struct S {{ {}: {}, }}", name, ty))
113+
pub fn record_field(
114+
visibility: Option<ast::Visibility>,
115+
name: ast::Name,
116+
ty: ast::Type,
117+
) -> ast::RecordField {
118+
let visibility = match visibility {
119+
None => String::new(),
120+
Some(it) => format!("{} ", it),
121+
};
122+
ast_from_text(&format!("struct S {{ {}{}: {}, }}", visibility, name, ty))
115123
}
116124

117125
pub fn block_expr(
@@ -360,6 +368,13 @@ pub fn tuple_field_list(fields: impl IntoIterator<Item = ast::TupleField>) -> as
360368
ast_from_text(&format!("struct f({});", fields))
361369
}
362370

371+
pub fn record_field_list(
372+
fields: impl IntoIterator<Item = ast::RecordField>,
373+
) -> ast::RecordFieldList {
374+
let fields = fields.into_iter().join(", ");
375+
ast_from_text(&format!("struct f {{ {} }}", fields))
376+
}
377+
363378
pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::TupleField {
364379
let visibility = match visibility {
365380
None => String::new(),
@@ -368,6 +383,14 @@ pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::T
368383
ast_from_text(&format!("struct f({}{});", visibility, ty))
369384
}
370385

386+
pub fn variant(name: ast::Name, field_list: Option<ast::FieldList>) -> ast::Variant {
387+
let field_list = match field_list {
388+
None => String::new(),
389+
Some(it) => format!("{}", it),
390+
};
391+
ast_from_text(&format!("enum f {{ {}{} }}", name, field_list))
392+
}
393+
371394
pub fn fn_(
372395
visibility: Option<ast::Visibility>,
373396
fn_name: ast::Name,

0 commit comments

Comments
 (0)