1+ use either:: Either ;
12use ide_db:: {
23 famous_defs:: FamousDefs ,
34 syntax_helpers:: node_ext:: { for_each_tail_expr, walk_expr} ,
45} ;
5- use itertools:: Itertools ;
66use syntax:: {
7- ast:: { self , Expr , HasGenericArgs } ,
8- match_ast, AstNode , NodeOrToken , SyntaxKind , TextRange ,
7+ ast:: { self , syntax_factory :: SyntaxFactory , HasArgList , HasGenericArgs } ,
8+ match_ast, AstNode , NodeOrToken , SyntaxKind ,
99} ;
1010
1111use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -39,11 +39,11 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
3939pub ( crate ) fn unwrap_return_type ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
4040 let ret_type = ctx. find_node_at_offset :: < ast:: RetType > ( ) ?;
4141 let parent = ret_type. syntax ( ) . parent ( ) ?;
42- let body = match_ast ! {
42+ let body_expr = match_ast ! {
4343 match parent {
44- ast:: Fn ( func) => func. body( ) ?,
44+ ast:: Fn ( func) => func. body( ) ?. into ( ) ,
4545 ast:: ClosureExpr ( closure) => match closure. body( ) ? {
46- Expr :: BlockExpr ( block) => block,
46+ ast :: Expr :: BlockExpr ( block) => block. into ( ) ,
4747 // closures require a block when a return type is specified
4848 _ => return None ,
4949 } ,
@@ -65,72 +65,94 @@ pub(crate) fn unwrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) ->
6565 let happy_type = extract_wrapped_type ( type_ref) ?;
6666
6767 acc. add ( kind. assist_id ( ) , kind. label ( ) , type_ref. syntax ( ) . text_range ( ) , |builder| {
68- let body = ast:: Expr :: BlockExpr ( body) ;
68+ let mut editor = builder. make_editor ( & parent) ;
69+ let make = SyntaxFactory :: new ( ) ;
6970
7071 let mut exprs_to_unwrap = Vec :: new ( ) ;
7172 let tail_cb = & mut |e : & _ | tail_cb_impl ( & mut exprs_to_unwrap, e) ;
72- walk_expr ( & body , & mut |expr| {
73- if let Expr :: ReturnExpr ( ret_expr) = expr {
73+ walk_expr ( & body_expr , & mut |expr| {
74+ if let ast :: Expr :: ReturnExpr ( ret_expr) = expr {
7475 if let Some ( ret_expr_arg) = & ret_expr. expr ( ) {
7576 for_each_tail_expr ( ret_expr_arg, tail_cb) ;
7677 }
7778 }
7879 } ) ;
79- for_each_tail_expr ( & body , tail_cb) ;
80+ for_each_tail_expr ( & body_expr , tail_cb) ;
8081
8182 let is_unit_type = is_unit_type ( & happy_type) ;
8283 if is_unit_type {
83- let mut text_range = ret_type. syntax ( ) . text_range ( ) ;
84-
8584 if let Some ( NodeOrToken :: Token ( token) ) = ret_type. syntax ( ) . next_sibling_or_token ( ) {
8685 if token. kind ( ) == SyntaxKind :: WHITESPACE {
87- text_range = TextRange :: new ( text_range . start ( ) , token. text_range ( ) . end ( ) ) ;
86+ editor . delete ( token) ;
8887 }
8988 }
9089
91- builder . delete ( text_range ) ;
90+ editor . delete ( ret_type . syntax ( ) ) ;
9291 } else {
93- builder . replace ( type_ref. syntax ( ) . text_range ( ) , happy_type. syntax ( ) . text ( ) ) ;
92+ editor . replace ( type_ref. syntax ( ) , happy_type. syntax ( ) ) ;
9493 }
9594
96- for ret_expr_arg in exprs_to_unwrap {
97- let ret_expr_str = ret_expr_arg. to_string ( ) ;
98-
99- let needs_replacing = match kind {
100- UnwrapperKind :: Option => ret_expr_str. starts_with ( "Some(" ) ,
101- UnwrapperKind :: Result => {
102- ret_expr_str. starts_with ( "Ok(" ) || ret_expr_str. starts_with ( "Err(" )
103- }
104- } ;
95+ for tail_expr in exprs_to_unwrap {
96+ match & tail_expr {
97+ ast:: Expr :: CallExpr ( call_expr) => {
98+ let ast:: Expr :: PathExpr ( path_expr) = call_expr. expr ( ) . unwrap ( ) else {
99+ continue ;
100+ } ;
101+
102+ let path_str = path_expr. path ( ) . unwrap ( ) . to_string ( ) ;
103+ let needs_replacing = match kind {
104+ UnwrapperKind :: Option => path_str == "Some" ,
105+ UnwrapperKind :: Result => path_str == "Ok" || path_str == "Err" ,
106+ } ;
107+
108+ if !needs_replacing {
109+ continue ;
110+ }
105111
106- if needs_replacing {
107- let arg_list = ret_expr_arg. syntax ( ) . children ( ) . find_map ( ast:: ArgList :: cast) ;
108- if let Some ( arg_list) = arg_list {
112+ let arg_list = call_expr. arg_list ( ) . unwrap ( ) ;
109113 if is_unit_type {
110- match ret_expr_arg . syntax ( ) . prev_sibling_or_token ( ) {
111- // Useful to delete the entire line without leaving trailing whitespaces
112- Some ( whitespace ) => {
113- let new_range = TextRange :: new (
114- whitespace . text_range ( ) . start ( ) ,
115- ret_expr_arg . syntax ( ) . text_range ( ) . end ( ) ,
116- ) ;
117- builder . delete ( new_range ) ;
114+ let tail_parent = tail_expr
115+ . syntax ( )
116+ . parent ( )
117+ . and_then ( Either :: < ast :: ReturnExpr , ast :: StmtList > :: cast )
118+ . unwrap ( ) ;
119+ match tail_parent {
120+ Either :: Left ( ret_expr ) => {
121+ editor . replace ( ret_expr . syntax ( ) , make . expr_return ( None ) . syntax ( ) )
118122 }
119- None => {
120- builder. delete ( ret_expr_arg. syntax ( ) . text_range ( ) ) ;
123+ Either :: Right ( stmt_list) => {
124+ let new_block = if stmt_list. statements ( ) . next ( ) . is_none ( ) {
125+ make. expr_empty_block ( )
126+ } else {
127+ make. block_expr ( stmt_list. statements ( ) , None )
128+ } ;
129+ editor. replace (
130+ stmt_list. syntax ( ) ,
131+ new_block. stmt_list ( ) . unwrap ( ) . syntax ( ) ,
132+ ) ;
121133 }
122134 }
123- } else {
124- builder. replace (
125- ret_expr_arg. syntax ( ) . text_range ( ) ,
126- arg_list. args ( ) . join ( ", " ) ,
127- ) ;
135+ } else if let Some ( first_arg) = arg_list. args ( ) . next ( ) {
136+ editor. replace ( tail_expr. syntax ( ) , first_arg. syntax ( ) ) ;
128137 }
129138 }
130- } else if matches ! ( kind, UnwrapperKind :: Option if ret_expr_str == "None" ) {
131- builder. replace ( ret_expr_arg. syntax ( ) . text_range ( ) , "()" ) ;
139+ ast:: Expr :: PathExpr ( path_expr) => {
140+ let UnwrapperKind :: Option = kind else {
141+ continue ;
142+ } ;
143+
144+ if path_expr. path ( ) . unwrap ( ) . to_string ( ) != "None" {
145+ continue ;
146+ }
147+
148+ editor. replace ( path_expr. syntax ( ) , make. expr_unit ( ) . syntax ( ) ) ;
149+ }
150+ _ => ( ) ,
132151 }
133152 }
153+
154+ editor. add_mappings ( make. finish_with_mappings ( ) ) ;
155+ builder. add_file_edits ( ctx. file_id ( ) , editor) ;
134156 } )
135157}
136158
@@ -168,12 +190,12 @@ impl UnwrapperKind {
168190
169191fn tail_cb_impl ( acc : & mut Vec < ast:: Expr > , e : & ast:: Expr ) {
170192 match e {
171- Expr :: BreakExpr ( break_expr) => {
193+ ast :: Expr :: BreakExpr ( break_expr) => {
172194 if let Some ( break_expr_arg) = break_expr. expr ( ) {
173195 for_each_tail_expr ( & break_expr_arg, & mut |e| tail_cb_impl ( acc, e) )
174196 }
175197 }
176- Expr :: ReturnExpr ( _) => {
198+ ast :: Expr :: ReturnExpr ( _) => {
177199 // all return expressions have already been handled by the walk loop
178200 }
179201 e => acc. push ( e. clone ( ) ) ,
@@ -238,8 +260,7 @@ fn foo() -> Option<()$0> {
238260}
239261"# ,
240262 r#"
241- fn foo() {
242- }
263+ fn foo() {}
243264"# ,
244265 "Unwrap Option return type" ,
245266 ) ;
@@ -254,8 +275,7 @@ fn foo() -> Option<()$0>{
254275}
255276"# ,
256277 r#"
257- fn foo() {
258- }
278+ fn foo() {}
259279"# ,
260280 "Unwrap Option return type" ,
261281 ) ;
@@ -1262,8 +1282,7 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
12621282}
12631283"# ,
12641284 r#"
1265- fn foo() {
1266- }
1285+ fn foo() {}
12671286"# ,
12681287 "Unwrap Result return type" ,
12691288 ) ;
@@ -1278,8 +1297,7 @@ fn foo() -> Result<(), Box<dyn Error$0>>{
12781297}
12791298"# ,
12801299 r#"
1281- fn foo() {
1282- }
1300+ fn foo() {}
12831301"# ,
12841302 "Unwrap Result return type" ,
12851303 ) ;
0 commit comments