diff --git a/src/formatter.rs b/src/formatter.rs index 95d1570..cdbbffd 100644 --- a/src/formatter.rs +++ b/src/formatter.rs @@ -200,8 +200,19 @@ impl<'a> Formatter<'a> { } fn format_type_specifier(&self, token: &Token<'_>, query: &mut String) { + const WHITESPACE_BEFORE: &[TokenKind] = &[ + TokenKind::Reserved, + TokenKind::ReservedNewline, + TokenKind::ReservedNewlineAfter, + ]; self.trim_all_spaces_end(query); query.push_str(token.value); + if self + .next_non_whitespace_token(1) + .is_some_and(|t| WHITESPACE_BEFORE.contains(&t.kind)) + { + query.push(' ') + } } fn format_block_comment(&mut self, token: &Token<'_>, query: &mut String) { self.add_new_line(query); @@ -565,6 +576,17 @@ impl<'a> Formatter<'a> { } } + fn next_non_whitespace_token(&self, idx: usize) -> Option<&Token<'_>> { + let index = self.index.checked_add(idx); + if let Some(index) = index { + self.tokens[index..] + .iter() + .find(|t| t.kind != TokenKind::Whitespace) + } else { + None + } + } + fn next_token(&self, idx: usize) -> Option<&Token<'_>> { let index = self.index.checked_add(idx); if let Some(index) = index { diff --git a/src/lib.rs b/src/lib.rs index 0b05b33..316748a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -471,7 +471,7 @@ mod tests { #[test] fn it_formats_type_specifiers() { - let input = "SELECT id, ARRAY [] :: UUID [] FROM UNNEST($1 :: UUID []);"; + let input = "SELECT id, ARRAY [] :: UUID [] FROM UNNEST($1 :: UUID []) WHERE $1::UUID[] IS NOT NULL;"; let options = FormatOptions::default(); let expected = indoc!( " @@ -479,7 +479,9 @@ mod tests { id, ARRAY[]::UUID[] FROM - UNNEST($1::UUID[]);" + UNNEST($1::UUID[]) + WHERE + $1::UUID[] IS NOT NULL;" ); assert_eq!(format(input, &QueryParams::None, &options), expected);