Skip to content

Commit ac3811c

Browse files
committed
fix
1 parent 33a176d commit ac3811c

File tree

4 files changed

+74
-35
lines changed

4 files changed

+74
-35
lines changed

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
44
#
55
import uuid
6-
from collections import Counter, defaultdict
6+
from collections import defaultdict
77
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union
88

99
import snowflake.snowpark
@@ -151,7 +151,10 @@
151151
)
152152
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
153153
from snowflake.snowpark._internal.telemetry import TelemetryField
154-
from snowflake.snowpark._internal.utils import quote_name
154+
from snowflake.snowpark._internal.utils import (
155+
quote_name,
156+
merge_multiple_snowflake_plan_expr_to_alias,
157+
)
155158
from snowflake.snowpark.types import BooleanType, _NumericType
156159

157160
ARRAY_BIND_THRESHOLD = 512
@@ -804,22 +807,9 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
804807
# Selectable doesn't have children. It already has the expr_to_alias dict.
805808
self.alias_maps_to_use = logical_plan.expr_to_alias.copy()
806809
else:
807-
use_maps = {}
808-
# get counts of expr_to_alias keys
809-
counts = Counter()
810-
for v in resolved_children.values():
811-
if v.expr_to_alias:
812-
counts.update(list(v.expr_to_alias.keys()))
813-
814-
# Keep only non-shared expr_to_alias keys
815-
# let (df1.join(df2)).join(df2.join(df3)).select(df2) report error
816-
for v in resolved_children.values():
817-
if v.expr_to_alias:
818-
use_maps.update(
819-
{p: q for p, q in v.expr_to_alias.items() if counts[p] < 2}
820-
)
821-
822-
self.alias_maps_to_use = use_maps
810+
self.alias_maps_to_use = merge_multiple_snowflake_plan_expr_to_alias(
811+
list(resolved_children.values())
812+
)
823813

