Skip to content

Commit 205f97e

Browse files
authored
Merge branch 'main' into aling/v4-diamond-join
2 parents 0867348 + 66e7b65 commit 205f97e

File tree

8 files changed

+111
-40
lines changed

8 files changed

+111
-40
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
- Fixed a bug where creating a Dataframe with large number of values raised `Unsupported feature 'SCOPED_TEMPORARY'.` error if thread-safe session was disabled.
1717
- Fixed a bug where `df.describe` raised internal SQL execution error when the dataframe is created from reading a stage file and CTE optimization is enabled.
18+
- Fixed a bug where `df.order_by(A).select(B).distinct()` would generate invalid SQL when simplified query generation was enabled using `session.conf.set("use_simplified_query_generation", True)`.
19+
- Disabled simplified query generation by default.
1820

1921
#### Improvements
2022

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,9 @@ def distinct(self) -> "SelectStatement":
12781278
# has a limit clause to avoid moving distinct in front of limit.
12791279
and (not self.limit_)
12801280
and (not self.offset)
1281+
# .order_by(col1).select(col2).distinct() cannot be flattened because
1282+
# SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL
1283+
and (not (self.order_by and self.projection))
12811284
and not has_data_generator_exp(self.projection)
12821285
)
12831286
if can_be_flattened:

src/snowflake/snowpark/_internal/telemetry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class TelemetryField(Enum):
112112
"select_expr": 1,
113113
"drop": 1,
114114
"agg": 2,
115+
"distinct": 2,
115116
"with_column": 1,
116117
"with_columns": 1,
117118
"with_column_renamed": 1,

src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from snowflake.snowpark.modin.plugin._internal.utils import (
4141
TempObjectType,
4242
generate_snowflake_quoted_identifiers_helper,
43+
get_default_snowpark_pandas_statement_params,
4344
parse_object_construct_snowflake_quoted_identifier_and_extract_pandas_label,
4445
parse_snowflake_object_construct_identifier_to_map,
4546
)
@@ -303,6 +304,7 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover
303304
# We have to use the current pandas version to ensure the behavior consistency
304305
packages=[native_pd] + packages,
305306
session=session,
307+
statement_params=get_default_snowpark_pandas_statement_params(),
306308
)
307309

308310
return func_udtf
@@ -683,6 +685,7 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def
683685
# behavior is consistent with client-side pandas behavior.
684686
packages=[native_pd] + list(session.get_packages().values()),
685687
session=session,
688+
statement_params=get_default_snowpark_pandas_statement_params(),
686689
)
687690

688691

@@ -947,6 +950,7 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def
947950
# behavior is consistent with client-side pandas behavior.
948951
packages=[native_pd] + list(session.get_packages().values()),
949952
session=session,
953+
statement_params=get_default_snowpark_pandas_statement_params(),
950954
)
951955

952956

@@ -1019,6 +1023,7 @@ def apply_func(x): # type: ignore[no-untyped-def] # pragma: no cover
10191023
strict=bool(na_action == "ignore"),
10201024
session=session,
10211025
packages=packages,
1026+
statement_params=get_default_snowpark_pandas_statement_params(),
10221027
)
10231028
return func_udf
10241029

src/snowflake/snowpark/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def __init__(self, session: "Session", conf: Dict[str, Any]) -> None:
374374
"use_constant_subquery_alias": True,
375375
"flatten_select_after_filter_and_orderby": True,
376376
"collect_stacktrace_in_query_tag": False,
377-
"use_simplified_query_generation": True,
377+
"use_simplified_query_generation": False,
378378
} # For config that's temporary/to be removed soon
379379
self._lock = self._session._lock
380380
for key, val in conf.items():

tests/integ/scala/test_snowflake_plan_suite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def test_plan_height(session, temp_table, sql_simplifier_enabled):
210210

211211
aggregate1 = df3.distinct()
212212
if sql_simplifier_enabled:
213-
assert aggregate1._plan.plan_state[PlanState.PLAN_HEIGHT] == 2
213+
assert aggregate1._plan.plan_state[PlanState.PLAN_HEIGHT] == 4
214214
else:
215215
assert aggregate1._plan.plan_state[PlanState.PLAN_HEIGHT] == 3
216216

tests/integ/test_simplifier_suite.py

