@@ -378,6 +378,14 @@ impl VisitorMut for ParameterExtractor {
378378 type Break = ( ) ;
379379 fn pre_visit_expr ( & mut self , value : & mut Expr ) -> ControlFlow < Self :: Break > {
380380 match value {
381+ Expr :: Identifier ( Ident {
382+ value : var_name,
383+ quote_style : None ,
384+ } ) if var_name. starts_with ( '$' ) || var_name. starts_with ( ':' ) => {
385+ let name = std:: mem:: take ( var_name) ;
386+ * value = self . make_placeholder ( ) ;
387+ self . parameters . push ( map_param ( name) ) ;
388+ }
381389 Expr :: Value ( Value :: Placeholder ( param) ) if !self . is_own_placeholder ( param) =>
382390 // this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves
383391 {
@@ -480,6 +488,56 @@ mod test {
480488 ) ;
481489 }
482490
491+ #[ test]
492+ fn test_statement_rewrite_sqlite ( ) {
493+ let mut ast = parse_stmt ( "select $x, :y from t" , SQLiteDialect { } ) ;
494+ let parameters = ParameterExtractor :: extract_parameters ( & mut ast, AnyKind :: Sqlite ) ;
495+ assert_eq ! (
496+ ast. to_string( ) ,
497+ "SELECT CAST(? AS VARCHAR), CAST(? AS VARCHAR) FROM t"
498+ ) ;
499+ assert_eq ! (
500+ parameters,
501+ [
502+ StmtParam :: GetOrPost ( "x" . to_string( ) ) ,
503+ StmtParam :: Post ( "y" . to_string( ) ) ,
504+ ]
505+ ) ;
506+ }
507+
508+ #[ test]
509+ fn is_own_placeholder ( ) {
510+ assert ! ( ParameterExtractor {
511+ db_kind: AnyKind :: Postgres ,
512+ parameters: vec![ ]
513+ }
514+ . is_own_placeholder( "$1" ) ) ;
515+
516+ assert ! ( ParameterExtractor {
517+ db_kind: AnyKind :: Postgres ,
518+ parameters: vec![ StmtParam :: Get ( 'x' . to_string( ) ) ]
519+ }
520+ . is_own_placeholder( "$2" ) ) ;
521+
522+ assert ! ( !ParameterExtractor {
523+ db_kind: AnyKind :: Postgres ,
524+ parameters: vec![ ]
525+ }
526+ . is_own_placeholder( "$2" ) ) ;
527+
528+ assert ! ( ParameterExtractor {
529+ db_kind: AnyKind :: Sqlite ,
530+ parameters: vec![ ]
531+ }
532+ . is_own_placeholder( "?" ) ) ;
533+
534+ assert ! ( !ParameterExtractor {
535+ db_kind: AnyKind :: Sqlite ,
536+ parameters: vec![ ]
537+ }
538+ . is_own_placeholder( "$1" ) ) ;
539+ }
540+
483541 #[ test]
484542 fn test_mssql_statement_rewrite ( ) {
485543 let mut ast = parse_stmt ( "select '' || $1 from [a schema].[a table]" , MsSqlDialect { } ) ;
0 commit comments