Skip to content

Commit 4f8d21c

Browse files
committed
fix #701
1 parent 1c5fb16 commit 4f8d21c

File tree

8 files changed

+171
-24
lines changed

8 files changed

+171
-24
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
- Fixed layout issues in the card component when embedding content with `embed`: remove double border and padding.
4141
- ![embedded card screenshot](https://github.com/user-attachments/assets/ea85438d-5fcb-4eed-b90b-a4385675355d)
4242
- Added support for `empty_option` in the form component to add an empty option before the options defined in `options`. Useful when generating other options from a database table.
43+
- Allow nested json objects and arrays as sqlpage function parameters (useful in `sqlpage.fetch`).
4344

4445
## 0.30.1 (2024-10-31)
4546
- fix a bug where table sorting would break if table search was not also enabled.

src/webserver/database/execute_queries.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ pub fn stop_at_first_error(
9696
results_stream.take_while(move |item| {
9797
// We stop the stream AFTER the first error, so that the error is still returned to the client, but the rest of the queries are not executed.
9898
let should_continue = !has_error;
99-
has_error |= matches!(item, DbItem::Error(_));
99+
if let DbItem::Error(err) = item {
100+
log::error!("{err:?}");
101+
has_error = true;
102+
}
100103
futures_util::future::ready(should_continue)
101104
})
102105
}
@@ -167,7 +170,6 @@ async fn execute_set_variable_query<'a>(
167170
Ok(None) => None,
168171
Err(e) => {
169172
let err = display_db_error(source_file, &statement.query, e);
170-
log::error!("{err}");
171173
return Err(err);
172174
}
173175
};
@@ -244,7 +246,6 @@ fn parse_single_sql_result(
244246
}
245247
Err(err) => {
246248
let nice_err = display_db_error(source_file, sql, err);
247-
log::error!("{:?}", nice_err);
248249
DbItem::Error(nice_err)
249250
}
250251
}

src/webserver/database/sql.rs

