1+ use either:: Either ;
12use ide_db:: {
23 famous_defs:: FamousDefs ,
34 syntax_helpers:: node_ext:: { for_each_tail_expr, walk_expr} ,
45} ;
5- use itertools:: Itertools ;
66use syntax:: {
7- ast:: { self , Expr , HasGenericArgs } ,
8- match_ast, AstNode , NodeOrToken , SyntaxKind , TextRange ,
7+ ast:: { self , syntax_factory :: SyntaxFactory , HasArgList , HasGenericArgs } ,
8+ match_ast, AstNode , NodeOrToken , SyntaxKind ,
99} ;
1010
1111use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -39,11 +39,11 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
3939pub ( crate ) fn unwrap_return_type ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
4040 let ret_type = ctx. find_node_at_offset :: < ast:: RetType > ( ) ?;
4141 let parent = ret_type. syntax ( ) . parent ( ) ?;
42- let body = match_ast ! {
42+ let body_expr = match_ast ! {
4343 match parent {
44- ast:: Fn ( func) => func. body( ) ?,
44+ ast:: Fn ( func) => func. body( ) ?. into ( ) ,
4545 ast:: ClosureExpr ( closure) => match closure. body( ) ? {
46- Expr :: BlockExpr ( block) => block,
46+ ast :: Expr :: BlockExpr ( block) => block. into ( ) ,
4747 // closures require a block when a return type is specified
4848 _ => return None ,
4949 } ,
@@ -65,72 +65,110 @@ pub(crate) fn unwrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) ->
6565 let happy_type = extract_wrapped_type ( type_ref) ?;
6666
6767 acc. add ( kind. assist_id ( ) , kind. label ( ) , type_ref. syntax ( ) . text_range ( ) , |builder| {
68- let body = ast:: Expr :: BlockExpr ( body) ;
68+ let mut editor = builder. make_editor ( & parent) ;
69+ let make = SyntaxFactory :: new ( ) ;
6970
7071 let mut exprs_to_unwrap = Vec :: new ( ) ;
7172 let tail_cb = & mut |e : & _ | tail_cb_impl ( & mut exprs_to_unwrap, e) ;
72- walk_expr ( & body , & mut |expr| {
73- if let Expr :: ReturnExpr ( ret_expr) = expr {
73+ walk_expr ( & body_expr , & mut |expr| {
74+ if let ast :: Expr :: ReturnExpr ( ret_expr) = expr {
7475 if let Some ( ret_expr_arg) = & ret_expr. expr ( ) {
7576 for_each_tail_expr ( ret_expr_arg, tail_cb) ;
7677 }
7778 }
7879 } ) ;
79- for_each_tail_expr ( & body , tail_cb) ;
80+ for_each_tail_expr ( & body_expr , tail_cb) ;
8081
8182 let is_unit_type = is_unit_type ( & happy_type) ;
8283 if is_unit_type {
83- let mut text_range = ret_type. syntax ( ) . text_range ( ) ;
84-
8584 if let Some ( NodeOrToken :: Token ( token) ) = ret_type. syntax ( ) . next_sibling_or_token ( ) {
8685 if token. kind ( ) == SyntaxKind :: WHITESPACE {
87- text_range = TextRange :: new ( text_range . start ( ) , token. text_range ( ) . end ( ) ) ;
86+ editor . delete ( token) ;
8887 }
8988 }
9089
91- builder . delete ( text_range ) ;
90+ editor . delete ( ret_type . syntax ( ) ) ;
9291 } else {
93- builder . replace ( type_ref. syntax ( ) . text_range ( ) , happy_type. syntax ( ) . text ( ) ) ;
92+ editor . replace ( type_ref. syntax ( ) , happy_type. syntax ( ) ) ;
9493 }
9594
96- for ret_expr_arg in exprs_to_unwrap {
97- let ret_expr_str = ret_expr_arg. to_string ( ) ;
98-
99- let needs_replacing = match kind {
100- UnwrapperKind :: Option => ret_expr_str. starts_with ( "Some(" ) ,
101- UnwrapperKind :: Result => {
102- ret_expr_str. starts_with ( "Ok(" ) || ret_expr_str. starts_with ( "Err(" )
103- }
104- } ;
95+ let mut final_placeholder = None ;
96+ for tail_expr in exprs_to_unwrap {
97+ match & tail_expr {
98+ ast:: Expr :: CallExpr ( call_expr) => {
99+ let ast:: Expr :: PathExpr ( path_expr) = call_expr. expr ( ) . unwrap ( ) else {
100+ continue ;
101+ } ;
102+
103+ let path_str = path_expr. path ( ) . unwrap ( ) . to_string ( ) ;
104+ let needs_replacing = match kind {
105+ UnwrapperKind :: Option => path_str == "Some" ,
106+ UnwrapperKind :: Result => path_str == "Ok" || path_str == "Err" ,
107+ } ;
108+
109+ if !needs_replacing {
110+ continue ;
111+ }
105112
106- if needs_replacing {
107- let arg_list = ret_expr_arg. syntax ( ) . children ( ) . find_map ( ast:: ArgList :: cast) ;
108- if let Some ( arg_list) = arg_list {
113+ let arg_list = call_expr. arg_list ( ) . unwrap ( ) ;
109114 if is_unit_type {
110- match ret_expr_arg . syntax ( ) . prev_sibling_or_token ( ) {
111- // Useful to delete the entire line without leaving trailing whitespaces
112- Some ( whitespace ) => {
113- let new_range = TextRange :: new (
114- whitespace . text_range ( ) . start ( ) ,
115- ret_expr_arg . syntax ( ) . text_range ( ) . end ( ) ,
116- ) ;
117- builder . delete ( new_range ) ;
115+ let tail_parent = tail_expr
116+ . syntax ( )
117+ . parent ( )
118+ . and_then ( Either :: < ast :: ReturnExpr , ast :: StmtList > :: cast )
119+ . unwrap ( ) ;
120+ match tail_parent {
121+ Either :: Left ( ret_expr ) => {
122+ editor . replace ( ret_expr . syntax ( ) , make . expr_return ( None ) . syntax ( ) )
118123 }
119- None => {
120- builder. delete ( ret_expr_arg. syntax ( ) . text_range ( ) ) ;
124+ Either :: Right ( stmt_list) => {
125+ let new_block = if stmt_list. statements ( ) . next ( ) . is_none ( ) {
126+ make. expr_empty_block ( )
127+ } else {
128+ make. block_expr ( stmt_list. statements ( ) , None )
129+ } ;
130+ editor. replace (
131+ stmt_list. syntax ( ) ,
132+ new_block. stmt_list ( ) . unwrap ( ) . syntax ( ) ,
133+ ) ;
121134 }
122135 }
123- } else {
124- builder. replace (
125- ret_expr_arg. syntax ( ) . text_range ( ) ,
126- arg_list. args ( ) . join ( ", " ) ,
136+ } else if let Some ( first_arg) = arg_list. args ( ) . next ( ) {
137+ editor. replace ( tail_expr. syntax ( ) , first_arg. syntax ( ) ) ;
138+ }
139+ }
140+ ast:: Expr :: PathExpr ( path_expr) => {
141+ let UnwrapperKind :: Option = kind else {
142+ continue ;
143+ } ;
144+
145+ if path_expr. path ( ) . unwrap ( ) . to_string ( ) != "None" {
146+ continue ;
147+ }
148+
149+ let new_tail_expr = make. expr_unit ( ) ;
150+ editor. replace ( path_expr. syntax ( ) , new_tail_expr. syntax ( ) ) ;
151+ if let Some ( cap) = ctx. config . snippet_cap {
152+ editor. add_annotation (
153+ new_tail_expr. syntax ( ) ,
154+ builder. make_placeholder_snippet ( cap) ,
127155 ) ;
156+
157+ final_placeholder = Some ( new_tail_expr) ;
128158 }
129159 }
130- } else if matches ! ( kind, UnwrapperKind :: Option if ret_expr_str == "None" ) {
131- builder. replace ( ret_expr_arg. syntax ( ) . text_range ( ) , "()" ) ;
160+ _ => ( ) ,
132161 }
133162 }
163+
164+ if let Some ( cap) = ctx. config . snippet_cap {
165+ if let Some ( final_placeholder) = final_placeholder {
166+ editor. add_annotation ( final_placeholder. syntax ( ) , builder. make_tabstop_after ( cap) ) ;
167+ }
168+ }
169+
170+ editor. add_mappings ( make. finish_with_mappings ( ) ) ;
171+ builder. add_file_edits ( ctx. file_id ( ) , editor) ;
134172 } )
135173}
136174
@@ -168,12 +206,12 @@ impl UnwrapperKind {
168206
169207fn tail_cb_impl ( acc : & mut Vec < ast:: Expr > , e : & ast:: Expr ) {
170208 match e {
171- Expr :: BreakExpr ( break_expr) => {
209+ ast :: Expr :: BreakExpr ( break_expr) => {
172210 if let Some ( break_expr_arg) = break_expr. expr ( ) {
173211 for_each_tail_expr ( & break_expr_arg, & mut |e| tail_cb_impl ( acc, e) )
174212 }
175213 }
176- Expr :: ReturnExpr ( _) => {
214+ ast :: Expr :: ReturnExpr ( _) => {
177215 // all return expressions have already been handled by the walk loop
178216 }
179217 e => acc. push ( e. clone ( ) ) ,
@@ -238,8 +276,7 @@ fn foo() -> Option<()$0> {
238276}
239277"# ,
240278 r#"
241- fn foo() {
242- }
279+ fn foo() {}
243280"# ,
244281 "Unwrap Option return type" ,
245282 ) ;
@@ -254,8 +291,7 @@ fn foo() -> Option<()$0>{
254291}
255292"# ,
256293 r#"
257- fn foo() {
258- }
294+ fn foo() {}
259295"# ,
260296 "Unwrap Option return type" ,
261297 ) ;
@@ -280,7 +316,42 @@ fn foo() -> i32 {
280316 if true {
281317 42
282318 } else {
283- ()
319+ ${1:()}$0
320+ }
321+ }
322+ "# ,
323+ "Unwrap Option return type" ,
324+ ) ;
325+ }
326+
327+ #[ test]
328+ fn unwrap_option_return_type_multi_none ( ) {
329+ check_assist_by_label (
330+ unwrap_return_type,
331+ r#"
332+ //- minicore: option
333+ fn foo() -> Option<i3$02> {
334+ if false {
335+ return None;
336+ }
337+
338+ if true {
339+ Some(42)
340+ } else {
341+ None
342+ }
343+ }
344+ "# ,
345+ r#"
346+ fn foo() -> i32 {
347+ if false {
348+ return ${1:()};
349+ }
350+
351+ if true {
352+ 42
353+ } else {
354+ ${2:()}$0
284355 }
285356}
286357"# ,
@@ -1262,8 +1333,7 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
12621333}
12631334"# ,
12641335 r#"
1265- fn foo() {
1266- }
1336+ fn foo() {}
12671337"# ,
12681338 "Unwrap Result return type" ,
12691339 ) ;
@@ -1278,8 +1348,7 @@ fn foo() -> Result<(), Box<dyn Error$0>>{
12781348}
12791349"# ,
12801350 r#"
1281- fn foo() {
1282- }
1351+ fn foo() {}
12831352"# ,
12841353 "Unwrap Result return type" ,
12851354 ) ;
0 commit comments