Skip to content

Commit cdddcae

Browse files
bors[bot]Veykril
andauthored
Merge #6477
6477: Add infer_function_return_type assist r=matklad a=Veykril This adds an assist to insert a functions return type if it hasn't been specified yet by inferring it from the functions tail expression. This assist only becomes active if the cursor is on the tail expression. See #6303 (comment) Co-authored-by: Lukas Wirth <[email protected]>
2 parents fe13a4a + 186431e commit cdddcae

File tree

3 files changed

+352
-0
lines changed

3 files changed

+352
-0
lines changed
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
use hir::HirDisplay;
2+
use syntax::{ast, AstNode, TextRange, TextSize};
3+
use test_utils::mark;
4+
5+
use crate::{AssistContext, AssistId, AssistKind, Assists};
6+
7+
// Assist: infer_function_return_type
8+
//
9+
// Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return
10+
// type specified. This assists is useable in a functions or closures tail expression or return type position.
11+
//
12+
// ```
13+
// fn foo() { 4<|>2i32 }
14+
// ```
15+
// ->
16+
// ```
17+
// fn foo() -> i32 { 42i32 }
18+
// ```
19+
pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
20+
let (tail_expr, builder_edit_pos, wrap_expr) = extract_tail(ctx)?;
21+
let module = ctx.sema.scope(tail_expr.syntax()).module()?;
22+
let ty = ctx.sema.type_of_expr(&tail_expr)?;
23+
if ty.is_unit() {
24+
return None;
25+
}
26+
let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
27+
28+
acc.add(
29+
AssistId("infer_function_return_type", AssistKind::RefactorRewrite),
30+
"Add this function's return type",
31+
tail_expr.syntax().text_range(),
32+
|builder| {
33+
match builder_edit_pos {
34+
InsertOrReplace::Insert(insert_pos) => {
35+
builder.insert(insert_pos, &format!("-> {} ", ty))
36+
}
37+
InsertOrReplace::Replace(text_range) => {
38+
builder.replace(text_range, &format!("-> {}", ty))
39+
}
40+
}
41+
if wrap_expr {
42+
mark::hit!(wrap_closure_non_block_expr);
43+
// `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
44+
builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr));
45+
}
46+
},
47+
)
48+
}
49+
50+
enum InsertOrReplace {
51+
Insert(TextSize),
52+
Replace(TextRange),
53+
}
54+
55+
/// Check the potentially already specified return type and reject it or turn it into a builder command
56+
/// if allowed.
57+
fn ret_ty_to_action(ret_ty: Option<ast::RetType>, insert_pos: TextSize) -> Option<InsertOrReplace> {
58+
match ret_ty {
59+
Some(ret_ty) => match ret_ty.ty() {
60+
Some(ast::Type::InferType(_)) | None => {
61+
mark::hit!(existing_infer_ret_type);
62+
mark::hit!(existing_infer_ret_type_closure);
63+
Some(InsertOrReplace::Replace(ret_ty.syntax().text_range()))
64+
}
65+
_ => {
66+
mark::hit!(existing_ret_type);
67+
mark::hit!(existing_ret_type_closure);
68+
None
69+
}
70+
},
71+
None => Some(InsertOrReplace::Insert(insert_pos + TextSize::from(1))),
72+
}
73+
}
74+
75+
fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool)> {
76+
let (tail_expr, return_type_range, action, wrap_expr) =
77+
if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
78+
let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end();
79+
let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?;
80+
81+
let body = closure.body()?;
82+
let body_start = body.syntax().first_token()?.text_range().start();
83+
let (tail_expr, wrap_expr) = match body {
84+
ast::Expr::BlockExpr(block) => (block.expr()?, false),
85+
body => (body, true),
86+
};
87+
88+
let ret_range = TextRange::new(rpipe_pos, body_start);
89+
(tail_expr, ret_range, action, wrap_expr)
90+
} else {
91+
let func = ctx.find_node_at_offset::<ast::Fn>()?;
92+
let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end();
93+
let action = ret_ty_to_action(func.ret_type(), rparen_pos)?;
94+
95+
let body = func.body()?;
96+
let tail_expr = body.expr()?;
97+
98+
let ret_range_end = body.l_curly_token()?.text_range().start();
99+
let ret_range = TextRange::new(rparen_pos, ret_range_end);
100+
(tail_expr, ret_range, action, false)
101+
};
102+
let frange = ctx.frange.range;
103+
if return_type_range.contains_range(frange) {
104+
mark::hit!(cursor_in_ret_position);
105+
mark::hit!(cursor_in_ret_position_closure);
106+
} else if tail_expr.syntax().text_range().contains_range(frange) {
107+
mark::hit!(cursor_on_tail);
108+
mark::hit!(cursor_on_tail_closure);
109+
} else {
110+
return None;
111+
}
112+
Some((tail_expr, action, wrap_expr))
113+
}
114+
115+
#[cfg(test)]
116+
mod tests {
117+
use crate::tests::{check_assist, check_assist_not_applicable};
118+
119+
use super::*;
120+
121+
#[test]
122+
fn infer_return_type_specified_inferred() {
123+
mark::check!(existing_infer_ret_type);
124+
check_assist(
125+
infer_function_return_type,
126+
r#"fn foo() -> <|>_ {
127+
45
128+
}"#,
129+
r#"fn foo() -> i32 {
130+
45
131+
}"#,
132+
);
133+
}
134+
135+
#[test]
136+
fn infer_return_type_specified_inferred_closure() {
137+
mark::check!(existing_infer_ret_type_closure);
138+
check_assist(
139+
infer_function_return_type,
140+
r#"fn foo() {
141+
|| -> _ {<|>45};
142+
}"#,
143+
r#"fn foo() {
144+
|| -> i32 {45};
145+
}"#,
146+
);
147+
}
148+
149+
#[test]
150+
fn infer_return_type_cursor_at_return_type_pos() {
151+
mark::check!(cursor_in_ret_position);
152+
check_assist(
153+
infer_function_return_type,
154+
r#"fn foo() <|>{
155+
45
156+
}"#,
157+
r#"fn foo() -> i32 {
158+
45
159+
}"#,
160+
);
161+
}
162+
163+
#[test]
164+
fn infer_return_type_cursor_at_return_type_pos_closure() {
165+
mark::check!(cursor_in_ret_position_closure);
166+
check_assist(
167+
infer_function_return_type,
168+
r#"fn foo() {
169+
|| <|>45
170+
}"#,
171+
r#"fn foo() {
172+
|| -> i32 {45}
173+
}"#,
174+
);
175+
}
176+
177+
#[test]
178+
fn infer_return_type() {
179+
mark::check!(cursor_on_tail);
180+
check_assist(
181+
infer_function_return_type,
182+
r#"fn foo() {
183+
45<|>
184+
}"#,
185+
r#"fn foo() -> i32 {
186+
45
187+
}"#,
188+
);
189+
}
190+
191+
#[test]
192+
fn infer_return_type_nested() {
193+
check_assist(
194+
infer_function_return_type,
195+
r#"fn foo() {
196+
if true {
197+
3<|>
198+
} else {
199+
5
200+
}
201+
}"#,
202+
r#"fn foo() -> i32 {
203+
if true {
204+
3
205+
} else {
206+
5
207+
}
208+
}"#,
209+
);
210+
}
211+
212+
#[test]
213+
fn not_applicable_ret_type_specified() {
214+
mark::check!(existing_ret_type);
215+
check_assist_not_applicable(
216+
infer_function_return_type,
217+
r#"fn foo() -> i32 {
218+
( 45<|> + 32 ) * 123
219+
}"#,
220+
);
221+
}
222+
223+
#[test]
224+
fn not_applicable_non_tail_expr() {
225+
check_assist_not_applicable(
226+
infer_function_return_type,
227+
r#"fn foo() {
228+
let x = <|>3;
229+
( 45 + 32 ) * 123
230+
}"#,
231+
);
232+
}
233+
234+
#[test]
235+
fn not_applicable_unit_return_type() {
236+
check_assist_not_applicable(
237+
infer_function_return_type,
238+
r#"fn foo() {
239+
(<|>)
240+
}"#,
241+
);
242+
}
243+
244+
#[test]
245+
fn infer_return_type_closure_block() {
246+
mark::check!(cursor_on_tail_closure);
247+
check_assist(
248+
infer_function_return_type,
249+
r#"fn foo() {
250+
|x: i32| {
251+
x<|>
252+
};
253+
}"#,
254+
r#"fn foo() {
255+
|x: i32| -> i32 {
256+
x
257+
};
258+
}"#,
259+
);
260+
}
261+
262+
#[test]
263+
fn infer_return_type_closure() {
264+
check_assist(
265+
infer_function_return_type,
266+
r#"fn foo() {
267+
|x: i32| { x<|> };
268+
}"#,
269+
r#"fn foo() {
270+
|x: i32| -> i32 { x };
271+
}"#,
272+
);
273+
}
274+
275+
#[test]
276+
fn infer_return_type_closure_wrap() {
277+
mark::check!(wrap_closure_non_block_expr);
278+
check_assist(
279+
infer_function_return_type,
280+
r#"fn foo() {
281+
|x: i32| x<|>;
282+
}"#,
283+
r#"fn foo() {
284+
|x: i32| -> i32 {x};
285+
}"#,
286+
);
287+
}
288+
289+
#[test]
290+
fn infer_return_type_nested_closure() {
291+
check_assist(
292+
infer_function_return_type,
293+
r#"fn foo() {
294+
|| {
295+
if true {
296+
3<|>
297+
} else {
298+
5
299+
}
300+
}
301+
}"#,
302+
r#"fn foo() {
303+
|| -> i32 {
304+
if true {
305+
3
306+
} else {
307+
5
308+
}
309+
}
310+
}"#,
311+
);
312+
}
313+
314+
#[test]
315+
fn not_applicable_ret_type_specified_closure() {
316+
mark::check!(existing_ret_type_closure);
317+
check_assist_not_applicable(
318+
infer_function_return_type,
319+
r#"fn foo() {
320+
|| -> i32 { 3<|> }
321+
}"#,
322+
);
323+
}
324+
325+
#[test]
326+
fn not_applicable_non_tail_expr_closure() {
327+
check_assist_not_applicable(
328+
infer_function_return_type,
329+
r#"fn foo() {
330+
|| -> i32 {
331+
let x = 3<|>;
332+
6
333+
}
334+
}"#,
335+
);
336+
}
337+
}

crates/assists/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ mod handlers {
143143
mod generate_function;
144144
mod generate_impl;
145145
mod generate_new;
146+
mod infer_function_return_type;
146147
mod inline_local_variable;
147148
mod introduce_named_lifetime;
148149
mod invert_if;
@@ -190,6 +191,7 @@ mod handlers {
190191
generate_function::generate_function,
191192
generate_impl::generate_impl,
192193
generate_new::generate_new,
194+
infer_function_return_type::infer_function_return_type,
193195
inline_local_variable::inline_local_variable,
194196
introduce_named_lifetime::introduce_named_lifetime,
195197
invert_if::invert_if,

crates/assists/src/tests/generated.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,19 @@ impl<T: Clone> Ctx<T> {
505505
)
506506
}
507507

508+
#[test]
509+
fn doctest_infer_function_return_type() {
510+
check_doc_test(
511+
"infer_function_return_type",
512+
r#####"
513+
fn foo() { 4<|>2i32 }
514+
"#####,
515+
r#####"
516+
fn foo() -> i32 { 42i32 }
517+
"#####,
518+
)
519+
}
520+
508521
#[test]
509522
fn doctest_inline_local_variable() {
510523
check_doc_test(

0 commit comments

Comments
 (0)