Skip to content

Commit c6ddb90

Browse files
committed
Add references to fn args during completion
1 parent ac4b134 commit c6ddb90

File tree

5 files changed

+151
-10
lines changed

5 files changed

+151
-10
lines changed

crates/hir/src/code_model.rs

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -708,12 +708,24 @@ impl Function {
708708
Some(SelfParam { func: self.id })
709709
}
710710

711-
pub fn params(self, db: &dyn HirDatabase) -> Vec<Param> {
711+
pub fn params(self, db: &dyn HirDatabase) -> Vec<Type> {
712+
let resolver = self.id.resolver(db.upcast());
713+
let ctx = hir_ty::TyLoweringContext::new(db, &resolver);
714+
let environment = TraitEnvironment::lower(db, &resolver);
712715
db.function_data(self.id)
713716
.params
714717
.iter()
715718
.skip(if self.self_param(db).is_some() { 1 } else { 0 })
716-
.map(|_| Param { _ty: () })
719+
.map(|type_ref| {
720+
let ty = Type {
721+
krate: self.id.lookup(db.upcast()).container.module(db.upcast()).krate,
722+
ty: InEnvironment {
723+
value: Ty::from_hir_ext(&ctx, type_ref).0,
724+
environment: environment.clone(),
725+
},
726+
};
727+
ty
728+
})
717729
.collect()
718730
}
719731

@@ -747,10 +759,6 @@ pub struct SelfParam {
747759
func: FunctionId,
748760
}
749761

750-
pub struct Param {
751-
_ty: (),
752-
}
753-
754762
impl SelfParam {
755763
pub fn access(self, db: &dyn HirDatabase) -> Access {
756764
let func_data = db.function_data(self.func);
@@ -1100,6 +1108,12 @@ impl Local {
11001108
ast.map_left(|it| it.cast().unwrap().to_node(&root)).map_right(|it| it.to_node(&root))
11011109
})
11021110
}
1111+
1112+
pub fn can_unify(self, other: Type, db: &dyn HirDatabase) -> bool {
1113+
let def = DefWithBodyId::from(self.parent);
1114+
let infer = db.infer(def);
1115+
db.can_unify(def, infer[self.pat_id].clone(), other.ty.value)
1116+
}
11031117
}
11041118

11051119
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
@@ -1276,6 +1290,14 @@ impl Type {
12761290
)
12771291
}
12781292

1293+
pub fn remove_ref(&self) -> Option<Type> {
1294+
if let Ty::Apply(ApplicationTy { ctor: TypeCtor::Ref(_), .. }) = self.ty.value {
1295+
self.ty.value.substs().map(|substs| self.derived(substs[0].clone()))
1296+
} else {
1297+
None
1298+
}
1299+
}
1300+
12791301
pub fn is_unknown(&self) -> bool {
12801302
matches!(self.ty.value, Ty::Unknown)
12811303
}

crates/hir_ty/src/db.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
2626
#[salsa::invoke(crate::infer::infer_query)]
2727
fn infer_query(&self, def: DefWithBodyId) -> Arc<InferenceResult>;
2828

29+
#[salsa::invoke(crate::infer::can_unify)]
30+
fn can_unify(&self, def: DefWithBodyId, ty1: Ty, ty2: Ty) -> bool;
31+
2932
#[salsa::invoke(crate::lower::ty_query)]
3033
#[salsa::cycle(crate::lower::ty_recover)]
3134
fn ty(&self, def: TyDefId) -> Binders<Ty>;

crates/hir_ty/src/infer.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ macro_rules! ty_app {
5555
};
5656
}
5757

58-
mod unify;
58+
pub mod unify;
5959
mod path;
6060
mod expr;
6161
mod pat;
@@ -78,6 +78,19 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<Infer
7878
Arc::new(ctx.resolve_all())
7979
}
8080