Lines changed: 102 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,8 @@ fn parse_single_statement(
144144
semicolon = true;
145145
}
146146
let mut params = ParameterExtractor::extract_parameters(&mut stmt, db_kind);
147-
if let Some((variable, query)) = extract_set_variable(&mut stmt) {
148-
return Some(ParsedStatement::SetVariable {
149-
variable,
150-
value: StmtWithParams {
151-
query,
152-
params,
153-
delayed_functions: Vec::new(),
154-
json_columns: Vec::new(),
155-
},
156-
});
147+
if let Some((variable, value)) = extract_set_variable(&mut stmt, &mut params, db_kind) {
148+
return Some(ParsedStatement::SetVariable { variable, value });
157149
}
158150
if let Some(csv_import) = extract_csv_copy_statement(&mut stmt) {
159151
return Some(ParsedStatement::CsvImport(csv_import));
@@ -406,7 +398,11 @@ fn is_simple_select_placeholder(e: &Expr) -> bool {
406398
}
407399
}
408400

409-
fn extract_set_variable(stmt: &mut Statement) -> Option<(StmtParam, String)> {
401+
fn extract_set_variable(
402+
stmt: &mut Statement,
403+
params: &mut Vec<StmtParam>,
404+
db_kind: AnyKind,
405+
) -> Option<(StmtParam, StmtWithParams)> {
410406
if let Statement::SetVariable {
411407
variables: OneOrManyWithParens::One(ObjectName(name)),
412408
value,
@@ -420,7 +416,20 @@ fn extract_set_variable(stmt: &mut Statement) -> Option<(StmtParam, String)> {
420416
} else {
421417
StmtParam::PostOrGet(std::mem::take(&mut ident.value))
422418
};
423-
return Some((variable, format!("SELECT {value}")));
419+
let owned_expr = std::mem::replace(value, Expr::Value(Value::Null));
420+
let mut select_stmt: Statement = expr_to_statement(owned_expr);
421+
let delayed_functions = extract_toplevel_functions(&mut select_stmt);
422+
remove_invalid_function_calls(&mut select_stmt, params);
423+
let json_columns = extract_json_columns(&select_stmt, db_kind);
424+
return Some((
425+
variable,
426+
StmtWithParams {
427+
query: select_stmt.to_string(),
428+
params: std::mem::take(params),
429+
delayed_functions,
430+
json_columns,
431+
},
432+
));
424433
}
425434
}
426435
None
@@ -862,6 +871,47 @@ fn is_json_function(expr: &Expr) -> bool {
862871
}
863872
}
864873

874+
fn expr_to_statement(expr: Expr) -> Statement {
875+
Statement::Query(Box::new(sqlparser::ast::Query {
876+
with: None,
877+
body: Box::new(sqlparser::ast::SetExpr::Select(Box::new(
878+
sqlparser::ast::Select {
879+
distinct: None,
880+
top: None,
881+
projection: vec![SelectItem::ExprWithAlias {
882+
expr,
883+
alias: Ident::new("sqlpage_set_expr"),
884+
}],
885+
into: None,
886+
from: vec![],
887+
lateral_views: vec![],
888+
selection: None,
889+
group_by: sqlparser::ast::GroupByExpr::Expressions(vec![], vec![]),
890+
cluster_by: vec![],
891+
distribute_by: vec![],
892+
sort_by: vec![],
893+
having: None,
894+
named_window: vec![],
895+
qualify: None,
896+
top_before_distinct: false,
897+
prewhere: None,
898+
window_before_qualify: false,
899+
value_table_mode: None,
900+
connect_by: None,
901+
},
902+
))),
903+
order_by: None,
904+
limit: None,
905+
offset: None,
906+
fetch: None,
907+
locks: vec![],
908+
limit_by: vec![],
909+
for_clause: None,
910+
settings: None,
911+
format_clause: None,
912+
}))
913+
}
914+
865915
#[cfg(test)]
866916
mod test {
867917
use super::super::sqlpage_functions::functions::SqlPageFunctionName;
@@ -1169,7 +1219,7 @@ mod test {
11691219
StmtParam::PostOrGet("x".to_string()),
11701220
"{dialect:?}"
11711221
);
1172-
assert_eq!(query, "SELECT 42");
1222+
assert_eq!(query, "SELECT 42 AS sqlpage_set_expr");
11731223
assert!(params.is_empty());
11741224
} else {
11751225
panic!("Failed for dialect {dialect:?}: {stmt:#?}",);
@@ -1261,4 +1311,43 @@ mod test {
12611311
]
12621312
);
12631313
}
1314+
1315+
#[test]
1316+
fn test_set_variable_with_sqlpage_function() {
1317+
let sql = "set x = sqlpage.url_encode(some_db_function())";
1318+
for &(dialect, db_kind) in ALL_DIALECTS {
1319+
let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap();
1320+
let stmt = parse_single_statement(&mut parser, db_kind, sql);
1321+
let Some(ParsedStatement::SetVariable {
1322+
variable,
1323+
value:
1324+
StmtWithParams {
1325+
query,
1326+
params,
1327+
delayed_functions,
1328+
json_columns,
1329+
..
1330+
},
1331+
}) = stmt
1332+
else {
1333+
panic!("for dialect {dialect:?}: {stmt:#?} instead of SetVariable");
1334+
};
1335+
assert_eq!(
1336+
variable,
1337+
StmtParam::PostOrGet("x".to_string()),
1338+
"{dialect:?}"
1339+
);
1340+
assert_eq!(
1341+
delayed_functions,
1342+
[DelayedFunctionCall {
1343+
function: SqlPageFunctionName::url_encode,
1344+
argument_col_names: vec!["_sqlpage_f0_a0".to_string()],
1345+
target_col_name: "sqlpage_set_expr".to_string()
1346+
}]
1347+
);
1348+
assert_eq!(query, "SELECT some_db_function() AS _sqlpage_f0_a0");
1349+
assert_eq!(params, []);
1350+
assert_eq!(json_columns, Vec::<String>::new());
1351+
}
1352+
}
12641353
}

src/webserver/database/sqlpage_functions/http_fetch_request.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use anyhow::Context;
2+
13
use super::function_traits::BorrowFromStr;
24
use std::borrow::Cow;
35

@@ -55,10 +57,11 @@ impl<'a> BorrowFromStr<'a> for HttpFetchRequest<'a> {
5557
}
5658
} else {
5759
match s {
58-
Cow::Borrowed(s) => serde_json::from_str(s)?,
59-
Cow::Owned(s) => serde_json::from_str::<HttpFetchRequest<'_>>(&s)
60-
.map(HttpFetchRequest::into_owned)?,
60+
Cow::Borrowed(s) => serde_json::from_str(s),
61+
Cow::Owned(ref s) => serde_json::from_str::<HttpFetchRequest<'_>>(s)
62+
.map(HttpFetchRequest::into_owned),
6163
}
64+
.with_context(|| format!("Invalid http fetch request definition: {s}"))?
6265
})
6366
}
6467
}