824814
res = self.do_resolve_with_resolved_children(
825815
logical_plan, resolved_children, df_aliased_col_name_to_real_col_name

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
generate_random_alphanumeric,
109109
get_copy_into_table_options,
110110
is_sql_select_statement,
111+
merge_multiple_snowflake_plan_expr_to_alias,
111112
)
112113
from snowflake.snowpark.row import Row
113114
from snowflake.snowpark.types import StructType
@@ -580,17 +581,10 @@ def build_binary(
580581
right_schema_query = schema_value_statement(select_right.attributes)
581582
schema_query = sql_generator(left_schema_query, right_schema_query)
582583

583-
common_columns = set(select_left.expr_to_alias.keys()).intersection(
584-
select_right.expr_to_alias.keys()
584+
new_expr_to_alias = merge_multiple_snowflake_plan_expr_to_alias(
585+
[select_left, select_right]
585586
)
586-
new_expr_to_alias = {
587-
k: v
588-
for k, v in {
589-
**select_left.expr_to_alias,
590-
**select_right.expr_to_alias,
591-
}.items()
592-
if k not in common_columns
593-
}
587+
594588
api_calls = [*select_left.api_calls, *select_right.api_calls]
595589

596590
# Need to do a deduplication to avoid repeated query.

src/snowflake/snowpark/_internal/utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import sys
2222
import threading
2323
import traceback
24+
import uuid
2425
import zipfile
2526
from enum import Enum
2627
from functools import lru_cache
@@ -50,12 +51,15 @@
5051
from snowflake.connector.description import OPERATING_SYSTEM, PLATFORM
5152
from snowflake.connector.options import MissingOptionalDependency, ModuleLikeObject
5253
from snowflake.connector.version import VERSION as connector_version
54+
5355
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
5456
from snowflake.snowpark.context import _should_use_structured_type_semantics
5557
from snowflake.snowpark.row import Row
5658
from snowflake.snowpark.version import VERSION as snowpark_version
5759

5860
if TYPE_CHECKING:
61+
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan
62+
5963
try:
6064
from snowflake.connector.cursor import ResultMetadataV2
6165
except ImportError:
@@ -1411,3 +1415,59 @@ def next(self) -> int:
14111415

14121416

14131417
global_counter: GlobalCounter = GlobalCounter()
1418+
1419+
1420+
def merge_multiple_snowflake_plan_expr_to_alias(
1421+
snowflake_plans: List["SnowflakePlan"],
1422+
) -> Dict[uuid.UUID, str]:
1423+
"""
1424+
Merges expression-to-alias mappings from multiple Snowflake plans, resolving conflicts where possible.
1425+
1426+
Args:
1427+
snowflake_plans (List[SnowflakePlan]): List of SnowflakePlan objects.
1428+
1429+
Returns:
1430+
Dict[Any, str]: Merged expression-to-alias mapping.
1431+
"""
1432+
# Collect all unique output column names from all plans
1433+
all_output_columns = {attr.name for plan in snowflake_plans for attr in plan.output}
1434+
1435+
# Gather all expression-to-alias mappings
1436+
all_expr_to_alias_dicts = [plan.expr_to_alias for plan in snowflake_plans]
1437+
1438+
# Initialize the merged dictionary
1439+
merged_dict = {}
1440+
1441+
# Collect all unique keys from all dictionaries
1442+
all_keys = set().union(*all_expr_to_alias_dicts)
1443+
1444+
for key in all_keys:
1445+
# Gather all aliases for the current key
1446+
values = [d[key] for d in all_expr_to_alias_dicts if key in d]
1447+
1448+
# Check if all aliases are identical
1449+
if len(set(values)) == 1:
1450+
merged_dict[key] = values[0]
1451+
else:
1452+
# Resolve conflicts by checking against output columns
1453+
candidate = None
1454+
for alias in values:
1455+
if alias in all_output_columns:
1456+
if candidate is None:
1457+
candidate = alias
1458+
else:
1459+
# Ambiguous case: multiple valid candidates
1460+
candidate = None
1461+
break
1462+
1463+
# Add the candidate to the merged dictionary if resolved
1464+
if candidate is not None:
1465+
merged_dict[key] = candidate
1466+
else:
1467+
# No valid candidate found
1468+
_logger.debug(
1469+
f"Expression '{key}' is associated with multiple aliases across different plans. "
1470+
f"Unable to determine which alias to use. Conflicting values: {values}"
1471+
)
1472+
1473+
return merged_dict

tests/integ/scala/test_dataframe_join_suite.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,22 +1462,17 @@ def test_select_columns_on_join_result_with_conflict_name(
14621462
"config.getoption('local_testing_mode', default=False)",
14631463
reason="SNOW-1235716: match error behavior",
14641464
)
1465-
def test_nested_join_diamond_shape_error(
1465+
def test_nested_join_diamond_shape(
14661466
session,
14671467
): # TODO: local testing match error behavior
1468-
"""This is supposed to work but currently we don't handle it correctly. We should fix this with a good design."""
14691468
df1 = session.create_dataframe([[1]], schema=["a"])
14701469
df2 = session.create_dataframe([[1]], schema=["a"])
14711470
df3 = df1.join(df2, df1["a"] == df2["a"])
14721471
df4 = df3.select(df1["a"].as_("a"))
14731472
# df1["a"] and df4["a"] has the same expr_id in map expr_to_alias. When they join, only one will be in df5's alias
14741473
# map. It leaves the other one resolved to "a" instead of the alias.
14751474
df5 = df1.join(df4, df1["a"] == df4["a"]) # (df1) JOIN ((df1 JOIN df2)->df4)
1476-
with pytest.raises(
1477-
SnowparkSQLAmbiguousJoinException,
1478-
match="The reference to the column 'A' is ambiguous.",
1479-
):
1480-
df5.collect()
1475+
Utils.check_answer(df5, [Row(1, 1)])
14811476

14821477

14831478
def test_nested_join_diamond_shape_workaround(session):

0 commit comments

Comments
 (0)