Skip to content

Commit f1d02bf

Browse files
authored
SNOW-3176017: Fix accidental removal of aliases in certain JOIN statements (#4096)
1 parent 1d0b1fb commit f1d02bf

File tree

10 files changed

+254
-64
lines changed

10 files changed

+254
-64
lines changed

CHANGELOG.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
# Release History
22

3+
## 1.49.0 (TBD)
4+
5+
### Snowpark Python API Updates
6+
7+
#### New Features
8+
9+
#### Bug Fixes
10+
11+
#### Improvements
12+
13+
- Restored the following query improvements that were reverted in 1.47.0 due to bugs:
14+
- Reduced the size of queries generated by certain `DataFrame.join` operations.
15+
- Removed redundant aliases in generated queries (for example, `SELECT "A" AS "A"` is now always simplified to `SELECT "A"`).
16+
317
## 1.48.0 (TBD)
418

519
### Snowpark Python API Updates

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -763,10 +763,18 @@ def unary_expression_extractor(
763763
for k, v in df_alias_dict.items():
764764
if v == expr.child.name:
765765
df_alias_dict[k] = updated_due_to_inheritance # type: ignore
766+
origin = self.analyze(
767+
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
768+
)
769+
if (
770+
isinstance(expr.child, (Attribute, UnresolvedAttribute))
771+
and origin == quoted_name
772+
):
773+
# If the column name matches the target of the alias (`quoted_name`),
774+
# we can directly emit the column name without an AS clause.
775+
return origin
766776
return alias_expression(
767-
self.analyze(
768-
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
769-
),
777+
origin,
770778
quoted_name,
771779
)
772780

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,11 +2244,18 @@ def derive_column_states_from_subquery(
22442244
else Attribute(quoted_c_name, DataType())
22452245
)
22462246
from_c_state = from_.column_states.get(quoted_c_name)
2247+
result_name = analyzer.analyze(
2248+
c, from_.df_aliased_col_name_to_real_col_name, parse_local_name=True
2249+
).strip(" ")
22472250
if from_c_state and from_c_state.change_state != ColumnChangeState.DROPPED:
22482251
# review later. should use parse_column_name
2249-
if c_name != analyzer.analyze(
2250-
c, from_.df_aliased_col_name_to_real_col_name, parse_local_name=True
2251-
).strip(" "):
2252+
# SNOW-2895675: Always treat Aliases as "changed", even if it is an identity.
2253+
# The fact this check is needed may be a bug in column state analysis, and we should revisit it later.
2254+
# The following tests fail without this check:
2255+
# - tests/integ/test_cte.py::test_sql_simplifier
2256+
# - tests/integ/scala/test_dataframe_suite.py::test_rename_join_dataframe
2257+
# - tests/integ/test_dataframe.py::test_dataframe_alias
2258+
if c_name != result_name or isinstance(c, Alias):
22522259
column_states[quoted_c_name] = ColumnState(
22532260
quoted_c_name,
22542261
ColumnChangeState.CHANGED_EXP,

src/snowflake/snowpark/dataframe.py

Lines changed: 94 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -277,29 +277,69 @@ def _get_unaliased(col_name: str) -> List[str]:
277277
return unaliased
278278

279279

280+
def _get_aliased_column_names(
281+
df: "DataFrame",
282+
cs: List[str],
283+
prefix: Optional[str],
284+
suffix: Optional[str],
285+
common_col_names: List[str],
286+
) -> List[str]:
287+
aliases = []
288+
for c in cs:
289+
unquoted_col_name = c.strip('"')
290+
if c in common_col_names:
291+
if suffix:
292+
column_case_insensitive = is_snowflake_quoted_id_case_insensitive(c)
293+
suffix_unqouted_case_insensitive = (
294+
is_snowflake_unquoted_suffix_case_insensitive(suffix)
295+
)
296+
aliases.append(
297+
f'"{unquoted_col_name}{suffix.upper()}"'
298+
if column_case_insensitive and suffix_unqouted_case_insensitive
299+
else f'''"{unquoted_col_name}{escape_quotes(suffix.strip('"'))}"'''
300+
)
301+
else:
302+
aliases.append(f'"{prefix}{unquoted_col_name}"')
303+
else:
304+
# Removal of redundant aliases (like `"A" AS "A"`) is handled at the analyzer level.
305+
aliases.append(f'"{unquoted_col_name}"')
306+
return aliases
307+
308+
309+
def _apply_aliases(
310+
df: "DataFrame",
311+
cs: List[str],
312+
c_aliases: List[str],
313+
) -> List[Column]:
314+
return [
315+
df.col(c, _emit_ast=False).alias(c_alias) for c, c_alias in zip(cs, c_aliases)
316+
]
317+
318+
280319
def _alias_if_needed(
281320
df: "DataFrame",
282-
c: str,
321+
cs: List[str],
283322
prefix: Optional[str],
284323
suffix: Optional[str],
285324
common_col_names: List[str],
286-
):
287-
col = df.col(c, _emit_ast=False)
288-
unquoted_col_name = c.strip('"')
289-
if c in common_col_names:
290-
if suffix:
291-
column_case_insensitive = is_snowflake_quoted_id_case_insensitive(c)
292-
suffix_unqouted_case_insensitive = (
293-
is_snowflake_unquoted_suffix_case_insensitive(suffix)
294-
)
295-
return col.alias(
296-
f'"{unquoted_col_name}{suffix.upper()}"'
297-
if column_case_insensitive and suffix_unqouted_case_insensitive
298-
else f'''"{unquoted_col_name}{escape_quotes(suffix.strip('"'))}"'''
299-
)
300-
return col.alias(f'"{prefix}{unquoted_col_name}"')
301-
else:
302-
return col.alias(f'"{unquoted_col_name}"')
325+
) -> List[Column]:
326+
return _apply_aliases(
327+
df, cs, _get_aliased_column_names(df, cs, prefix, suffix, common_col_names)
328+
)
329+
330+
331+
def _populate_expr_to_alias(df: "DataFrame") -> None:
332+
"""
333+
Populate expr_to_alias mapping for a DataFrame's output columns.
334+
This is needed for column lineage tracking when we skip the select() wrapping
335+
optimization in _disambiguate.
336+
"""
337+
for attr in df._output:
338+
# Map each attribute's expr_id to its quoted column name
339+
# This allows later lookups like df["column_name"] to resolve correctly
340+
# Use quote_name() for consistency with analyzer.py Alias handling (line 743, 756)
341+
if attr.expr_id not in df._plan.expr_to_alias:
342+
df._plan.expr_to_alias[attr.expr_id] = quote_name(attr.name)
303343

304344

305345
def _disambiguate(
@@ -328,11 +368,11 @@ def _disambiguate(
328368
for n in lhs_names
329369
if n in set(rhs_names) and n not in normalized_using_columns
330370
]
371+
331372
all_names = [unquote_if_quoted(n) for n in lhs_names + rhs_names]
332373

333-
if common_col_names:
334-
# We use the session of the LHS DataFrame to report this telemetry
335-
lhs._session._conn._telemetry_client.send_alias_in_join_telemetry()
374+
# We use the session of the LHS DataFrame to report this telemetry
375+
lhs._session._conn._telemetry_client.send_alias_in_join_telemetry()
336376

337377
lsuffix = lsuffix or lhs._alias
338378
rsuffix = rsuffix or rhs._alias
@@ -344,25 +384,37 @@ def _disambiguate(
344384
_generate_deterministic_prefix("r", all_names) if not suffix_provided else ""
345385
)
346386

387+
lhs_aliases = _get_aliased_column_names(
388+
lhs,
389+
lhs_names,
390+
lhs_prefix,
391+
lsuffix,
392+
[] if isinstance(join_type, (LeftSemi, LeftAnti)) else common_col_names,
393+
)
394+
rhs_aliases = _get_aliased_column_names(
395+
rhs, rhs_names, rhs_prefix, rsuffix, common_col_names
396+
)
397+
if all(
398+
l_name == l_aliased for l_name, l_aliased in zip(lhs_names, lhs_aliases)
399+
) and all(r_name == r_aliased for r_name, r_aliased in zip(rhs_names, rhs_aliases)):
400+
# Optimization: No column name conflicts, so we can skip aliasing and the select() wrapping.
401+
# But we still need to populate expr_to_alias for column lineage tracking,
402+
# so that df["column_name"] can resolve correctly after the join.
403+
# This is identified by the test case
404+
# tests/integ/scala/test_dataframe_join_suite.py::test_name_alias_on_multiple_join.
405+
# Note that we must also ensure none of the column names have changed due to internal quote stripping:
406+
# see tests/integ/compiler/test_query_generator.py::test_disambiguate_skips_quoted_alias for details.
407+
_populate_expr_to_alias(lhs)
408+
_populate_expr_to_alias(rhs)
409+
return lhs, rhs
410+
347411
lhs_remapped = lhs.select(
348-
[
349-
_alias_if_needed(
350-
lhs,
351-
name,
352-
lhs_prefix,
353-
lsuffix,
354-
[] if isinstance(join_type, (LeftSemi, LeftAnti)) else common_col_names,
355-
)
356-
for name in lhs_names
357-
],
412+
_apply_aliases(lhs, lhs_names, lhs_aliases),
358413
_emit_ast=False,
359414
)
360415

361416
rhs_remapped = rhs.select(
362-
[
363-
_alias_if_needed(rhs, name, rhs_prefix, rsuffix, common_col_names)
364-
for name in rhs_names
365-
],
417+
_apply_aliases(rhs, rhs_names, rhs_aliases),
366418
_emit_ast=False,
367419
)
368420
return lhs_remapped, rhs_remapped
@@ -5113,16 +5165,13 @@ def _lateral(
51135165
)
51145166
prefix = _generate_prefix("a")
51155167
child = self.select(
5116-
[
5117-
_alias_if_needed(
5118-
self,
5119-
attr.name,
5120-
prefix,
5121-
suffix=None,
5122-
common_col_names=common_col_names,
5123-
)
5124-
for attr in self._output
5125-
],
5168+
_alias_if_needed(
5169+
self,
5170+
[attr.name for attr in self._output],
5171+
prefix,
5172+
suffix=None,
5173+
common_col_names=common_col_names,
5174+
),
51265175
_emit_ast=False,
51275176
)
51285177
return DataFrame(

src/snowflake/snowpark/mock/_analyzer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -654,12 +654,18 @@ def unary_expression_extractor(
654654
if v == expr.child.name:
655655
df_alias_dict[k] = quoted_name
656656

657-
alias_exp = alias_expression(
658-
self.analyze(
659-
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
660-
),
661-
quoted_name,
657+
origin = self.analyze(
658+
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
662659
)
660+
if (
661+
isinstance(expr.child, (Attribute, UnresolvedAttribute))
662+
and origin == quoted_name
663+
):
664+
# If the column name matches the target of the alias (`quoted_name`),
665+
# we can directly emit the column name without an AS clause.
666+
return origin
667+
668+
alias_exp = alias_expression(origin, quoted_name)
663669

664670
expr_str = alias_exp if keep_alias else expr.name or keep_alias
665671
expr_str = expr_str.upper() if parse_local_name else expr_str

src/snowflake/snowpark/mock/_plan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,8 +1495,8 @@ def aggregate_by_groups(cur_group: TableEmulator):
14951495
if isinstance(source_plan, Project):
14961496
return TableEmulator(ColumnEmulator(col) for col in source_plan.project_list)
14971497
if isinstance(source_plan, Join):
1498-
L_expr_to_alias = {}
1499-
R_expr_to_alias = {}
1498+
L_expr_to_alias = dict(getattr(source_plan.left, "expr_to_alias", None) or {})
1499+
R_expr_to_alias = dict(getattr(source_plan.right, "expr_to_alias", None) or {})
15001500
left = execute_mock_plan(source_plan.left, L_expr_to_alias).reset_index(
15011501
drop=True
15021502
)

tests/integ/compiler/test_query_generator.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import copy
66
from typing import List
77
from unittest.mock import patch
8+
import tempfile
9+
import os
10+
import re
811

12+
import pandas
913
import pytest
1014

1115
import snowflake.snowpark._internal.analyzer.snowflake_plan as snowflake_plan
@@ -551,6 +555,102 @@ def test_select_alias(session):
551555
check_generated_plan_queries(df2._plan)
552556

553557

558+
def test_select_alias_identity(session):
559+
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
560+
df_res = df.select("a", col("b").as_("b"))
561+
if session.sql_simplifier_enabled:
562+
# Because "b" was aliased to itself, the emitted SQL should drop the AS clause.
563+
ref_query = 'SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT))'
564+
else:
565+
ref_query = 'SELECT "A", "B" FROM (SELECT "A", "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT)))'
566+
assert Utils.normalize_sql(df_res.queries["queries"][-1]) == Utils.normalize_sql(
567+
ref_query
568+
)
569+
570+
571+
def test_disambiguate_skips_quoted_alias(session):
572+
# SNOW-3176017: This tests a previous regression in a SnowML pipeline where alias optimization
573+
# incorrectly removed an alias from """col_0""" (triple-quoted in SQL) to "col_0" (single-quoted).
574+
# Due to differences in code paths for generating select statements, this bug is only apparent with
575+
# triple-quoted identifiers created from a file read operation, and not from a direct `.project` call.
576+
session_stage = session.get_session_stage()
577+
data = [[0, 1, 2], [3, 4, 5]]
578+
pandas_df = pandas.DataFrame(data, columns=["ID", '"COL_0"', '"COL_1"'])
579+
stage_filename = f"{session_stage}/disambiguate_test.parquet"
580+
with tempfile.TemporaryDirectory() as temp_dir:
581+
local_path = os.path.join(temp_dir, "disambiguate_test.parquet")
582+
pandas_df.to_parquet(local_path)
583+
Utils.upload_to_stage(session, stage_filename, local_path, compress=False)
584+
df1 = session.read.parquet(stage_filename)
585+
df2 = session.create_dataframe(data, schema=["ID", "A", "B"])
586+
df_res = df1.join(df2, on=["ID"])[['"COL_0"', '"COL_1"']]
587+
# TODO run with sql simplifier disabled
588+
actual_query = re.sub(
589+
r'@"[\d\w\_]+"\."[\d\w\_]+"\.',
590+
'@"DB_SCHEMA_NAME".',
591+
re.sub(
592+
r"SNOWPARK_TEMP_(STAGE|FILE_FORMAT)_[\d\w]+",
593+
"SNOWPARK_TEMP_NAME",
594+
df_res.queries["queries"][-1],
595+
),
596+
)
597+
if session.sql_simplifier_enabled:
598+
rhs_creation_sql = """
599+
SELECT
600+
"ID",
601+
"A",
602+
"B"
603+
FROM (
604+
SELECT $1 AS "ID", $2 AS "A", $3 AS "B" FROM VALUES (0 :: INT, 1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT, 5 :: INT)
605+
)
606+
"""
607+
else:
608+
rhs_creation_sql = """
609+
SELECT
610+
"ID",
611+
"A",
612+
"B"
613+
FROM (
614+
SELECT
615+
"ID",
616+
"A",
617+
"B"
618+
FROM (
619+
SELECT $1 AS "ID", $2 AS "A", $3 AS "B" FROM VALUES (0 :: INT, 1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT, 5 :: INT)
620+
)
621+
)
622+
"""
623+
624+
ref_query = f'''
625+
SELECT
626+
"COL_0",
627+
"COL_1"
628+
FROM (
629+
SELECT *
630+
FROM (
631+
(
632+
SELECT
633+
"ID",
634+
"""COL_0""" AS "COL_0",
635+
"""COL_1""" AS "COL_1"
636+
FROM (
637+
SELECT $1:"ID"::NUMBER(38, 0) AS "ID", $1:"""COL_0"""::NUMBER(38, 0) AS """COL_0""", $1:"""COL_1"""::NUMBER(38, 0) AS """COL_1""" FROM @"DB_SCHEMA_NAME".SNOWPARK_TEMP_NAME/disambiguate_test.parquet( FILE_FORMAT => 'SNOWPARK_TEMP_NAME')
638+
)
639+
) AS SNOWPARK_LEFT
640+
INNER JOIN
641+
(
642+
{rhs_creation_sql}
643+
) AS SNOWPARK_RIGHT
644+
USING (ID)
645+
)
646+
)
647+
'''
648+
assert Utils.normalize_sql(actual_query) == Utils.normalize_sql(ref_query)
649+
# Ensure the DF can be materialized without error
650+
materialized = df_res.to_pandas()
651+
assert list(materialized.columns) == ["COL_0", "COL_1"]
652+
653+
554654
def test_nullable_is_false_dataframe(session):
555655
from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD
556656

0 commit comments

Comments
 (0)