@@ -144,16 +144,8 @@ fn parse_single_statement(
144144 semicolon = true ;
145145 }
146146 let mut params = ParameterExtractor :: extract_parameters ( & mut stmt, db_kind) ;
147- if let Some ( ( variable, query) ) = extract_set_variable ( & mut stmt) {
148- return Some ( ParsedStatement :: SetVariable {
149- variable,
150- value : StmtWithParams {
151- query,
152- params,
153- delayed_functions : Vec :: new ( ) ,
154- json_columns : Vec :: new ( ) ,
155- } ,
156- } ) ;
147+ if let Some ( ( variable, value) ) = extract_set_variable ( & mut stmt, & mut params, db_kind) {
148+ return Some ( ParsedStatement :: SetVariable { variable, value } ) ;
157149 }
158150 if let Some ( csv_import) = extract_csv_copy_statement ( & mut stmt) {
159151 return Some ( ParsedStatement :: CsvImport ( csv_import) ) ;
@@ -406,7 +398,11 @@ fn is_simple_select_placeholder(e: &Expr) -> bool {
406398 }
407399}
408400
409- fn extract_set_variable ( stmt : & mut Statement ) -> Option < ( StmtParam , String ) > {
401+ fn extract_set_variable (
402+ stmt : & mut Statement ,
403+ params : & mut Vec < StmtParam > ,
404+ db_kind : AnyKind ,
405+ ) -> Option < ( StmtParam , StmtWithParams ) > {
410406 if let Statement :: SetVariable {
411407 variables : OneOrManyWithParens :: One ( ObjectName ( name) ) ,
412408 value,
@@ -420,7 +416,20 @@ fn extract_set_variable(stmt: &mut Statement) -> Option<(StmtParam, String)> {
420416 } else {
421417 StmtParam :: PostOrGet ( std:: mem:: take ( & mut ident. value ) )
422418 } ;
423- return Some ( ( variable, format ! ( "SELECT {value}" ) ) ) ;
419+ let owned_expr = std:: mem:: replace ( value, Expr :: Value ( Value :: Null ) ) ;
420+ let mut select_stmt: Statement = expr_to_statement ( owned_expr) ;
421+ let delayed_functions = extract_toplevel_functions ( & mut select_stmt) ;
422+ remove_invalid_function_calls ( & mut select_stmt, params) ;
423+ let json_columns = extract_json_columns ( & select_stmt, db_kind) ;
424+ return Some ( (
425+ variable,
426+ StmtWithParams {
427+ query : select_stmt. to_string ( ) ,
428+ params : std:: mem:: take ( params) ,
429+ delayed_functions,
430+ json_columns,
431+ } ,
432+ ) ) ;
424433 }
425434 }
426435 None
@@ -862,6 +871,47 @@ fn is_json_function(expr: &Expr) -> bool {
862871 }
863872}
864873
874+ fn expr_to_statement ( expr : Expr ) -> Statement {
875+ Statement :: Query ( Box :: new ( sqlparser:: ast:: Query {
876+ with : None ,
877+ body : Box :: new ( sqlparser:: ast:: SetExpr :: Select ( Box :: new (
878+ sqlparser:: ast:: Select {
879+ distinct : None ,
880+ top : None ,
881+ projection : vec ! [ SelectItem :: ExprWithAlias {
882+ expr,
883+ alias: Ident :: new( "sqlpage_set_expr" ) ,
884+ } ] ,
885+ into : None ,
886+ from : vec ! [ ] ,
887+ lateral_views : vec ! [ ] ,
888+ selection : None ,
889+ group_by : sqlparser:: ast:: GroupByExpr :: Expressions ( vec ! [ ] , vec ! [ ] ) ,
890+ cluster_by : vec ! [ ] ,
891+ distribute_by : vec ! [ ] ,
892+ sort_by : vec ! [ ] ,
893+ having : None ,
894+ named_window : vec ! [ ] ,
895+ qualify : None ,
896+ top_before_distinct : false ,
897+ prewhere : None ,
898+ window_before_qualify : false ,
899+ value_table_mode : None ,
900+ connect_by : None ,
901+ } ,
902+ ) ) ) ,
903+ order_by : None ,
904+ limit : None ,
905+ offset : None ,
906+ fetch : None ,
907+ locks : vec ! [ ] ,
908+ limit_by : vec ! [ ] ,
909+ for_clause : None ,
910+ settings : None ,
911+ format_clause : None ,
912+ } ) )
913+ }
914+
865915#[ cfg( test) ]
866916mod test {
867917 use super :: super :: sqlpage_functions:: functions:: SqlPageFunctionName ;
@@ -1169,7 +1219,7 @@ mod test {
11691219 StmtParam :: PostOrGet ( "x" . to_string( ) ) ,
11701220 "{dialect:?}"
11711221 ) ;
1172- assert_eq ! ( query, "SELECT 42" ) ;
1222+ assert_eq ! ( query, "SELECT 42 AS sqlpage_set_expr " ) ;
11731223 assert ! ( params. is_empty( ) ) ;
11741224 } else {
11751225 panic ! ( "Failed for dialect {dialect:?}: {stmt:#?}" , ) ;
@@ -1261,4 +1311,43 @@ mod test {
12611311 ]
12621312 ) ;
12631313 }
1314+
1315+ #[ test]
1316+ fn test_set_variable_with_sqlpage_function ( ) {
1317+ let sql = "set x = sqlpage.url_encode(some_db_function())" ;
1318+ for & ( dialect, db_kind) in ALL_DIALECTS {
1319+ let mut parser = Parser :: new ( dialect) . try_with_sql ( sql) . unwrap ( ) ;
1320+ let stmt = parse_single_statement ( & mut parser, db_kind, sql) ;
1321+ let Some ( ParsedStatement :: SetVariable {
1322+ variable,
1323+ value :
1324+ StmtWithParams {
1325+ query,
1326+ params,
1327+ delayed_functions,
1328+ json_columns,
1329+ ..
1330+ } ,
1331+ } ) = stmt
1332+ else {
1333+ panic ! ( "for dialect {dialect:?}: {stmt:#?} instead of SetVariable" ) ;
1334+ } ;
1335+ assert_eq ! (
1336+ variable,
1337+ StmtParam :: PostOrGet ( "x" . to_string( ) ) ,
1338+ "{dialect:?}"
1339+ ) ;
1340+ assert_eq ! (
1341+ delayed_functions,
1342+ [ DelayedFunctionCall {
1343+ function: SqlPageFunctionName :: url_encode,
1344+ argument_col_names: vec![ "_sqlpage_f0_a0" . to_string( ) ] ,
1345+ target_col_name: "sqlpage_set_expr" . to_string( )
1346+ } ]
1347+ ) ;
1348+ assert_eq ! ( query, "SELECT some_db_function() AS _sqlpage_f0_a0" ) ;
1349+ assert_eq ! ( params, [ ] ) ;
1350+ assert_eq ! ( json_columns, Vec :: <String >:: new( ) ) ;
1351+ }
1352+ }
12641353}
0 commit comments