Skip to content

Commit acbf5e7

Browse files
committed
Initial support for dialects
1 parent 9385670 commit acbf5e7

File tree

2 files changed

+88
-37
lines changed

2 files changed

+88
-37
lines changed

src/lib.rs

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ pub fn format(query: &str, params: &QueryParams, options: &FormatOptions) -> Str
2727
formatter::format(&tokens, params, options)
2828
}
2929

30+
/// The SQL dialect to use
31+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32+
pub enum Dialect {
33+
/// Best effort, most dialect-specific constructs are disabled
34+
Generic,
35+
/// It considers array notations
36+
PostgreSql,
37+
/// It uses the `[brakets to quote]` notation
38+
SQLServer,
39+
}
40+
3041
/// Options for controlling how the library formats SQL
3142
#[derive(Debug, Clone)]
3243
pub struct FormatOptions<'a> {
@@ -68,6 +79,10 @@ pub struct FormatOptions<'a> {
6879
///
6980
/// Default: false,
7081
pub joins_as_top_level: bool,
82+
/// Tell the SQL dialect to use
83+
///
84+
/// Default: Generic
85+
pub dialect: Dialect,
7186
}
7287

7388
impl<'a> Default for FormatOptions<'a> {
@@ -82,6 +97,7 @@ impl<'a> Default for FormatOptions<'a> {
8297
max_inline_arguments: None,
8398
max_inline_top_level: None,
8499
joins_as_top_level: false,
100+
dialect: Dialect::Generic,
85101
}
86102
}
87103
}
@@ -1328,7 +1344,10 @@ mod tests {
13281344
#[test]
13291345
fn it_recognizes_bracketed_strings() {
13301346
let inputs = ["[foo JOIN bar]", "[foo ]] JOIN bar]"];
1331-
let options = FormatOptions::default();
1347+
let options = FormatOptions {
1348+
dialect: Dialect::SQLServer,
1349+
..Default::default()
1350+
};
13321351
for input in &inputs {
13331352
assert_eq!(&format(input, &QueryParams::None, &options), input);
13341353
}
@@ -1338,7 +1357,10 @@ mod tests {
13381357
fn it_recognizes_at_variables() {
13391358
let input =
13401359
"SELECT @variable, @a1_2.3$, @'var name', @\"var name\", @`var name`, @[var name];";
1341-
let options = FormatOptions::default();
1360+
let options = FormatOptions {
1361+
dialect: Dialect::SQLServer,
1362+
..Default::default()
1363+
};
13421364
let expected = indoc!(
13431365
"
13441366
SELECT
@@ -1363,7 +1385,10 @@ mod tests {
13631385
("var name".to_string(), "'var value'".to_string()),
13641386
("var\\name".to_string(), "'var\\ value'".to_string()),
13651387
];
1366-
let options = FormatOptions::default();
1388+
let options = FormatOptions {
1389+
dialect: Dialect::SQLServer,
1390+
..Default::default()
1391+
};
13671392
let expected = indoc!(
13681393
"
13691394
SELECT
@@ -1386,7 +1411,10 @@ mod tests {
13861411
fn it_recognizes_colon_variables() {
13871412
let input =
13881413
"SELECT :variable, :a1_2.3$, :'var name', :\"var name\", :`var name`, :[var name];";
1389-
let options = FormatOptions::default();
1414+
let options = FormatOptions {
1415+
dialect: Dialect::SQLServer,
1416+
..Default::default()
1417+
};
13901418
let expected = indoc!(
13911419
"
13921420
SELECT
@@ -1419,7 +1447,10 @@ mod tests {
14191447
"'super weird value'".to_string(),
14201448
),
14211449
];
1422-
let options = FormatOptions::default();
1450+
let options = FormatOptions {
1451+
dialect: Dialect::SQLServer,
1452+
..Default::default()
1453+
};
14231454
let expected = indoc!(
14241455
"
14251456
SELECT

src/tokenizer.rs

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use winnow::prelude::*;
88
use winnow::token::{any, one_of, rest, take, take_until, take_while};
99
use winnow::Result;
1010

11-
use crate::FormatOptions;
11+
use crate::{Dialect, FormatOptions};
1212

1313
pub(crate) fn tokenize<'a>(
1414
mut input: &'a str,
@@ -32,6 +32,7 @@ pub(crate) fn tokenize<'a>(
3232
last_reserved_token.clone(),
3333
last_reserved_top_level_token.clone(),
3434
named_placeholders,
35+
options.dialect,
3536
) {
3637
match result.kind {
3738
TokenKind::Reserved => {
@@ -124,13 +125,14 @@ fn get_next_token<'a>(
124125
last_reserved_token: Option<Token<'a>>,
125126
last_reserved_top_level_token: Option<Token<'a>>,
126127
named_placeholders: bool,
128+
dialect: Dialect,
127129
) -> Result<Token<'a>> {
128130
alt((
129131
get_comment_token,
130132
|input: &mut _| get_type_specifier_token(input, previous_token.clone()),
131-
get_string_token,
132-
get_open_paren_token,
133-
get_close_paren_token,
133+
|input: &mut _| get_string_token(input, dialect),
134+
|input: &mut _| get_open_paren_token(input, dialect),
135+
|input: &mut _| get_close_paren_token(input, dialect),
134136
get_number_token,
135137
|input: &mut _| {
136138
get_reserved_word_token(
@@ -141,7 +143,7 @@ fn get_next_token<'a>(
141143
)
142144
},
143145
get_operator_token,
144-
|input: &mut _| get_placeholder_token(input, named_placeholders),
146+
|input: &mut _| get_placeholder_token(input, named_placeholders, dialect),
145147
get_word_token,
146148
get_any_other_char,
147149
))
@@ -237,10 +239,10 @@ pub fn take_till_escaping<'a>(
237239
// 4. single quoted string using '' or \' to escape
238240
// 5. national character quoted string using N'' or N\' to escape
239241
// 6. hex(blob literal) does not need to escape
240-
fn get_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
242+
fn get_string_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
241243
dispatch! {any;
242244
'`' => (take_till_escaping('`', &['`']), any).void(),
243-
'[' => (take_till_escaping(']', &[']']), any).void(),
245+
'[' if dialect == Dialect::SQLServer => (take_till_escaping(']', &[']']), any).void(),
244246
'"' => (take_till_escaping('"', &['"', '\\']), any).void(),
245247
'\'' => (take_till_escaping('\'', &['\'', '\\']), any).void(),
246248
'N' => ('\'', take_till_escaping('\'', &['\'', '\\']), any).void(),
@@ -260,10 +262,10 @@ fn get_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
260262
}
261263

262264
// Like above but it doesn't replace double quotes
263-
fn get_placeholder_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
265+
fn get_placeholder_string_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
264266
dispatch! {any;
265267
'`'=>( take_till_escaping('`', &['`']), any).void(),
266-
'['=>( take_till_escaping(']', &[']']), any).void(),
268+
'[' if dialect == Dialect::SQLServer =>( take_till_escaping(']', &[']']), any).void(),
267269
'"'=>( take_till_escaping('"', &['\\']), any).void(),
268270
'\''=>( take_till_escaping('\'', &['\\']), any).void(),
269271
'N' =>('\'', take_till_escaping('\'', &['\\']), any).void(),
@@ -279,44 +281,57 @@ fn get_placeholder_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
279281
})
280282
}
281283

282-
fn get_open_paren_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
283-
alt(("(", terminated(Caseless("CASE"), end_of_word)))
284-
.parse_next(input)
285-
.map(|token| Token {
286-
kind: TokenKind::OpenParen,
287-
value: token,
288-
key: None,
289-
alias: token,
290-
})
284+
fn get_open_paren_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
285+
let case = terminated(Caseless("CASE"), end_of_word);
286+
let open_paren = if dialect == Dialect::PostgreSql {
287+
("(", "[", case)
288+
} else {
289+
("(", "(", case)
290+
};
291+
292+
alt(open_paren).parse_next(input).map(|token| Token {
293+
kind: TokenKind::OpenParen,
294+
value: token,
295+
key: None,
296+
alias: token,
297+
})
291298
}
292299

293-
fn get_close_paren_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
294-
alt((")", terminated(Caseless("END"), end_of_word)))
295-
.parse_next(input)
296-
.map(|token| Token {
297-
kind: TokenKind::CloseParen,
298-
value: token,
299-
key: None,
300-
alias: token,
301-
})
300+
fn get_close_paren_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
301+
let end = terminated(Caseless("END"), end_of_word);
302+
let close_paren = if dialect == Dialect::PostgreSql {
303+
(")", "]", end)
304+
} else {
305+
(")", ")", end)
306+
};
307+
alt(close_paren).parse_next(input).map(|token| Token {
308+
kind: TokenKind::CloseParen,
309+
value: token,
310+
key: None,
311+
alias: token,
312+
})
302313
}
303314

304-
fn get_placeholder_token<'i>(input: &mut &'i str, named_placeholders: bool) -> Result<Token<'i>> {
315+
fn get_placeholder_token<'i>(
316+
input: &mut &'i str,
317+
named_placeholders: bool,
318+
dialect: Dialect,
319+
) -> Result<Token<'i>> {
305320
// The precedence changes based on 'named_placeholders' but not the exhaustiveness.
306321
// This is to ensure the formatting is the same even if parameters aren't used.
307322

308323
if named_placeholders {
309324
alt((
310325
get_ident_named_placeholder_token,
311-
get_string_named_placeholder_token,
326+
|input: &mut _| get_string_named_placeholder_token(input, dialect),
312327
get_indexed_placeholder_token,
313328
))
314329
.parse_next(input)
315330
} else {
316331
alt((
317332
get_indexed_placeholder_token,
318333
get_ident_named_placeholder_token,
319-
get_string_named_placeholder_token,
334+
|input: &mut _| get_string_named_placeholder_token(input, dialect),
320335
))
321336
.parse_next(input)
322337
}
@@ -365,8 +380,13 @@ fn get_ident_named_placeholder_token<'i>(input: &mut &'i str) -> Result<Token<'i
365380
})
366381
}
367382

368-
fn get_string_named_placeholder_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
369-
(one_of(('@', ':')), get_placeholder_string_token)
383+
fn get_string_named_placeholder_token<'i>(
384+
input: &mut &'i str,
385+
dialect: Dialect,
386+
) -> Result<Token<'i>> {
387+
(one_of(('@', ':')), |input: &mut _| {
388+
get_placeholder_string_token(input, dialect)
389+
})
370390
.take()
371391
.parse_next(input)
372392
.map(|token| {

0 commit comments

Comments
 (0)