Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/formatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub(crate) fn format(
formatter.format_no_change(token, &mut formatted_query);
continue;
}

match token.kind {
TokenKind::Whitespace => {
// ignore (we do our own whitespace formatting)
Expand Down Expand Up @@ -368,7 +369,8 @@ impl<'a> Formatter<'a> {
];

const ADD_WHITESPACE_BETWEEN: &[TokenKind] = &[TokenKind::CloseParen, TokenKind::Reserved];

const BEFORE_ARRAY: &[TokenKind] =
&[TokenKind::CloseParen, TokenKind::Word, TokenKind::Reserved];
let inlined = self.inline_block.begin_if_possible(self.tokens, self.index);
let previous_non_whitespace_token = self.previous_non_whitespace_token(1);
let fold_in_top_level = !inlined
Expand All @@ -386,13 +388,16 @@ impl<'a> Formatter<'a> {
// Take out the preceding space unless there was whitespace there in the original query
// or another opening parens or line comment
let previous_token = self.previous_token(1);
if previous_token.is_none()
|| !PRESERVE_WHITESPACE_FOR.contains(&previous_token.unwrap().kind)
if previous_token.is_none_or(|t| !PRESERVE_WHITESPACE_FOR.contains(&t.kind))
|| previous_non_whitespace_token
.is_some_and(|t| token.value == "[" && BEFORE_ARRAY.contains(&t.kind))
{
self.trim_spaces_end(query);
}

if previous_non_whitespace_token.is_some_and(|t| ADD_WHITESPACE_BETWEEN.contains(&t.kind)) {
if previous_non_whitespace_token
.is_some_and(|t| token.value != "[" && ADD_WHITESPACE_BETWEEN.contains(&t.kind))
{
self.trim_spaces_end(query);
query.push(' ');
}
Expand Down
106 changes: 100 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ pub fn format(query: &str, params: &QueryParams, options: &FormatOptions) -> Str
formatter::format(&tokens, params, options)
}

/// The SQL dialect to use. This affects parsing of special characters.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Dialect {
/// Generic SQL syntax, most dialect-specific constructs are disabled
Generic,
/// Enables array syntax (`[`, `]`) and operators
PostgreSql,
/// Enables `[bracketed identifiers]` and `@variables`
SQLServer,
}

/// Options for controlling how the library formats SQL
#[derive(Debug, Clone)]
pub struct FormatOptions<'a> {
Expand Down Expand Up @@ -68,6 +79,10 @@ pub struct FormatOptions<'a> {
///
/// Default: false,
pub joins_as_top_level: bool,
/// Tell the SQL dialect to use
///
/// Default: Generic
pub dialect: Dialect,
}

impl<'a> Default for FormatOptions<'a> {
Expand All @@ -82,6 +97,7 @@ impl<'a> Default for FormatOptions<'a> {
max_inline_arguments: None,
max_inline_top_level: None,
joins_as_top_level: false,
dialect: Dialect::Generic,
}
}
}
Expand Down Expand Up @@ -475,7 +491,10 @@ mod tests {
#[test]
fn it_formats_type_specifiers() {
let input = "SELECT id, ARRAY [] :: UUID [] FROM UNNEST($1 :: UUID []) WHERE $1::UUID[] IS NOT NULL;";
let options = FormatOptions::default();
let options = FormatOptions {
dialect: Dialect::PostgreSql,
..Default::default()
};
let expected = indoc!(
"
SELECT
Expand All @@ -490,6 +509,66 @@ mod tests {
assert_eq!(format(input, &QueryParams::None, &options), expected);
}

#[test]
fn it_formats_arrays_as_function_arguments() {
let input =
"SELECT array_position(ARRAY['sun','mon','tue', 'wed', 'thu','fri', 'sat'], 'mon');";
let options = FormatOptions {
dialect: Dialect::PostgreSql,
..Default::default()
};
let expected = indoc!(
"
SELECT
array_position(
ARRAY['sun', 'mon', 'tue', 'wed', 'thu', 'fri', 'sat'],
'mon'
);"
);

assert_eq!(format(input, &QueryParams::None, &options), expected);
}

#[test]
fn it_formats_arrays_as_values() {
let input = " INSERT INTO t VALUES('a', ARRAY[0, 1,2,3], ARRAY[['a','b'], ['c' ,'d']]);";
let options = FormatOptions {
dialect: Dialect::PostgreSql,
max_inline_block: 10,
max_inline_top_level: Some(50),
..Default::default()
};
let expected = indoc!(
"
INSERT INTO t
VALUES (
'a',
ARRAY[0, 1, 2, 3],
ARRAY[
['a', 'b'],
['c', 'd']
]
);"
);

assert_eq!(format(input, &QueryParams::None, &options), expected);
}

#[test]
fn it_formats_array_index_notation() {
let input = "SELECT a [ 1 ] + b [ 2 ] [ 5+1 ] > c [3] ;";
let options = FormatOptions {
dialect: Dialect::PostgreSql,
..Default::default()
};
let expected = indoc!(
"
SELECT
a[1] + b[2][5 + 1] > c[3];"
);

assert_eq!(format(input, &QueryParams::None, &options), expected);
}
#[test]
fn it_formats_limit_of_single_value_and_offset() {
let input = "LIMIT 5 OFFSET 8;";
Expand Down Expand Up @@ -1328,7 +1407,10 @@ mod tests {
#[test]
fn it_recognizes_bracketed_strings() {
let inputs = ["[foo JOIN bar]", "[foo ]] JOIN bar]"];
let options = FormatOptions::default();
let options = FormatOptions {
dialect: Dialect::SQLServer,
..Default::default()
};
for input in &inputs {
assert_eq!(&format(input, &QueryParams::None, &options), input);
}
Expand All @@ -1338,7 +1420,10 @@ mod tests {
fn it_recognizes_at_variables() {
let input =
"SELECT @variable, @a1_2.3$, @'var name', @\"var name\", @`var name`, @[var name];";
let options = FormatOptions::default();
let options = FormatOptions {
dialect: Dialect::SQLServer,
..Default::default()
};
let expected = indoc!(
"
SELECT
Expand All @@ -1363,7 +1448,10 @@ mod tests {
("var name".to_string(), "'var value'".to_string()),
("var\\name".to_string(), "'var\\ value'".to_string()),
];
let options = FormatOptions::default();
let options = FormatOptions {
dialect: Dialect::SQLServer,
..Default::default()
};
let expected = indoc!(
"
SELECT
Expand All @@ -1386,7 +1474,10 @@ mod tests {
fn it_recognizes_colon_variables() {
let input =
"SELECT :variable, :a1_2.3$, :'var name', :\"var name\", :`var name`, :[var name];";
let options = FormatOptions::default();
let options = FormatOptions {
dialect: Dialect::SQLServer,
..Default::default()
};
let expected = indoc!(
"
SELECT
Expand Down Expand Up @@ -1419,7 +1510,10 @@ mod tests {
"'super weird value'".to_string(),
),
];
let options = FormatOptions::default();
let options = FormatOptions {
dialect: Dialect::SQLServer,
..Default::default()
};
let expected = indoc!(
"
SELECT
Expand Down
84 changes: 52 additions & 32 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use winnow::prelude::*;
use winnow::token::{any, one_of, rest, take, take_until, take_while};
use winnow::Result;

use crate::FormatOptions;
use crate::{Dialect, FormatOptions};

pub(crate) fn tokenize<'a>(
mut input: &'a str,
Expand All @@ -32,6 +32,7 @@ pub(crate) fn tokenize<'a>(
last_reserved_token.clone(),
last_reserved_top_level_token.clone(),
named_placeholders,
options.dialect,
) {
match result.kind {
TokenKind::Reserved => {
Expand Down Expand Up @@ -124,13 +125,14 @@ fn get_next_token<'a>(
last_reserved_token: Option<Token<'a>>,
last_reserved_top_level_token: Option<Token<'a>>,
named_placeholders: bool,
dialect: Dialect,
) -> Result<Token<'a>> {
alt((
get_comment_token,
|input: &mut _| get_type_specifier_token(input, previous_token.clone()),
get_string_token,
get_open_paren_token,
get_close_paren_token,
|input: &mut _| get_string_token(input, dialect),
|input: &mut _| get_open_paren_token(input, dialect),
|input: &mut _| get_close_paren_token(input, dialect),
get_number_token,
|input: &mut _| {
get_reserved_word_token(
Expand All @@ -141,7 +143,7 @@ fn get_next_token<'a>(
)
},
get_operator_token,
|input: &mut _| get_placeholder_token(input, named_placeholders),
|input: &mut _| get_placeholder_token(input, named_placeholders, dialect),
get_word_token,
get_any_other_char,
))
Expand Down Expand Up @@ -237,10 +239,10 @@ pub fn take_till_escaping<'a>(
// 4. single quoted string using '' or \' to escape
// 5. national character quoted string using N'' or N\' to escape
// 6. hex(blob literal) does not need to escape
fn get_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
fn get_string_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
dispatch! {any;
'`' => (take_till_escaping('`', &['`']), any).void(),
'[' => (take_till_escaping(']', &[']']), any).void(),
'[' if dialect == Dialect::SQLServer => (take_till_escaping(']', &[']']), any).void(),
'"' => (take_till_escaping('"', &['"', '\\']), any).void(),
'\'' => (take_till_escaping('\'', &['\'', '\\']), any).void(),
'N' => ('\'', take_till_escaping('\'', &['\'', '\\']), any).void(),
Expand All @@ -260,10 +262,10 @@ fn get_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
}

// Like above but it doesn't replace double quotes
fn get_placeholder_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
fn get_placeholder_string_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
dispatch! {any;
'`'=>( take_till_escaping('`', &['`']), any).void(),
'['=>( take_till_escaping(']', &[']']), any).void(),
'[' if dialect == Dialect::SQLServer =>( take_till_escaping(']', &[']']), any).void(),
'"'=>( take_till_escaping('"', &['\\']), any).void(),
'\''=>( take_till_escaping('\'', &['\\']), any).void(),
'N' =>('\'', take_till_escaping('\'', &['\\']), any).void(),
Expand All @@ -279,44 +281,57 @@ fn get_placeholder_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
})
}

fn get_open_paren_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
alt(("(", terminated(Caseless("CASE"), end_of_word)))
.parse_next(input)
.map(|token| Token {
kind: TokenKind::OpenParen,
value: token,
key: None,
alias: token,
})
fn get_open_paren_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
let case = terminated(Caseless("CASE"), end_of_word);
let open_paren = if dialect == Dialect::PostgreSql {
("(", "[", case)
} else {
("(", "(", case)
};

alt(open_paren).parse_next(input).map(|token| Token {
kind: TokenKind::OpenParen,
value: token,
key: None,
alias: token,
})
}

fn get_close_paren_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
alt((")", terminated(Caseless("END"), end_of_word)))
.parse_next(input)
.map(|token| Token {
kind: TokenKind::CloseParen,
value: token,
key: None,
alias: token,
})
fn get_close_paren_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
let end = terminated(Caseless("END"), end_of_word);
let close_paren = if dialect == Dialect::PostgreSql {
(")", "]", end)
} else {
(")", ")", end)
};
alt(close_paren).parse_next(input).map(|token| Token {
kind: TokenKind::CloseParen,
value: token,
key: None,
alias: token,
})
}

fn get_placeholder_token<'i>(input: &mut &'i str, named_placeholders: bool) -> Result<Token<'i>> {
fn get_placeholder_token<'i>(
input: &mut &'i str,
named_placeholders: bool,
dialect: Dialect,
) -> Result<Token<'i>> {
// The precedence changes based on 'named_placeholders' but not the exhaustiveness.
// This is to ensure the formatting is the same even if parameters aren't used.

if named_placeholders {
alt((
get_ident_named_placeholder_token,
get_string_named_placeholder_token,
|input: &mut _| get_string_named_placeholder_token(input, dialect),
get_indexed_placeholder_token,
))
.parse_next(input)
} else {
alt((
get_indexed_placeholder_token,
get_ident_named_placeholder_token,
get_string_named_placeholder_token,
|input: &mut _| get_string_named_placeholder_token(input, dialect),
))
.parse_next(input)
}
Expand Down Expand Up @@ -365,8 +380,13 @@ fn get_ident_named_placeholder_token<'i>(input: &mut &'i str) -> Result<Token<'i
})
}

fn get_string_named_placeholder_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
(one_of(('@', ':')), get_placeholder_string_token)
fn get_string_named_placeholder_token<'i>(
input: &mut &'i str,
dialect: Dialect,
) -> Result<Token<'i>> {
(one_of(('@', ':')), |input: &mut _| {
get_placeholder_string_token(input, dialect)
})
.take()
.parse_next(input)
.map(|token| {
Expand Down