Skip to content

Commit c1cd1da

Browse files
committed
Add let-chain support for convert_to_guarded_return
- And add early expression `None` in function `Option` return Example --- ```rust fn main() { if$0 let Ok(x) = Err(92) && x < 30 && let Some(y) = Some(8) { foo(x, y); } } ``` -> ```rust fn main() { let Ok(x) = Err(92) else { return }; if x >= 30 { return; } let Some(y) = Some(8) else { return }; foo(x, y); } ```
1 parent da9831c commit c1cd1da

File tree

1 file changed

+179
-38
lines changed

1 file changed

+179
-38
lines changed

src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs

Lines changed: 179 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
use std::iter::once;
22

3-
use ide_db::{
4-
syntax_helpers::node_ext::{is_pattern_cond, single_let},
5-
ty_filter::TryEnum,
6-
};
3+
use hir::Semantics;
4+
use ide_db::{RootDatabase, ty_filter::TryEnum};
75
use syntax::{
86
AstNode,
97
SyntaxKind::{FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
10-
T,
8+
SyntaxNode, T,
119
ast::{
1210
self,
1311
edit::{AstNodeEdit, IndentLevel},
@@ -73,13 +71,7 @@ fn if_expr_to_guarded_return(
7371
return None;
7472
}
7573

76-
// Check if there is an IfLet that we can handle.
77-
let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
78-
let let_ = single_let(cond)?;
79-
(Some(let_.pat()?), let_.expr()?)
80-
} else {
81-
(None, cond)
82-
};
74+
let let_chains = flat_let_chain(cond);
8375

8476
let then_block = if_expr.then_branch()?;
8577
let then_block = then_block.stmt_list()?;
@@ -106,11 +98,7 @@ fn if_expr_to_guarded_return(
10698

10799
let parent_container = parent_block.syntax().parent()?;
108100

109-
let early_expression: ast::Expr = match parent_container.kind() {
110-
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
111-
FN => make::expr_return(None),
112-
_ => return None,
113-
};
101+
let early_expression: ast::Expr = early_expression(parent_container, &ctx.sema)?;
114102

115103
then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{'])?;
116104

@@ -132,32 +120,42 @@ fn if_expr_to_guarded_return(
132120
target,
133121
|edit| {
134122
let if_indent_level = IndentLevel::from_node(if_expr.syntax());
135-
let replacement = match if_let_pat {
136-
None => {
137-
// If.
138-
let new_expr = {
139-
let then_branch =
140-
make::block_expr(once(make::expr_stmt(early_expression).into()), None);
141-
let cond = invert_boolean_expression_legacy(cond_expr);
142-
make::expr_if(cond, then_branch, None).indent(if_indent_level)
143-
};
144-
new_expr.syntax().clone()
145-
}
146-
Some(pat) => {
123+
let replacement = let_chains.into_iter().map(|expr| {
124+
if let ast::Expr::LetExpr(let_expr) = &expr
125+
&& let (Some(pat), Some(expr)) = (let_expr.pat(), let_expr.expr())
126+
{
147127
// If-let.
148128
let let_else_stmt = make::let_else_stmt(
149129
pat,
150130
None,
151-
cond_expr,
152-
ast::make::tail_only_block_expr(early_expression),
131+
expr,
132+
ast::make::tail_only_block_expr(early_expression.clone()),
153133
);
154134
let let_else_stmt = let_else_stmt.indent(if_indent_level);
155135
let_else_stmt.syntax().clone()
136+
} else {
137+
// If.
138+
let new_expr = {
139+
let then_branch = make::block_expr(
140+
once(make::expr_stmt(early_expression.clone()).into()),
141+
None,
142+
);
143+
let cond = invert_boolean_expression_legacy(expr);
144+
make::expr_if(cond, then_branch, None).indent(if_indent_level)
145+
};
146+
new_expr.syntax().clone()
156147
}
157-
};
148+
});
158149

150+
let newline = &format!("\n{if_indent_level}");
159151
let then_statements = replacement
160-
.children_with_tokens()
152+
.enumerate()
153+
.flat_map(|(i, node)| {
154+
(i != 0)
155+
.then(|| make::tokens::whitespace(newline).into())
156+
.into_iter()
157+
.chain(node.children_with_tokens())
158+
})
161159
.chain(
162160
then_block_items
163161
.syntax()
@@ -201,11 +199,7 @@ fn let_stmt_to_guarded_return(
201199
let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
202200
let parent_container = parent_block.syntax().parent()?;
203201

204-
match parent_container.kind() {
205-
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
206-
FN => make::expr_return(None),
207-
_ => return None,
208-
}
202+
early_expression(parent_container, &ctx.sema)?
209203
};
210204

211205
acc.add(
@@ -232,6 +226,44 @@ fn let_stmt_to_guarded_return(
232226
)
233227
}
234228

229+
fn early_expression(
230+
parent_container: SyntaxNode,
231+
sema: &Semantics<'_, RootDatabase>,
232+
) -> Option<ast::Expr> {
233+
if let Some(fn_) = ast::Fn::cast(parent_container.clone())
234+
&& let Some(fn_def) = sema.to_def(&fn_)
235+
&& let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db))
236+
{
237+
let none_expr = make::expr_path(make::ext::ident_path("None"));
238+
return Some(make::expr_return(Some(none_expr)));
239+
}
240+
Some(match parent_container.kind() {
241+
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
242+
FN => make::expr_return(None),
243+
_ => return None,
244+
})
245+
}
246+
247+
fn flat_let_chain(mut expr: ast::Expr) -> Vec<ast::Expr> {
248+
let mut chains = vec![];
249+
250+
while let ast::Expr::BinExpr(bin_expr) = &expr
251+
&& bin_expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And))
252+
&& let (Some(lhs), Some(rhs)) = (bin_expr.lhs(), bin_expr.rhs())
253+
{
254+
if let Some(last) = chains.pop_if(|last| !matches!(last, ast::Expr::LetExpr(_))) {
255+
chains.push(make::expr_bin_op(rhs, ast::BinaryOp::LogicOp(ast::LogicOp::And), last));
256+
} else {
257+
chains.push(rhs);
258+
}
259+
expr = lhs;
260+
}
261+
262+
chains.push(expr);
263+
chains.reverse();
264+
chains
265+
}
266+
235267
#[cfg(test)]
236268
mod tests {
237269
use crate::tests::{check_assist, check_assist_not_applicable};
@@ -268,6 +300,37 @@ fn main() {
268300
);
269301
}
270302

303+
#[test]
304+
fn convert_inside_fn_return_option() {
305+
check_assist(
306+
convert_to_guarded_return,
307+
r#"
308+
//- minicore: option
309+
fn ret_option() -> Option<()> {
310+
bar();
311+
if$0 true {
312+
foo();
313+
314+
// comment
315+
bar();
316+
}
317+
}
318+
"#,
319+
r#"
320+
fn ret_option() -> Option<()> {
321+
bar();
322+
if false {
323+
return None;
324+
}
325+
foo();
326+
327+
// comment
328+
bar();
329+
}
330+
"#,
331+
);
332+
}
333+
271334
#[test]
272335
fn convert_let_inside_fn() {
273336
check_assist(
@@ -316,6 +379,58 @@ fn main() {
316379
);
317380
}
318381

382+
#[test]
383+
fn convert_if_let_chain_result() {
384+
check_assist(
385+
convert_to_guarded_return,
386+
r#"
387+
fn main() {
388+
if$0 let Ok(x) = Err(92)
389+
&& x < 30
390+
&& let Some(y) = Some(8)
391+
{
392+
foo(x, y);
393+
}
394+
}
395+
"#,
396+
r#"
397+
fn main() {
398+
let Ok(x) = Err(92) else { return };
399+
if x >= 30 {
400+
return;
401+
}
402+
let Some(y) = Some(8) else { return };
403+
foo(x, y);
404+
}
405+
"#,
406+
);
407+
408+
check_assist(
409+
convert_to_guarded_return,
410+
r#"
411+
fn main() {
412+
if$0 let Ok(x) = Err(92)
413+
&& x < 30
414+
&& y < 20
415+
&& let Some(y) = Some(8)
416+
{
417+
foo(x, y);
418+
}
419+
}
420+
"#,
421+
r#"
422+
fn main() {
423+
let Ok(x) = Err(92) else { return };
424+
if !(x < 30 && y < 20) {
425+
return;
426+
}
427+
let Some(y) = Some(8) else { return };
428+
foo(x, y);
429+
}
430+
"#,
431+
);
432+
}
433+
319434
#[test]
320435
fn convert_let_ok_inside_fn() {
321436
check_assist(
@@ -560,6 +675,32 @@ fn main() {
560675
);
561676
}
562677

678+
#[test]
679+
fn convert_let_stmt_inside_fn_return_option() {
680+
check_assist(
681+
convert_to_guarded_return,
682+
r#"
683+
//- minicore: option
684+
fn foo() -> Option<i32> {
685+
None
686+
}
687+
688+
fn ret_option() -> Option<i32> {
689+
let x$0 = foo();
690+
}
691+
"#,
692+
r#"
693+
fn foo() -> Option<i32> {
694+
None
695+
}
696+
697+
fn ret_option() -> Option<i32> {
698+
let Some(x) = foo() else { return None };
699+
}
700+
"#,
701+
);
702+
}
703+
563704
#[test]
564705
fn convert_let_stmt_inside_loop() {
565706
check_assist(

0 commit comments

Comments
 (0)