Skip to content

Commit 6ffbddc

Browse files
authored
SNOW-1734385: Do not create CTE when partitioning WithQueryBlocks (#2948)
1 parent a652886 commit 6ffbddc

File tree

3 files changed

+42
-21
lines changed

3 files changed

+42
-21
lines changed

src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@
5151
)
5252
from snowflake.snowpark._internal.compiler.utils import (
5353
TreeNode,
54+
extract_child_from_with_query_block,
5455
is_active_transaction,
56+
is_with_query_block,
5557
replace_child,
5658
update_resolvable_node,
5759
)
@@ -271,7 +273,7 @@ def _try_to_breakdown_plan(self, root: TreeNode) -> List[LogicalPlan]:
271273
)
272274
break
273275

274-
partition = self._get_partitioned_plan(root, child)
276+
partition = self._get_partitioned_plan(child)
275277
plans.append(partition)
276278
complexity_score = get_complexity_score(root)
277279

@@ -356,7 +358,7 @@ def _find_node_to_breakdown(
356358
current_node_validity_statistics,
357359
)
358360

359-
def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePlan:
361+
def _get_partitioned_plan(self, child: TreeNode) -> SnowflakePlan:
360362
"""This method takes cuts the child out from the root, creates a temp table plan for the
361363
partitioned child and returns the plan. The steps involved are:
362364
@@ -375,7 +377,9 @@ def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePla
375377
[temp_table_name],
376378
None,
377379
SaveMode.ERROR_IF_EXISTS,
378-
child,
380+
extract_child_from_with_query_block(child)
381+
if is_with_query_block(child)
382+
else child,
379383
table_type="temp",
380384
creation_source=TableCreationSource.LARGE_QUERY_BREAKDOWN,
381385
)

src/snowflake/snowpark/_internal/compiler/utils.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,34 @@ def is_active_transaction(session):
356356
return session._run_query("SELECT CURRENT_TRANSACTION()")[0][0] is not None
357357

358358

359+
def extract_child_from_with_query_block(child: LogicalPlan) -> TreeNode:
360+
"""Given a WithQueryBlock node, or a node that contains a WithQueryBlock node, this method
361+
extracts the child node from the WithQueryBlock node and returns it."""
362+
if isinstance(child, WithQueryBlock):
363+
return child.children[0]
364+
if isinstance(child, SnowflakePlan) and child.source_plan is not None:
365+
return extract_child_from_with_query_block(child.source_plan)
366+
if isinstance(child, SelectSnowflakePlan):
367+
return extract_child_from_with_query_block(child.snowflake_plan)
368+
369+
raise ValueError(
370+
f"Invalid node type {type(child)} for partitioning."
371+
) # pragma: no cover
372+
373+
374+
def is_with_query_block(node: LogicalPlan) -> bool:
375+
"""Given a node, this method checks if the node is a WithQueryBlock node or contains a
376+
WithQueryBlock node."""
377+
if isinstance(node, WithQueryBlock):
378+
return True
379+
if isinstance(node, SnowflakePlan) and node.source_plan is not None:
380+
return is_with_query_block(node.source_plan)
381+
if isinstance(node, SelectSnowflakePlan):
382+
return is_with_query_block(node.snowflake_plan)
383+
384+
return False
385+
386+
359387
def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None:
360388
"""A helper function to plot the query plan tree using graphviz useful for debugging.
361389
It plots the plan if the environment variable ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING
@@ -456,16 +484,6 @@ def get_sql_text(node: LogicalPlan) -> str: # pragma: no cover
456484

457485
return f"{name=}\n{score=}, {ref_ctes=}, {sql_size=}\n{sql_preview=}"
458486

459-
def is_with_query_block(node: Optional[LogicalPlan]) -> bool: # pragma: no cover
460-
if isinstance(node, WithQueryBlock):
461-
return True
462-
if isinstance(node, SnowflakePlan):
463-
return is_with_query_block(node.source_plan)
464-
if isinstance(node, SelectSnowflakePlan):
465-
return is_with_query_block(node.snowflake_plan)
466-
467-
return False
468-
469487
g = graphviz.Graph(format="png")
470488

471489
curr_level = [root]

tests/integ/test_large_query_breakdown.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import logging
77
import os
8-
import re
98
import tempfile
109
from unittest.mock import patch
1110

@@ -219,8 +218,7 @@ def test_breakdown_at_with_query_node(session):
219218
queries = final_df.queries
220219
assert len(queries["queries"]) == 2
221220
assert queries["queries"][0].startswith("CREATE SCOPED TEMPORARY TABLE")
222-
# SNOW-1734385: Remove it when the issue is fixed
223-
assert "WITH SNOWPARK_TEMP_CTE_" in queries["queries"][0]
221+
assert "WITH SNOWPARK_TEMP_CTE_" not in queries["queries"][0]
224222
assert len(queries["post_actions"]) == 1
225223

226224

@@ -773,12 +771,13 @@ def test_large_query_breakdown_with_nested_cte(session):
773771
queries = final_df.queries
774772
assert len(queries["queries"]) == 2
775773
assert len(queries["post_actions"]) == 1
776-
match = re.search(r"SNOWPARK_TEMP_CTE_[\w]+", queries["queries"][0])
777-
assert match is not None
778-
cte_name_for_first_partition = match.group()
774+
775+
# assert that the first query contains the base temp table name
776+
assert temp_table in queries["queries"][0]
777+
779778
# assert that query for upper cte node is re-written and does not
780-
# contain the cte name for the first partition
781-
assert cte_name_for_first_partition not in queries["queries"][1]
779+
# contain query for the base temp table
780+
assert temp_table not in queries["queries"][1]
782781

783782
check_result_with_and_without_breakdown(session, final_df)
784783

0 commit comments

Comments
 (0)