Skip to content

Commit 7bf4aff

Browse files
committed
allow setting variables with "set x = y" instead of "set $x = y"
1 parent 9013f44 commit 7bf4aff

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
- **variables** . SQLPage now support setting and reusing variables between statements. This allows you to write more complex SQL queries, and to reuse the result of a query in multiple places.
66
```sql
77
-- Set a variable
8-
SET $person = 'Alice';
8+
SET person = 'Alice';
99
-- Use it in a query
1010
SELECT 'text' AS component, 'Hello ' || $person AS contents;
1111
```

src/webserver/database/sql.rs

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,13 @@ impl AsyncFromStrWithState for ParsedSqlFile {
9797
}
9898
}
9999

100+
#[derive(Debug, PartialEq)]
100101
struct StmtWithParams {
101102
query: String,
102103
params: Vec<StmtParam>,
103104
}
104105

106+
#[derive(Debug)]
105107
enum ParsedStatement {
106108
StmtWithParams(StmtWithParams),
107109
StaticSimpleSelect(serde_json::Map<String, serde_json::Value>),
@@ -268,10 +270,12 @@ fn extract_set_variable(stmt: &mut Statement) -> Option<(StmtParam, String)> {
268270
} = stmt
269271
{
270272
if let ([ident], [value]) = (name.as_mut_slice(), value.as_mut_slice()) {
271-
if let Some(variable) = extract_ident_param(ident) {
272-
let query = format!("SELECT {value}");
273-
return Some((variable, query));
274-
}
273+
let variable = if let Some(variable) = extract_ident_param(ident) {
274+
variable
275+
} else {
276+
StmtParam::GetOrPost(std::mem::take(&mut ident.value))
277+
};
278+
return Some((variable, format!("SELECT {value}")));
275279
}
276280
}
277281
None
@@ -439,8 +443,8 @@ pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String {
439443
DEFAULT_PLACEHOLDER.to_string()
440444
}
441445

442-
fn extract_ident_param(Ident { value, quote_style }: &mut Ident) -> Option<StmtParam> {
443-
if quote_style.is_none() && value.starts_with('$') || value.starts_with(':') {
446+
fn extract_ident_param(Ident { value, .. }: &mut Ident) -> Option<StmtParam> {
447+
if value.starts_with('$') || value.starts_with(':') {
444448
let name = std::mem::take(value);
445449
Some(map_param(name))
446450
} else {
@@ -599,6 +603,30 @@ mod test {
599603
}
600604
}
601605

606+
#[test]
607+
fn test_set_variable() {
608+
let sql = "set x = $y";
609+
for &(dialect, db_kind) in ALL_DIALECTS {
610+
let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap();
611+
let stmt = parse_single_statement(&mut parser, db_kind);
612+
if let Some(ParsedStatement::SetVariable {
613+
variable,
614+
value: StmtWithParams { query, params },
615+
}) = stmt
616+
{
617+
assert_eq!(
618+
variable,
619+
StmtParam::GetOrPost("x".to_string()),
620+
"{dialect:?}"
621+
);
622+
assert!(query.starts_with("SELECT "));
623+
assert_eq!(params, [StmtParam::GetOrPost("y".to_string())]);
624+
} else {
625+
panic!("Failed for dialect {dialect:?}: {stmt:#?}",);
626+
}
627+
}
628+
}
629+
602630
#[test]
603631
fn is_own_placeholder() {
604632
assert!(ParameterExtractor {

tests/test_set_variable.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
set $what_does_it_do = 'wo' || 'rks';
1+
set what_does_it_do = 'wo' || 'rks';
22
select 'text' as component, 'It ' || $what_does_it_do || ' !' as contents;

0 commit comments

Comments
 (0)