Lines changed: 82 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -140,44 +140,49 @@ def test_set_same_operator(session, set_operator):
140140
],
141141
)
142142
def test_distinct_set_operator(session, distinct_table, action, operator):
143-
df1 = session.table(distinct_table)
144-
df2 = session.table(distinct_table)
143+
try:
144+
original = session.conf.get("use_simplified_query_generation")
145+
session.conf.set("use_simplified_query_generation", True)
146+
df1 = session.table(distinct_table)
147+
df2 = session.table(distinct_table)
145148

146-
df = action(df1, df2.distinct())
147-
assert (
148-
df.queries["queries"][0]
149-
== f"""( SELECT * FROM {distinct_table}){operator}( SELECT DISTINCT * FROM {distinct_table})"""
150-
)
149+
df = action(df1, df2.distinct())
150+
assert (
151+
df.queries["queries"][0]
152+
== f"""( SELECT * FROM {distinct_table}){operator}( SELECT DISTINCT * FROM {distinct_table})"""
153+
)
151154

152-
df = action(df1.distinct(), df2)
153-
assert (
154-
df.queries["queries"][0]
155-
== f"""( SELECT DISTINCT * FROM {distinct_table}){operator}( SELECT * FROM {distinct_table})"""
156-
)
155+
df = action(df1.distinct(), df2)
156+
assert (
157+
df.queries["queries"][0]
158+
== f"""( SELECT DISTINCT * FROM {distinct_table}){operator}( SELECT * FROM {distinct_table})"""
159+
)
157160

158-
df = action(df1, df2).distinct()
159-
assert (
160-
df.queries["queries"][0]
161-
== f"""SELECT DISTINCT * FROM (( SELECT * FROM {distinct_table}){operator}( SELECT * FROM {distinct_table}))"""
162-
)
161+
df = action(df1, df2).distinct()
162+
assert (
163+
df.queries["queries"][0]
164+
== f"""SELECT DISTINCT * FROM (( SELECT * FROM {distinct_table}){operator}( SELECT * FROM {distinct_table}))"""
165+
)
163166

164-
df = action(df1, df2.distinct()).distinct()
165-
assert (
166-
df.queries["queries"][0]
167-
== f"""SELECT DISTINCT * FROM (( SELECT * FROM {distinct_table}){operator}( SELECT DISTINCT * FROM {distinct_table}))"""
168-
)
167+
df = action(df1, df2.distinct()).distinct()
168+
assert (
169+
df.queries["queries"][0]
170+
== f"""SELECT DISTINCT * FROM (( SELECT * FROM {distinct_table}){operator}( SELECT DISTINCT * FROM {distinct_table}))"""
171+
)
169172

170-
df = action(df1.distinct(), df2).distinct()
171-
assert (
172-
df.queries["queries"][0]
173-
== f"""SELECT DISTINCT * FROM (( SELECT DISTINCT * FROM {distinct_table}){operator}( SELECT * FROM {distinct_table}))"""
174-
)
173+
df = action(df1.distinct(), df2).distinct()
174+
assert (
175+
df.queries["queries"][0]
176+
== f"""SELECT DISTINCT * FROM (( SELECT DISTINCT * FROM {distinct_table}){operator}( SELECT * FROM {distinct_table}))"""
177+
)
175178

176-
df = action(df1.distinct(), df2.distinct()).distinct()
177-
assert (
178-
df.queries["queries"][0]
179-
== f"""SELECT DISTINCT * FROM (( SELECT DISTINCT * FROM {distinct_table}){operator}( SELECT DISTINCT * FROM {distinct_table}))"""
180-
)
179+
df = action(df1.distinct(), df2.distinct()).distinct()
180+
assert (
181+
df.queries["queries"][0]
182+
== f"""SELECT DISTINCT * FROM (( SELECT DISTINCT * FROM {distinct_table}){operator}( SELECT DISTINCT * FROM {distinct_table}))"""
183+
)
184+
finally:
185+
session.conf.set("use_simplified_query_generation", original)
181186

182187

