Skip to content

Commit 914bdd1

Browse files
committed
feat: extract_struct_from_function_signature
made it work for normal parameters in methods
1 parent 9c11270 commit 914bdd1

File tree

1 file changed

+108
-57
lines changed

1 file changed

+108
-57
lines changed

crates/ide-assists/src/handlers/extract_struct_from_function_signature.rs

Lines changed: 108 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ use ide_db::{
1414
use itertools::Itertools;
1515
use syntax::{
1616
AstNode, Edition, SyntaxElement, SyntaxKind, SyntaxNode, T,
17+
algo::find_node_at_range,
1718
ast::{
18-
self, CallExpr, HasArgList, HasAttrs, HasGenericArgs, HasGenericParams, HasName,
19-
HasVisibility, RecordExprField,
19+
self, HasArgList, HasAttrs, HasGenericParams, HasName, HasVisibility, MethodCallExpr,
20+
RecordExprField,
2021
edit::{AstNodeEdit, IndentLevel},
2122
make,
2223
},
@@ -111,10 +112,9 @@ pub(crate) fn extract_struct_from_function_signature(
111112
references,
112113
name.clone()
113114
);
114-
processed.into_iter().for_each(|(path, node, import)| {
115-
apply_references(ctx.config.insert_use, path, node, import, edition, used_params_range.clone(), &field_list,
115+
processed.into_iter().for_each(|(path, import)| {
116+
apply_references(ctx.config.insert_use, path ,import, edition, used_params_range.clone(), &field_list,
116117
name.clone(),
117-
// new_lifetime_count
118118
);
119119
});
120120
}
@@ -136,10 +136,9 @@ pub(crate) fn extract_struct_from_function_signature(
136136
references,
137137
name.clone()
138138
);
139-
processed.into_iter().for_each(|(path, node, import)| {
140-
apply_references(ctx.config.insert_use, path, node, import, edition, used_params_range.clone(), &field_list,
139+
processed.into_iter().for_each(|(path, import)| {
140+
apply_references(ctx.config.insert_use, path, import, edition, used_params_range.clone(), &field_list,
141141
name.clone(),
142-
// new_lifetime_count
143142
);
144143
});
145144
}
@@ -178,12 +177,13 @@ pub(crate) fn extract_struct_from_function_signature(
178177
let def = create_struct_def(name.clone(), &fn_ast_mut, &used_param_list, &field_list, generics);
179178
tracing::info!("extract_struct_from_function_signature: creating struct");
180179

181-
let indent = fn_ast_mut.indent_level();
180+
// if in impl block then put struct before the impl block
181+
let (indent, syntax) = param_list.self_param().and_then(|_|ctx.find_node_at_range::<ast::Impl>() ).map(|impl_|builder.make_mut(impl_)).map(|impl_|( impl_.indent_level(), impl_.syntax().clone())).unwrap_or((fn_ast.indent_level(), fn_ast_mut.syntax().clone()));
182182
let def = def.indent(indent);
183183

184184

185185
ted::insert_all(
186-
ted::Position::before(fn_ast_mut.syntax()),
186+
ted::Position::before(syntax),
187187
vec![
188188
def.syntax().clone().into(),
189189
make::tokens::whitespace(&format!("\n\n{indent}")).into(),
@@ -477,14 +477,15 @@ fn process_references(
477477
function_module_def: &ModuleDef,
478478
refs: Vec<FileReference>,
479479
name: ast::Name,
480-
) -> Vec<(ast::PathSegment, SyntaxNode, Option<(ImportScope, hir::ModPath)>)> {
480+
) -> Vec<(CallExpr, Option<(ImportScope, hir::ModPath)>)> {
481481
// we have to recollect here eagerly as we are about to edit the tree we need to calculate the changes
482482
// and corresponding nodes up front
483483
let name = make::name_ref(name.text_non_mutable());
484484
refs.into_iter()
485485
.flat_map(|reference| {
486-
let (segment, scope_node, module) = reference_to_node(&ctx.sema, reference)?;
486+
let (call, scope_node, module) = reference_to_node(&ctx.sema, reference)?;
487487
let scope_node = builder.make_syntax_mut(scope_node);
488+
let call = builder.make_mut(call);
488489
if !visited_modules.contains(&module) {
489490
let mod_path = module.find_use_path(
490491
ctx.sema.db,
@@ -497,81 +498,57 @@ fn process_references(
497498
mod_path.push_segment(hir::Name::new_root(name.text_non_mutable()).clone());
498499
let scope = ImportScope::find_insert_use_container(&scope_node, &ctx.sema)?;
499500
visited_modules.insert(module);
500-
return Some((segment, scope_node, Some((scope, mod_path))));
501+
return Some((call, Some((scope, mod_path))));
501502
}
502503
}
503-
Some((segment, scope_node, None))
504+
Some((call, None))
504505
})
505506
.collect()
506507
}
507508
fn reference_to_node(
508509
sema: &hir::Semantics<'_, RootDatabase>,
509510
reference: FileReference,
510-
) -> Option<(ast::PathSegment, SyntaxNode, hir::Module)> {
511-
// filter out the reference in macro (seems to be probalamtic with lifetimes/generics arguments)
512-
let segment =
513-
reference.name.as_name_ref()?.syntax().parent().and_then(ast::PathSegment::cast)?;
511+
) -> Option<(CallExpr, SyntaxNode, hir::Module)> {
512+
// find neareat method call/call to the reference because different amount of parents between
513+
// name and full call depending on if its method call or normal call
514+
let node =
515+
find_node_at_range::<CallExpr>(reference.name.as_name_ref()?.syntax(), reference.range)?;
514516

515517
// let segment_range = segment.syntax().text_range();
516518
// if segment_range != reference.range {
517519
// return None;
518520
// }
519521

520-
let parent = segment.parent_path().syntax().parent()?;
521-
let expr_or_pat = match_ast! {
522-
match parent {
523-
ast::PathExpr(_it) => parent.parent()?,
524-
ast::RecordExpr(_it) => parent,
525-
ast::TupleStructPat(_it) => parent,
526-
ast::RecordPat(_it) => parent,
527-
_ => return None,
528-
}
529-
};
530-
let module = sema.scope(&expr_or_pat)?.module();
522+
let module = sema.scope(&node.syntax())?.module();
531523

532-
Some((segment.clone_for_update(), expr_or_pat, module))
524+
Some((node.clone(), node.syntax().clone(), module))
533525
}
534526

535527
fn apply_references(
536528
insert_use_cfg: InsertUseConfig,
537-
segment: ast::PathSegment,
538-
node: SyntaxNode,
529+
call: CallExpr,
539530
import: Option<(ImportScope, hir::ModPath)>,
540531
edition: Edition,
541532
used_params_range: Range<usize>,
542533
field_list: &ast::RecordFieldList,
543534
name: ast::Name,
544-
// new_lifetime_count: usize,
545535
) -> Option<()> {
546536
if let Some((scope, path)) = import {
547537
insert_use(&scope, mod_path_to_ast(&path, edition), &insert_use_cfg);
548538
}
549-
// TODO: figure out lifetimes in referecnecs
550-
// becauuse we have to convert from segment being non turbofish, also only need
551-
// generics/lifetimes that are used in struct possibly not all the no ones for the original call
552-
// if no specified lifetimes/generics we just give empty one
553-
// if new_lifetime_count > 0 {
554-
// (0..new_lifetime_count).for_each(|_| {
555-
// segment
556-
// .get_or_create_generic_arg_list()
557-
// .add_generic_arg(make::lifetime_arg(make::lifetime("'_")).clone_for_update().into())
558-
// });
559-
// }
560539

561540
// current idea: the lifetimes can be inferred from the call
562-
if let Some(generics) = segment.generic_arg_list() {
563-
ted::remove(generics.syntax());
564-
}
565-
ted::replace(segment.name_ref()?.syntax(), name.clone_for_update().syntax());
566-
// deep clone to prevent cycle
567-
let path = make::path_from_segments(std::iter::once(segment.clone_subtree()), false);
568-
// TODO: do I need to to method call to
569-
let call = CallExpr::cast(node)?;
541+
let path = make::path_from_text(name.text_non_mutable());
570542
let fields = make::record_expr_field_list(
571543
call.arg_list()?
572544
.args()
573-
.skip(used_params_range.start - 1)
574-
.take(used_params_range.end - used_params_range.start)
545+
.skip(match call {
546+
// for some reason the indices for parameters of method go in increments of 3s (but
547+
// start at 4 to accommodate the self parameter)
548+
CallExpr::Method(_) => used_params_range.start / 3 - 1,
549+
CallExpr::Normal(_) => used_params_range.start - 1,
550+
})
551+
// the zip implicitly makes that it will only take the amount of parameters required
575552
.zip(field_list.fields())
576553
.map(|e| {
577554
e.1.name().map(|name| -> RecordExprField {
@@ -582,11 +559,57 @@ fn apply_references(
582559
);
583560
let record_expr = make::record_expr(path, fields).clone_for_update();
584561

585-
call.arg_list()?
586-
.syntax()
587-
.splice_children(used_params_range, vec![record_expr.syntax().syntax_element()]);
562+
// range for method definition used parames seems to be off
563+
call.arg_list()?.syntax().splice_children(
564+
match call {
565+
// but at call sites methods don't include the self argument as part of the "arg list" so
566+
// we have to decduct one parameters (for some reason length 3) from range
567+
CallExpr::Method(_) => (used_params_range.start - 3)..(used_params_range.end - 3),
568+
CallExpr::Normal(_) => used_params_range,
569+
},
570+
vec![record_expr.syntax().syntax_element()],
571+
);
588572
Some(())
589573
}
574+
575+
#[derive(Debug, Clone)]
576+
enum CallExpr {
577+
Normal(ast::CallExpr),
578+
Method(ast::MethodCallExpr),
579+
}
580+
impl AstNode for CallExpr {
581+
fn can_cast(kind: SyntaxKind) -> bool
582+
where
583+
Self: Sized,
584+
{
585+
kind == ast::MethodCallExpr::kind() && kind == ast::CallExpr::kind()
586+
}
587+
588+
fn cast(syntax: SyntaxNode) -> Option<Self>
589+
where
590+
Self: Sized,
591+
{
592+
ast::CallExpr::cast(syntax.clone())
593+
.map(CallExpr::Normal)
594+
.or(MethodCallExpr::cast(syntax).map(CallExpr::Method))
595+
}
596+
597+
fn syntax(&self) -> &SyntaxNode {
598+
match self {
599+
CallExpr::Normal(call_expr) => call_expr.syntax(),
600+
CallExpr::Method(method_call_expr) => method_call_expr.syntax(),
601+
}
602+
}
603+
}
604+
impl HasArgList for CallExpr {
605+
fn arg_list(&self) -> Option<ast::ArgList> {
606+
match self {
607+
CallExpr::Normal(call_expr) => call_expr.arg_list(),
608+
CallExpr::Method(method_call_expr) => method_call_expr.arg_list(),
609+
}
610+
}
611+
}
612+
590613
#[cfg(test)]
591614
mod tests {
592615
use super::*;
@@ -802,4 +825,32 @@ fn foo<'a>(FooStruct { bar, .. }: FooStruct<'a, '_>, baz: i32) {
802825
r"fn foo($0i: impl ToString) { }",
803826
);
804827
}
828+
#[test]
829+
fn test_extract_function_signature_in_method() {
830+
check_assist(
831+
extract_struct_from_function_signature,
832+
r#"
833+
struct Foo
834+
impl Foo {
835+
fn foo(&self, $0j: i32, i: i32$0, z:i32) { }
836+
}
837+
838+
fn bar() {
839+
Foo.foo(1, 2, 3)
840+
}
841+
"#,
842+
r#"
843+
struct Foo
844+
struct FooStruct{ j: i32, i: i32 }
845+
846+
impl Foo {
847+
fn foo(&self, FooStruct { j, i, .. }: FooStruct, z:i32) { }
848+
}
849+
850+
fn bar() {
851+
Foo.foo(FooStruct { j: 1, i: 2 }, 3)
852+
}
853+
"#,
854+
);
855+
}
805856
}

0 commit comments

Comments
 (0)