Skip to content

Commit 4a0527f

Browse files
committed
split ted from gen_trait_fn_body
1 parent 48ccbe0 commit 4a0527f

File tree

4 files changed

+65
-88
lines changed

4 files changed

+65
-88
lines changed

crates/ide-assists/src/handlers/add_missing_impl_members.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use hir::HasSource;
22
use syntax::{
33
Edition,
44
ast::{self, AstNode, make},
5+
ted,
56
};
67

78
use crate::{
@@ -157,19 +158,21 @@ fn add_missing_impl_members_inner(
157158
&target_scope,
158159
);
159160

161+
let mut editor = edit.make_editor(impl_def.syntax());
160162
if let Some(cap) = ctx.config.snippet_cap {
161163
let mut placeholder = None;
162164
if let DefaultMethods::No = mode {
163165
if let ast::AssocItem::Fn(func) = &first_new_item {
164-
if try_gen_trait_body(
166+
if let Some(body) = try_gen_trait_body(
165167
ctx,
166168
func,
167169
trait_ref,
168170
&impl_def,
169171
target_scope.krate().edition(ctx.sema.db),
170-
)
171-
.is_none()
172+
) && let Some(func_body) = func.body()
172173
{
174+
ted::replace(func_body.syntax(), body.syntax());
175+
} else {
173176
if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast)
174177
{
175178
if m.syntax().text() == "todo!()" {
@@ -195,7 +198,7 @@ fn try_gen_trait_body(
195198
trait_ref: hir::TraitRef<'_>,
196199
impl_def: &ast::Impl,
197200
edition: Edition,
198-
) -> Option<()> {
201+
) -> Option<ast::BlockExpr> {
199202
let trait_path = make::ext::ident_path(
200203
&trait_ref.trait_().name(ctx.db()).display(ctx.db(), edition).to_string(),
201204
);

crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -135,47 +135,29 @@ fn add_assist(
135135
&annotated_name,
136136
trait_,
137137
replace_trait_path,
138+
impl_is_unsafe,
138139
);
139140
update_attribute(builder, old_derives, old_tree, old_trait_path, attr);
140141

141142
let trait_path = make::ty_path(replace_trait_path.clone());
142143

143144
match (ctx.config.snippet_cap, impl_def_with_items) {
144145
(None, None) => {
145-
let impl_def = generate_trait_impl(adt, trait_path);
146-
if impl_is_unsafe {
147-
ted::insert(
148-
Position::first_child_of(impl_def.syntax()),
149-
make::token(T![unsafe]),
150-
);
151-
}
146+
let impl_def = generate_trait_impl(impl_is_unsafe, adt, trait_path);
152147

153148
ted::insert_all(
154149
insert_after,
155150
vec![make::tokens::blank_line().into(), impl_def.syntax().clone().into()],
156151
);
157152
}
158153
(None, Some((impl_def, _))) => {
159-
if impl_is_unsafe {
160-
ted::insert(
161-
Position::first_child_of(impl_def.syntax()),
162-
make::token(T![unsafe]),
163-
);
164-
}
165154
ted::insert_all(
166155
insert_after,
167156
vec![make::tokens::blank_line().into(), impl_def.syntax().clone().into()],
168157
);
169158
}
170159
(Some(cap), None) => {
171-
let impl_def = generate_trait_impl(adt, trait_path);
172-
173-
if impl_is_unsafe {
174-
ted::insert(
175-
Position::first_child_of(impl_def.syntax()),
176-
make::token(T![unsafe]),
177-
);
178-
}
160+
let impl_def = generate_trait_impl(impl_is_unsafe, adt, trait_path);
179161

180162
if let Some(l_curly) = impl_def.assoc_item_list().and_then(|it| it.l_curly_token())
181163
{
@@ -188,26 +170,13 @@ fn add_assist(
188170
);
189171
}
190172
(Some(cap), Some((impl_def, first_assoc_item))) => {
191-
let mut added_snippet = false;
192-
193-
if impl_is_unsafe {
194-
ted::insert(
195-
Position::first_child_of(impl_def.syntax()),
196-
make::token(T![unsafe]),
197-
);
198-
}
199-
200-
if let ast::AssocItem::Fn(ref func) = first_assoc_item {
201-
if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast) {
202-
if m.syntax().text() == "todo!()" {
203-
// Make the `todo!()` a placeholder
204-
builder.add_placeholder_snippet(cap, m);
205-
added_snippet = true;
206-
}
207-
}
208-
}
209-
210-
if !added_snippet {
173+
if let ast::AssocItem::Fn(ref func) = first_assoc_item
174+
&& let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast)
175+
&& m.syntax().text() == "todo!()"
176+
{
177+
// Make the `todo!()` a placeholder
178+
builder.add_placeholder_snippet(cap, m);
179+
} else {
211180
// If we haven't already added a snippet, add a tabstop before the generated function
212181
builder.add_tabstop_before(cap, first_assoc_item);
213182
}
@@ -228,6 +197,7 @@ fn impl_def_from_trait(
228197
annotated_name: &ast::Name,
229198
trait_: Option<hir::Trait>,
230199
trait_path: &ast::Path,
200+
impl_is_unsafe: bool,
231201
) -> Option<(ast::Impl, ast::AssocItem)> {
232202
let trait_ = trait_?;
233203
let target_scope = sema.scope(annotated_name.syntax())?;
@@ -245,14 +215,18 @@ fn impl_def_from_trait(
245215
if trait_items.is_empty() {
246216
return None;
247217
}
248-
let impl_def = generate_trait_impl(adt, make::ty_path(trait_path.clone()));
218+
let impl_def = generate_trait_impl(impl_is_unsafe, adt, make::ty_path(trait_path.clone()));
249219

250220
let first_assoc_item =
251221
add_trait_assoc_items_to_impl(sema, config, &trait_items, trait_, &impl_def, &target_scope);
252222

253223
// Generate a default `impl` function body for the derived trait.
254224
if let ast::AssocItem::Fn(ref func) = first_assoc_item {
255-
let _ = gen_trait_fn_body(func, trait_path, adt, None);
225+
if let Some(body) = gen_trait_fn_body(func, trait_path, adt, None)
226+
&& let Some(func_body) = func.body()
227+
{
228+
ted::replace(func_body.syntax(), body.syntax());
229+
}
256230
};
257231

258232
Some((impl_def, first_assoc_item))

crates/ide-assists/src/utils.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -668,30 +668,31 @@ pub(crate) fn generate_impl_with_item(
668668
adt: &ast::Adt,
669669
body: Option<ast::AssocItemList>,
670670
) -> ast::Impl {
671-
generate_impl_inner(adt, None, true, body)
671+
generate_impl_inner(false, adt, None, true, body)
672672
}
673673

674674
pub(crate) fn generate_impl(adt: &ast::Adt) -> ast::Impl {
675-
generate_impl_inner(adt, None, true, None)
675+
generate_impl_inner(false, adt, None, true, None)
676676
}
677677

678678
/// Generates the corresponding `impl <trait> for Type {}` including type
679679
/// and lifetime parameters, with `<trait>` appended to `impl`'s generic parameters' bounds.
680680
///
681681
/// This is useful for traits like `PartialEq`, since `impl<T> PartialEq for U<T>` often requires `T: PartialEq`.
682-
pub(crate) fn generate_trait_impl(adt: &ast::Adt, trait_: ast::Type) -> ast::Impl {
683-
generate_impl_inner(adt, Some(trait_), true, None)
682+
pub(crate) fn generate_trait_impl(is_unsafe: bool, adt: &ast::Adt, trait_: ast::Type) -> ast::Impl {
683+
generate_impl_inner(is_unsafe, adt, Some(trait_), true, None)
684684
}
685685

686686
/// Generates the corresponding `impl <trait> for Type {}` including type
687687
/// and lifetime parameters, with `impl`'s generic parameters' bounds kept as-is.
688688
///
689689
/// This is useful for traits like `From<T>`, since `impl<T> From<T> for U<T>` doesn't require `T: From<T>`.
690690
pub(crate) fn generate_trait_impl_intransitive(adt: &ast::Adt, trait_: ast::Type) -> ast::Impl {
691-
generate_impl_inner(adt, Some(trait_), false, None)
691+
generate_impl_inner(false, adt, Some(trait_), false, None)
692692
}
693693

694694
fn generate_impl_inner(
695+
is_unsafe: bool,
695696
adt: &ast::Adt,
696697
trait_: Option<ast::Type>,
697698
trait_is_transitive: bool,
@@ -735,7 +736,7 @@ fn generate_impl_inner(
735736

736737
let impl_ = match trait_ {
737738
Some(trait_) => make::impl_trait(
738-
false,
739+
is_unsafe,
739740
None,
740741
None,
741742
generic_params,

crates/ide-assists/src/utils/gen_trait_fn_body.rs

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
//! This module contains functions to generate default trait impl function bodies where possible.
22
33
use hir::TraitRef;
4-
use syntax::{
5-
ast::{self, AstNode, BinaryOp, CmpOp, HasName, LogicOp, edit::AstNodeEdit, make},
6-
ted,
7-
};
4+
use syntax::ast::{self, AstNode, BinaryOp, CmpOp, HasName, LogicOp, edit::AstNodeEdit, make};
85

96
/// Generate custom trait bodies without default implementation where possible.
107
///
@@ -18,21 +15,33 @@ pub(crate) fn gen_trait_fn_body(
1815
trait_path: &ast::Path,
1916
adt: &ast::Adt,
2017
trait_ref: Option<TraitRef<'_>>,
21-
) -> Option<()> {
18+
) -> Option<ast::BlockExpr> {
19+
let _ = func.body()?;
2220
match trait_path.segment()?.name_ref()?.text().as_str() {
23-
"Clone" => gen_clone_impl(adt, func),
24-
"Debug" => gen_debug_impl(adt, func),
25-
"Default" => gen_default_impl(adt, func),
26-
"Hash" => gen_hash_impl(adt, func),
27-
"PartialEq" => gen_partial_eq(adt, func, trait_ref),
28-
"PartialOrd" => gen_partial_ord(adt, func, trait_ref),
21+
"Clone" => {
22+
stdx::always!(func.name().is_some_and(|name| name.text() == "clone"));
23+
gen_clone_impl(adt)
24+
}
25+
"Debug" => gen_debug_impl(adt),
26+
"Default" => gen_default_impl(adt),
27+
"Hash" => {
28+
stdx::always!(func.name().is_some_and(|name| name.text() == "hash"));
29+
gen_hash_impl(adt)
30+
}
31+
"PartialEq" => {
32+
stdx::always!(func.name().is_some_and(|name| name.text() == "eq"));
33+
gen_partial_eq(adt, trait_ref)
34+
}
35+
"PartialOrd" => {
36+
stdx::always!(func.name().is_some_and(|name| name.text() == "partial_cmp"));
37+
gen_partial_ord(adt, trait_ref)
38+
}
2939
_ => None,
3040
}
3141
}
3242

3343
/// Generate a `Clone` impl based on the fields and members of the target type.
34-
fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
35-
stdx::always!(func.name().is_some_and(|name| name.text() == "clone"));
44+
fn gen_clone_impl(adt: &ast::Adt) -> Option<ast::BlockExpr> {
3645
fn gen_clone_call(target: ast::Expr) -> ast::Expr {
3746
let method = make::name_ref("clone");
3847
make::expr_method_call(target, method, make::arg_list(None)).into()
@@ -139,12 +148,11 @@ fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
139148
}
140149
};
141150
let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
142-
ted::replace(func.body()?.syntax(), body.syntax());
143-
Some(())
151+
Some(body)
144152
}
145153

146154
/// Generate a `Debug` impl based on the fields and members of the target type.
147-
fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
155+
fn gen_debug_impl(adt: &ast::Adt) -> Option<ast::BlockExpr> {
148156
let annotated_name = adt.name()?;
149157
match adt {
150158
// `Debug` cannot be derived for unions, so no default impl can be provided.
@@ -248,8 +256,7 @@ fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
248256

249257
let body = make::block_expr(None, Some(match_expr.into()));
250258
let body = body.indent(ast::edit::IndentLevel(1));
251-
ted::replace(func.body()?.syntax(), body.syntax());
252-
Some(())
259+
Some(body)
253260
}
254261

255262
ast::Adt::Struct(strukt) => {
@@ -296,14 +303,13 @@ fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
296303
let method = make::name_ref("finish");
297304
let expr = make::expr_method_call(expr, method, make::arg_list(None)).into();
298305
let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
299-
ted::replace(func.body()?.syntax(), body.syntax());
300-
Some(())
306+
Some(body)
301307
}
302308
}
303309
}
304310

305311
/// Generate a `Debug` impl based on the fields and members of the target type.
306-
fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
312+
fn gen_default_impl(adt: &ast::Adt) -> Option<ast::BlockExpr> {
307313
fn gen_default_call() -> Option<ast::Expr> {
308314
let fn_name = make::ext::path_from_idents(["Default", "default"])?;
309315
Some(make::expr_call(make::expr_path(fn_name), make::arg_list(None)).into())
@@ -342,15 +348,13 @@ fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
342348
}
343349
};
344350
let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
345-
ted::replace(func.body()?.syntax(), body.syntax());
346-
Some(())
351+
Some(body)
347352
}
348353
}
349354
}
350355

351356
/// Generate a `Hash` impl based on the fields and members of the target type.
352-
fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
353-
stdx::always!(func.name().is_some_and(|name| name.text() == "hash"));
357+
fn gen_hash_impl(adt: &ast::Adt) -> Option<ast::BlockExpr> {
354358
fn gen_hash_call(target: ast::Expr) -> ast::Stmt {
355359
let method = make::name_ref("hash");
356360
let arg = make::expr_path(make::ext::ident_path("state"));
@@ -400,13 +404,11 @@ fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
400404
},
401405
};
402406

403-
ted::replace(func.body()?.syntax(), body.syntax());
404-
Some(())
407+
Some(body)
405408
}
406409

407410
/// Generate a `PartialEq` impl based on the fields and members of the target type.
408-
fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn, trait_ref: Option<TraitRef<'_>>) -> Option<()> {
409-
stdx::always!(func.name().is_some_and(|name| name.text() == "eq"));
411+
fn gen_partial_eq(adt: &ast::Adt, trait_ref: Option<TraitRef<'_>>) -> Option<ast::BlockExpr> {
410412
fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
411413
match expr {
412414
Some(expr) => Some(make::expr_bin_op(expr, BinaryOp::LogicOp(LogicOp::And), cmp)),
@@ -595,12 +597,10 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn, trait_ref: Option<TraitRef<'_>
595597
},
596598
};
597599

598-
ted::replace(func.body()?.syntax(), body.syntax());
599-
Some(())
600+
Some(body)
600601
}
601602

602-
fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn, trait_ref: Option<TraitRef<'_>>) -> Option<()> {
603-
stdx::always!(func.name().is_some_and(|name| name.text() == "partial_cmp"));
603+
fn gen_partial_ord(adt: &ast::Adt, trait_ref: Option<TraitRef<'_>>) -> Option<ast::BlockExpr> {
604604
fn gen_partial_eq_match(match_target: ast::Expr) -> Option<ast::Stmt> {
605605
let mut arms = vec![];
606606

@@ -686,8 +686,7 @@ fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn, trait_ref: Option<TraitRef<'_
686686
},
687687
};
688688

689-
ted::replace(func.body()?.syntax(), body.syntax());
690-
Some(())
689+
Some(body)
691690
}
692691

693692
fn make_discriminant() -> Option<ast::Expr> {

0 commit comments

Comments
 (0)