Skip to content

Commit 44f7114

Browse files
committed
Merge branch 'main' of github.com:snowflakedb/snowpark-python into bkogan-default-artifact
2 parents eea6081 + fac826e commit 44f7114

File tree

10 files changed

+86
-22
lines changed

10 files changed

+86
-22
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
#### Improvements
1818

19+
- Reduced the size of queries generated by certain `DataFrame.join` operations.
20+
- Removed redundant aliases in generated queries (for example, `SELECT "A" AS "A"` is now always simplified to `SELECT "A"`).
21+
1922
### Snowpark pandas API Updates
2023

2124
#### New Features

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -763,10 +763,18 @@ def unary_expression_extractor(
763763
for k, v in df_alias_dict.items():
764764
if v == expr.child.name:
765765
df_alias_dict[k] = updated_due_to_inheritance # type: ignore
766+
origin = self.analyze(
767+
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
768+
)
769+
if (
770+
isinstance(expr.child, (Attribute, UnresolvedAttribute))
771+
and origin == quoted_name
772+
):
773+
# If the column name matches the target of the alias (`quoted_name`),
774+
# we can directly emit the column name without an AS clause.
775+
return origin
766776
return alias_expression(
767-
self.analyze(
768-
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
769-
),
777+
origin,
770778
quoted_name,
771779
)
772780

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,11 +2244,18 @@ def derive_column_states_from_subquery(
22442244
else Attribute(quoted_c_name, DataType())
22452245
)
22462246
from_c_state = from_.column_states.get(quoted_c_name)
2247+
result_name = analyzer.analyze(
2248+
c, from_.df_aliased_col_name_to_real_col_name, parse_local_name=True
2249+
).strip(" ")
22472250
if from_c_state and from_c_state.change_state != ColumnChangeState.DROPPED:
22482251
# review later. should use parse_column_name
2249-
if c_name != analyzer.analyze(
2250-
c, from_.df_aliased_col_name_to_real_col_name, parse_local_name=True
2251-
).strip(" "):
2252+
# SNOW-2895675: Always treat Aliases as "changed", even if it is an identity.
2253+
# The fact this check is needed may be a bug in column state analysis, and we should revisit it later.
2254+
# The following tests fail without this check:
2255+
# - tests/integ/test_cte.py::test_sql_simplifier
2256+
# - tests/integ/scala/test_dataframe_suite.py::test_rename_join_dataframe
2257+
# - tests/integ/test_dataframe.py::test_dataframe_alias
2258+
if c_name != result_name or isinstance(c, Alias):
22522259
column_states[quoted_c_name] = ColumnState(
22532260
quoted_c_name,
22542261
ColumnChangeState.CHANGED_EXP,

src/snowflake/snowpark/dataframe.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,24 @@ def _alias_if_needed(
293293
)
294294
return col.alias(f'"{prefix}{unquoted_col_name}"')
295295
else:
296+
# Removal of redundant aliases (like `"A" AS "A"`) is handled at the analyzer level.
296297
return col.alias(f'"{unquoted_col_name}"')
297298

298299

300+
def _populate_expr_to_alias(df: "DataFrame") -> None:
301+
"""
302+
Populate expr_to_alias mapping for a DataFrame's output columns.
303+
This is needed for column lineage tracking when we skip the select() wrapping
304+
optimization in _disambiguate.
305+
"""
306+
for attr in df._output:
307+
# Map each attribute's expr_id to its quoted column name
308+
# This allows later lookups like df["column_name"] to resolve correctly
309+
# Use quote_name() for consistency with analyzer.py Alias handling (line 743, 756)
310+
if attr.expr_id not in df._plan.expr_to_alias:
311+
df._plan.expr_to_alias[attr.expr_id] = quote_name(attr.name)
312+
313+
299314
def _disambiguate(
300315
lhs: "DataFrame",
301316
rhs: "DataFrame",
@@ -322,11 +337,21 @@ def _disambiguate(
322337
for n in lhs_names
323338
if n in set(rhs_names) and n not in normalized_using_columns
324339
]
340+
341+
if not common_col_names:
342+
# Optimization: No column name conflicts, so we can skip aliasing and the select() wrapping.
343+
# But we still need to populate expr_to_alias for column lineage tracking,
344+
# so that df["column_name"] can resolve correctly after the join.
345+
# This is identified by the test case
346+
# tests/integ/scala/test_dataframe_join_suite.py::test_name_alias_on_multiple_join.
347+
_populate_expr_to_alias(lhs)
348+
_populate_expr_to_alias(rhs)
349+
return lhs, rhs
350+
325351
all_names = [unquote_if_quoted(n) for n in lhs_names + rhs_names]
326352

327-
if common_col_names:
328-
# We use the session of the LHS DataFrame to report this telemetry
329-
lhs._session._conn._telemetry_client.send_alias_in_join_telemetry()
353+
# We use the session of the LHS DataFrame to report this telemetry
354+
lhs._session._conn._telemetry_client.send_alias_in_join_telemetry()
330355

331356
lsuffix = lsuffix or lhs._alias
332357
rsuffix = rsuffix or rhs._alias

src/snowflake/snowpark/mock/_analyzer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -654,12 +654,18 @@ def unary_expression_extractor(
654654
if v == expr.child.name:
655655
df_alias_dict[k] = quoted_name
656656

657-
alias_exp = alias_expression(
658-
self.analyze(
659-
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
660-
),
661-
quoted_name,
657+
origin = self.analyze(
658+
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
662659
)
660+
if (
661+
isinstance(expr.child, (Attribute, UnresolvedAttribute))
662+
and origin == quoted_name
663+
):
664+
# If the column name matches the target of the alias (`quoted_name`),
665+
# we can directly emit the column name without an AS clause.
666+
return origin
667+
668+
alias_exp = alias_expression(origin, quoted_name)
663669

664670
expr_str = alias_exp if keep_alias else expr.name or keep_alias
665671
expr_str = expr_str.upper() if parse_local_name else expr_str

src/snowflake/snowpark/mock/_plan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,8 +1495,8 @@ def aggregate_by_groups(cur_group: TableEmulator):
14951495
if isinstance(source_plan, Project):
14961496
return TableEmulator(ColumnEmulator(col) for col in source_plan.project_list)
14971497
if isinstance(source_plan, Join):
1498-
L_expr_to_alias = {}
1499-
R_expr_to_alias = {}
1498+
L_expr_to_alias = dict(getattr(source_plan.left, "expr_to_alias", None) or {})
1499+
R_expr_to_alias = dict(getattr(source_plan.right, "expr_to_alias", None) or {})
15001500
left = execute_mock_plan(source_plan.left, L_expr_to_alias).reset_index(
15011501
drop=True
15021502
)

tests/integ/compiler/test_query_generator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,15 @@ def test_select_alias(session):
551551
check_generated_plan_queries(df2._plan)
552552

553553

554+
def test_select_alias_identity(session):
555+
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
556+
df_res = df.select("a", col("b").as_("b"))
557+
# Because "b" was aliased to itself, the emitted SQL should drop the AS clause.
558+
assert Utils.normalize_sql(df_res.queries["queries"][-1]) == Utils.normalize_sql(
559+
'SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT))'
560+
)
561+
562+
554563
def test_nullable_is_false_dataframe(session):
555564
from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD
556565

tests/integ/test_query_line_intervals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def test_get_plan_from_line_numbers_join_operations(session):
135135
)
136136

137137
line_to_expected_pattern = {
138-
2: r'SELECT \* FROM \(\(SELECT "ID" AS "l_\d+_ID", "NAME" AS "NAME" FROM \(SELECT \$1 AS "ID", \$2 AS "NAME" FROM VALUES \(1 :: INT, \'A\' :: STRING\), \(2 :: INT, \'B\' :: STRING\)\)\) AS SNOWPARK_LEFT INNER JOIN \(SELECT "ID" AS "r_\d+_ID", "VALUE" AS "VALUE" FROM \(SELECT \$1 AS "ID", \$2 AS "VALUE" FROM VALUES \(1 :: INT, 10 :: INT\), \(2 :: INT, 20 :: INT\)\)\) AS SNOWPARK_RIGHT ON \("l_\d+_ID" = "r_\d+_ID"\)\)',
138+
2: r'SELECT \* FROM \(\(SELECT "ID" AS "l_\d+_ID", "NAME" FROM \(SELECT \$1 AS "ID", \$2 AS "NAME" FROM VALUES \(1 :: INT, \'A\' :: STRING\), \(2 :: INT, \'B\' :: STRING\)\)\) AS SNOWPARK_LEFT INNER JOIN \(SELECT "ID" AS "r_\d+_ID", "VALUE" FROM \(SELECT \$1 AS "ID", \$2 AS "VALUE" FROM VALUES \(1 :: INT, 10 :: INT\), \(2 :: INT, 20 :: INT\)\)\) AS SNOWPARK_RIGHT ON \("l_\d+_ID" = "r_\d+_ID"\)\)',
139139
7: r'SELECT \$1 AS "ID", \$2 AS "NAME" FROM VALUES \(1 :: INT, \'A\' :: STRING\), \(2 :: INT, \'B\' :: STRING\)',
140-
14: r'SELECT "ID" AS "r_\d+_ID", "VALUE" AS "VALUE" FROM \(SELECT \$1 AS "ID", \$2 AS "VALUE" FROM VALUES \(1 :: INT, 10 :: INT\), \(2 :: INT, 20 :: INT\)\)',
140+
14: r'SELECT "ID" AS "r_\d+_ID", "VALUE" FROM \(SELECT \$1 AS "ID", \$2 AS "VALUE" FROM VALUES \(1 :: INT, 10 :: INT\), \(2 :: INT, 20 :: INT\)\)',
141141
}
142142

143143
for line_num, expected_pattern in line_to_expected_pattern.items():

tests/integ/test_query_plan_analysis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,9 @@ def test_join_statement(session: Session, sample_table: str):
358358
assert_df_subtree_query_complexity(
359359
df5,
360360
sum_node_complexities(
361-
get_cumulative_node_complexity(df3), {PlanNodeCategory.COLUMN: 2}
361+
get_cumulative_node_complexity(df1),
362+
get_cumulative_node_complexity(df2),
363+
{PlanNodeCategory.COLUMN: 2, PlanNodeCategory.JOIN: 1},
362364
),
363365
)
364366

tests/integ/test_simplifier_suite.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import itertools
66
import sys
77
import time
8+
import re
89
from typing import Tuple
910

1011
import pytest
@@ -1097,14 +1098,17 @@ def test_join_dataframes(session, simplifier_table):
10971098

10981099
df = df_left.join(df_right)
10991100
df1 = df.select("a").select("a").select("a")
1100-
assert df1.queries["queries"][0].count("SELECT") == 8
1101+
assert df1.queries["queries"][0].count("SELECT") == 6
1102+
df1.queries["queries"][0]
1103+
normalized_sql = re.sub(r"\s+", " ", df1.queries["queries"][0])
1104+
assert not any(f'"{c}" AS "{c}"' in normalized_sql for c in ["A", "B", "C", "D"])
11011105

11021106
df2 = (
11031107
df.select((col("a") + 1).as_("a"))
11041108
.select((col("a") + 1).as_("a"))
11051109
.select((col("a") + 1).as_("a"))
11061110
)
1107-
assert df2.queries["queries"][0].count("SELECT") == 10
1111+
assert df2.queries["queries"][0].count("SELECT") == 8
11081112

11091113
df3 = df.with_column("x", df_left.a).with_column("y", df_right.d)
11101114
assert '"A" AS "X", "D" AS "Y"' in Utils.normalize_sql(df3.queries["queries"][0])
@@ -1114,7 +1118,7 @@ def test_join_dataframes(session, simplifier_table):
11141118
df4 = df_right.to_df("e", "f")
11151119
df5 = df_left.join(df4)
11161120
df6 = df5.with_column("x", df_right.c).with_column("y", df4.f)
1117-
assert df6.queries["queries"][0].count("SELECT") == 10
1121+
assert df6.queries["queries"][0].count("SELECT") == 8
11181122
Utils.check_answer(df6, [Row(1, 2, 3, 4, 3, 4)])
11191123

11201124

0 commit comments

Comments
 (0)