Skip to content

Commit 82d6b79

Browse files
authored
linter: fix comparison edge cases with identifiers (#601)
Follow up to #600
1 parent f9dc2b0 commit 82d6b79

15 files changed

+138
-52
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/// Postgres Identifiers are case insensitive unless they're quoted.
2+
///
3+
/// This type handles the casing rules for us to make comparisions easier.
4+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
5+
pub(crate) struct Identifier(String);
6+
7+
impl Identifier {
8+
// TODO: we need to handle more advanced identifiers like:
9+
// U&"d!0061t!+000061" UESCAPE '!'
10+
pub fn new(s: &str) -> Self {
11+
let normalized = if s.starts_with('"') && s.ends_with('"') {
12+
s[1..s.len() - 1].to_string()
13+
} else {
14+
s.to_lowercase()
15+
};
16+
Identifier(normalized)
17+
}
18+
}
19+
20+
#[cfg(test)]
21+
mod test {
22+
use crate::identifier::Identifier;
23+
24+
#[test]
25+
fn case_folds_correctly() {
26+
// https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
27+
// For example, the identifiers FOO, foo, and "foo" are considered the
28+
// same by PostgreSQL, but "Foo" and "FOO" are different from these
29+
// three and each other.
30+
assert_eq!(Identifier::new("FOO"), Identifier::new("foo"));
31+
assert_eq!(Identifier::new(r#""foo""#), Identifier::new("foo"));
32+
assert_eq!(Identifier::new(r#""foo""#), Identifier::new("FOO"));
33+
}
34+
}

crates/squawk_linter/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ mod ignore_index;
1919
mod version;
2020
mod visitors;
2121

22+
mod identifier;
2223
mod rules;
23-
mod text;
2424
use rules::adding_field_with_default;
2525
use rules::adding_foreign_key_constraint;
2626
use rules::adding_not_null_field;

crates/squawk_linter/src/rules/adding_field_with_default.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use squawk_syntax::ast;
55
use squawk_syntax::ast::AstNode;
66
use squawk_syntax::{Parse, SourceFile};
77

8+
use crate::identifier::Identifier;
89
use crate::{Linter, Rule, Violation};
910

1011
fn is_const_expr(expr: &ast::Expr) -> bool {
@@ -16,11 +17,12 @@ fn is_const_expr(expr: &ast::Expr) -> bool {
1617
}
1718

1819
lazy_static! {
19-
static ref NON_VOLATILE_FUNCS: HashSet<String> = {
20+
static ref NON_VOLATILE_FUNCS: HashSet<Identifier> = {
2021
NON_VOLATILE_BUILT_IN_FUNCTIONS
2122
.split('\n')
22-
.map(|x| x.trim().to_lowercase())
23+
.map(|x| x.trim())
2324
.filter(|x| !x.is_empty())
25+
.map(|x| Identifier::new(x))
2426
.collect()
2527
};
2628
}
@@ -36,7 +38,8 @@ fn is_non_volatile(expr: &ast::Expr) -> bool {
3638
return false;
3739
};
3840

39-
let non_volatile_name = NON_VOLATILE_FUNCS.contains(name_ref.text().as_str());
41+
let non_volatile_name =
42+
NON_VOLATILE_FUNCS.contains(&Identifier::new(name_ref.text().as_str()));
4043

4144
no_args && non_volatile_name
4245
} else {

crates/squawk_linter/src/rules/constraint_missing_not_valid.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@ use squawk_syntax::{
55
ast::{self, AstNode},
66
};
77

8-
use crate::{Linter, Rule, Violation, text::trim_quotes};
8+
use crate::{
9+
Linter, Rule, Violation,
10+
identifier::Identifier,
11+
};
912

1013
pub fn tables_created_in_transaction(
1114
assume_in_transaction: bool,
1215
file: &ast::SourceFile,
13-
) -> HashSet<String> {
16+
) -> HashSet<Identifier> {
1417
let mut created_table_names = HashSet::new();
1518
let mut inside_transaction = assume_in_transaction;
1619
for stmt in file.stmts() {
@@ -29,7 +32,7 @@ pub fn tables_created_in_transaction(
2932
else {
3033
continue;
3134
};
32-
created_table_names.insert(trim_quotes(&table_name.text()).to_string());
35+
created_table_names.insert(Identifier::new(&table_name.text()));
3336
}
3437
_ => (),
3538
}
@@ -43,7 +46,7 @@ fn not_valid_validate_in_transaction(
4346
file: &ast::SourceFile,
4447
) {
4548
let mut inside_transaction = assume_in_transaction;
46-
let mut not_valid_names: HashSet<String> = HashSet::new();
49+
let mut not_valid_names: HashSet<Identifier> = HashSet::new();
4750
for stmt in file.stmts() {
4851
match stmt {
4952
ast::Stmt::AlterTable(alter_table) => {
@@ -54,7 +57,7 @@ fn not_valid_validate_in_transaction(
5457
validate_constraint.name_ref().map(|x| x.text().to_string())
5558
{
5659
if inside_transaction
57-
&& not_valid_names.contains(trim_quotes(&constraint_name))
60+
&& not_valid_names.contains(&Identifier::new(&constraint_name))
5861
{
5962
ctx.report(
6063
Violation::new(
@@ -70,9 +73,7 @@ fn not_valid_validate_in_transaction(
7073
if add_constraint.not_valid().is_some() {
7174
if let Some(constraint) = add_constraint.constraint() {
7275
if let Some(constraint_name) = constraint.name() {
73-
not_valid_names.insert(
74-
trim_quotes(&constraint_name.text()).to_string(),
75-
);
76+
not_valid_names.insert(Identifier::new(&constraint_name.text()));
7677
}
7778
}
7879
}
@@ -117,7 +118,7 @@ pub(crate) fn constraint_missing_not_valid(ctx: &mut Linter, parse: &Parse<Sourc
117118
};
118119
for action in alter_table.actions() {
119120
if let ast::AlterTableAction::AddConstraint(add_constraint) = action {
120-
if !tables_created.contains(trim_quotes(&table_name))
121+
if !tables_created.contains(&Identifier::new(&table_name))
121122
&& add_constraint.not_valid().is_none()
122123
{
123124
if let Some(ast::Constraint::UniqueConstraint(uc)) =

crates/squawk_linter/src/rules/disallow_unique_constraint.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use squawk_syntax::{
33
ast::{self, AstNode},
44
};
55

6-
use crate::{Linter, Rule, Violation, text::trim_quotes};
6+
use crate::{Linter, Rule, Violation, identifier::Identifier};
77

88
use super::constraint_missing_not_valid::tables_created_in_transaction;
99

@@ -30,7 +30,7 @@ pub(crate) fn disallow_unique_constraint(ctx: &mut Linter, parse: &Parse<SourceF
3030
add_constraint.constraint()
3131
{
3232
if unique_constraint.using_index().is_none()
33-
&& !tables_created.contains(trim_quotes(&table_name))
33+
&& !tables_created.contains(&Identifier::new(&table_name))
3434
{
3535
ctx.report(Violation::new(
3636
Rule::DisallowedUniqueConstraint,

crates/squawk_linter/src/rules/prefer_bigint_over_int.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::collections::HashSet;
33
use squawk_syntax::ast::AstNode;
44
use squawk_syntax::{Parse, SourceFile, ast};
55

6+
use crate::identifier::Identifier;
67
use crate::{Linter, Rule, Violation};
78

89
use crate::visitors::check_not_allowed_types;
@@ -11,8 +12,12 @@ use crate::visitors::is_not_valid_int_type;
1112
use lazy_static::lazy_static;
1213

1314
lazy_static! {
14-
static ref INT_TYPES: HashSet<&'static str> =
15-
HashSet::from(["integer", "int4", "serial", "serial4",]);
15+
static ref INT_TYPES: HashSet<Identifier> = HashSet::from([
16+
Identifier::new("integer"),
17+
Identifier::new("int4"),
18+
Identifier::new("serial"),
19+
Identifier::new("serial4"),
20+
]);
1621
}
1722

1823
fn check_ty_for_big_int(ctx: &mut Linter, ty: Option<ast::Type>) {

crates/squawk_linter/src/rules/prefer_bigint_over_smallint.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::collections::HashSet;
33
use squawk_syntax::ast::AstNode;
44
use squawk_syntax::{Parse, SourceFile, ast};
55

6+
use crate::identifier::Identifier;
67
use crate::{Linter, Rule, Violation};
78

89
use crate::visitors::check_not_allowed_types;
@@ -11,8 +12,12 @@ use crate::visitors::is_not_valid_int_type;
1112
use lazy_static::lazy_static;
1213

1314
lazy_static! {
14-
static ref SMALL_INT_TYPES: HashSet<&'static str> =
15-
HashSet::from(["smallint", "int2", "smallserial", "serial2",]);
15+
static ref SMALL_INT_TYPES: HashSet<Identifier> = HashSet::from([
16+
Identifier::new("smallint"),
17+
Identifier::new("int2"),
18+
Identifier::new("smallserial"),
19+
Identifier::new("serial2"),
20+
]);
1621
}
1722

1823
fn check_ty_for_small_int(ctx: &mut Linter, ty: Option<ast::Type>) {

crates/squawk_linter/src/rules/prefer_identity.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,20 @@ use squawk_syntax::{
55
ast::{self, AstNode},
66
};
77

8-
use crate::{Linter, Rule, Violation};
8+
use crate::{Linter, Rule, Violation, identifier::Identifier};
99

1010
use lazy_static::lazy_static;
1111

1212
use crate::visitors::{check_not_allowed_types, is_not_valid_int_type};
1313

1414
lazy_static! {
15-
static ref SERIAL_TYPES: HashSet<&'static str> = HashSet::from([
16-
"serial",
17-
"serial2",
18-
"serial4",
19-
"serial8",
20-
"smallserial",
21-
"bigserial",
15+
static ref SERIAL_TYPES: HashSet<Identifier> = HashSet::from([
16+
Identifier::new("serial"),
17+
Identifier::new("serial2"),
18+
Identifier::new("serial4"),
19+
Identifier::new("serial8"),
20+
Identifier::new("smallserial"),
21+
Identifier::new("bigserial"),
2222
]);
2323
}
2424

@@ -86,6 +86,23 @@ create table users (
8686
assert_debug_snapshot!(errors);
8787
}
8888

89+
#[test]
90+
fn ok_when_quoted() {
91+
let sql = r#"
92+
create table users (
93+
id "serial"
94+
);
95+
create table users (
96+
id "bigserial"
97+
);
98+
"#;
99+
let file = squawk_syntax::SourceFile::parse(sql);
100+
let mut linter = Linter::from([Rule::PreferIdentity]);
101+
let errors = linter.lint(file, sql);
102+
assert_eq!(errors.len(), 2);
103+
assert_debug_snapshot!(errors);
104+
}
105+
89106
#[test]
90107
fn ok() {
91108
let sql = r#"

crates/squawk_linter/src/rules/prefer_robust_stmts.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use squawk_syntax::{
55
ast::{self, AstNode},
66
};
77

8-
use crate::{Linter, Rule, Violation, text::trim_quotes};
8+
use crate::{
9+
Linter, Rule, Violation,
10+
identifier::Identifier,
11+
};
912

1013
#[derive(PartialEq)]
1114
enum Constraint {
@@ -16,7 +19,7 @@ enum Constraint {
1619
pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse<SourceFile>) {
1720
let file = parse.tree();
1821
let mut inside_transaction = ctx.settings.assume_in_transaction;
19-
let mut constraint_names: HashMap<String, Constraint> = HashMap::new();
22+
let mut constraint_names: HashMap<Identifier, Constraint> = HashMap::new();
2023

2124
let mut total_stmts = 0;
2225
for _ in file.stmts() {
@@ -50,7 +53,7 @@ pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse<SourceFile>) {
5053
ast::AlterTableAction::DropConstraint(drop_constraint) => {
5154
if let Some(constraint_name) = drop_constraint.name_ref() {
5255
constraint_names.insert(
53-
trim_quotes(constraint_name.text().as_str()).to_string(),
56+
Identifier::new(constraint_name.text().as_str()),
5457
Constraint::Dropped,
5558
);
5659
}
@@ -68,7 +71,7 @@ pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse<SourceFile>) {
6871
ast::AlterTableAction::ValidateConstraint(validate_constraint) => {
6972
if let Some(constraint_name) = validate_constraint.name_ref() {
7073
if constraint_names
71-
.contains_key(trim_quotes(constraint_name.text().as_str()))
74+
.contains_key(&Identifier::new(constraint_name.text().as_str()))
7275
{
7376
continue;
7477
}
@@ -79,8 +82,8 @@ pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse<SourceFile>) {
7982
let constraint = add_constraint.constraint();
8083
if let Some(constraint_name) = constraint.and_then(|x| x.name()) {
8184
let name_text = constraint_name.text();
82-
let name = trim_quotes(name_text.as_str());
83-
if let Some(constraint) = constraint_names.get_mut(name) {
85+
let name = Identifier::new(name_text.as_str());
86+
if let Some(constraint) = constraint_names.get_mut(&name) {
8487
if *constraint == Constraint::Dropped {
8588
*constraint = Constraint::Added;
8689
continue;

crates/squawk_linter/src/rules/prefer_text_field.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use squawk_syntax::{
55
ast::{self, AstNode},
66
};
77

8-
use crate::{Linter, Rule, Violation, text::trim_quotes};
8+
use crate::{Linter, Rule, Violation, identifier::Identifier};
99

1010
use crate::visitors::check_not_allowed_types;
1111

@@ -35,10 +35,10 @@ fn is_not_allowed_varchar(ty: &ast::Type) -> bool {
3535
return false;
3636
};
3737
// if we don't have any args, then it's the same as `text`
38-
trim_quotes(ty_name.as_str()) == "varchar" && path_type.arg_list().is_some()
38+
Identifier::new(ty_name.as_str()) == Identifier::new("varchar") && path_type.arg_list().is_some()
3939
}
4040
ast::Type::CharType(char_type) => {
41-
trim_quotes(&char_type.text()) == "varchar" && char_type.arg_list().is_some()
41+
Identifier::new(&char_type.text()) == Identifier::new("varchar") && char_type.arg_list().is_some()
4242
}
4343
ast::Type::BitType(_) => false,
4444
ast::Type::DoubleType(_) => false,

0 commit comments

Comments
 (0)