Skip to content

Commit d92ea24

Browse files
committed
Loosen flattening rules for sort and filter
1 parent d5b232d commit d92ea24

File tree

2 files changed

+80
-33
lines changed

2 files changed

+80
-33
lines changed

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Sequence,
2121
Set,
2222
Union,
23+
Literal,
2324
)
2425

2526
import snowflake.snowpark._internal.utils
@@ -1362,7 +1363,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
13621363
):
13631364
# TODO: Clean up, this entire if case is parameter protection
13641365
can_be_flattened = False
1365-
elif (self.where or self.order_by or self.limit_) and has_data_generator_exp(
1366+
elif (self.where or self.order_by or self.limit_) and has_data_generator_or_window_function_exp(
13661367
cols
13671368
):
13681369
can_be_flattened = False
@@ -1453,9 +1454,9 @@ def filter(self, col: Expression) -> "SelectStatement":
14531454
can_be_flattened = (
14541455
(not self.flatten_disabled)
14551456
and can_clause_dependent_columns_flatten(
1456-
derive_dependent_columns(col), self.column_states
1457+
derive_dependent_columns(col), self.column_states, "filter"
14571458
)
1458-
and not has_data_generator_exp(self.projection)
1459+
and not has_data_generator_or_window_function_exp(self.projection)
14591460
and not (self.order_by and self.limit_ is not None)
14601461
)
14611462
if can_be_flattened:
@@ -1490,7 +1491,7 @@ def sort(self, cols: List[Expression]) -> "SelectStatement":
14901491
and (not self.limit_)
14911492
and (not self.offset)
14921493
and can_clause_dependent_columns_flatten(
1493-
derive_dependent_columns(*cols), self.column_states
1494+
derive_dependent_columns(*cols), self.column_states, "sort"
14941495
)
14951496
and not has_data_generator_exp(self.projection)
14961497
)
@@ -1529,7 +1530,7 @@ def distinct(self) -> "SelectStatement":
15291530
# .order_by(col1).select(col2).distinct() cannot be flattened because
15301531
# SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL
15311532
and (not (self.order_by and self.has_projection))
1532-
and not has_data_generator_exp(self.projection)
1533+
and not has_data_generator_or_window_function_exp(self.projection)
15331534
)
15341535
if can_be_flattened:
15351536
new = copy(self)
@@ -2020,7 +2021,10 @@ def can_projection_dependent_columns_be_flattened(
20202021
def can_clause_dependent_columns_flatten(
20212022
dependent_columns: Optional[AbstractSet[str]],
20222023
subquery_column_states: ColumnStateDict,
2024+
clause: Literal["filter", "sort"],
20232025
) -> bool:
2026+
if clause not in ["filter", "sort"]:
2027+
raise ValueError(f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}")
20242028
if dependent_columns == COLUMN_DEPENDENCY_DOLLAR:
20252029
return False
20262030
elif (
@@ -2034,15 +2038,7 @@ def can_clause_dependent_columns_flatten(
20342038
for dc in dependent_columns:
20352039
dc_state = subquery_column_states.get(dc)
20362040
if dc_state:
2037-
if dc_state.change_state == ColumnChangeState.CHANGED_EXP:
2038-
return False
2039-
elif dc_state.change_state == ColumnChangeState.NEW:
2040-
# Most of the time this can be flattened. But if a new column uses window function and this column
2041-
# is used in a clause, the sql doesn't work in Snowflake.
2042-
# For instance `select a, rank() over(order by b) as d from test_table where d = 1` doesn't work.
2043-
# But `select a, b as d from test_table where d = 1` works
2044-
# We can inspect whether the referenced new column uses window function. Here we are being
2045-
# conservative for now to not flatten the SQL.
2041+
if dc_state.change_state == ColumnChangeState.CHANGED_EXP and clause == "filter":
20462042
return False
20472043
return True
20482044

@@ -2264,8 +2260,6 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool:
22642260
if expressions is None:
22652261
return False
22662262
for exp in expressions:
2267-
if isinstance(exp, WindowExpression):
2268-
return True
22692263
if isinstance(exp, FunctionExpression) and (
22702264
exp.is_data_generator
22712265
or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION
@@ -2275,3 +2269,18 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool:
22752269
if exp is not None and has_data_generator_exp(exp.children):
22762270
return True
22772271
return False
2272+
2273+
2274+
def has_window_function_exp(expressions: Optional[List["Expression"]]) -> bool:
2275+
if expressions is None:
2276+
return False
2277+
for exp in expressions:
2278+
if isinstance(exp, WindowExpression):
2279+
return True
2280+
if exp is not None and has_window_function_exp(exp.children):
2281+
return True
2282+
return False
2283+
2284+
2285+
def has_data_generator_or_window_function_exp(expressions: Optional[List["Expression"]]) -> bool:
2286+
return has_data_generator_exp(expressions) or has_window_function_exp(expressions)

tests/integ/test_simplifier_suite.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytest
1111

12-
from snowflake.snowpark import Row
12+
from snowflake.snowpark import Row, Window
1313
from snowflake.snowpark._internal.analyzer.select_statement import (
1414
SET_EXCEPT,
1515
SET_INTERSECT,
@@ -30,6 +30,7 @@
3030
sum as sum_,
3131
table_function,
3232
udtf,
33+
rank,
3334
)
3435
from tests.utils import TestData, Utils
3536

@@ -754,21 +755,34 @@ def test_order_by(setup_reduce_cast, session, simplifier_table):
754755
f'SELECT "A", "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST'
755756
)
756757

757-
# no flatten because c is a new column
758+
# flatten if a new column is used in the order by clause
758759
df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).sort("a", "b", "c")
759760
assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql(
760-
f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST'
761+
f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST'
761762
)
762763

763-
# no flatten because a and be are changed
764+
# still flatten even if a is changed because it's used in the order by clause
764765
df4 = df.select((col("a") + 1).as_("a"), ((col("b") + 1).as_("b"))).sort("a", "b")
765766
assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql(
766-
f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST'
767+
f'SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST'
767768
)
768769

769-
# subquery has sql text so unable to figure out same-level dependency, so assuming d depends on c. No flatten.
770-
df5 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).sort("a", "b")
770+
# still flatten if a window function is used in the projection
771+
df5 = df.select("a", "b", rank().over(Window.order_by("b")).alias("c")).sort("a", "b")
771772
assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql(
773+
f'SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST'
774+
)
775+
776+
777+
# No flatten if a data generator is used in the projection
778+
df6 = df.select("a", "b", seq1().alias("c")).sort("a", "b")
779+
assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql(
780+
f'SELECT * FROM ( SELECT "A", "B", seq1(0) AS "C" FROM {simplifier_table}) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST'
781+
)
782+
783+
# subquery has sql text so unable to figure out if a data generator is used in the projection. No flatten.
784+
df7 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).sort("a", "b")
785+
assert Utils.normalize_sql(df7.queries["queries"][-1]) == Utils.normalize_sql(
772786
f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST'
773787
)
774788

