Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,26 @@ def _alias_if_needed(
)
return col.alias(f'"{prefix}{unquoted_col_name}"')
else:
return col.alias(f'"{unquoted_col_name}"')
# No alias needed when source name equals destination name.
# Populate expr_to_alias directly instead of creating redundant Alias expression.
# This avoids generating "A" AS "A" in SQL while still maintaining column tracking.
quoted_name = quote_name(unquoted_col_name)
df._plan.expr_to_alias[col._expression.expr_id] = quoted_name
return col


def _populate_expr_to_alias(df: "DataFrame") -> None:
"""
Populate expr_to_alias mapping for a DataFrame's output columns.
This is needed for column lineage tracking when we skip the select() wrapping
optimization in _disambiguate.
"""
for attr in df._output:
# Map each attribute's expr_id to its quoted column name
# This allows later lookups like df["column_name"] to resolve correctly
# Use quote_name() for consistency with analyzer.py Alias handling (line 743, 756)
if attr.expr_id not in df._plan.expr_to_alias:
df._plan.expr_to_alias[attr.expr_id] = quote_name(attr.name)


def _disambiguate(
Expand Down Expand Up @@ -322,11 +341,20 @@ def _disambiguate(
for n in lhs_names
if n in set(rhs_names) and n not in normalized_using_columns
]

if not common_col_names:
# Optimization: No column name conflicts, so we can skip aliasing and the select() wrapping.
# But we still need to populate expr_to_alias for column lineage tracking,
# so that df["column_name"] can resolve correctly after the join.
# This is identified by the test case test_name_alias_on_multiple_join.
_populate_expr_to_alias(lhs)
_populate_expr_to_alias(rhs)
return lhs, rhs

all_names = [unquote_if_quoted(n) for n in lhs_names + rhs_names]

if common_col_names:
# We use the session of the LHS DataFrame to report this telemetry
lhs._session._conn._telemetry_client.send_alias_in_join_telemetry()
# We use the session of the LHS DataFrame to report this telemetry
lhs._session._conn._telemetry_client.send_alias_in_join_telemetry()

lsuffix = lsuffix or lhs._alias
rsuffix = rsuffix or rhs._alias
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,8 +1495,8 @@ def aggregate_by_groups(cur_group: TableEmulator):
if isinstance(source_plan, Project):
return TableEmulator(ColumnEmulator(col) for col in source_plan.project_list)
if isinstance(source_plan, Join):
L_expr_to_alias = {}
R_expr_to_alias = {}
L_expr_to_alias = dict(getattr(source_plan.left, "expr_to_alias", None) or {})
R_expr_to_alias = dict(getattr(source_plan.right, "expr_to_alias", None) or {})
left = execute_mock_plan(source_plan.left, L_expr_to_alias).reset_index(
drop=True
)
Expand Down
4 changes: 3 additions & 1 deletion tests/integ/test_query_plan_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,9 @@ def test_join_statement(session: Session, sample_table: str):
assert_df_subtree_query_complexity(
df5,
sum_node_complexities(
get_cumulative_node_complexity(df3), {PlanNodeCategory.COLUMN: 2}
get_cumulative_node_complexity(df1),
get_cumulative_node_complexity(df2),
{PlanNodeCategory.COLUMN: 2, PlanNodeCategory.JOIN: 1},
),
)

Expand Down
10 changes: 7 additions & 3 deletions tests/integ/test_simplifier_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
import sys
import time
import re
from typing import Tuple

import pytest
Expand Down Expand Up @@ -1097,14 +1098,17 @@ def test_join_dataframes(session, simplifier_table):

df = df_left.join(df_right)
df1 = df.select("a").select("a").select("a")
assert df1.queries["queries"][0].count("SELECT") == 8
assert df1.queries["queries"][0].count("SELECT") == 6
df1.queries["queries"][0]
normalized_sql = re.sub(r"\s+", " ", df1.queries["queries"][0])
assert not any(f'"{c}" AS "{c}"' in normalized_sql for c in ["A", "B", "C", "D"])

df2 = (
df.select((col("a") + 1).as_("a"))
.select((col("a") + 1).as_("a"))
.select((col("a") + 1).as_("a"))
)
assert df2.queries["queries"][0].count("SELECT") == 10
assert df2.queries["queries"][0].count("SELECT") == 8

df3 = df.with_column("x", df_left.a).with_column("y", df_right.d)
assert '"A" AS "X", "D" AS "Y"' in Utils.normalize_sql(df3.queries["queries"][0])
Expand All @@ -1114,7 +1118,7 @@ def test_join_dataframes(session, simplifier_table):
df4 = df_right.to_df("e", "f")
df5 = df_left.join(df4)
df6 = df5.with_column("x", df_right.c).with_column("y", df4.f)
assert df6.queries["queries"][0].count("SELECT") == 10
assert df6.queries["queries"][0].count("SELECT") == 8
Utils.check_answer(df6, [Row(1, 2, 3, 4, 3, 4)])


Expand Down
Loading