183188
@pytest.mark.parametrize("set_operator", [SET_UNION_ALL, SET_EXCEPT, SET_INTERSECT])
@@ -1486,19 +1491,58 @@ def test_select_limit_orderby(session):
14861491
[Row(1, "c"), Row(3, "b"), Row(3, "c"), Row(5, "a")],
14871492
False,
14881493
),
1494+
(
1495+
lambda df: df.sort(col("a"), col("b")).distinct(),
1496+
lambda table: f"""SELECT DISTINCT * FROM {table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST""",
1497+
[Row(1, "c"), Row(3, "b"), Row(3, "c"), Row(5, "a")],
1498+
True,
1499+
),
14891500
(
14901501
lambda df: df.select("a", "b").sort(col("a"), col("b")).distinct(),
1491-
lambda table: f"""SELECT DISTINCT "A", "B" FROM {table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST""",
1502+
lambda table: f"""SELECT DISTINCT * FROM ( SELECT "A", "B" FROM {table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST)""",
14921503
[Row(1, "c"), Row(3, "b"), Row(3, "c"), Row(5, "a")],
14931504
True,
14941505
),
1506+
# df.sort(A).select(B).distinct()
1507+
(
1508+
lambda df: df.sort(col("a")).select("b").distinct(),
1509+
lambda table: f"""SELECT DISTINCT * FROM ( SELECT "B" FROM {table} ORDER BY "A" ASC NULLS FIRST)""",
1510+
[Row("a"), Row("b"), Row("c")],
1511+
True,
1512+
),
1513+
# df.sort(A).distinct().select(B)
1514+
(
1515+
lambda df: df.sort(col("a")).distinct().select("b"),
1516+
lambda table: f"""SELECT "B" FROM ( SELECT DISTINCT * FROM {table} ORDER BY "A" ASC NULLS FIRST)""",
1517+
[Row("a"), Row("b"), Row("c"), Row("c")],
1518+
True,
1519+
),
1520+
# df.filter(A).select(B).distinct()
1521+
(
1522+
lambda df: df.filter(col("a") > 1).select("b").distinct(),
1523+
lambda table: f"""SELECT DISTINCT "B" FROM {table} WHERE ("A" > 1)""",
1524+
[Row("a"), Row("b"), Row("c")],
1525+
True,
1526+
),
1527+
# df.filter(A).distinct().select(B)
1528+
(
1529+
lambda df: df.filter(col("a") > 1).distinct().select("b"),
1530+
lambda table: f"""SELECT "B" FROM ( SELECT DISTINCT * FROM {table} WHERE ("A" > 1))""",
1531+
[Row("a"), Row("b"), Row("c")],
1532+
True,
1533+
),
14951534
],
14961535
)
14971536
def test_select_distinct(
14981537
session, distinct_table, operation, expected_query, expected_result, sort_results
14991538
):
1500-
df = session.table(distinct_table)
1501-
df1 = operation(df)
1502-
if expected_result is not None:
1503-
Utils.check_answer(df1, expected_result, sort=sort_results)
1504-
assert df1.queries["queries"][0] == expected_query(distinct_table)
1539+
try:
1540+
original = session.conf.get("use_simplified_query_generation")
1541+
session.conf.set("use_simplified_query_generation", True)
1542+
df = session.table(distinct_table)
1543+
df1 = operation(df)
1544+
if expected_result is not None:
1545+
Utils.check_answer(df1, expected_result, sort=sort_results)
1546+
assert df1.queries["queries"][0] == expected_query(distinct_table)
1547+
finally:
1548+
session.conf.set("use_simplified_query_generation", original)

tests/integ/test_telemetry.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ def test_drop_duplicates_api_calls(session):
297297
"subcalls": [
298298
{
299299
"name": "DataFrame.distinct",
300+
"subcalls": [
301+
{"name": "DataFrame.group_by"},
302+
{"name": "RelationalGroupedDataFrame.agg"},
303+
],
300304
}
301305
],
302306
},
@@ -444,6 +448,10 @@ def test_distinct_api_calls(session):
444448
{"name": "DataFrame.to_df", "subcalls": [{"name": "DataFrame.select"}]},
445449
{
446450
"name": "DataFrame.distinct",
451+
"subcalls": [
452+
{"name": "DataFrame.group_by"},
453+
{"name": "RelationalGroupedDataFrame.agg"},
454+
],
447455
},
448456
]
449457
# check to make sure that the original DF is unchanged
@@ -460,6 +468,10 @@ def test_distinct_api_calls(session):
460468
{"name": "DataFrame.select"},
461469
{
462470
"name": "DataFrame.distinct",
471+
"subcalls": [
472+
{"name": "DataFrame.group_by"},
473+
{"name": "RelationalGroupedDataFrame.agg"},
474+
],
463475
},
464476
{"name": "DataFrame.sort"},
465477
]
@@ -470,6 +482,10 @@ def test_distinct_api_calls(session):
470482
{"name": "DataFrame.select"},
471483
{
472484
"name": "DataFrame.distinct",
485+
"subcalls": [
486+
{"name": "DataFrame.group_by"},
487+
{"name": "RelationalGroupedDataFrame.agg"},
488+
],
473489
},
474490
]
475491

0 commit comments

Comments
 (0)