@@ -790,33 +804,57 @@ def test_filter(setup_reduce_cast, session, simplifier_table):
790804
assert Utils.normalize_sql(df2.queries["queries"][-1]) == Utils.normalize_sql(
791805
f'SELECT "A", "B" FROM {simplifier_table} WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))'
792806
)
793-
794-
# no flatten because c is a new column
807+
808+
# flatten if a regular new column is in the projection
795809
df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter(
796-
(col("a") > 1) & (col("b") > 2) & (col("c") < 1)
810+
(col("a") > 1) & (col("b") > 2)
797811
)
798812
assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql(
799-
f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))'
813+
f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))'
814+
)
815+
816+
# flatten if a regular new column is used in the filter clause
817+
df4 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter(
818+
(col("a") > 1) & (col("b") > 2) & (col("c") < 1)
819+
)
820+
assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql(
821+
f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))'
822+
)
823+
824+
# no flatten if a window function is used in the projection
825+
df5 = df.select("a", "b", rank().over(Window.order_by("b")).alias("c")).filter(
826+
(col("a") > 1) & (col("b") > 2) & (col("c") < 1)
827+
)
828+
assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql(
829+
f'SELECT * FROM ( SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))'
830+
)
831+
832+
# no flatten if a data generator is used in the projection
833+
df6 = df.select("a", "b", seq1().alias("c")).filter(
834+
(col("a") > 1) & (col("b") > 2) & (col("c") < 1)
835+
)
836+
assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql(
837+
f'SELECT * FROM ( SELECT "A", "B", seq1(0) AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))'
800838
)
801839

802840
# no flatten because a and be are changed
803-
df4 = df.select((col("a") + 1).as_("a"), (col("b") + 1).as_("b")).filter(
841+
df7 = df.select((col("a") + 1).as_("a"), (col("b") + 1).as_("b")).filter(
804842
(col("a") > 1) & (col("b") > 2)
805843
)
806-
assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql(
844+
assert Utils.normalize_sql(df7.queries["queries"][-1]) == Utils.normalize_sql(
807845
f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))'
808846
)
809847

810-
df5 = df4.select("a")
811-
assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql(
848+
df8 = df7.select("a")
849+
assert Utils.normalize_sql(df8.queries["queries"][-1]) == Utils.normalize_sql(
812850
f'SELECT "A" FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))'
813851
)
814852

815853
# subquery has sql text so unable to figure out same-level dependency, so assuming d depends on c. No flatten.
816-
df6 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).filter(
854+
df9 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).filter(
817855
col("a") > 1
818856
)
819-
assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql(
857+
assert Utils.normalize_sql(df9.queries["queries"][-1]) == Utils.normalize_sql(
820858
f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) WHERE ("A" > 1{integer_literal_postfix})'
821859
)
822860

0 commit comments

Comments
 (0)