@@ -2,7 +2,7 @@ use std::iter;
22
33use syntax:: {
44 ast:: { self , make, BlockExpr , Expr , LoopBodyOwner } ,
5- AstNode , SyntaxNode ,
5+ match_ast , AstNode , SyntaxNode ,
66} ;
77use test_utils:: mark;
88
@@ -21,8 +21,18 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
2121// ```
2222pub ( crate ) fn change_return_type_to_result ( acc : & mut Assists , ctx : & AssistContext ) -> Option < ( ) > {
2323 let ret_type = ctx. find_node_at_offset :: < ast:: RetType > ( ) ?;
24- // FIXME: extend to lambdas as well
25- let fn_def = ret_type. syntax ( ) . parent ( ) . and_then ( ast:: Fn :: cast) ?;
24+ let parent = ret_type. syntax ( ) . parent ( ) ?;
25+ let block_expr = match_ast ! {
26+ match parent {
27+ ast:: Fn ( func) => func. body( ) ?,
28+ ast:: ClosureExpr ( closure) => match closure. body( ) ? {
29+ Expr :: BlockExpr ( block) => block,
30+ // closures require a block when a return type is specified
31+ _ => return None ,
32+ } ,
33+ _ => return None ,
34+ }
35+ } ;
2636
2737 let type_ref = & ret_type. ty ( ) ?;
2838 let ret_type_str = type_ref. syntax ( ) . text ( ) . to_string ( ) ;
@@ -34,16 +44,14 @@ pub(crate) fn change_return_type_to_result(acc: &mut Assists, ctx: &AssistContex
3444 }
3545 }
3646
37- let block_expr = & fn_def. body ( ) ?;
38-
3947 acc. add (
4048 AssistId ( "change_return_type_to_result" , AssistKind :: RefactorRewrite ) ,
4149 "Wrap return type in Result" ,
4250 type_ref. syntax ( ) . text_range ( ) ,
4351 |builder| {
4452 let mut tail_return_expr_collector = TailReturnCollector :: new ( ) ;
45- tail_return_expr_collector. collect_jump_exprs ( block_expr, false ) ;
46- tail_return_expr_collector. collect_tail_exprs ( block_expr) ;
53+ tail_return_expr_collector. collect_jump_exprs ( & block_expr, false ) ;
54+ tail_return_expr_collector. collect_tail_exprs ( & block_expr) ;
4755
4856 for ret_expr_arg in tail_return_expr_collector. exprs_to_wrap {
4957 let ok_wrapped = make:: expr_call (
@@ -285,16 +293,20 @@ mod tests {
285293 }
286294
287295 #[ test]
288- fn change_return_type_to_result_simple_return_type ( ) {
296+ fn change_return_type_to_result_simple_closure ( ) {
289297 check_assist (
290298 change_return_type_to_result,
291- r#"fn foo() -> i32<|> {
292- let test = "test";
293- return 42i32;
299+ r#"fn foo() {
300+ || -> i32<|> {
301+ let test = "test";
302+ return 42i32;
303+ };
294304 }"# ,
295- r#"fn foo() -> Result<i32, ${0:_}> {
296- let test = "test";
297- return Ok(42i32);
305+ r#"fn foo() {
306+ || -> Result<i32, ${0:_}> {
307+ let test = "test";
308+ return Ok(42i32);
309+ };
298310 }"# ,
299311 ) ;
300312 }
@@ -310,6 +322,29 @@ mod tests {
310322 ) ;
311323 }
312324
325+ #[ test]
326+ fn change_return_type_to_result_simple_return_type_bad_cursor_closure ( ) {
327+ check_assist_not_applicable (
328+ change_return_type_to_result,
329+ r#"fn foo() {
330+ || -> i32 {
331+ let test = "test";<|>
332+ return 42i32;
333+ };
334+ }"# ,
335+ ) ;
336+ }
337+
338+ #[ test]
339+ fn change_return_type_to_result_closure_non_block ( ) {
340+ check_assist_not_applicable (
341+ change_return_type_to_result,
342+ r#"fn foo() {
343+ || -> i<|>32 3;
344+ }"# ,
345+ ) ;
346+ }
347+
313348 #[ test]
314349 fn change_return_type_to_result_simple_return_type_already_result_std ( ) {
315350 check_assist_not_applicable (
@@ -333,6 +368,19 @@ mod tests {
333368 ) ;
334369 }
335370
371+ #[ test]
372+ fn change_return_type_to_result_simple_return_type_already_result_closure ( ) {
373+ check_assist_not_applicable (
374+ change_return_type_to_result,
375+ r#"fn foo() {
376+ || -> Result<i32<|>, String> {
377+ let test = "test";
378+ return 42i32;
379+ };
380+ }"# ,
381+ ) ;
382+ }
383+
336384 #[ test]
337385 fn change_return_type_to_result_simple_with_cursor ( ) {
338386 check_assist (
@@ -363,6 +411,25 @@ mod tests {
363411 ) ;
364412 }
365413
414+ #[ test]
415+ fn change_return_type_to_result_simple_with_tail_closure ( ) {
416+ check_assist (
417+ change_return_type_to_result,
418+ r#"fn foo() {
419+ || -><|> i32 {
420+ let test = "test";
421+ 42i32
422+ };
423+ }"# ,
424+ r#"fn foo() {
425+ || -> Result<i32, ${0:_}> {
426+ let test = "test";
427+ Ok(42i32)
428+ };
429+ }"# ,
430+ ) ;
431+ }
432+
366433 #[ test]
367434 fn change_return_type_to_result_simple_with_tail_only ( ) {
368435 check_assist (
@@ -375,6 +442,7 @@ mod tests {
375442 }"# ,
376443 ) ;
377444 }
445+
378446 #[ test]
379447 fn change_return_type_to_result_simple_with_tail_block_like ( ) {
380448 check_assist (
@@ -396,6 +464,31 @@ mod tests {
396464 ) ;
397465 }
398466
467+ #[ test]
468+ fn change_return_type_to_result_simple_without_block_closure ( ) {
469+ check_assist (
470+ change_return_type_to_result,
471+ r#"fn foo() {
472+ || -> i32<|> {
473+ if true {
474+ 42i32
475+ } else {
476+ 24i32
477+ }
478+ };
479+ }"# ,
480+ r#"fn foo() {
481+ || -> Result<i32, ${0:_}> {
482+ if true {
483+ Ok(42i32)
484+ } else {
485+ Ok(24i32)
486+ }
487+ };
488+ }"# ,
489+ ) ;
490+ }
491+
399492 #[ test]
400493 fn change_return_type_to_result_simple_with_nested_if ( ) {
401494 check_assist (
0 commit comments