@@ -5,12 +5,14 @@ use ra_db::{SourceDatabase, SourceDatabaseExt};
55use ra_ide_db:: symbol_index:: SymbolsDatabase ;
66use ra_ide_db:: RootDatabase ;
77use ra_syntax:: ast:: make:: try_expr_from_text;
8- use ra_syntax:: ast:: { AstToken , Comment , RecordField , RecordLit } ;
9- use ra_syntax:: { AstNode , SyntaxElement , SyntaxNode } ;
8+ use ra_syntax:: ast:: {
9+ ArgList , AstToken , CallExpr , Comment , Expr , MethodCallExpr , RecordField , RecordLit ,
10+ } ;
11+ use ra_syntax:: { AstNode , SyntaxElement , SyntaxKind , SyntaxNode } ;
1012use ra_text_edit:: { TextEdit , TextEditBuilder } ;
1113use rustc_hash:: FxHashMap ;
1214use std:: collections:: HashMap ;
13- use std:: str:: FromStr ;
15+ use std:: { iter :: once , str:: FromStr } ;
1416
1517#[ derive( Debug , PartialEq ) ]
1618pub struct SsrError ( String ) ;
@@ -219,6 +221,50 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
219221 )
220222 }
221223
224+ fn check_call_and_method_call (
225+ pattern : CallExpr ,
226+ code : MethodCallExpr ,
227+ placeholders : & [ Var ] ,
228+ match_ : Match ,
229+ ) -> Option < Match > {
230+ let ( pattern_name, pattern_type_args) = if let Some ( Expr :: PathExpr ( path_exr) ) =
231+ pattern. expr ( )
232+ {
233+ let segment = path_exr. path ( ) . and_then ( |p| p. segment ( ) ) ;
234+ ( segment. as_ref ( ) . and_then ( |s| s. name_ref ( ) ) , segment. and_then ( |s| s. type_arg_list ( ) ) )
235+ } else {
236+ ( None , None )
237+ } ;
238+ let match_ = check_opt_nodes ( pattern_name, code. name_ref ( ) , placeholders, match_) ?;
239+ let match_ =
240+ check_opt_nodes ( pattern_type_args, code. type_arg_list ( ) , placeholders, match_) ?;
241+ let pattern_args = pattern. syntax ( ) . children ( ) . find_map ( ArgList :: cast) ?. args ( ) ;
242+ let code_args = code. syntax ( ) . children ( ) . find_map ( ArgList :: cast) ?. args ( ) ;
243+ let code_args = once ( code. expr ( ) ?) . chain ( code_args) ;
244+ check_iter ( pattern_args, code_args, placeholders, match_)
245+ }
246+
247+ fn check_method_call_and_call (
248+ pattern : MethodCallExpr ,
249+ code : CallExpr ,
250+ placeholders : & [ Var ] ,
251+ match_ : Match ,
252+ ) -> Option < Match > {
253+ let ( code_name, code_type_args) = if let Some ( Expr :: PathExpr ( path_exr) ) = code. expr ( ) {
254+ let segment = path_exr. path ( ) . and_then ( |p| p. segment ( ) ) ;
255+ ( segment. as_ref ( ) . and_then ( |s| s. name_ref ( ) ) , segment. and_then ( |s| s. type_arg_list ( ) ) )
256+ } else {
257+ ( None , None )
258+ } ;
259+ let match_ = check_opt_nodes ( pattern. name_ref ( ) , code_name, placeholders, match_) ?;
260+ let match_ =
261+ check_opt_nodes ( pattern. type_arg_list ( ) , code_type_args, placeholders, match_) ?;
262+ let code_args = code. syntax ( ) . children ( ) . find_map ( ArgList :: cast) ?. args ( ) ;
263+ let pattern_args = pattern. syntax ( ) . children ( ) . find_map ( ArgList :: cast) ?. args ( ) ;
264+ let pattern_args = once ( pattern. expr ( ) ?) . chain ( pattern_args) ;
265+ check_iter ( pattern_args, code_args, placeholders, match_)
266+ }
267+
222268 fn check_opt_nodes (
223269 pattern : Option < impl AstNode > ,
224270 code : Option < impl AstNode > ,
@@ -227,8 +273,8 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
227273 ) -> Option < Match > {
228274 match ( pattern, code) {
229275 ( Some ( pattern) , Some ( code) ) => check (
230- & SyntaxElement :: from ( pattern. syntax ( ) . clone ( ) ) ,
231- & SyntaxElement :: from ( code. syntax ( ) . clone ( ) ) ,
276+ & pattern. syntax ( ) . clone ( ) . into ( ) ,
277+ & code. syntax ( ) . clone ( ) . into ( ) ,
232278 placeholders,
233279 match_,
234280 ) ,
@@ -237,6 +283,33 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
237283 }
238284 }
239285
286+ fn check_iter < T , I1 , I2 > (
287+ mut pattern : I1 ,
288+ mut code : I2 ,
289+ placeholders : & [ Var ] ,
290+ match_ : Match ,
291+ ) -> Option < Match >
292+ where
293+ T : AstNode ,
294+ I1 : Iterator < Item = T > ,
295+ I2 : Iterator < Item = T > ,
296+ {
297+ pattern
298+ . by_ref ( )
299+ . zip ( code. by_ref ( ) )
300+ . fold ( Some ( match_) , |accum, ( a, b) | {
301+ accum. and_then ( |match_| {
302+ check (
303+ & a. syntax ( ) . clone ( ) . into ( ) ,
304+ & b. syntax ( ) . clone ( ) . into ( ) ,
305+ placeholders,
306+ match_,
307+ )
308+ } )
309+ } )
310+ . filter ( |_| pattern. next ( ) . is_none ( ) && code. next ( ) . is_none ( ) )
311+ }
312+
240313 fn check (
241314 pattern : & SyntaxElement ,
242315 code : & SyntaxElement ,
@@ -260,6 +333,14 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
260333 ( RecordLit :: cast ( pattern. clone ( ) ) , RecordLit :: cast ( code. clone ( ) ) )
261334 {
262335 check_record_lit ( pattern, code, placeholders, match_)
336+ } else if let ( Some ( pattern) , Some ( code) ) =
337+ ( CallExpr :: cast ( pattern. clone ( ) ) , MethodCallExpr :: cast ( code. clone ( ) ) )
338+ {
339+ check_call_and_method_call ( pattern, code, placeholders, match_)
340+ } else if let ( Some ( pattern) , Some ( code) ) =
341+ ( MethodCallExpr :: cast ( pattern. clone ( ) ) , CallExpr :: cast ( code. clone ( ) ) )
342+ {
343+ check_method_call_and_call ( pattern, code, placeholders, match_)
263344 } else {
264345 let mut pattern_children = pattern
265346 . children_with_tokens ( )
@@ -290,16 +371,15 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
290371 let kind = pattern. pattern . kind ( ) ;
291372 let matches = code
292373 . descendants ( )
293- . filter ( |n| n. kind ( ) == kind)
374+ . filter ( |n| {
375+ n. kind ( ) == kind
376+ || ( kind == SyntaxKind :: CALL_EXPR && n. kind ( ) == SyntaxKind :: METHOD_CALL_EXPR )
377+ || ( kind == SyntaxKind :: METHOD_CALL_EXPR && n. kind ( ) == SyntaxKind :: CALL_EXPR )
378+ } )
294379 . filter_map ( |code| {
295380 let match_ =
296381 Match { place : code. clone ( ) , binding : HashMap :: new ( ) , ignored_comments : vec ! [ ] } ;
297- check (
298- & SyntaxElement :: from ( pattern. pattern . clone ( ) ) ,
299- & SyntaxElement :: from ( code) ,
300- & pattern. vars ,
301- match_,
302- )
382+ check ( & pattern. pattern . clone ( ) . into ( ) , & code. into ( ) , & pattern. vars , match_)
303383 } )
304384 . collect ( ) ;
305385 SsrMatches { matches }
@@ -498,4 +578,22 @@ mod tests {
498578 "fn main() { foo::new(1, 2) }" ,
499579 )
500580 }
581+
582+ #[ test]
583+ fn ssr_call_and_method_call ( ) {
584+ assert_ssr_transform (
585+ "foo::<'a>($a:expr, $b:expr)) ==>> foo2($a, $b)" ,
586+ "fn main() { get().bar.foo::<'a>(1); }" ,
587+ "fn main() { foo2(get().bar, 1); }" ,
588+ )
589+ }
590+
591+ #[ test]
592+ fn ssr_method_call_and_call ( ) {
593+ assert_ssr_transform (
594+ "$o:expr.foo::<i32>($a:expr)) ==>> $o.foo2($a)" ,
595+ "fn main() { X::foo::<i32>(x, 1); }" ,
596+ "fn main() { x.foo2(1); }" ,
597+ )
598+ }
501599}
0 commit comments