diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 7878b4c818..c380b79565 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -293,7 +293,26 @@ def _alias_if_needed( ) return col.alias(f'"{prefix}{unquoted_col_name}"') else: - return col.alias(f'"{unquoted_col_name}"') + # No alias needed when source name equals destination name. + # Populate expr_to_alias directly instead of creating redundant Alias expression. + # This avoids generating "A" AS "A" in SQL while still maintaining column tracking. + quoted_name = quote_name(unquoted_col_name) + df._plan.expr_to_alias[col._expression.expr_id] = quoted_name + return col + + +def _populate_expr_to_alias(df: "DataFrame") -> None: + """ + Populate expr_to_alias mapping for a DataFrame's output columns. + This is needed for column lineage tracking when we skip the select() wrapping + optimization in _disambiguate. + """ + for attr in df._output: + # Map each attribute's expr_id to its quoted column name + # This allows later lookups like df["column_name"] to resolve correctly + # Use quote_name() for consistency with analyzer.py Alias handling (line 743, 756) + if attr.expr_id not in df._plan.expr_to_alias: + df._plan.expr_to_alias[attr.expr_id] = quote_name(attr.name) def _disambiguate( @@ -322,11 +341,20 @@ def _disambiguate( for n in lhs_names if n in set(rhs_names) and n not in normalized_using_columns ] + + if not common_col_names: + # Optimization: No column name conflicts, so we can skip aliasing and the select() wrapping. + # But we still need to populate expr_to_alias for column lineage tracking, + # so that df["column_name"] can resolve correctly after the join. + # This is identified by the test case test_name_alias_on_multiple_join. + _populate_expr_to_alias(lhs) + _populate_expr_to_alias(rhs) + return lhs, rhs + all_names = [unquote_if_quoted(n) for n in lhs_names + rhs_names] - if common_col_names: - # We use the session of the LHS DataFrame to report this telemetry - lhs._session._conn._telemetry_client.send_alias_in_join_telemetry() + # We use the session of the LHS DataFrame to report this telemetry + lhs._session._conn._telemetry_client.send_alias_in_join_telemetry() lsuffix = lsuffix or lhs._alias rsuffix = rsuffix or rhs._alias diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py index fa1b76eba2..a0275600ae 100644 --- a/src/snowflake/snowpark/mock/_plan.py +++ b/src/snowflake/snowpark/mock/_plan.py @@ -1495,8 +1495,8 @@ def aggregate_by_groups(cur_group: TableEmulator): if isinstance(source_plan, Project): return TableEmulator(ColumnEmulator(col) for col in source_plan.project_list) if isinstance(source_plan, Join): - L_expr_to_alias = {} - R_expr_to_alias = {} + L_expr_to_alias = dict(getattr(source_plan.left, "expr_to_alias", None) or {}) + R_expr_to_alias = dict(getattr(source_plan.right, "expr_to_alias", None) or {}) left = execute_mock_plan(source_plan.left, L_expr_to_alias).reset_index( drop=True ) diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 3803ff3d5c..84a8646417 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -358,7 +358,9 @@ def test_join_statement(session: Session, sample_table: str): assert_df_subtree_query_complexity( df5, sum_node_complexities( - get_cumulative_node_complexity(df3), {PlanNodeCategory.COLUMN: 2} + get_cumulative_node_complexity(df1), + get_cumulative_node_complexity(df2), + {PlanNodeCategory.COLUMN: 2, PlanNodeCategory.JOIN: 1}, ), ) diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 467e85d659..68804657aa 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -5,6 +5,7 @@ import itertools import sys import time +import re from typing import Tuple import pytest @@ -1097,14 +1098,17 @@ def test_join_dataframes(session, simplifier_table): df = df_left.join(df_right) df1 = df.select("a").select("a").select("a") - assert df1.queries["queries"][0].count("SELECT") == 8 + assert df1.queries["queries"][0].count("SELECT") == 6 + df1.queries["queries"][0] + normalized_sql = re.sub(r"\s+", " ", df1.queries["queries"][0]) + assert not any(f'"{c}" AS "{c}"' in normalized_sql for c in ["A", "B", "C", "D"]) df2 = ( df.select((col("a") + 1).as_("a")) .select((col("a") + 1).as_("a")) .select((col("a") + 1).as_("a")) ) - assert df2.queries["queries"][0].count("SELECT") == 10 + assert df2.queries["queries"][0].count("SELECT") == 8 df3 = df.with_column("x", df_left.a).with_column("y", df_right.d) assert '"A" AS "X", "D" AS "Y"' in Utils.normalize_sql(df3.queries["queries"][0]) @@ -1114,7 +1118,7 @@ def test_join_dataframes(session, simplifier_table): df4 = df_right.to_df("e", "f") df5 = df_left.join(df4) df6 = df5.with_column("x", df_right.c).with_column("y", df4.f) - assert df6.queries["queries"][0].count("SELECT") == 10 + assert df6.queries["queries"][0].count("SELECT") == 8 Utils.check_answer(df6, [Row(1, 2, 3, 4, 3, 4)])