src/webserver/database/syntax_tree.rs

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,23 @@ async fn json_object_params<'a, 'b>(
203203
let val = it
204204
.next()
205205
.ok_or_else(|| anyhow::anyhow!("Odd number of arguments in JSON_OBJECT"))?;
206-
let val = Box::pin(extract_req_param(val, request, db_connection)).await?;
207-
map_ser.serialize_value(&val)?;
206+
207+
match val {
208+
StmtParam::JsonObject(args) => {
209+
let raw_json = Box::pin(json_object_params(args, request, db_connection)).await?;
210+
let obj = cow_to_raw_json(&raw_json);
211+
map_ser.serialize_value(&obj)?;
212+
}
213+
StmtParam::JsonArray(args) => {
214+
let raw_json = Box::pin(json_array_params(args, request, db_connection)).await?;
215+
let obj = cow_to_raw_json(&raw_json);
216+
map_ser.serialize_value(&obj)?;
217+
}
218+
val => {
219+
let evaluated = Box::pin(extract_req_param(val, request, db_connection)).await?;
220+
map_ser.serialize_value(&evaluated)?;
221+
}
222+
};
208223
}
209224
map_ser.end()?;
210225
Ok(Some(Cow::Owned(String::from_utf8(result)?)))
@@ -220,9 +235,33 @@ async fn json_array_params<'a, 'b>(
220235
let mut ser = serde_json::Serializer::new(&mut result);
221236
let mut seq_ser = ser.serialize_seq(Some(args.len()))?;
222237
for element in args {
223-
let element = Box::pin(extract_req_param(element, request, db_connection)).await?;
224-
seq_ser.serialize_element(&element)?;
238+
match element {
239+
StmtParam::JsonObject(args) => {
240+
let raw_json = json_object_params(args, request, db_connection).await?;
241+
let obj = cow_to_raw_json(&raw_json);
242+
seq_ser.serialize_element(&obj)?;
243+
}
244+
StmtParam::JsonArray(args) => {
245+
let raw_json = Box::pin(json_array_params(args, request, db_connection)).await?;
246+
let obj = cow_to_raw_json(&raw_json);
247+
seq_ser.serialize_element(&obj)?;
248+
}
249+
element => {
250+
let evaluated =
251+
Box::pin(extract_req_param(element, request, db_connection)).await?;
252+
seq_ser.serialize_element(&evaluated)?;
253+
}
254+
};
225255
}
226256
seq_ser.end()?;
227257
Ok(Some(Cow::Owned(String::from_utf8(result)?)))
228258
}
259+
260+
fn cow_to_raw_json<'a>(
261+
raw_json: &'a Option<Cow<'a, str>>,
262+
) -> Option<&'a serde_json::value::RawValue> {
263+
raw_json
264+
.as_deref()
265+
.map(serde_json::from_str::<&'a serde_json::value::RawValue>)
266+
.map(Result::unwrap)
267+
}

tests/sql_test_files/it_works_native_json_array_impl.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ set res = sqlpage.fetch(json_object(
44
'headers', json_object('x-custom', '1'),
55
'body', json_array('hello', 'world')
66
));
7-
set expected = 'POST /post|accept-encoding: br, gzip, deflate, zstd|content-length: 18|content-type: application/json|host: localhost:62802|user-agent: sqlpage|x-custom: 1|["hello", "world"]';
7+
set expected = 'POST /post|accept-encoding: br, gzip, deflate, zstd|content-length: 17|content-type: application/json|host: localhost:62802|user-agent: sqlpage|x-custom: 1|["hello","world"]';
88
select 'text' as component,
99
case $res
1010
when $expected then 'It works !'
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
set my_var = sqlpage.url_encode(' ');
2+
select 'text' as component,
3+
CASE $my_var
4+
WHEN '%20' THEN 'It works !'
5+
ELSE 'It failed !'
6+
END
7+
AS contents;
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
set my_var = sqlpage.url_encode(UPPER('a'));
2+
select 'text' as component,
3+
CASE $my_var
4+
WHEN 'A' THEN 'It works !'
5+
ELSE 'It failed !'
6+
END
7+
AS contents;

0 commit comments

Comments
 (0)