@@ -2,19 +2,11 @@ use lazy_static::lazy_static;
22use std:: collections:: HashSet ;
33
44use squawk_syntax:: ast:: AstNode ;
5- use squawk_syntax:: { Parse , SourceFile } ;
5+ use squawk_syntax:: { Parse , SourceFile , SyntaxKind } ;
66use squawk_syntax:: { ast, identifier:: Identifier } ;
77
88use crate :: { Linter , Rule , Version , Violation } ;
99
10- fn is_const_expr ( expr : & ast:: Expr ) -> bool {
11- match expr {
12- ast:: Expr :: Literal ( _) => true ,
13- ast:: Expr :: CastExpr ( cast) => matches ! ( cast. expr( ) , Some ( ast:: Expr :: Literal ( _) ) ) ,
14- _ => false ,
15- }
16- }
17-
1810lazy_static ! {
1911 static ref NON_VOLATILE_FUNCS : HashSet <Identifier > = {
2012 NON_VOLATILE_BUILT_IN_FUNCTIONS
@@ -26,8 +18,18 @@ lazy_static! {
2618 } ;
2719}
2820
29- fn is_non_volatile ( expr : & ast:: Expr ) -> bool {
21+ fn is_non_volatile_or_const ( expr : & ast:: Expr ) -> bool {
3022 match expr {
23+ ast:: Expr :: Literal ( _) => true ,
24+ ast:: Expr :: ArrayExpr ( _) => true ,
25+ ast:: Expr :: BinExpr ( bin_expr) => {
26+ if let Some ( lhs) = bin_expr. lhs ( ) {
27+ if let Some ( rhs) = bin_expr. rhs ( ) {
28+ return is_non_volatile_or_const ( & lhs) && is_non_volatile_or_const ( & rhs) ;
29+ }
30+ }
31+ false
32+ }
3133 ast:: Expr :: CallExpr ( call_expr) => {
3234 if let Some ( arglist) = call_expr. arg_list ( ) {
3335 let no_args = arglist. args ( ) . count ( ) == 0 ;
@@ -45,6 +47,24 @@ fn is_non_volatile(expr: &ast::Expr) -> bool {
4547 false
4648 }
4749 }
50+ // array[]::t[] is non-volatile. We don't check for a plain array expr
51+ // since postgres will reject it as a default unless it's cast to a type.
52+ ast:: Expr :: CastExpr ( cast_expr) => {
53+ if let Some ( inner_expr) = cast_expr. expr ( ) {
54+ is_non_volatile_or_const ( & inner_expr)
55+ } else {
56+ false
57+ }
58+ }
59+ // current_timestamp is the same as calling now()
60+ ast:: Expr :: NameRef ( name_ref) => {
61+ if let Some ( child) = name_ref. syntax ( ) . first_child_or_token ( ) {
62+ if child. kind ( ) == SyntaxKind :: CURRENT_TIMESTAMP_KW {
63+ return true ;
64+ }
65+ }
66+ false
67+ }
4868 _ => false ,
4969 }
5070}
@@ -69,7 +89,7 @@ pub(crate) fn adding_field_with_default(ctx: &mut Linter, parse: &Parse<SourceFi
6989 continue ;
7090 } ;
7191 if ctx. settings . pg_version > Version :: new ( 11 , None , None )
72- && ( is_const_expr ( & expr) || is_non_volatile ( & expr ) )
92+ && is_non_volatile_or_const ( & expr)
7393 {
7494 continue ;
7595 }
@@ -181,6 +201,33 @@ ALTER TABLE "core_recipe" ADD COLUMN "foo" boolean DEFAULT true;
181201 assert_debug_snapshot ! ( errors) ;
182202 }
183203
204+ #[ test]
205+ fn default_empty_array_ok ( ) {
206+ let sql = r#"
207+ alter table t add column a double precision[] default array[]::double precision[];
208+
209+ alter table t add column b bigint[] default cast(array[] as bigint[]);
210+
211+ alter table t add column c text[] default array['foo', 'bar']::text[];
212+ "# ;
213+
214+ let errors = lint ( sql, Rule :: AddingFieldWithDefault ) ;
215+ assert ! ( errors. is_empty( ) ) ;
216+ assert_debug_snapshot ! ( errors) ;
217+ }
218+
219+ #[ test]
220+ fn default_with_const_bin_expr ( ) {
221+ let sql = r#"
222+ ALTER TABLE assessments
223+ ADD COLUMN statistics_last_updated_at timestamptz NOT NULL DEFAULT now() - interval '100 years';
224+ "# ;
225+
226+ let errors = lint ( sql, Rule :: AddingFieldWithDefault ) ;
227+ assert ! ( errors. is_empty( ) ) ;
228+ assert_debug_snapshot ! ( errors) ;
229+ }
230+
184231 #[ test]
185232 fn default_str_ok ( ) {
186233 let sql = r#"
@@ -240,6 +287,7 @@ ALTER TABLE "core_recipe" ADD COLUMN "foo" timestamptz DEFAULT now(123);
240287 assert ! ( !errors. is_empty( ) ) ;
241288 assert_debug_snapshot ! ( errors) ;
242289 }
290+
243291 #[ test]
244292 fn default_func_now_ok ( ) {
245293 let sql = r#"
@@ -252,14 +300,25 @@ ALTER TABLE "core_recipe" ADD COLUMN "foo" timestamptz DEFAULT now();
252300 assert_debug_snapshot ! ( errors) ;
253301 }
254302
303+ #[ test]
304+ fn default_func_current_timestamp_ok ( ) {
305+ let sql = r#"
306+ alter table t add column c timestamptz default current_timestamp;
307+ "# ;
308+
309+ let errors = lint ( sql, Rule :: AddingFieldWithDefault ) ;
310+ assert ! ( errors. is_empty( ) ) ;
311+ assert_debug_snapshot ! ( errors) ;
312+ }
313+
255314 #[ test]
256315 fn add_numbers_ok ( ) {
257- // This should be okay, but we don't handle expressions like this at the moment.
258316 let sql = r#"
259317alter table account_metadata add column blah integer default 2 + 2;
260318 "# ;
261319
262320 let errors = lint ( sql, Rule :: AddingFieldWithDefault ) ;
321+ assert ! ( errors. is_empty( ) ) ;
263322 assert_debug_snapshot ! ( errors) ;
264323 }
265324
0 commit comments