Skip to content

Commit 3db0617

Browse files
authored
SNOW-2261400: Fix incorrect join condition in repeated subquery elimination (#3808)
1 parent 02e594e commit 3db0617

File tree

3 files changed

+128
-7
lines changed

3 files changed

+128
-7
lines changed

src/snowflake/snowpark/_internal/compiler/cte_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ def get_duplicated_node_complexity_distribution(
224224

225225
def encode_query_id(node: "TreeNode") -> Optional[str]:
226226
"""
227-
Encode the query and its query parameter into an id using sha256.
228-
227+
Encode the query, its query parameter, expr_to_alias and df_aliased_col_name_to_real_col_name
228+
into an id using sha256.
229229
230230
Returns:
231231
If encode succeed, return the first 10 encoded value.
@@ -252,7 +252,25 @@ def encode_query_id(node: "TreeNode") -> Optional[str]:
252252
# to avoid being detected as a common subquery.
253253
return None
254254

255-
string = f"{query}#{query_params}" if query_params else query
255+
def stringify(d):
256+
if isinstance(d, dict):
257+
key_value_pairs = list(d.items())
258+
key_value_pairs.sort(key=lambda x: x[0])
259+
return str(key_value_pairs)
260+
else:
261+
return str(d)
262+
263+
string = query
264+
if query_params:
265+
string = f"{string}#{query_params}"
266+
if hasattr(node, "expr_to_alias") and node.expr_to_alias:
267+
string = f"{string}#{stringify(node.expr_to_alias)}"
268+
if (
269+
hasattr(node, "df_aliased_col_name_to_real_col_name")
270+
and node.df_aliased_col_name_to_real_col_name
271+
):
272+
string = f"{string}#{stringify(node.df_aliased_col_name_to_real_col_name)}"
273+
256274
try:
257275
return hashlib.sha256(string.encode()).hexdigest()[:10]
258276
except Exception as ex:

tests/integ/test_cte.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,82 @@ def test_same_duplicate_subtree(session):
566566
assert count_number_of_ctes(df_result2.queries["queries"][-1]) == 3
567567

568568

569+
@pytest.mark.parametrize("use_different_df", [True, False])
570+
def test_cte_preserves_join_suffix_aliases(session, use_different_df):
571+
df_ad_group = session.create_dataframe(
572+
[["1048771", "group_1", "campaign_1"]],
573+
schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"],
574+
)
575+
576+
df_ad_group_excv = session.create_dataframe(
577+
[["1048771", "group_1", "device", "8308"]],
578+
schema=["ACCOUNT_ID", "AD_GROUP_ID", "DEVICE", "EXTERNAL_CONVERSION_ID"],
579+
)
580+
581+
df_ad_group_excv = df_ad_group_excv.join(
582+
df_ad_group,
583+
df_ad_group.col("AD_GROUP_ID") == df_ad_group_excv.col("AD_GROUP_ID"),
584+
rsuffix="_WITH_AD_GROUP",
585+
).select(
586+
col("ACCOUNT_ID"),
587+
col("CAMPAIGN_ID"),
588+
col("AD_GROUP_ID"),
589+
lit(None).as_("AD_ID"),
590+
)
591+
592+
if use_different_df:
593+
df_ad_group = session.create_dataframe(
594+
[["1048771", "group_1", "campaign_1"]],
595+
schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"],
596+
)
597+
598+
df_ad_group_ad = session.create_dataframe(
599+
[["1048771", "ad_1", "group_1"]],
600+
schema=["ACCOUNT_ID", "AD_ID", "AD_GROUP_ID"],
601+
)
602+
603+
df_ad_excv = session.create_dataframe(
604+
[["1048771", "group_1", "ad_1", "device", "8308"]],
605+
schema=[
606+
"ACCOUNT_ID",
607+
"AD_GROUP_ID",
608+
"AD_ID",
609+
"DEVICE",
610+
"EXTERNAL_CONVERSION_ID",
611+
],
612+
)
613+
614+
df_ad_excv = (
615+
df_ad_excv.join(
616+
df_ad_group_ad,
617+
df_ad_group_ad.col("AD_ID") == df_ad_excv.col("AD_ID"),
618+
rsuffix="_WITH_AD_GROUP_AD",
619+
)
620+
.join(
621+
df_ad_group,
622+
df_ad_group.col("AD_GROUP_ID") == df_ad_group_ad.col("AD_GROUP_ID"),
623+
rsuffix="_WITH_AD_GROUP",
624+
)
625+
.select(
626+
col("ACCOUNT_ID"),
627+
col("CAMPAIGN_ID"),
628+
col("AD_GROUP_ID"),
629+
col("AD_ID"),
630+
)
631+
)
632+
633+
df_union = df_ad_group_excv.union_all(df_ad_excv)
634+
union_sql = df_union.queries["queries"][-1]
635+
636+
# the second one is incorrect join condition as we have rsuffix for join alias
637+
assert 'ON ("AD_GROUP_ID_WITH_AD_GROUP" = "AD_GROUP_ID")' in union_sql
638+
assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql
639+
# when using different df_ad_group with disambiguation, because rsuffix in join,
640+
# they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE
641+
# However there is still a CTE for create_dataframe call
642+
assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == 1
643+
644+
569645
@pytest.mark.parametrize(
570646
"mode", ["append", "truncate", "overwrite", "errorifexists", "ignore"]
571647
)
@@ -736,12 +812,12 @@ def test_sql_simplifier(session):
736812
describe_count_for_optimized=1 if session._join_alias_fix else None,
737813
)
738814
with SqlCounter(query_count=0, describe_count=0):
739-
# When adding a lsuffix, the columns of right dataframe don't need to be renamed,
740-
# so we will get a common CTE with filter
815+
# When adding a lsuffix, expr alias map will be updated, so df2 and df3 are considered
816+
# different and have different ids. So only df1 and df will be converted to a CTE
741817
assert (
742-
count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 2
818+
count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 1
743819
)
744-
assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 2
820+
assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 3
745821

746822
df7 = df1.with_column("c", lit(1))
747823
df8 = df1.with_column("c", lit(1)).with_column("d", lit(1))

tests/unit/test_cte.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
44

5+
import hashlib
6+
from types import SimpleNamespace
57
from unittest import mock
68

79
import pytest
@@ -103,3 +105,28 @@ def test_encode_node_id_with_query_select_sql(mock_session, mock_analyzer):
103105
encode_node_id_with_query(select_statement_node)
104106
== f"{expected_hash}_SelectStatement"
105107
)
108+
109+
110+
def test_encode_node_id_with_query_includes_aliases():
111+
node = SimpleNamespace(
112+
sql_query="select col1 from t",
113+
query_params=(("p1", 1), ("p2", "x")),
114+
expr_to_alias={"uuid1": "ALIAS1"},
115+
df_aliased_col_name_to_real_col_name={"ALIAS1": "col1"},
116+
)
117+
118+
def stringify_dict(d: dict) -> str:
119+
key_value_pairs = list(d.items())
120+
key_value_pairs.sort(key=lambda x: x[0])
121+
return str(key_value_pairs)
122+
123+
expected_string = node.sql_query
124+
if node.query_params:
125+
expected_string = f"{expected_string}#{node.query_params}"
126+
if node.expr_to_alias:
127+
expected_string = f"{expected_string}#{stringify_dict(node.expr_to_alias)}"
128+
if node.df_aliased_col_name_to_real_col_name:
129+
expected_string = f"{expected_string}#{stringify_dict(node.df_aliased_col_name_to_real_col_name)}"
130+
131+
expected_hash = hashlib.sha256(expected_string.encode()).hexdigest()[:10]
132+
assert encode_node_id_with_query(node) == f"{expected_hash}_SimpleNamespace"

0 commit comments

Comments
 (0)