Skip to content

Commit fc75f21

Browse files
committed
parameter protection and agg function check for fitler
1 parent 1d3ad20 commit fc75f21

File tree

3 files changed

+74
-20
lines changed

3 files changed

+74
-20
lines changed

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
is_sql_select_statement,
8888
ExprAliasUpdateDict,
8989
)
90+
import snowflake.snowpark.context as context
9091

9192
# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
9293
# Python 3.9 can use both
@@ -1377,17 +1378,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
13771378
)
13781379
)
13791380
or (
1380-
new_column_states.dropped_columns
1381+
# unflattenable condition: dropped column is used in subquery WHERE clause and dropped column status is NEW or CHANGED in the subquery
1382+
# reason: we should not flatten because the dropped column is not available in the new query, leading to WHERE clause error
1383+
# sample query: 'select "b" from (select "a" as "c", "b" from table where "c" > 1)' can not be flatten to 'select "b" from table where "c" > 1'
1384+
context._is_snowpark_connect_compatible_mode
1385+
and new_column_states.dropped_columns
13811386
and any(
1382-
new_column_states[_col].change_state == ColumnChangeState.DROPPED
1383-
and self.column_states[_col].change_state
1387+
self.column_states[_col].change_state
13841388
in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)
1385-
and _col in subquery_dependent_columns
1386-
for _col in (new_column_states.dropped_columns)
1389+
for _col in (
1390+
subquery_dependent_columns & new_column_states.dropped_columns
1391+
)
13871392
)
13881393
)
13891394
):
1390-
# or (new_column_states[_col].change_state == ColumnChangeState.DROPPED and self.column_states[_col].change_state in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP))
13911395
can_be_flattened = False
13921396
elif self.order_by and (
13931397
(subquery_dependent_columns := derive_dependent_columns(*self.order_by))
@@ -1400,13 +1404,17 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14001404
)
14011405
)
14021406
or (
1403-
new_column_states.dropped_columns
1407+
# unflattenable condition: dropped column is used in subquery ORDER BY clause and dropped column status is NEW or CHANGED in the subquery
1408+
# reason: we should not flatten because the dropped column is not available in the new query, leading to ORDER BY clause error
1409+
# sample query: 'select "b" from (select "a" as "c", "b" order by "c")' can not be flatten to 'select "b" from table order by "c"'
1410+
context._is_snowpark_connect_compatible_mode
1411+
and new_column_states.dropped_columns
14041412
and any(
1405-
new_column_states[_col].change_state == ColumnChangeState.DROPPED
1406-
and self.column_states[_col].change_state
1413+
self.column_states[_col].change_state
14071414
in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)
1408-
and _col in subquery_dependent_columns
1409-
for _col in (new_column_states.dropped_columns)
1415+
for _col in (
1416+
subquery_dependent_columns & new_column_states.dropped_columns
1417+
)
14101418
)
14111419
)
14121420
):
@@ -1478,6 +1486,10 @@ def filter(self, col: Expression) -> "SelectStatement":
14781486
derive_dependent_columns(col), self.column_states, "filter"
14791487
)
14801488
and not has_data_generator_or_window_function_exp(self.projection)
1489+
and not (
1490+
context._is_snowpark_connect_compatible_mode
1491+
and has_aggregation_function_exp(self.projection)
1492+
) # sum(col) as new_col, new_col can not be flattened in where clause
14811493
and not (self.order_by and self.limit_ is not None)
14821494
)
14831495
if can_be_flattened:
@@ -2044,10 +2056,10 @@ def can_clause_dependent_columns_flatten(
20442056
subquery_column_states: ColumnStateDict,
20452057
clause: Literal["filter", "sort"],
20462058
) -> bool:
2047-
if clause not in ["filter", "sort"]:
2048-
raise ValueError(
2049-
f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}"
2050-
)
2059+
assert clause in (
2060+
"filter",
2061+
"sort",
2062+
), f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}"
20512063
if dependent_columns == COLUMN_DEPENDENCY_DOLLAR:
20522064
return False
20532065
elif (
@@ -2061,11 +2073,19 @@ def can_clause_dependent_columns_flatten(
20612073
for dc in dependent_columns:
20622074
dc_state = subquery_column_states.get(dc)
20632075
if dc_state:
2064-
if (
2065-
dc_state.change_state == ColumnChangeState.CHANGED_EXP
2066-
and clause == "filter"
2067-
):
2068-
return False
2076+
if dc_state.change_state == ColumnChangeState.CHANGED_EXP:
2077+
if (
2078+
clause == "filter"
2079+
): # where can not be flattened because 'where' is evaluated before projection, flattening leads to wrong result
2080+
# df.select((col('a') + 1).alias('a')).filter(col('a') > 5) -- this should be applied to the new 'a', flattening will use the old 'a' to evaluated
2081+
return False
2082+
else: # clause == 'sort'
2083+
# df.select((col('a') + 1).alias('a')).sort(col('a')) -- this is valid to flatten because 'order by' is evaluated after projection
2084+
# however, if the order by is a data generator, it should not be flattened because generator is evaluated dynamically according to the order.
2085+
return context._is_snowpark_connect_compatible_mode
2086+
elif dc_state.change_state == ColumnChangeState.NEW:
2087+
return context._is_snowpark_connect_compatible_mode
2088+
20692089
return True
20702090

20712091

@@ -2286,6 +2306,10 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool:
22862306
if expressions is None:
22872307
return False
22882308
for exp in expressions:
2309+
if not context._is_snowpark_connect_compatible_mode and isinstance(
2310+
exp, WindowExpression
2311+
):
2312+
return True
22892313
if isinstance(exp, FunctionExpression) and (
22902314
exp.is_data_generator
22912315
or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION
@@ -2311,4 +2335,19 @@ def has_window_function_exp(expressions: Optional[List["Expression"]]) -> bool:
23112335
def has_data_generator_or_window_function_exp(
23122336
expressions: Optional[List["Expression"]],
23132337
) -> bool:
2338+
if not context._is_snowpark_connect_compatible_mode:
2339+
return has_data_generator_exp(expressions)
23142340
return has_data_generator_exp(expressions) or has_window_function_exp(expressions)
2341+
2342+
2343+
def has_aggregation_function_exp(expressions: Optional[List["Expression"]]) -> bool:
2344+
if expressions is None:
2345+
return False
2346+
for exp in expressions:
2347+
if isinstance(exp, FunctionExpression) and (
2348+
exp.name.lower() in context._aggregation_function_set
2349+
):
2350+
return True
2351+
if exp is not None and has_aggregation_function_exp(exp.children):
2352+
return True
2353+
return False

src/snowflake/snowpark/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
# This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect
3333
_is_snowpark_connect_compatible_mode = False
34+
_aggregation_function_set = set()
3435

3536
# Following are internal-only global flags, used to enable development features.
3637
_enable_dataframe_trace_on_error = False

src/snowflake/snowpark/session.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,20 @@ def create(self) -> "Session":
521521
_add_session(session)
522522
else:
523523
session = self._create_internal(self._options.get("connection"))
524+
if context._is_snowpark_connect_compatible_mode:
525+
for sql in [
526+
"""select function_name from information_schema.functions where is_aggregate = 'YES'""",
527+
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
528+
]:
529+
try:
530+
context._aggregation_function_set.update(
531+
{r[0] for r in session.sql(sql).collect()}
532+
)
533+
except BaseException as e:
534+
_logger.debug(
535+
"Unable to get aggregation functions from the database: %s",
536+
e,
537+
)
524538

525539
if self._app_name:
526540
if self._format_json:

0 commit comments

Comments
 (0)