@@ -9,9 +9,9 @@ use async_trait::async_trait;
99use sqlparser:: ast:: helpers:: attached_token:: AttachedToken ;
1010use sqlparser:: ast:: {
1111 BinaryOperator , CastKind , CharacterLength , DataType , Expr , Function , FunctionArg ,
12- FunctionArgExpr , FunctionArgumentList , FunctionArguments , Ident , ObjectName ,
13- OneOrManyWithParens , SelectItem , SetExpr , Spanned , Statement , Value , Visit , VisitMut , Visitor ,
14- VisitorMut ,
12+ FunctionArgExpr , FunctionArgumentList , FunctionArguments , Ident , ObjectName , ObjectNamePart ,
13+ OneOrManyWithParens , SelectFlavor , SelectItem , SetExpr , Spanned , Statement , Value ,
14+ ValueWithSpan , Visit , VisitMut , Visitor , VisitorMut ,
1515} ;
1616use sqlparser:: dialect:: { Dialect , MsSqlDialect , MySqlDialect , PostgreSqlDialect , SQLiteDialect } ;
1717use sqlparser:: parser:: { Parser , ParserError } ;
@@ -329,7 +329,7 @@ fn extract_toplevel_functions(stmt: &mut Statement) -> Vec<DelayedFunctionCall>
329329 let argument_col_name = format ! ( "_sqlpage_f{func_idx}_a{arg_idx}" ) ;
330330 argument_col_names. push ( argument_col_name. clone ( ) ) ;
331331 let expr_to_insert = SelectItem :: ExprWithAlias {
332- expr : std:: mem:: replace ( expr, Expr :: Value ( Value :: Null ) ) ,
332+ expr : std:: mem:: replace ( expr, Expr :: value ( Value :: Null ) ) ,
333333 alias : Ident :: new ( argument_col_name) ,
334334 } ;
335335 select_items_to_add. push ( SelectItemToAdd {
@@ -417,10 +417,21 @@ fn extract_static_simple_select(
417417 return None ;
418418 } ;
419419 let value = match expr {
420- Expr :: Value ( Value :: Boolean ( b) ) => Static ( Bool ( * b) ) ,
421- Expr :: Value ( Value :: Number ( n, _) ) => Static ( Number ( n. parse ( ) . ok ( ) ?) ) ,
422- Expr :: Value ( Value :: SingleQuotedString ( s) ) => Static ( String ( s. clone ( ) ) ) ,
423- Expr :: Value ( Value :: Null ) => Static ( Null ) ,
420+ Expr :: Value ( ValueWithSpan {
421+ value : Value :: Boolean ( b) ,
422+ ..
423+ } ) => Static ( Bool ( * b) ) ,
424+ Expr :: Value ( ValueWithSpan {
425+ value : Value :: Number ( n, _) ,
426+ ..
427+ } ) => Static ( Number ( n. parse ( ) . ok ( ) ?) ) ,
428+ Expr :: Value ( ValueWithSpan {
429+ value : Value :: SingleQuotedString ( s) ,
430+ ..
431+ } ) => Static ( String ( s. clone ( ) ) ) ,
432+ Expr :: Value ( ValueWithSpan {
433+ value : Value :: Null , ..
434+ } ) => Static ( Null ) ,
424435 e if is_simple_select_placeholder ( e) => {
425436 if let Some ( p) = params_iter. next ( ) {
426437 Dynamic ( p)
@@ -446,7 +457,10 @@ fn extract_static_simple_select(
446457
447458fn is_simple_select_placeholder ( e : & Expr ) -> bool {
448459 match e {
449- Expr :: Value ( Value :: Placeholder ( _) ) => true ,
460+ Expr :: Value ( ValueWithSpan {
461+ value : Value :: Placeholder ( _) ,
462+ ..
463+ } ) => true ,
450464 Expr :: Cast {
451465 expr,
452466 data_type : DataType :: Text | DataType :: Varchar ( _) | DataType :: Char ( _) ,
@@ -469,13 +483,15 @@ fn extract_set_variable(
469483 hivevar : false ,
470484 } = stmt
471485 {
472- if let ( [ ident] , [ value] ) = ( name. as_mut_slice ( ) , value. as_mut_slice ( ) ) {
486+ if let ( [ ObjectNamePart :: Identifier ( ident) ] , [ value] ) =
487+ ( name. as_mut_slice ( ) , value. as_mut_slice ( ) )
488+ {
473489 let variable = if let Some ( variable) = extract_ident_param ( ident) {
474490 variable
475491 } else {
476492 StmtParam :: PostOrGet ( std:: mem:: take ( & mut ident. value ) )
477493 } ;
478- let owned_expr = std:: mem:: replace ( value, Expr :: Value ( Value :: Null ) ) ;
494+ let owned_expr = std:: mem:: replace ( value, Expr :: value ( Value :: Null ) ) ;
479495 let mut select_stmt: Statement = expr_to_statement ( owned_expr) ;
480496 let delayed_functions = extract_toplevel_functions ( & mut select_stmt) ;
481497 if let Err ( err) = validate_function_calls ( & select_stmt) {
@@ -576,7 +592,7 @@ impl ParameterExtractor {
576592 AnyKind :: Mssql => DataType :: Varchar ( Some ( CharacterLength :: Max ) ) ,
577593 _ => DataType :: Text ,
578594 } ;
579- let value = Expr :: Value ( Value :: Placeholder ( name) ) ;
595+ let value = Expr :: value ( Value :: Placeholder ( name) ) ;
580596 Expr :: Cast {
581597 expr : Box :: new ( value) ,
582598 data_type,
@@ -693,9 +709,10 @@ pub(super) fn function_args_to_stmt_params(
693709
694710fn expr_to_stmt_param ( arg : & mut Expr ) -> Option < StmtParam > {
695711 match arg {
696- Expr :: Value ( Value :: Placeholder ( placeholder) ) => {
697- Some ( map_param ( std:: mem:: take ( placeholder) ) )
698- }
712+ Expr :: Value ( ValueWithSpan {
713+ value : Value :: Placeholder ( placeholder) ,
714+ ..
715+ } ) => Some ( map_param ( std:: mem:: take ( placeholder) ) ) ,
699716 Expr :: Identifier ( ident) => extract_ident_param ( ident) ,
700717 Expr :: Function ( Function {
701718 name : ObjectName ( func_name_parts) ,
@@ -710,13 +727,17 @@ fn expr_to_stmt_param(arg: &mut Expr) -> Option<StmtParam> {
710727 sqlpage_func_name ( func_name_parts) ,
711728 args. as_mut_slice ( ) ,
712729 ) ) ,
713- Expr :: Value ( Value :: SingleQuotedString ( param_value) ) => {
714- Some ( StmtParam :: Literal ( std:: mem:: take ( param_value) ) )
715- }
716- Expr :: Value ( Value :: Number ( param_value, _is_long) ) => {
717- Some ( StmtParam :: Literal ( param_value. clone ( ) ) )
718- }
719- Expr :: Value ( Value :: Null ) => Some ( StmtParam :: Null ) ,
730+ Expr :: Value ( ValueWithSpan {
731+ value : Value :: SingleQuotedString ( param_value) ,
732+ ..
733+ } ) => Some ( StmtParam :: Literal ( std:: mem:: take ( param_value) ) ) ,
734+ Expr :: Value ( ValueWithSpan {
735+ value : Value :: Number ( param_value, _is_long) ,
736+ ..
737+ } ) => Some ( StmtParam :: Literal ( param_value. clone ( ) ) ) ,
738+ Expr :: Value ( ValueWithSpan {
739+ value : Value :: Null , ..
740+ } ) => Some ( StmtParam :: Null ) ,
720741 Expr :: BinaryOp {
721742 // 'str1' || 'str2'
722743 left,
@@ -741,7 +762,10 @@ fn expr_to_stmt_param(arg: &mut Expr) -> Option<StmtParam> {
741762 } ) ,
742763 ..
743764 } ) if func_name_parts. len ( ) == 1 => {
744- let func_name = func_name_parts[ 0 ] . value . as_str ( ) ;
765+ let func_name = func_name_parts[ 0 ]
766+ . as_ident ( )
767+ . map ( |ident| ident. value . as_str ( ) )
768+ . unwrap_or_default ( ) ;
745769 if func_name. eq_ignore_ascii_case ( "concat" ) {
746770 let mut concat_args = Vec :: with_capacity ( args. len ( ) ) ;
747771 for arg in args {
@@ -829,7 +853,10 @@ impl VisitorMut for ParameterExtractor {
829853 self . replace_with_placeholder ( value, param) ;
830854 }
831855 }
832- Expr :: Value ( Value :: Placeholder ( param) ) if !self . is_own_placeholder ( param) =>
856+ Expr :: Value ( ValueWithSpan {
857+ value : Value :: Placeholder ( param) ,
858+ ..
859+ } ) if !self . is_own_placeholder ( param) =>
833860 // this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves
834861 {
835862 let name = std:: mem:: take ( param) ;
@@ -860,10 +887,10 @@ impl VisitorMut for ParameterExtractor {
860887 op : BinaryOperator :: StringConcat ,
861888 right,
862889 } if self . db_kind == AnyKind :: Mssql => {
863- let left = std:: mem:: replace ( left. as_mut ( ) , Expr :: Value ( Value :: Null ) ) ;
864- let right = std:: mem:: replace ( right. as_mut ( ) , Expr :: Value ( Value :: Null ) ) ;
890+ let left = std:: mem:: replace ( left. as_mut ( ) , Expr :: value ( Value :: Null ) ) ;
891+ let right = std:: mem:: replace ( right. as_mut ( ) , Expr :: value ( Value :: Null ) ) ;
865892 * value = Expr :: Function ( Function {
866- name : ObjectName ( vec ! [ Ident :: new( "CONCAT" ) ] ) ,
893+ name : ObjectName ( vec ! [ ObjectNamePart :: Identifier ( Ident :: new( "CONCAT" ) ) ] ) ,
867894 args : FunctionArguments :: List ( FunctionArgumentList {
868895 args : vec ! [
869896 FunctionArg :: Unnamed ( FunctionArgExpr :: Expr ( left) ) ,
@@ -896,18 +923,22 @@ impl VisitorMut for ParameterExtractor {
896923
897924const SQLPAGE_FUNCTION_NAMESPACE : & str = "sqlpage" ;
898925
899- fn is_sqlpage_func ( func_name_parts : & [ Ident ] ) -> bool {
900- if let [ Ident { value, .. } , Ident { .. } ] = func_name_parts {
926+ fn is_sqlpage_func ( func_name_parts : & [ ObjectNamePart ] ) -> bool {
927+ if let [ ObjectNamePart :: Identifier ( Ident { value, .. } ) , ObjectNamePart :: Identifier ( Ident { .. } ) ] =
928+ func_name_parts
929+ {
901930 value == SQLPAGE_FUNCTION_NAMESPACE
902931 } else {
903932 false
904933 }
905934}
906935
907- fn extract_sqlpage_function_name ( func_name_parts : & [ Ident ] ) -> Option < SqlPageFunctionName > {
908- if let [ Ident {
936+ fn extract_sqlpage_function_name (
937+ func_name_parts : & [ ObjectNamePart ] ,
938+ ) -> Option < SqlPageFunctionName > {
939+ if let [ ObjectNamePart :: Identifier ( Ident {
909940 value : namespace, ..
910- } , Ident { value, .. } ] = func_name_parts
941+ } ) , ObjectNamePart :: Identifier ( Ident { value, .. } ) ] = func_name_parts
911942 {
912943 if namespace == SQLPAGE_FUNCTION_NAMESPACE {
913944 return SqlPageFunctionName :: from_str ( value) . ok ( ) ;
@@ -916,8 +947,10 @@ fn extract_sqlpage_function_name(func_name_parts: &[Ident]) -> Option<SqlPageFun
916947 None
917948}
918949
919- fn sqlpage_func_name ( func_name_parts : & [ Ident ] ) -> & str {
920- if let [ Ident { .. } , Ident { value, .. } ] = func_name_parts {
950+ fn sqlpage_func_name ( func_name_parts : & [ ObjectNamePart ] ) -> & str {
951+ if let [ ObjectNamePart :: Identifier ( Ident { .. } ) , ObjectNamePart :: Identifier ( Ident { value, .. } ) ] =
952+ func_name_parts
953+ {
921954 value
922955 } else {
923956 debug_assert ! (
@@ -955,7 +988,7 @@ fn extract_json_columns(stmt: &Statement, db_kind: AnyKind) -> Vec<String> {
955988fn is_json_function ( expr : & Expr ) -> bool {
956989 match expr {
957990 Expr :: Function ( function) => {
958- if let [ Ident { value, .. } ] = function. name . 0 . as_slice ( ) {
991+ if let [ ObjectNamePart :: Identifier ( Ident { value, .. } ) ] = function. name . 0 . as_slice ( ) {
959992 [
960993 "json_object" ,
961994 "json_array" ,
@@ -979,10 +1012,17 @@ fn is_json_function(expr: &Expr) -> bool {
9791012 }
9801013 }
9811014 Expr :: Cast { data_type, .. } => {
982- matches ! ( data_type, DataType :: JSON | DataType :: JSONB )
983- || ( matches ! ( data_type, DataType :: Custom ( ObjectName ( parts) , _) if
984- ( parts. len( ) == 1 )
985- && ( parts[ 0 ] . value. eq_ignore_ascii_case( "json" ) ) ) )
1015+ if matches ! ( data_type, DataType :: JSON | DataType :: JSONB ) {
1016+ true
1017+ } else if let DataType :: Custom ( ObjectName ( parts) , _) = data_type {
1018+ if let [ ObjectNamePart :: Identifier ( ident) ] = parts. as_slice ( ) {
1019+ ident. value . eq_ignore_ascii_case ( "json" )
1020+ } else {
1021+ false
1022+ }
1023+ } else {
1024+ false
1025+ }
9861026 }
9871027 _ => false ,
9881028 }
@@ -1019,6 +1059,7 @@ fn expr_to_statement(expr: Expr) -> Statement {
10191059 window_before_qualify : false ,
10201060 value_table_mode : None ,
10211061 connect_by : None ,
1062+ flavor : SelectFlavor :: Standard ,
10221063 } ,
10231064 ) ) ) ,
10241065 order_by : None ,
0 commit comments