Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import hashlib
import logging
from collections import defaultdict
from collections import Counter, defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
Expand All @@ -16,6 +16,7 @@
WithQueryBlock,
)
from snowflake.snowpark._internal.utils import is_sql_select_statement
import snowflake.snowpark.context as context

if TYPE_CHECKING:
from snowflake.snowpark._internal.compiler.utils import TreeNode # pragma: no cover
Expand Down Expand Up @@ -57,6 +58,14 @@ def find_duplicate_subtrees(
# during this process
invalid_ids_for_deduplication = set()

# When _is_snowpark_connect_compatible_mode is enabled, we track unique
# object identities per encoded_node_id to avoid merging nodes from
# different DataFrame construction calls that happen to produce
# identical SQL. Only the same Python object appearing multiple times
# (e.g. df.union_all(df)) should be treated as a duplicate.
use_object_identity = context._is_snowpark_connect_compatible_mode
object_ids_per_node_id: Dict[str, Set[int]] = defaultdict(set)

from snowflake.snowpark._internal.analyzer.select_statement import (
Selectable,
SelectStatement,
Expand Down Expand Up @@ -115,15 +124,17 @@ def traverse(root: "TreeNode") -> None:
while len(current_level) > 0:
next_level = []
for node in current_level:
id_node_map[node.encoded_node_id_with_query].append(node)
encoded_id = node.encoded_node_id_with_query
id_node_map[encoded_id].append(node)

if use_object_identity:
object_ids_per_node_id[encoded_id].add(id(node))

if is_select_from_file_node(node):
invalid_ids_for_deduplication.add(node.encoded_node_id_with_query)
invalid_ids_for_deduplication.add(encoded_id)

for child in node.children_plan_nodes:
id_parents_map[child.encoded_node_id_with_query].add(
node.encoded_node_id_with_query
)
id_parents_map[child.encoded_node_id_with_query].add(encoded_id)
next_level.append(child)
current_level = next_level

Expand All @@ -138,6 +149,30 @@ def traverse(root: "TreeNode") -> None:
next_level.append(parent_id)
current_level = next_level

def _node_occurrence_count(encoded_node_id_with_query: str) -> int:
"""How many times this node appears in the tree.

In connect-compatible mode, different Python objects with the same
encoded id are counted as distinct nodes (occurrence = 1 each).
A true duplicate requires the *same* object referenced from
multiple parents.

When multiple distinct objects share the same encoded id, we return
the max occurrence count among any single object. This handles cases
like union(union(df1, df1), union(df2, df2)) where df1 and df2
produce identical SQL but df1 itself appears twice and should be
CTE-deduplicated.
"""
total = len(id_node_map[encoded_node_id_with_query])
if use_object_identity:
object_ids = object_ids_per_node_id[encoded_node_id_with_query]
if len(object_ids) > 1:
id_counts = Counter(
id(node) for node in id_node_map[encoded_node_id_with_query]
)
return max(id_counts.values())
return total

def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:
# when a sql query is a select statement, its encoded_node_id_with_query
# contains _, which is used to separate the query id and node type name.
Expand All @@ -154,10 +189,10 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:
if encoded_node_id_with_query in invalid_ids_for_deduplication:
return False

is_duplicate_node = len(id_node_map[encoded_node_id_with_query]) > 1
is_duplicate_node = _node_occurrence_count(encoded_node_id_with_query) > 1
if is_duplicate_node:
is_any_parent_unique_node = any(
len(id_node_map[id]) == 1
_node_occurrence_count(id) == 1
for id in id_parents_map[encoded_node_id_with_query]
)
if is_any_parent_unique_node:
Expand Down
104 changes: 80 additions & 24 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import tracemalloc
from unittest import mock
import uuid

import pytest

Expand Down Expand Up @@ -60,8 +61,17 @@
WITH = "WITH"


@pytest.fixture(params=[False, True], ids=["connect_mode_off", "connect_mode_on"])
def is_connect_mode(request):
"""Parametrize every test over _is_snowpark_connect_compatible_mode."""
with mock.patch.object(
context, "_is_snowpark_connect_compatible_mode", request.param
):
yield request.param


@pytest.fixture(autouse=True)
def setup(request, session):
def setup(request, session, is_connect_mode):
is_cte_optimization_enabled = session._cte_optimization_enabled
is_query_compilation_enabled = session._query_compilation_stage_enabled
session._query_compilation_stage_enabled = True
Expand Down Expand Up @@ -251,7 +261,9 @@ def test_binary(session, type, action):
assert len(plan_queries["post_actions"]) == 1


def test_join_with_alias_dataframe(session):
def test_join_with_alias_dataframe(session, is_connect_mode):
c1 = f"col1_{uuid.uuid4().hex[:8]}"
c2 = f"col2_{uuid.uuid4().hex[:8]}"
expected_describe_count = (
3
if (session.reduce_describe_query_enabled and session.sql_simplifier_enabled)
Expand All @@ -260,11 +272,11 @@ def test_join_with_alias_dataframe(session):
with SqlCounter(
query_count=2, describe_count=expected_describe_count, join_count=2
):
df1 = session.create_dataframe([[1, 6]], schema=["col1", "col2"])
df1 = session.create_dataframe([[1, 6]], schema=[c1, c2])
df_res = (
df1.alias("L")
.join(df1.alias("R"), col("L", "col1") == col("R", "col1"))
.select(col("L", "col1"), col("R", "col2"))
.join(df1.alias("R"), col("L", c1) == col("R", c1))
.select(col("L", c1), col("R", c2))
)

session._cte_optimization_enabled = False
Expand Down Expand Up @@ -355,7 +367,7 @@ def test_join_with_set_operation(session):


@pytest.mark.parametrize("type, action", binary_operations)
def test_variable_binding_binary(session, type, action):
def test_variable_binding_binary(session, type, action, is_connect_mode):
df1 = session.sql(
"select $1 as a, $2 as b from values (?, ?), (?, ?)", params=[1, "a", 2, "b"]
)
Expand All @@ -372,10 +384,12 @@ def test_variable_binding_binary(session, type, action):
join_count = 1
if type == "union":
union_count = 1
# df1 and df3 are different Python objects with the same SQL.
# In connect mode they should NOT be deduplicated.
check_result(
session,
action(df1, df3),
expect_cte_optimized=True,
expect_cte_optimized=not is_connect_mode,
query_count=1,
describe_count=0,
union_count=union_count,
Expand Down Expand Up @@ -551,21 +565,24 @@ def test_number_of_ctes(session, type, action):
)


def test_different_df_same_query(session):
def test_different_df_same_query(session, is_connect_mode):
df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a")
df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a")
df = df2.union_all(df1)
# df1 and df2 are different Python objects with the same SQL.
# In connect mode they should NOT be deduplicated.
check_result(
session,
df,
expect_cte_optimized=True,
expect_cte_optimized=not is_connect_mode,
query_count=1,
describe_count=0,
union_count=1,
join_count=0,
)
with SqlCounter(query_count=0, describe_count=0):
assert count_number_of_ctes(df.queries["queries"][-1]) == 1
expected_cte_count = 0 if is_connect_mode else 1
assert count_number_of_ctes(df.queries["queries"][-1]) == expected_cte_count


def test_same_duplicate_subtree(session):
Expand Down Expand Up @@ -624,7 +641,7 @@ def test_same_duplicate_subtree(session):


@pytest.mark.parametrize("use_different_df", [True, False])
def test_cte_preserves_join_suffix_aliases(session, use_different_df):
def test_cte_preserves_join_suffix_aliases(session, use_different_df, is_connect_mode):
df_ad_group = session.create_dataframe(
[["1048771", "group_1", "campaign_1"]],
schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"],
Expand Down Expand Up @@ -695,8 +712,14 @@ def test_cte_preserves_join_suffix_aliases(session, use_different_df):
assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql
# when using different df_ad_group with disambiguation, because rsuffix in join,
# they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE
# However there is still a CTE for create_dataframe call
assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == 1
# However there is still a CTE for create_dataframe call.
# In connect mode with use_different_df, all create_dataframe calls are
# distinct objects so no CTEs are produced.
if is_connect_mode and use_different_df:
expected_cte_count = 0
else:
expected_cte_count = 1
assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == expected_cte_count


@pytest.mark.parametrize(
Expand Down Expand Up @@ -807,7 +830,7 @@ def test_explain(session):
assert "WITH SNOWPARK_TEMP_CTE" in explain_string


def test_sql_simplifier(session):
def test_sql_simplifier(session, is_connect_mode):
if not session._sql_simplifier_enabled:
pytest.skip("SQL simplifier is not enabled")

Expand All @@ -822,6 +845,9 @@ def test_sql_simplifier(session):
df2 = df1.select("a", "b")
df3 = df1.select("a", "b").select("a", "b")
df4 = df1.union_by_name(df2).union_by_name(df3)
# df1, df2, df3 are different Python objects that simplify to the same SQL.
# In connect mode they are not deduplicated, but df (create_dataframe) is
# still the same object appearing across all branches → still CTE'd.
check_result(
session,
df4,
Expand All @@ -832,11 +858,35 @@ def test_sql_simplifier(session):
join_count=0,
)
with SqlCounter(query_count=0, describe_count=0):
# after applying sql simplifier, there is only one CTE (df1, df2, df3 have the same query)
assert (
count_number_of_ctes(Utils.normalize_sql(df4.queries["queries"][-1])) == 1
)
assert Utils.normalize_sql(df4.queries["queries"][-1]).count(filter_clause) == 1
if is_connect_mode:
# df1, df2, df3 are different objects → not merged.
# Only df (create_dataframe) is the same object across all branches → 1 CTE.
# Generated SQL:
# WITH CTE AS (SELECT $1 AS "A", $2 AS "B" FROM VALUES ...)
# (SELECT "A","B" FROM (CTE) WHERE ("A"=1))
# UNION (SELECT "A","B" FROM (CTE) WHERE ("A"=1))
# UNION (SELECT "A","B" FROM (CTE) WHERE ("A"=1))
assert (
count_number_of_ctes(Utils.normalize_sql(df4.queries["queries"][-1]))
== 1
)
assert (
Utils.normalize_sql(df4.queries["queries"][-1]).count(filter_clause)
== 3
)
else:
# df1, df2, df3 all simplify to the same SQL and are merged into 1 CTE.
# Generated SQL:
# WITH CTE AS (SELECT "A","B" FROM (VALUES ...) WHERE ("A"=1))
# (CTE) UNION (CTE) UNION (CTE)
assert (
count_number_of_ctes(Utils.normalize_sql(df4.queries["queries"][-1]))
== 1
)
assert (
Utils.normalize_sql(df4.queries["queries"][-1]).count(filter_clause)
== 1
)

df5 = df1.join(df2).join(df3)
check_result(
Expand Down Expand Up @@ -988,18 +1038,20 @@ def test_sql_non_select(session):
)


def test_sql_with(session):
def test_sql_with(session, is_connect_mode):
df1 = session.sql("with t as (select 1 as A) select * from t")
df2 = session.sql("with t as (select 1 as A) select * from t")

df_result = df1.union(df2).select("A").filter(lit(True))

# df1 and df2 are different Python objects with the same SQL.
# In connect mode they should NOT be deduplicated.
check_result(
session,
df_result,
# with ... select is also treated as a select query
# see is_sql_select_statement() function
expect_cte_optimized=True,
expect_cte_optimized=not is_connect_mode,
query_count=1,
describe_count=0,
union_count=1,
Expand Down Expand Up @@ -1325,7 +1377,7 @@ def test_table_select_cte(session):
],
)
def test_dataframe_queries_with_cte_reuses_schema_cache(
session, reduce_describe_enabled, expected_describe_counts
session, reduce_describe_enabled, expected_describe_counts, is_connect_mode
):
"""Test that calling dataframe.queries (not same dataframe but same operation) multiple times with CTE optimization
does not issue extra DESCRIBE queries when reduce_describe_query_enabled is True.
Expand All @@ -1335,9 +1387,13 @@ def test_dataframe_queries_with_cte_reuses_schema_cache(
identical SQL (with same CTE names), allowing the schema cache to hit.
"""

# randomize column names to avoid schema cache hits from prior test runs in the same session.
col_a = f"col_{uuid.uuid4().hex[:8]}"
col_b = f"col_{uuid.uuid4().hex[:8]}"

def create_cte_dataframe():
"""Create a DataFrame that triggers CTE optimization (same df used twice)."""
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
df = session.create_dataframe([[1, 2], [3, 4]], schema=[col_a, col_b])
return df.union_all(df)

def access_queries_and_schema(df):
Expand All @@ -1347,7 +1403,7 @@ def access_queries_and_schema(df):

with mock.patch.object(
session, "_reduce_describe_query_enabled", reduce_describe_enabled
), mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True):
):
for expected_describe_count in expected_describe_counts:
df_union = create_cte_dataframe()
with SqlCounter(query_count=0, describe_count=expected_describe_count):
Expand Down
Loading
Loading