Skip to content

Commit 533e2af

Browse files
authored
Improve array support (#106)
1 parent cdc0732 commit 533e2af

File tree

3 files changed

+161
-42
lines changed

3 files changed

+161
-42
lines changed

src/formatter.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ pub(crate) fn format(
7676
formatter.format_no_change(token, &mut formatted_query);
7777
continue;
7878
}
79+
7980
match token.kind {
8081
TokenKind::Whitespace => {
8182
// ignore (we do our own whitespace formatting)
@@ -386,7 +387,8 @@ impl<'a> Formatter<'a> {
386387
];
387388

388389
const ADD_WHITESPACE_BETWEEN: &[TokenKind] = &[TokenKind::CloseParen, TokenKind::Reserved];
389-
390+
const BEFORE_ARRAY: &[TokenKind] =
391+
&[TokenKind::CloseParen, TokenKind::Word, TokenKind::Reserved];
390392
let inlined = self.inline_block.begin_if_possible(self.tokens, self.index);
391393
let previous_non_whitespace_token = self.previous_non_whitespace_token(1);
392394
let fold_in_top_level = !inlined
@@ -405,13 +407,16 @@ impl<'a> Formatter<'a> {
405407
// Take out the preceding space unless there was whitespace there in the original query
406408
// or another opening parens or line comment
407409
let previous_token = self.previous_token(1);
408-
if previous_token.is_none()
409-
|| !PRESERVE_WHITESPACE_FOR.contains(&previous_token.unwrap().kind)
410+
if previous_token.is_none_or(|t| !PRESERVE_WHITESPACE_FOR.contains(&t.kind))
411+
|| previous_non_whitespace_token
412+
.is_some_and(|t| token.value == "[" && BEFORE_ARRAY.contains(&t.kind))
410413
{
411414
self.trim_spaces_end(query);
412415
}
413416

414-
if previous_non_whitespace_token.is_some_and(|t| ADD_WHITESPACE_BETWEEN.contains(&t.kind)) {
417+
if previous_non_whitespace_token
418+
.is_some_and(|t| token.value != "[" && ADD_WHITESPACE_BETWEEN.contains(&t.kind))
419+
{
415420
self.trim_spaces_end(query);
416421
query.push(' ');
417422
}

src/lib.rs

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ pub fn format(query: &str, params: &QueryParams, options: &FormatOptions) -> Str
2727
formatter::format(&tokens, params, options)
2828
}
2929

30+
/// The SQL dialect to use. This affects parsing of special characters.
31+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32+
pub enum Dialect {
33+
/// Generic SQL syntax, most dialect-specific constructs are disabled
34+
Generic,
35+
/// Enables array syntax (`[`, `]`) and operators
36+
PostgreSql,
37+
/// Enables `[bracketed identifiers]` and `@variables`
38+
SQLServer,
39+
}
40+
3041
/// Options for controlling how the library formats SQL
3142
#[derive(Debug, Clone)]
3243
pub struct FormatOptions<'a> {
@@ -68,6 +79,10 @@ pub struct FormatOptions<'a> {
6879
///
6980
/// Default: false,
7081
pub joins_as_top_level: bool,
82+
/// Tell the SQL dialect to use
83+
///
84+
/// Default: Generic
85+
pub dialect: Dialect,
7186
}
7287

7388
impl<'a> Default for FormatOptions<'a> {
@@ -82,6 +97,7 @@ impl<'a> Default for FormatOptions<'a> {
8297
max_inline_arguments: None,
8398
max_inline_top_level: None,
8499
joins_as_top_level: false,
100+
dialect: Dialect::Generic,
85101
}
86102
}
87103
}
@@ -496,7 +512,10 @@ mod tests {
496512
#[test]
497513
fn it_formats_type_specifiers() {
498514
let input = "SELECT id, ARRAY [] :: UUID [] FROM UNNEST($1 :: UUID []) WHERE $1::UUID[] IS NOT NULL;";
499-
let options = FormatOptions::default();
515+
let options = FormatOptions {
516+
dialect: Dialect::PostgreSql,
517+
..Default::default()
518+
};
500519
let expected = indoc!(
501520
"
502521
SELECT
@@ -511,6 +530,66 @@ mod tests {
511530
assert_eq!(format(input, &QueryParams::None, &options), expected);
512531
}
513532

533+
#[test]
534+
fn it_formats_arrays_as_function_arguments() {
535+
let input =
536+
"SELECT array_position(ARRAY['sun','mon','tue', 'wed', 'thu','fri', 'sat'], 'mon');";
537+
let options = FormatOptions {
538+
dialect: Dialect::PostgreSql,
539+
..Default::default()
540+
};
541+
let expected = indoc!(
542+
"
543+
SELECT
544+
array_position(
545+
ARRAY['sun', 'mon', 'tue', 'wed', 'thu', 'fri', 'sat'],
546+
'mon'
547+
);"
548+
);
549+
550+
assert_eq!(format(input, &QueryParams::None, &options), expected);
551+
}
552+
553+
#[test]
554+
fn it_formats_arrays_as_values() {
555+
let input = " INSERT INTO t VALUES('a', ARRAY[0, 1,2,3], ARRAY[['a','b'], ['c' ,'d']]);";
556+
let options = FormatOptions {
557+
dialect: Dialect::PostgreSql,
558+
max_inline_block: 10,
559+
max_inline_top_level: Some(50),
560+
..Default::default()
561+
};
562+
let expected = indoc!(
563+
"
564+
INSERT INTO t
565+
VALUES (
566+
'a',
567+
ARRAY[0, 1, 2, 3],
568+
ARRAY[
569+
['a', 'b'],
570+
['c', 'd']
571+
]
572+
);"
573+
);
574+
575+
assert_eq!(format(input, &QueryParams::None, &options), expected);
576+
}
577+
578+
#[test]
579+
fn it_formats_array_index_notation() {
580+
let input = "SELECT a [ 1 ] + b [ 2 ] [ 5+1 ] > c [3] ;";
581+
let options = FormatOptions {
582+
dialect: Dialect::PostgreSql,
583+
..Default::default()
584+
};
585+
let expected = indoc!(
586+
"
587+
SELECT
588+
a[1] + b[2][5 + 1] > c[3];"
589+
);
590+
591+
assert_eq!(format(input, &QueryParams::None, &options), expected);
592+
}
514593
#[test]
515594
fn it_formats_limit_of_single_value_and_offset() {
516595
let input = "LIMIT 5 OFFSET 8;";
@@ -1349,7 +1428,10 @@ mod tests {
13491428
#[test]
13501429
fn it_recognizes_bracketed_strings() {
13511430
let inputs = ["[foo JOIN bar]", "[foo ]] JOIN bar]"];
1352-
let options = FormatOptions::default();
1431+
let options = FormatOptions {
1432+
dialect: Dialect::SQLServer,
1433+
..Default::default()
1434+
};
13531435
for input in &inputs {
13541436
assert_eq!(&format(input, &QueryParams::None, &options), input);
13551437
}
@@ -1359,7 +1441,10 @@ mod tests {
13591441
fn it_recognizes_at_variables() {
13601442
let input =
13611443
"SELECT @variable, @a1_2.3$, @'var name', @\"var name\", @`var name`, @[var name];";
1362-
let options = FormatOptions::default();
1444+
let options = FormatOptions {
1445+
dialect: Dialect::SQLServer,
1446+
..Default::default()
1447+
};
13631448
let expected = indoc!(
13641449
"
13651450
SELECT
@@ -1384,7 +1469,10 @@ mod tests {
13841469
("var name".to_string(), "'var value'".to_string()),
13851470
("var\\name".to_string(), "'var\\ value'".to_string()),
13861471
];
1387-
let options = FormatOptions::default();
1472+
let options = FormatOptions {
1473+
dialect: Dialect::SQLServer,
1474+
..Default::default()
1475+
};
13881476
let expected = indoc!(
13891477
"
13901478
SELECT
@@ -1407,7 +1495,10 @@ mod tests {
14071495
fn it_recognizes_colon_variables() {
14081496
let input =
14091497
"SELECT :variable, :a1_2.3$, :'var name', :\"var name\", :`var name`, :[var name];";
1410-
let options = FormatOptions::default();
1498+
let options = FormatOptions {
1499+
dialect: Dialect::SQLServer,
1500+
..Default::default()
1501+
};
14111502
let expected = indoc!(
14121503
"
14131504
SELECT
@@ -1440,7 +1531,10 @@ mod tests {
14401531
"'super weird value'".to_string(),
14411532
),
14421533
];
1443-
let options = FormatOptions::default();
1534+
let options = FormatOptions {
1535+
dialect: Dialect::SQLServer,
1536+
..Default::default()
1537+
};
14441538
let expected = indoc!(
14451539
"
14461540
SELECT

src/tokenizer.rs

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use winnow::prelude::*;
88
use winnow::token::{any, one_of, rest, take, take_until, take_while};
99
use winnow::Result;
1010

11-
use crate::FormatOptions;
11+
use crate::{Dialect, FormatOptions};
1212

1313
pub(crate) fn tokenize<'a>(
1414
mut input: &'a str,
@@ -32,6 +32,7 @@ pub(crate) fn tokenize<'a>(
3232
last_reserved_token.clone(),
3333
last_reserved_top_level_token.clone(),
3434
named_placeholders,
35+
options.dialect,
3536
) {
3637
match result.kind {
3738
TokenKind::Reserved => {
@@ -124,13 +125,14 @@ fn get_next_token<'a>(
124125
last_reserved_token: Option<Token<'a>>,
125126
last_reserved_top_level_token: Option<Token<'a>>,
126127
named_placeholders: bool,
128+
dialect: Dialect,
127129
) -> Result<Token<'a>> {
128130
alt((
129131
get_comment_token,
130132
|input: &mut _| get_type_specifier_token(input, previous_token.clone()),
131-
get_string_token,
132-
get_open_paren_token,
133-
get_close_paren_token,
133+
|input: &mut _| get_string_token(input, dialect),
134+
|input: &mut _| get_open_paren_token(input, dialect),
135+
|input: &mut _| get_close_paren_token(input, dialect),
134136
get_number_token,
135137
|input: &mut _| {
136138
get_reserved_word_token(
@@ -141,7 +143,7 @@ fn get_next_token<'a>(
141143
)
142144
},
143145
get_operator_token,
144-
|input: &mut _| get_placeholder_token(input, named_placeholders),
146+
|input: &mut _| get_placeholder_token(input, named_placeholders, dialect),
145147
get_word_token,
146148
get_any_other_char,
147149
))
@@ -238,10 +240,10 @@ pub fn take_till_escaping<'a>(
238240
// 4. single quoted string using '' or \' to escape
239241
// 5. national character quoted string using N'' or N\' to escape
240242
// 6. hex(blob literal) does not need to escape
241-
fn get_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
243+
fn get_string_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
242244
dispatch! {any;
243245
'`' => (take_till_escaping('`', &['`']), any).void(),
244-
'[' => (take_till_escaping(']', &[']']), any).void(),
246+
'[' if dialect == Dialect::SQLServer => (take_till_escaping(']', &[']']), any).void(),
245247
'"' => (take_till_escaping('"', &['"', '\\']), any).void(),
246248
'\'' => (take_till_escaping('\'', &['\'', '\\']), any).void(),
247249
'N' => ('\'', take_till_escaping('\'', &['\'', '\\']), any).void(),
@@ -261,10 +263,10 @@ fn get_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
261263
}
262264

263265
// Like above but it doesn't replace double quotes
264-
fn get_placeholder_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
266+
fn get_placeholder_string_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
265267
dispatch! {any;
266268
'`'=>( take_till_escaping('`', &['`']), any).void(),
267-
'['=>( take_till_escaping(']', &[']']), any).void(),
269+
'[' if dialect == Dialect::SQLServer =>( take_till_escaping(']', &[']']), any).void(),
268270
'"'=>( take_till_escaping('"', &['\\']), any).void(),
269271
'\''=>( take_till_escaping('\'', &['\\']), any).void(),
270272
'N' =>('\'', take_till_escaping('\'', &['\\']), any).void(),
@@ -280,44 +282,57 @@ fn get_placeholder_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
280282
})
281283
}
282284

283-
fn get_open_paren_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
284-
alt(("(", terminated(Caseless("CASE"), end_of_word)))
285-
.parse_next(input)
286-
.map(|token| Token {
287-
kind: TokenKind::OpenParen,
288-
value: token,
289-
key: None,
290-
alias: token,
291-
})
285+
fn get_open_paren_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
286+
let case = terminated(Caseless("CASE"), end_of_word);
287+
let open_paren = if dialect == Dialect::PostgreSql {
288+
("(", "[", case)
289+
} else {
290+
("(", "(", case)
291+
};
292+
293+
alt(open_paren).parse_next(input).map(|token| Token {
294+
kind: TokenKind::OpenParen,
295+
value: token,
296+
key: None,
297+
alias: token,
298+
})
292299
}
293300

294-
fn get_close_paren_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
295-
alt((")", terminated(Caseless("END"), end_of_word)))
296-
.parse_next(input)
297-
.map(|token| Token {
298-
kind: TokenKind::CloseParen,
299-
value: token,
300-
key: None,
301-
alias: token,
302-
})
301+
fn get_close_paren_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
302+
let end = terminated(Caseless("END"), end_of_word);
303+
let close_paren = if dialect == Dialect::PostgreSql {
304+
(")", "]", end)
305+
} else {
306+
(")", ")", end)
307+
};
308+
alt(close_paren).parse_next(input).map(|token| Token {
309+
kind: TokenKind::CloseParen,
310+
value: token,
311+
key: None,
312+
alias: token,
313+
})
303314
}
304315

305-
fn get_placeholder_token<'i>(input: &mut &'i str, named_placeholders: bool) -> Result<Token<'i>> {
316+
fn get_placeholder_token<'i>(
317+
input: &mut &'i str,
318+
named_placeholders: bool,
319+
dialect: Dialect,
320+
) -> Result<Token<'i>> {
306321
// The precedence changes based on 'named_placeholders' but not the exhaustiveness.
307322
// This is to ensure the formatting is the same even if parameters aren't used.
308323

309324
if named_placeholders {
310325
alt((
311326
get_ident_named_placeholder_token,
312-
get_string_named_placeholder_token,
327+
|input: &mut _| get_string_named_placeholder_token(input, dialect),
313328
get_indexed_placeholder_token,
314329
))
315330
.parse_next(input)
316331
} else {
317332
alt((
318333
get_indexed_placeholder_token,
319334
get_ident_named_placeholder_token,
320-
get_string_named_placeholder_token,
335+
|input: &mut _| get_string_named_placeholder_token(input, dialect),
321336
))
322337
.parse_next(input)
323338
}
@@ -366,8 +381,13 @@ fn get_ident_named_placeholder_token<'i>(input: &mut &'i str) -> Result<Token<'i
366381
})
367382
}
368383

369-
fn get_string_named_placeholder_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
370-
(one_of(('@', ':')), get_placeholder_string_token)
384+
fn get_string_named_placeholder_token<'i>(
385+
input: &mut &'i str,
386+
dialect: Dialect,
387+
) -> Result<Token<'i>> {
388+
(one_of(('@', ':')), |input: &mut _| {
389+
get_placeholder_string_token(input, dialect)
390+
})
371391
.take()
372392
.parse_next(input)
373393
.map(|token| {

0 commit comments

Comments
 (0)