Skip to content

Commit 0328665

Browse files
committed
$parameter handling in sqlite
1 parent 9827add commit 0328665

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

src/webserver/database/sql.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,14 @@ impl VisitorMut for ParameterExtractor {
378378
type Break = ();
379379
fn pre_visit_expr(&mut self, value: &mut Expr) -> ControlFlow<Self::Break> {
380380
match value {
381+
Expr::Identifier(Ident {
382+
value: var_name,
383+
quote_style: None,
384+
}) if var_name.starts_with('$') || var_name.starts_with(':') => {
385+
let name = std::mem::take(var_name);
386+
*value = self.make_placeholder();
387+
self.parameters.push(map_param(name));
388+
}
381389
Expr::Value(Value::Placeholder(param)) if !self.is_own_placeholder(param) =>
382390
// this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves
383391
{
@@ -480,6 +488,56 @@ mod test {
480488
);
481489
}
482490

491+
#[test]
492+
fn test_statement_rewrite_sqlite() {
493+
let mut ast = parse_stmt("select $x, :y from t", SQLiteDialect {});
494+
let parameters = ParameterExtractor::extract_parameters(&mut ast, AnyKind::Sqlite);
495+
assert_eq!(
496+
ast.to_string(),
497+
"SELECT CAST(? AS VARCHAR), CAST(? AS VARCHAR) FROM t"
498+
);
499+
assert_eq!(
500+
parameters,
501+
[
502+
StmtParam::GetOrPost("x".to_string()),
503+
StmtParam::Post("y".to_string()),
504+
]
505+
);
506+
}
507+
508+
#[test]
509+
fn is_own_placeholder() {
510+
assert!(ParameterExtractor {
511+
db_kind: AnyKind::Postgres,
512+
parameters: vec![]
513+
}
514+
.is_own_placeholder("$1"));
515+
516+
assert!(ParameterExtractor {
517+
db_kind: AnyKind::Postgres,
518+
parameters: vec![StmtParam::Get('x'.to_string())]
519+
}
520+
.is_own_placeholder("$2"));
521+
522+
assert!(!ParameterExtractor {
523+
db_kind: AnyKind::Postgres,
524+
parameters: vec![]
525+
}
526+
.is_own_placeholder("$2"));
527+
528+
assert!(ParameterExtractor {
529+
db_kind: AnyKind::Sqlite,
530+
parameters: vec![]
531+
}
532+
.is_own_placeholder("?"));
533+
534+
assert!(!ParameterExtractor {
535+
db_kind: AnyKind::Sqlite,
536+
parameters: vec![]
537+
}
538+
.is_own_placeholder("$1"));
539+
}
540+
483541
#[test]
484542
fn test_mssql_statement_rewrite() {
485543
let mut ast = parse_stmt("select '' || $1 from [a schema].[a table]", MsSqlDialect {});

0 commit comments

Comments
 (0)