8787 is_sql_select_statement ,
8888 ExprAliasUpdateDict ,
8989)
90+ import snowflake .snowpark .context as context
9091
9192# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
9293# Python 3.9 can use both
@@ -1377,17 +1378,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
13771378 )
13781379 )
13791380 or (
1380- new_column_states .dropped_columns
1381+ # unflattenable condition: dropped column is used in subquery WHERE clause and dropped column status is NEW or CHANGED in the subquery
1382+ # reason: we should not flatten because the dropped column is not available in the new query, leading to WHERE clause error
1383+ # sample query: 'select "b" from (select "a" as "c", "b" from table where "c" > 1)' can not be flatten to 'select "b" from table where "c" > 1'
1384+ context ._is_snowpark_connect_compatible_mode
1385+ and new_column_states .dropped_columns
13811386 and any (
1382- new_column_states [_col ].change_state == ColumnChangeState .DROPPED
1383- and self .column_states [_col ].change_state
1387+ self .column_states [_col ].change_state
13841388 in (ColumnChangeState .NEW , ColumnChangeState .CHANGED_EXP )
1385- and _col in subquery_dependent_columns
1386- for _col in (new_column_states .dropped_columns )
1389+ for _col in (
1390+ subquery_dependent_columns & new_column_states .dropped_columns
1391+ )
13871392 )
13881393 )
13891394 ):
1390- # or (new_column_states[_col].change_state == ColumnChangeState.DROPPED and self.column_states[_col].change_state in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP))
13911395 can_be_flattened = False
13921396 elif self .order_by and (
13931397 (subquery_dependent_columns := derive_dependent_columns (* self .order_by ))
@@ -1400,13 +1404,17 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14001404 )
14011405 )
14021406 or (
1403- new_column_states .dropped_columns
1407+ # unflattenable condition: dropped column is used in subquery ORDER BY clause and dropped column status is NEW or CHANGED in the subquery
1408+ # reason: we should not flatten because the dropped column is not available in the new query, leading to ORDER BY clause error
1409+ # sample query: 'select "b" from (select "a" as "c", "b" order by "c")' can not be flatten to 'select "b" from table order by "c"'
1410+ context ._is_snowpark_connect_compatible_mode
1411+ and new_column_states .dropped_columns
14041412 and any (
1405- new_column_states [_col ].change_state == ColumnChangeState .DROPPED
1406- and self .column_states [_col ].change_state
1413+ self .column_states [_col ].change_state
14071414 in (ColumnChangeState .NEW , ColumnChangeState .CHANGED_EXP )
1408- and _col in subquery_dependent_columns
1409- for _col in (new_column_states .dropped_columns )
1415+ for _col in (
1416+ subquery_dependent_columns & new_column_states .dropped_columns
1417+ )
14101418 )
14111419 )
14121420 ):
@@ -1478,6 +1486,10 @@ def filter(self, col: Expression) -> "SelectStatement":
14781486 derive_dependent_columns (col ), self .column_states , "filter"
14791487 )
14801488 and not has_data_generator_or_window_function_exp (self .projection )
1489+ and not (
1490+ context ._is_snowpark_connect_compatible_mode
1491+ and has_aggregation_function_exp (self .projection )
1492+ ) # sum(col) as new_col, new_col can not be flattened in where clause
14811493 and not (self .order_by and self .limit_ is not None )
14821494 )
14831495 if can_be_flattened :
@@ -2044,10 +2056,10 @@ def can_clause_dependent_columns_flatten(
20442056 subquery_column_states : ColumnStateDict ,
20452057 clause : Literal ["filter" , "sort" ],
20462058) -> bool :
2047- if clause not in [ "filter" , "sort" ]:
2048- raise ValueError (
2049- f"Invalid clause called in can_clause_dependent_columns_flatten: { clause } "
2050- )
2059+ assert clause in (
2060+ "filter" ,
2061+ "sort" ,
2062+ ), f"Invalid clause called in can_clause_dependent_columns_flatten: { clause } "
20512063 if dependent_columns == COLUMN_DEPENDENCY_DOLLAR :
20522064 return False
20532065 elif (
@@ -2061,11 +2073,19 @@ def can_clause_dependent_columns_flatten(
20612073 for dc in dependent_columns :
20622074 dc_state = subquery_column_states .get (dc )
20632075 if dc_state :
2064- if (
2065- dc_state .change_state == ColumnChangeState .CHANGED_EXP
2066- and clause == "filter"
2067- ):
2068- return False
2076+ if dc_state .change_state == ColumnChangeState .CHANGED_EXP :
2077+ if (
2078+ clause == "filter"
2079+ ): # where can not be flattened because 'where' is evaluated before projection, flattening leads to wrong result
2080+ # df.select((col('a') + 1).alias('a')).filter(col('a') > 5) -- this should be applied to the new 'a', flattening will use the old 'a' to evaluated
2081+ return False
2082+ else : # clause == 'sort'
2083+ # df.select((col('a') + 1).alias('a')).sort(col('a')) -- this is valid to flatten because 'order by' is evaluated after projection
2084+ # however, if the order by is a data generator, it should not be flattened because generator is evaluated dynamically according to the order.
2085+ return context ._is_snowpark_connect_compatible_mode
2086+ elif dc_state .change_state == ColumnChangeState .NEW :
2087+ return context ._is_snowpark_connect_compatible_mode
2088+
20692089 return True
20702090
20712091
@@ -2286,6 +2306,10 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool:
22862306 if expressions is None :
22872307 return False
22882308 for exp in expressions :
2309+ if not context ._is_snowpark_connect_compatible_mode and isinstance (
2310+ exp , WindowExpression
2311+ ):
2312+ return True
22892313 if isinstance (exp , FunctionExpression ) and (
22902314 exp .is_data_generator
22912315 or exp .name .lower () in SEQUENCE_DEPENDENT_DATA_GENERATION
@@ -2311,4 +2335,19 @@ def has_window_function_exp(expressions: Optional[List["Expression"]]) -> bool:
23112335def has_data_generator_or_window_function_exp (
23122336 expressions : Optional [List ["Expression" ]],
23132337) -> bool :
2338+ if not context ._is_snowpark_connect_compatible_mode :
2339+ return has_data_generator_exp (expressions )
23142340 return has_data_generator_exp (expressions ) or has_window_function_exp (expressions )
2341+
2342+
2343+ def has_aggregation_function_exp (expressions : Optional [List ["Expression" ]]) -> bool :
2344+ if expressions is None :
2345+ return False
2346+ for exp in expressions :
2347+ if isinstance (exp , FunctionExpression ) and (
2348+ exp .name .lower () in context ._aggregation_function_set
2349+ ):
2350+ return True
2351+ if exp is not None and has_aggregation_function_exp (exp .children ):
2352+ return True
2353+ return False
0 commit comments