From acbf5e7912384ec1d0231ae2e8fb4488bbc7c75c Mon Sep 17 00:00:00 2001 From: Luca Barbato Date: Tue, 30 Sep 2025 22:28:46 +0200 Subject: [PATCH 1/4] Initial support for dialects --- src/lib.rs | 41 ++++++++++++++++++++--- src/tokenizer.rs | 84 ++++++++++++++++++++++++++++++------------------ 2 files changed, 88 insertions(+), 37 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9734e70..126e5f6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,17 @@ pub fn format(query: &str, params: &QueryParams, options: &FormatOptions) -> Str formatter::format(&tokens, params, options) } +/// The SQL dialect to use +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Dialect { + /// Best effort, most dialect-specific constructs are disabled + Generic, + /// It considers array notations + PostgreSql, + /// It uses the `[brakets to quote]` notation + SQLServer, +} + /// Options for controlling how the library formats SQL #[derive(Debug, Clone)] pub struct FormatOptions<'a> { @@ -68,6 +79,10 @@ pub struct FormatOptions<'a> { /// /// Default: false, pub joins_as_top_level: bool, + /// Tell the SQL dialect to use + /// + /// Default: Generic + pub dialect: Dialect, } impl<'a> Default for FormatOptions<'a> { @@ -82,6 +97,7 @@ impl<'a> Default for FormatOptions<'a> { max_inline_arguments: None, max_inline_top_level: None, joins_as_top_level: false, + dialect: Dialect::Generic, } } } @@ -1328,7 +1344,10 @@ mod tests { #[test] fn it_recognizes_bracketed_strings() { let inputs = ["[foo JOIN bar]", "[foo ]] JOIN bar]"]; - let options = FormatOptions::default(); + let options = FormatOptions { + dialect: Dialect::SQLServer, + ..Default::default() + }; for input in &inputs { assert_eq!(&format(input, &QueryParams::None, &options), input); } @@ -1338,7 +1357,10 @@ mod tests { fn it_recognizes_at_variables() { let input = "SELECT @variable, @a1_2.3$, @'var name', @\"var name\", @`var name`, @[var name];"; - let options = FormatOptions::default(); + let options = FormatOptions { + dialect: Dialect::SQLServer, + ..Default::default() + }; let expected = indoc!( " SELECT @@ -1363,7 +1385,10 @@ mod tests { ("var name".to_string(), "'var value'".to_string()), ("var\\name".to_string(), "'var\\ value'".to_string()), ]; - let options = FormatOptions::default(); + let options = FormatOptions { + dialect: Dialect::SQLServer, + ..Default::default() + }; let expected = indoc!( " SELECT @@ -1386,7 +1411,10 @@ mod tests { fn it_recognizes_colon_variables() { let input = "SELECT :variable, :a1_2.3$, :'var name', :\"var name\", :`var name`, :[var name];"; - let options = FormatOptions::default(); + let options = FormatOptions { + dialect: Dialect::SQLServer, + ..Default::default() + }; let expected = indoc!( " SELECT @@ -1419,7 +1447,10 @@ mod tests { "'super weird value'".to_string(), ), ]; - let options = FormatOptions::default(); + let options = FormatOptions { + dialect: Dialect::SQLServer, + ..Default::default() + }; let expected = indoc!( " SELECT diff --git a/src/tokenizer.rs b/src/tokenizer.rs index de125dc..cf75545 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -8,7 +8,7 @@ use winnow::prelude::*; use winnow::token::{any, one_of, rest, take, take_until, take_while}; use winnow::Result; -use crate::FormatOptions; +use crate::{Dialect, FormatOptions}; pub(crate) fn tokenize<'a>( mut input: &'a str, @@ -32,6 +32,7 @@ pub(crate) fn tokenize<'a>( last_reserved_token.clone(), last_reserved_top_level_token.clone(), named_placeholders, + options.dialect, ) { match result.kind { TokenKind::Reserved => { @@ -124,13 +125,14 @@ fn get_next_token<'a>( last_reserved_token: Option>, last_reserved_top_level_token: Option>, named_placeholders: bool, + dialect: Dialect, ) -> Result> { alt(( get_comment_token, |input: &mut _| get_type_specifier_token(input, previous_token.clone()), - get_string_token, - get_open_paren_token, - get_close_paren_token, + |input: &mut _| get_string_token(input, dialect), + |input: &mut _| get_open_paren_token(input, dialect), + |input: &mut _| get_close_paren_token(input, dialect), get_number_token, |input: &mut _| { get_reserved_word_token( @@ -141,7 +143,7 @@ fn get_next_token<'a>( ) }, get_operator_token, - |input: &mut _| get_placeholder_token(input, named_placeholders), + |input: &mut _| get_placeholder_token(input, named_placeholders, dialect), get_word_token, get_any_other_char, )) @@ -237,10 +239,10 @@ pub fn take_till_escaping<'a>( // 4. single quoted string using '' or \' to escape // 5. national character quoted string using N'' or N\' to escape // 6. hex(blob literal) does not need to escape -fn get_string_token<'i>(input: &mut &'i str) -> Result> { +fn get_string_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result> { dispatch! {any; '`' => (take_till_escaping('`', &['`']), any).void(), - '[' => (take_till_escaping(']', &[']']), any).void(), + '[' if dialect == Dialect::SQLServer => (take_till_escaping(']', &[']']), any).void(), '"' => (take_till_escaping('"', &['"', '\\']), any).void(), '\'' => (take_till_escaping('\'', &['\'', '\\']), any).void(), 'N' => ('\'', take_till_escaping('\'', &['\'', '\\']), any).void(), @@ -260,10 +262,10 @@ fn get_string_token<'i>(input: &mut &'i str) -> Result> { } // Like above but it doesn't replace double quotes -fn get_placeholder_string_token<'i>(input: &mut &'i str) -> Result> { +fn get_placeholder_string_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result> { dispatch! {any; '`'=>( take_till_escaping('`', &['`']), any).void(), - '['=>( take_till_escaping(']', &[']']), any).void(), + '[' if dialect == Dialect::SQLServer =>( take_till_escaping(']', &[']']), any).void(), '"'=>( take_till_escaping('"', &['\\']), any).void(), '\''=>( take_till_escaping('\'', &['\\']), any).void(), 'N' =>('\'', take_till_escaping('\'', &['\\']), any).void(), @@ -279,36 +281,49 @@ fn get_placeholder_string_token<'i>(input: &mut &'i str) -> Result> { }) } -fn get_open_paren_token<'i>(input: &mut &'i str) -> Result> { - alt(("(", terminated(Caseless("CASE"), end_of_word))) - .parse_next(input) - .map(|token| Token { - kind: TokenKind::OpenParen, - value: token, - key: None, - alias: token, - }) +fn get_open_paren_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result> { + let case = terminated(Caseless("CASE"), end_of_word); + let open_paren = if dialect == Dialect::PostgreSql { + ("(", "[", case) + } else { + ("(", "(", case) + }; + + alt(open_paren).parse_next(input).map(|token| Token { + kind: TokenKind::OpenParen, + value: token, + key: None, + alias: token, + }) } -fn get_close_paren_token<'i>(input: &mut &'i str) -> Result> { - alt((")", terminated(Caseless("END"), end_of_word))) - .parse_next(input) - .map(|token| Token { - kind: TokenKind::CloseParen, - value: token, - key: None, - alias: token, - }) +fn get_close_paren_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result> { + let end = terminated(Caseless("END"), end_of_word); + let close_paren = if dialect == Dialect::PostgreSql { + (")", "]", end) + } else { + (")", ")", end) + }; + alt(close_paren).parse_next(input).map(|token| Token { + kind: TokenKind::CloseParen, + value: token, + key: None, + alias: token, + }) } -fn get_placeholder_token<'i>(input: &mut &'i str, named_placeholders: bool) -> Result> { +fn get_placeholder_token<'i>( + input: &mut &'i str, + named_placeholders: bool, + dialect: Dialect, +) -> Result> { // The precedence changes based on 'named_placeholders' but not the exhaustiveness. // This is to ensure the formatting is the same even if parameters aren't used. if named_placeholders { alt(( get_ident_named_placeholder_token, - get_string_named_placeholder_token, + |input: &mut _| get_string_named_placeholder_token(input, dialect), get_indexed_placeholder_token, )) .parse_next(input) @@ -316,7 +331,7 @@ fn get_placeholder_token<'i>(input: &mut &'i str, named_placeholders: bool) -> R alt(( get_indexed_placeholder_token, get_ident_named_placeholder_token, - get_string_named_placeholder_token, + |input: &mut _| get_string_named_placeholder_token(input, dialect), )) .parse_next(input) } @@ -365,8 +380,13 @@ fn get_ident_named_placeholder_token<'i>(input: &mut &'i str) -> Result(input: &mut &'i str) -> Result> { - (one_of(('@', ':')), get_placeholder_string_token) +fn get_string_named_placeholder_token<'i>( + input: &mut &'i str, + dialect: Dialect, +) -> Result> { + (one_of(('@', ':')), |input: &mut _| { + get_placeholder_string_token(input, dialect) + }) .take() .parse_next(input) .map(|token| { From 5c0de0097c8051145c6567f1e54dee063be09db0 Mon Sep 17 00:00:00 2001 From: Luca Barbato Date: Tue, 30 Sep 2025 22:47:38 +0200 Subject: [PATCH 2/4] Add support for posgresql arrays --- src/formatter.rs | 1 + src/lib.rs | 50 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/formatter.rs b/src/formatter.rs index 12f00a0..b7cadec 100644 --- a/src/formatter.rs +++ b/src/formatter.rs @@ -76,6 +76,7 @@ pub(crate) fn format( formatter.format_no_change(token, &mut formatted_query); continue; } + match token.kind { TokenKind::Whitespace => { // ignore (we do our own whitespace formatting) diff --git a/src/lib.rs b/src/lib.rs index 126e5f6..80fc12b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -491,7 +491,10 @@ mod tests { #[test] fn it_formats_type_specifiers() { let input = "SELECT id, ARRAY [] :: UUID [] FROM UNNEST($1 :: UUID []) WHERE $1::UUID[] IS NOT NULL;"; - let options = FormatOptions::default(); + let options = FormatOptions { + dialect: Dialect::PostgreSql, + ..Default::default() + }; let expected = indoc!( " SELECT @@ -506,6 +509,51 @@ mod tests { assert_eq!(format(input, &QueryParams::None, &options), expected); } + #[test] + fn it_formats_arrays_as_function_arguments() { + let input = + "SELECT array_position(ARRAY['sun','mon','tue', 'wed', 'thu','fri', 'sat'], 'mon');"; + let options = FormatOptions { + dialect: Dialect::PostgreSql, + ..Default::default() + }; + let expected = indoc!( + " + SELECT + array_position( + ARRAY['sun', 'mon', 'tue', 'wed', 'thu', 'fri', 'sat'], + 'mon' + );" + ); + + assert_eq!(format(input, &QueryParams::None, &options), expected); + } + + #[test] + fn it_formats_arrays_as_values() { + let input = " INSERT INTO t VALUES('a', ARRAY[0, 1,2,3], ARRAY[['a','b'], ['c' ,'d']]);"; + let options = FormatOptions { + dialect: Dialect::PostgreSql, + max_inline_block: 10, + max_inline_top_level: Some(50), + ..Default::default() + }; + let expected = indoc!( + " + INSERT INTO t + VALUES ( + 'a', + ARRAY[0, 1, 2, 3], + ARRAY[ + ['a', 'b'], + ['c', 'd'] + ] + );" + ); + + assert_eq!(format(input, &QueryParams::None, &options), expected); + } + #[test] fn it_formats_limit_of_single_value_and_offset() { let input = "LIMIT 5 OFFSET 8;"; From 554f761aa2f847895f8bb09a2b5758cb802fec5a Mon Sep 17 00:00:00 2001 From: Luca Barbato Date: Wed, 8 Oct 2025 05:29:01 +0200 Subject: [PATCH 3/4] Correctly format array index notation --- src/formatter.rs | 12 ++++++++---- src/lib.rs | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/formatter.rs b/src/formatter.rs index b7cadec..ed876a3 100644 --- a/src/formatter.rs +++ b/src/formatter.rs @@ -369,7 +369,8 @@ impl<'a> Formatter<'a> { ]; const ADD_WHITESPACE_BETWEEN: &[TokenKind] = &[TokenKind::CloseParen, TokenKind::Reserved]; - + const BEFORE_ARRAY: &[TokenKind] = + &[TokenKind::CloseParen, TokenKind::Word, TokenKind::Reserved]; let inlined = self.inline_block.begin_if_possible(self.tokens, self.index); let previous_non_whitespace_token = self.previous_non_whitespace_token(1); let fold_in_top_level = !inlined @@ -387,13 +388,16 @@ impl<'a> Formatter<'a> { // Take out the preceding space unless there was whitespace there in the original query // or another opening parens or line comment let previous_token = self.previous_token(1); - if previous_token.is_none() - || !PRESERVE_WHITESPACE_FOR.contains(&previous_token.unwrap().kind) + if previous_token.is_none_or(|t| !PRESERVE_WHITESPACE_FOR.contains(&t.kind)) + || previous_non_whitespace_token + .is_some_and(|t| token.value == "[" && BEFORE_ARRAY.contains(&t.kind)) { self.trim_spaces_end(query); } - if previous_non_whitespace_token.is_some_and(|t| ADD_WHITESPACE_BETWEEN.contains(&t.kind)) { + if previous_non_whitespace_token + .is_some_and(|t| token.value != "[" && ADD_WHITESPACE_BETWEEN.contains(&t.kind)) + { self.trim_spaces_end(query); query.push(' '); } diff --git a/src/lib.rs b/src/lib.rs index 80fc12b..ccceb9e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -554,6 +554,21 @@ mod tests { assert_eq!(format(input, &QueryParams::None, &options), expected); } + #[test] + fn it_formats_array_index_notation() { + let input = "SELECT a [ 1 ] + b [ 2 ] [ 5+1 ] > c [3] ;"; + let options = FormatOptions { + dialect: Dialect::PostgreSql, + ..Default::default() + }; + let expected = indoc!( + " + SELECT + a[1] + b[2][5 + 1] > c[3];" + ); + + assert_eq!(format(input, &QueryParams::None, &options), expected); + } #[test] fn it_formats_limit_of_single_value_and_offset() { let input = "LIMIT 5 OFFSET 8;"; From 260f9558d2d386b891fdd86815651cd94c8143df Mon Sep 17 00:00:00 2001 From: Luca Barbato Date: Wed, 8 Oct 2025 14:46:25 +0200 Subject: [PATCH 4/4] Update src/lib.rs Co-authored-by: Josh Holmer --- src/lib.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ccceb9e..c0cc968 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,14 +27,14 @@ pub fn format(query: &str, params: &QueryParams, options: &FormatOptions) -> Str formatter::format(&tokens, params, options) } -/// The SQL dialect to use +/// The SQL dialect to use. This affects parsing of special characters. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Dialect { - /// Best effort, most dialect-specific constructs are disabled + /// Generic SQL syntax, most dialect-specific constructs are disabled Generic, - /// It considers array notations + /// Enables array syntax (`[`, `]`) and operators PostgreSql, - /// It uses the `[brakets to quote]` notation + /// Enables `[bracketed identifiers]` and `@variables` SQLServer, }