81+
pub(crate) fn can_unify(db: &dyn HirDatabase, def: DefWithBodyId, ty1: Ty, ty2: Ty) -> bool {
82+
let resolver = def.resolver(db.upcast());
83+
let mut ctx = InferenceContext::new(db, def, resolver);
84+
85+
let ty1 = ctx.canonicalizer().canonicalize_ty(ty1).value;
86+
let ty2 = ctx.canonicalizer().canonicalize_ty(ty2).value;
87+
let mut kinds = Vec::from(ty1.kinds.to_vec());
88+
kinds.extend_from_slice(ty2.kinds.as_ref());
89+
let tys = crate::Canonical::new((ty1.value, ty2.value), kinds);
90+
91+
unify(&tys).is_some()
92+
}
93+
8194
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
8295
enum ExprOrPatId {
8396
ExprId(ExprId),

crates/hir_ty/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use crate::{
4343
};
4444

4545
pub use autoderef::autoderef;
46-
pub use infer::{InferTy, InferenceResult};
46+
pub use infer::{unify, InferTy, InferenceResult};
4747
pub use lower::CallableDefId;
4848
pub use lower::{
4949
associated_type_shorthand_candidates, callable_item_sig, ImplTraitLoweringMode, TyDefId,

crates/ide/src/completion/presentation.rs

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,22 @@ impl Completions {
191191
func: hir::Function,
192192
local_name: Option<String>,
193193
) {
194+
fn add_arg(arg: &str, ty: &Type, ctx: &CompletionContext) -> String {
195+
let mut prefix = "";
196+
if let Some(derefed_ty) = ty.remove_ref() {
197+
ctx.scope.process_all_names(&mut |name, scope| {
198+
if prefix != "" {
199+
return;
200+
}
201+
if let ScopeDef::Local(local) = scope {
202+
if name.to_string() == arg && local.can_unify(derefed_ty.clone(), ctx.db) {
203+
prefix = if ty.is_mutable_reference() { "&mut " } else { "&" };
204+
}
205+
}
206+
});
207+
}
208+
prefix.to_string() + arg
209+
};
194210
let name = local_name.unwrap_or_else(|| func.name(ctx.db).to_string());
195211
let ast_node = func.source(ctx.db).value;
196212

@@ -205,12 +221,20 @@ impl Completions {
205221
.set_deprecated(is_deprecated(func, ctx.db))
206222
.detail(function_declaration(&ast_node));
207223

224+
let params_ty = func.params(ctx.db);
208225
let params = ast_node
209226
.param_list()
210227
.into_iter()
211228
.flat_map(|it| it.params())
212-
.flat_map(|it| it.pat())
213-
.map(|pat| pat.to_string().trim_start_matches('_').into())
229+
.zip(params_ty)
230+
.flat_map(|(it, param_ty)| {
231+
if let Some(pat) = it.pat() {
232+
let name = pat.to_string();
233+
let arg = name.trim_start_matches('_');
234+
return Some(add_arg(arg, &param_ty, ctx));
235+
}
236+
None
237+
})
214238
.collect();
215239

216240
builder = builder.add_call_parens(ctx, name, Params::Named(params));
@@ -863,6 +887,85 @@ fn main() { foo(${1:foo}, ${2:bar}, ${3:ho_ge_})$0 }
863887
);
864888
}
865889

890+
#[test]
891+
fn insert_ref_when_matching_local_in_scope() {
892+
check_edit(
893+
"ref_arg",
894+
r#"
895+
struct Foo {}
896+
fn ref_arg(x: &Foo) {}
897+
fn main() {
898+
let x = Foo {};
899+
ref_ar<|>
900+
}
901+
"#,
902+
r#"
903+
struct Foo {}
904+
fn ref_arg(x: &Foo) {}
905+
fn main() {
906+
let x = Foo {};
907+
ref_arg(${1:&x})$0
908+
}
909+
"#,
910+
);
911+
}
912+
913+
#[test]
914+
fn insert_mut_ref_when_matching_local_in_scope() {
915+
check_edit(
916+
"ref_arg",
917+
r#"
918+
struct Foo {}
919+
fn ref_arg(x: &mut Foo) {}
920+
fn main() {
921+
let x = Foo {};
922+
ref_ar<|>
923+
}
924+
"#,
925+
r#"
926+
struct Foo {}
927+
fn ref_arg(x: &mut Foo) {}
928+
fn main() {
929+
let x = Foo {};
930+
ref_arg(${1:&mut x})$0
931+
}
932+
"#,
933+
);
934+
}
935+
936+
#[test]
937+
fn insert_ref_when_matching_local_in_scope_for_method() {
938+
check_edit(
939+
"apply_foo",
940+
r#"
941+
struct Foo {}
942+
struct Bar {}
943+
impl Bar {
944+
fn apply_foo(&self, x: &Foo) {}
945+
}
946+
947+
fn main() {
948+
let x = Foo {};
949+
let y = Bar {};
950+
y.<|>
951+
}
952+
"#,
953+
r#"
954+
struct Foo {}
955+
struct Bar {}
956+
impl Bar {
957+
fn apply_foo(&self, x: &Foo) {}
958+
}
959+
960+
fn main() {
961+
let x = Foo {};
962+
let y = Bar {};
963+
y.apply_foo(${1:&x})$0
964+
}
965+
"#,
966+
);
967+
}
968+
866969
#[test]
867970
fn inserts_parens_for_tuple_enums() {
868971
mark::check!(inserts_parens_for_tuple_enums);

0 commit comments

Comments
 (0)