Skip to content

Commit c93d0f2

Browse files
committed
Initial support for dialects
1 parent f910e6f commit c93d0f2

File tree

2 files changed

+88
-37
lines changed

2 files changed

+88
-37
lines changed

src/lib.rs

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ pub fn format(query: &str, params: &QueryParams, options: &FormatOptions) -> Str
2424
formatter::format(&tokens, params, options)
2525
}
2626

27+
/// The SQL dialect to use
28+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29+
pub enum Dialect {
30+
/// Best effort, most dialect-specific constructs are disabled
31+
Generic,
32+
/// It considers array notations
33+
PostgreSql,
34+
/// It uses the `[brakets to quote]` notation
35+
SQLServer,
36+
}
37+
2738
/// Options for controlling how the library formats SQL
2839
#[derive(Debug, Clone)]
2940
pub struct FormatOptions<'a> {
@@ -65,6 +76,10 @@ pub struct FormatOptions<'a> {
6576
///
6677
/// Default: false,
6778
pub joins_as_top_level: bool,
79+
/// Tell the SQL dialect to use
80+
///
81+
/// Default: Generic
82+
pub dialect: Dialect,
6883
}
6984

7085
impl<'a> Default for FormatOptions<'a> {
@@ -79,6 +94,7 @@ impl<'a> Default for FormatOptions<'a> {
7994
max_inline_arguments: None,
8095
max_inline_top_level: None,
8196
joins_as_top_level: false,
97+
dialect: Dialect::Generic,
8298
}
8399
}
84100
}
@@ -1325,7 +1341,10 @@ mod tests {
13251341
#[test]
13261342
fn it_recognizes_bracketed_strings() {
13271343
let inputs = ["[foo JOIN bar]", "[foo ]] JOIN bar]"];
1328-
let options = FormatOptions::default();
1344+
let options = FormatOptions {
1345+
dialect: Dialect::SQLServer,
1346+
..Default::default()
1347+
};
13291348
for input in &inputs {
13301349
assert_eq!(&format(input, &QueryParams::None, &options), input);
13311350
}
@@ -1335,7 +1354,10 @@ mod tests {
13351354
fn it_recognizes_at_variables() {
13361355
let input =
13371356
"SELECT @variable, @a1_2.3$, @'var name', @\"var name\", @`var name`, @[var name];";
1338-
let options = FormatOptions::default();
1357+
let options = FormatOptions {
1358+
dialect: Dialect::SQLServer,
1359+
..Default::default()
1360+
};
13391361
let expected = indoc!(
13401362
"
13411363
SELECT
@@ -1360,7 +1382,10 @@ mod tests {
13601382
("var name".to_string(), "'var value'".to_string()),
13611383
("var\\name".to_string(), "'var\\ value'".to_string()),
13621384
];
1363-
let options = FormatOptions::default();
1385+
let options = FormatOptions {
1386+
dialect: Dialect::SQLServer,
1387+
..Default::default()
1388+
};
13641389
let expected = indoc!(
13651390
"
13661391
SELECT
@@ -1383,7 +1408,10 @@ mod tests {
13831408
fn it_recognizes_colon_variables() {
13841409
let input =
13851410
"SELECT :variable, :a1_2.3$, :'var name', :\"var name\", :`var name`, :[var name];";
1386-
let options = FormatOptions::default();
1411+
let options = FormatOptions {
1412+
dialect: Dialect::SQLServer,
1413+
..Default::default()
1414+
};
13871415
let expected = indoc!(
13881416
"
13891417
SELECT
@@ -1416,7 +1444,10 @@ mod tests {
14161444
"'super weird value'".to_string(),
14171445
),
14181446
];
1419-
let options = FormatOptions::default();
1447+
let options = FormatOptions {
1448+
dialect: Dialect::SQLServer,
1449+
..Default::default()
1450+
};
14201451
let expected = indoc!(
14211452
"
14221453
SELECT

src/tokenizer.rs

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use winnow::prelude::*;
88
use winnow::token::{any, one_of, rest, take, take_until, take_while};
99
use winnow::Result;
1010

11-
use crate::FormatOptions;
11+
use crate::{Dialect, FormatOptions};
1212

1313
pub(crate) fn tokenize<'a>(
1414
mut input: &'a str,
@@ -32,6 +32,7 @@ pub(crate) fn tokenize<'a>(
3232
last_reserved_token.clone(),
3333
last_reserved_top_level_token.clone(),
3434
named_placeholders,
35+
options.dialect,
3536
) {
3637
match result.kind {
3738
TokenKind::Reserved => {
@@ -124,13 +125,14 @@ fn get_next_token<'a>(
124125
last_reserved_token: Option<Token<'a>>,
125126
last_reserved_top_level_token: Option<Token<'a>>,
126127
named_placeholders: bool,
128+
dialect: Dialect,
127129
) -> Result<Token<'a>> {
128130
alt((
129131
get_comment_token,
130132
|input: &mut _| get_type_specifier_token(input, previous_token.clone()),
131-
get_string_token,
132-
get_open_paren_token,
133-
get_close_paren_token,
133+
|input: &mut _| get_string_token(input, dialect),
134+
|input: &mut _| get_open_paren_token(input, dialect),
135+
|input: &mut _| get_close_paren_token(input, dialect),
134136
get_number_token,
135137
|input: &mut _| {
136138
get_reserved_word_token(
@@ -141,7 +143,7 @@ fn get_next_token<'a>(
141143
)
142144
},
143145
get_operator_token,
144-
|input: &mut _| get_placeholder_token(input, named_placeholders),
146+
|input: &mut _| get_placeholder_token(input, named_placeholders, dialect),
145147
get_word_token,
146148
get_any_other_char,
147149
))
@@ -237,10 +239,10 @@ pub fn take_till_escaping<'a>(
237239
// 4. single quoted string using '' or \' to escape
238240
// 5. national character quoted string using N'' or N\' to escape
239241
// 6. hex(blob literal) does not need to escape
240-
fn get_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
242+
fn get_string_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
241243
dispatch! {any;
242244
'`' => (take_till_escaping('`', &['`']), any).void(),
243-
'[' => (take_till_escaping(']', &[']']), any).void(),
245+
'[' if dialect == Dialect::SQLServer => (take_till_escaping(']', &[']']), any).void(),
244246
'"' => (take_till_escaping('"', &['"', '\\']), any).void(),
245247
'\'' => (take_till_escaping('\'', &['\'', '\\']), any).void(),
246248
'N' => ('\'', take_till_escaping('\'', &['\'', '\\']), any).void(),
@@ -260,10 +262,10 @@ fn get_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
260262
}
261263

262264
// Like above but it doesn't replace double quotes
263-
fn get_placeholder_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
265+
fn get_placeholder_string_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
264266
dispatch! {any;
265267
'`'=>( take_till_escaping('`', &['`']), any).void(),
266-
'['=>( take_till_escaping(']', &[']']), any).void(),
268+
'[' if dialect == Dialect::SQLServer =>( take_till_escaping(']', &[']']), any).void(),
267269
'"'=>( take_till_escaping('"', &['\\']), any).void(),
268270
'\''=>( take_till_escaping('\'', &['\\']), any).void(),
269271
'N' =>('\'', take_till_escaping('\'', &['\\']), any).void(),
@@ -279,44 +281,57 @@ fn get_placeholder_string_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
279281
})
280282
}
281283

282-
fn get_open_paren_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
283-
alt(("(", terminated(Caseless("CASE"), end_of_word)))
284-
.parse_next(input)
285-
.map(|token| Token {
286-
kind: TokenKind::OpenParen,
287-
value: token,
288-
key: None,
289-
alias: token,
290-
})
284+
fn get_open_paren_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
285+
let case = terminated(Caseless("CASE"), end_of_word);
286+
let open_paren = if dialect == Dialect::PostgreSql {
287+
("(", "[", case)
288+
} else {
289+
("(", "(", case)
290+
};
291+
292+
alt(open_paren).parse_next(input).map(|token| Token {
293+
kind: TokenKind::OpenParen,
294+
value: token,
295+
key: None,
296+
alias: token,
297+
})
291298
}
292299

293-
fn get_close_paren_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
294-
alt((")", terminated(Caseless("END"), end_of_word)))
295-
.parse_next(input)
296-
.map(|token| Token {
297-
kind: TokenKind::CloseParen,
298-
value: token,
299-
key: None,
300-
alias: token,
301-
})
300+
fn get_close_paren_token<'i>(input: &mut &'i str, dialect: Dialect) -> Result<Token<'i>> {
301+
let end = terminated(Caseless("END"), end_of_word);
302+
let close_paren = if dialect == Dialect::PostgreSql {
303+
(")", "]", end)
304+
} else {
305+
(")", ")", end)
306+
};
307+
alt(close_paren).parse_next(input).map(|token| Token {
308+
kind: TokenKind::CloseParen,
309+
value: token,
310+
key: None,
311+
alias: token,
312+
})
302313
}
303314

304-
fn get_placeholder_token<'i>(input: &mut &'i str, named_placeholders: bool) -> Result<Token<'i>> {
315+
fn get_placeholder_token<'i>(
316+
input: &mut &'i str,
317+
named_placeholders: bool,
318+
dialect: Dialect,
319+
) -> Result<Token<'i>> {
305320
// The precedence changes based on 'named_placeholders' but not the exhaustiveness.
306321
// This is to ensure the formatting is the same even if parameters aren't used.
307322

308323
if named_placeholders {
309324
alt((
310325
get_ident_named_placeholder_token,
311-
get_string_named_placeholder_token,
326+
|input: &mut _| get_string_named_placeholder_token(input, dialect),
312327
get_indexed_placeholder_token,
313328
))
314329
.parse_next(input)
315330
} else {
316331
alt((
317332
get_indexed_placeholder_token,
318333
get_ident_named_placeholder_token,
319-
get_string_named_placeholder_token,
334+
|input: &mut _| get_string_named_placeholder_token(input, dialect),
320335
))
321336
.parse_next(input)
322337
}
@@ -365,8 +380,13 @@ fn get_ident_named_placeholder_token<'i>(input: &mut &'i str) -> Result<Token<'i
365380
})
366381
}
367382

368-
fn get_string_named_placeholder_token<'i>(input: &mut &'i str) -> Result<Token<'i>> {
369-
(one_of(('@', ':')), get_placeholder_string_token)
383+
fn get_string_named_placeholder_token<'i>(
384+
input: &mut &'i str,
385+
dialect: Dialect,
386+
) -> Result<Token<'i>> {
387+
(one_of(('@', ':')), |input: &mut _| {
388+
get_placeholder_string_token(input, dialect)
389+
})
370390
.take()
371391
.parse_next(input)
372392
.map(|token| {

0 commit comments

Comments
 (0)