@@ -4,8 +4,8 @@ use crate::file_cache::AsyncFromStrWithState;
44use crate :: { AppState , Database } ;
55use async_trait:: async_trait;
66use sqlparser:: ast:: {
7- DataType , Expr , Function , FunctionArg , FunctionArgExpr , Ident , ObjectName , Statement , Value ,
8- VisitMut , VisitorMut ,
7+ BinaryOperator , DataType , Expr , Function , FunctionArg , FunctionArgExpr , Ident , ObjectName ,
8+ Statement , Value , VisitMut , VisitorMut ,
99} ;
1010use sqlparser:: dialect:: GenericDialect ;
1111use sqlparser:: parser:: { Parser , ParserError } ;
@@ -182,6 +182,10 @@ struct ParameterExtractor {
182182 parameters : Vec < StmtParam > ,
183183}
184184
185+ const PLACEHOLDER_PREFIXES : [ ( AnyKind , & str ) ; 2 ] =
186+ [ ( AnyKind :: Postgres , "$" ) , ( AnyKind :: Mssql , "@p" ) ] ;
187+ const DEFAULT_PLACEHOLDER : & str = "?" ;
188+
185189impl ParameterExtractor {
186190 fn extract_parameters (
187191 sql_ast : & mut sqlparser:: ast:: Statement ,
@@ -200,7 +204,8 @@ impl ParameterExtractor {
200204 let data_type = match self . db_kind {
201205 // MySQL requires CAST(? AS CHAR) and does not understand CAST(? AS TEXT)
202206 AnyKind :: MySql => DataType :: Char ( None ) ,
203- _ => DataType :: Text ,
207+ AnyKind :: Postgres => DataType :: Text ,
208+ _ => DataType :: Varchar ( None ) ,
204209 } ;
205210 let value = Expr :: Value ( Value :: Placeholder ( name) ) ;
206211 Expr :: Cast {
@@ -220,6 +225,21 @@ impl ParameterExtractor {
220225 self . parameters . push ( param) ;
221226 placeholder
222227 }
228+
229+ fn is_own_placeholder ( & self , param : & str ) -> bool {
230+ if let Some ( ( _, prefix) ) = PLACEHOLDER_PREFIXES
231+ . iter ( )
232+ . find ( |( kind, _prefix) | * kind == self . db_kind )
233+ {
234+ if let Some ( param) = param. strip_prefix ( prefix) {
235+ if let Ok ( index) = param. parse :: < usize > ( ) {
236+ return index <= self . parameters . len ( ) + 1 ;
237+ }
238+ }
239+ return false ;
240+ }
241+ param == DEFAULT_PLACEHOLDER
242+ }
223243}
224244
225245/** This is a helper struct to format a list of arguments for an error message. */
@@ -299,21 +319,20 @@ fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> {
299319
300320#[ inline]
301321pub fn make_placeholder ( db_kind : AnyKind , arg_number : usize ) -> String {
302- match db_kind {
303- // Postgres only supports numbered parameters with $1, $2, etc.
304- AnyKind :: Postgres => format ! ( "${arg_number}" ) ,
305- // MSSQL only supports named parameters with @p1, @p2, etc.
306- AnyKind :: Mssql => format ! ( "@p{arg_number}" ) ,
307- _ => '?' . to_string ( ) ,
322+ if let Some ( ( _, prefix) ) = PLACEHOLDER_PREFIXES
323+ . iter ( )
324+ . find ( |( kind, _) | * kind == db_kind)
325+ {
326+ return format ! ( "{prefix}{arg_number}" ) ;
308327 }
328+ DEFAULT_PLACEHOLDER . to_string ( )
309329}
310330
311331impl VisitorMut for ParameterExtractor {
312332 type Break = ( ) ;
313333 fn pre_visit_expr ( & mut self , value : & mut Expr ) -> ControlFlow < Self :: Break > {
314334 match value {
315- Expr :: Value ( Value :: Placeholder ( param) )
316- if param. chars ( ) . nth ( 1 ) . is_some_and ( char:: is_alphabetic) =>
335+ Expr :: Value ( Value :: Placeholder ( param) ) if !self . is_own_placeholder ( param) =>
317336 // this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves
318337 {
319338 let new_expr = self . make_placeholder ( ) ;
@@ -334,6 +353,26 @@ impl VisitorMut for ParameterExtractor {
334353 let arguments = std:: mem:: take ( args) ;
335354 * value = self . handle_builtin_function ( func_name, arguments) ;
336355 }
356+ // Replace 'str1' || 'str2' with CONCAT('str1', 'str2') for MSSQL
357+ Expr :: BinaryOp {
358+ left,
359+ op : BinaryOperator :: StringConcat ,
360+ right,
361+ } if self . db_kind == AnyKind :: Mssql => {
362+ let left = std:: mem:: replace ( left. as_mut ( ) , Expr :: Value ( Value :: Null ) ) ;
363+ let right = std:: mem:: replace ( right. as_mut ( ) , Expr :: Value ( Value :: Null ) ) ;
364+ * value = Expr :: Function ( Function {
365+ name : ObjectName ( vec ! [ Ident :: new( "CONCAT" ) ] ) ,
366+ args : vec ! [
367+ FunctionArg :: Unnamed ( FunctionArgExpr :: Expr ( left) ) ,
368+ FunctionArg :: Unnamed ( FunctionArgExpr :: Expr ( right) ) ,
369+ ] ,
370+ over : None ,
371+ distinct : false ,
372+ special : false ,
373+ order_by : vec ! [ ] ,
374+ } ) ;
375+ }
337376 _ => ( ) ,
338377 }
339378 ControlFlow :: < ( ) > :: Continue ( ( ) )
@@ -390,6 +429,17 @@ mod test {
390429 ) ;
391430 }
392431
432+ #[ test]
433+ fn test_mssql_statement_rewrite ( ) {
434+ let mut ast = parse_stmt ( "select '' || $1 from t" ) ;
435+ let parameters = ParameterExtractor :: extract_parameters ( & mut ast, AnyKind :: Mssql ) ;
436+ assert_eq ! (
437+ ast. to_string( ) ,
438+ "SELECT CONCAT('', CAST(@p1 AS VARCHAR)) FROM t"
439+ ) ;
440+ assert_eq ! ( parameters, [ StmtParam :: GetOrPost ( "1" . to_string( ) ) , ] ) ;
441+ }
442+
393443 #[ test]
394444 fn test_static_extract ( ) {
395445 assert_eq ! (
0 commit comments