Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

#### Improvements

- Reduced the size of queries generated by certain `DataFrame.join` operations.
- Removed redundant aliases in generated queries (for example, `SELECT "A" AS "A"` is now always simplified to `SELECT "A"`).

### Snowpark pandas API Updates

#### New Features
Expand Down
14 changes: 11 additions & 3 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,10 +763,18 @@ def unary_expression_extractor(
for k, v in df_alias_dict.items():
if v == expr.child.name:
df_alias_dict[k] = updated_due_to_inheritance # type: ignore
origin = self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
)
if (
isinstance(expr.child, (Attribute, UnresolvedAttribute))
and origin == quoted_name
):
# If the column name matches the target of the alias (`quoted_name`),
# we can directly emit the column name without an AS clause.
return origin
return alias_expression(
self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
),
origin,
quoted_name,
)

Expand Down
13 changes: 10 additions & 3 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,11 +2244,18 @@ def derive_column_states_from_subquery(
else Attribute(quoted_c_name, DataType())
)
from_c_state = from_.column_states.get(quoted_c_name)
result_name = analyzer.analyze(
c, from_.df_aliased_col_name_to_real_col_name, parse_local_name=True
).strip(" ")
if from_c_state and from_c_state.change_state != ColumnChangeState.DROPPED:
# review later. should use parse_column_name
if c_name != analyzer.analyze(
c, from_.df_aliased_col_name_to_real_col_name, parse_local_name=True
).strip(" "):
# SNOW-2895675: Always treat Aliases as "changed", even if it is an identity.
# The fact this check is needed may be a bug in column state analysis, and we should revisit it later.
# The following tests fail without this check:
# - tests/integ/test_cte.py::test_sql_simplifier
# - tests/integ/scala/test_dataframe_suite.py::test_rename_join_dataframe
# - tests/integ/test_dataframe.py::test_dataframe_alias
if c_name != result_name or isinstance(c, Alias):
column_states[quoted_c_name] = ColumnState(
quoted_c_name,
ColumnChangeState.CHANGED_EXP,
Expand Down
31 changes: 28 additions & 3 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,24 @@ def _alias_if_needed(
)
return col.alias(f'"{prefix}{unquoted_col_name}"')
else:
# Removal of redundant aliases (like `"A" AS "A"`) is handled at the analyzer level.
return col.alias(f'"{unquoted_col_name}"')


def _populate_expr_to_alias(df: "DataFrame") -> None:
"""
Populate expr_to_alias mapping for a DataFrame's output columns.
This is needed for column lineage tracking when we skip the select() wrapping
optimization in _disambiguate.
"""
for attr in df._output:
# Map each attribute's expr_id to its quoted column name
# This allows later lookups like df["column_name"] to resolve correctly
# Use quote_name() for consistency with analyzer.py Alias handling (line 743, 756)
if attr.expr_id not in df._plan.expr_to_alias:
df._plan.expr_to_alias[attr.expr_id] = quote_name(attr.name)


def _disambiguate(
lhs: "DataFrame",
rhs: "DataFrame",
Expand All @@ -322,11 +337,21 @@ def _disambiguate(
for n in lhs_names
if n in set(rhs_names) and n not in normalized_using_columns
]

if not common_col_names:
# Optimization: No column name conflicts, so we can skip aliasing and the select() wrapping.
# But we still need to populate expr_to_alias for column lineage tracking,
# so that df["column_name"] can resolve correctly after the join.
# This is identified by the test case
# tests/integ/scala/test_dataframe_join_suite.py::test_name_alias_on_multiple_join.
_populate_expr_to_alias(lhs)
_populate_expr_to_alias(rhs)
return lhs, rhs

all_names = [unquote_if_quoted(n) for n in lhs_names + rhs_names]

if common_col_names:
# We use the session of the LHS DataFrame to report this telemetry
lhs._session._conn._telemetry_client.send_alias_in_join_telemetry()
# We use the session of the LHS DataFrame to report this telemetry
lhs._session._conn._telemetry_client.send_alias_in_join_telemetry()

lsuffix = lsuffix or lhs._alias
rsuffix = rsuffix or rhs._alias
Expand Down
16 changes: 11 additions & 5 deletions src/snowflake/snowpark/mock/_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,12 +654,18 @@ def unary_expression_extractor(
if v == expr.child.name:
df_alias_dict[k] = quoted_name

alias_exp = alias_expression(
self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
),
quoted_name,
origin = self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
)
if (
isinstance(expr.child, (Attribute, UnresolvedAttribute))
and origin == quoted_name
):
# If the column name matches the target of the alias (`quoted_name`),
# we can directly emit the column name without an AS clause.
return origin

alias_exp = alias_expression(origin, quoted_name)

expr_str = alias_exp if keep_alias else expr.name or keep_alias
expr_str = expr_str.upper() if parse_local_name else expr_str
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,8 +1495,8 @@ def aggregate_by_groups(cur_group: TableEmulator):
if isinstance(source_plan, Project):
return TableEmulator(ColumnEmulator(col) for col in source_plan.project_list)
if isinstance(source_plan, Join):
L_expr_to_alias = {}
R_expr_to_alias = {}
L_expr_to_alias = dict(getattr(source_plan.left, "expr_to_alias", None) or {})
R_expr_to_alias = dict(getattr(source_plan.right, "expr_to_alias", None) or {})
left = execute_mock_plan(source_plan.left, L_expr_to_alias).reset_index(
drop=True
)
Expand Down
9 changes: 9 additions & 0 deletions tests/integ/compiler/test_query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,15 @@ def test_select_alias(session):
check_generated_plan_queries(df2._plan)


def test_select_alias_identity(session):
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
df_res = df.select("a", col("b").as_("b"))
# Because "b" was aliased to itself, the emitted SQL should drop the AS clause.
assert Utils.normalize_sql(df_res.queries["queries"][-1]) == Utils.normalize_sql(
'SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT))'
)


def test_nullable_is_false_dataframe(session):
from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD

Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_query_line_intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def test_get_plan_from_line_numbers_join_operations(session):
)

line_to_expected_pattern = {
2: r'SELECT \* FROM \(\(SELECT "ID" AS "l_\d+_ID", "NAME" AS "NAME" FROM \(SELECT \$1 AS "ID", \$2 AS "NAME" FROM VALUES \(1 :: INT, \'A\' :: STRING\), \(2 :: INT, \'B\' :: STRING\)\)\) AS SNOWPARK_LEFT INNER JOIN \(SELECT "ID" AS "r_\d+_ID", "VALUE" AS "VALUE" FROM \(SELECT \$1 AS "ID", \$2 AS "VALUE" FROM VALUES \(1 :: INT, 10 :: INT\), \(2 :: INT, 20 :: INT\)\)\) AS SNOWPARK_RIGHT ON \("l_\d+_ID" = "r_\d+_ID"\)\)',
2: r'SELECT \* FROM \(\(SELECT "ID" AS "l_\d+_ID", "NAME" FROM \(SELECT \$1 AS "ID", \$2 AS "NAME" FROM VALUES \(1 :: INT, \'A\' :: STRING\), \(2 :: INT, \'B\' :: STRING\)\)\) AS SNOWPARK_LEFT INNER JOIN \(SELECT "ID" AS "r_\d+_ID", "VALUE" FROM \(SELECT \$1 AS "ID", \$2 AS "VALUE" FROM VALUES \(1 :: INT, 10 :: INT\), \(2 :: INT, 20 :: INT\)\)\) AS SNOWPARK_RIGHT ON \("l_\d+_ID" = "r_\d+_ID"\)\)',
7: r'SELECT \$1 AS "ID", \$2 AS "NAME" FROM VALUES \(1 :: INT, \'A\' :: STRING\), \(2 :: INT, \'B\' :: STRING\)',
14: r'SELECT "ID" AS "r_\d+_ID", "VALUE" AS "VALUE" FROM \(SELECT \$1 AS "ID", \$2 AS "VALUE" FROM VALUES \(1 :: INT, 10 :: INT\), \(2 :: INT, 20 :: INT\)\)',
14: r'SELECT "ID" AS "r_\d+_ID", "VALUE" FROM \(SELECT \$1 AS "ID", \$2 AS "VALUE" FROM VALUES \(1 :: INT, 10 :: INT\), \(2 :: INT, 20 :: INT\)\)',
}

for line_num, expected_pattern in line_to_expected_pattern.items():
Expand Down
4 changes: 3 additions & 1 deletion tests/integ/test_query_plan_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,9 @@ def test_join_statement(session: Session, sample_table: str):
assert_df_subtree_query_complexity(
df5,
sum_node_complexities(
get_cumulative_node_complexity(df3), {PlanNodeCategory.COLUMN: 2}
get_cumulative_node_complexity(df1),
get_cumulative_node_complexity(df2),
{PlanNodeCategory.COLUMN: 2, PlanNodeCategory.JOIN: 1},
),
)

Expand Down
10 changes: 7 additions & 3 deletions tests/integ/test_simplifier_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
import sys
import time
import re
from typing import Tuple

import pytest
Expand Down Expand Up @@ -1097,14 +1098,17 @@ def test_join_dataframes(session, simplifier_table):

df = df_left.join(df_right)
df1 = df.select("a").select("a").select("a")
assert df1.queries["queries"][0].count("SELECT") == 8
assert df1.queries["queries"][0].count("SELECT") == 6
df1.queries["queries"][0]
normalized_sql = re.sub(r"\s+", " ", df1.queries["queries"][0])
assert not any(f'"{c}" AS "{c}"' in normalized_sql for c in ["A", "B", "C", "D"])

df2 = (
df.select((col("a") + 1).as_("a"))
.select((col("a") + 1).as_("a"))
.select((col("a") + 1).as_("a"))
)
assert df2.queries["queries"][0].count("SELECT") == 10
assert df2.queries["queries"][0].count("SELECT") == 8

df3 = df.with_column("x", df_left.a).with_column("y", df_right.d)
assert '"A" AS "X", "D" AS "Y"' in Utils.normalize_sql(df3.queries["queries"][0])
Expand All @@ -1114,7 +1118,7 @@ def test_join_dataframes(session, simplifier_table):
df4 = df_right.to_df("e", "f")
df5 = df_left.join(df4)
df6 = df5.with_column("x", df_right.c).with_column("y", df4.f)
assert df6.queries["queries"][0].count("SELECT") == 10
assert df6.queries["queries"][0].count("SELECT") == 8
Utils.check_answer(df6, [Row(1, 2, 3, 4, 3, 4)])


Expand Down