Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
)
from snowflake.snowpark._internal.compiler.utils import (
TreeNode,
extract_child_from_with_query_block,
is_active_transaction,
is_with_query_block,
replace_child,
update_resolvable_node,
)
Expand Down Expand Up @@ -271,7 +273,7 @@ def _try_to_breakdown_plan(self, root: TreeNode) -> List[LogicalPlan]:
)
break

partition = self._get_partitioned_plan(root, child)
partition = self._get_partitioned_plan(child)
plans.append(partition)
complexity_score = get_complexity_score(root)

Expand Down Expand Up @@ -356,7 +358,7 @@ def _find_node_to_breakdown(
current_node_validity_statistics,
)

def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePlan:
def _get_partitioned_plan(self, child: TreeNode) -> SnowflakePlan:
"""This method takes cuts the child out from the root, creates a temp table plan for the
partitioned child and returns the plan. The steps involved are:

Expand All @@ -375,7 +377,9 @@ def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePla
[temp_table_name],
None,
SaveMode.ERROR_IF_EXISTS,
child,
extract_child_from_with_query_block(child)
if is_with_query_block(child)
else child,
table_type="temp",
creation_source=TableCreationSource.LARGE_QUERY_BREAKDOWN,
)
Expand Down
38 changes: 28 additions & 10 deletions src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,34 @@ def is_active_transaction(session):
return session._run_query("SELECT CURRENT_TRANSACTION()")[0][0] is not None


def extract_child_from_with_query_block(child: LogicalPlan) -> TreeNode:
"""Given a WithQueryBlock node, or a node that contains a WithQueryBlock node, this method
extracts the child node from the WithQueryBlock node and returns it."""
if isinstance(child, WithQueryBlock):
return child.children[0]
if isinstance(child, SnowflakePlan) and child.source_plan is not None:
return extract_child_from_with_query_block(child.source_plan)
if isinstance(child, SelectSnowflakePlan):
return extract_child_from_with_query_block(child.snowflake_plan)

raise ValueError(
f"Invalid node type {type(child)} for partitioning."
) # pragma: no cover


def is_with_query_block(node: LogicalPlan) -> bool:
"""Given a node, this method checks if the node is a WithQueryBlock node or contains a
WithQueryBlock node."""
if isinstance(node, WithQueryBlock):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment here that this function is check whether a given logical node is a node or resolved SnowflakePlan or Selectable for WithQueryBlock

return True
if isinstance(node, SnowflakePlan) and node.source_plan is not None:
return is_with_query_block(node.source_plan)
if isinstance(node, SelectSnowflakePlan):
return is_with_query_block(node.snowflake_plan)

return False


def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None:
"""A helper function to plot the query plan tree using graphviz useful for debugging.
It plots the plan if the environment variable ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING
Expand Down Expand Up @@ -456,16 +484,6 @@ def get_sql_text(node: LogicalPlan) -> str: # pragma: no cover

return f"{name=}\n{score=}, {ref_ctes=}, {sql_size=}\n{sql_preview=}"

def is_with_query_block(node: Optional[LogicalPlan]) -> bool: # pragma: no cover
if isinstance(node, WithQueryBlock):
return True
if isinstance(node, SnowflakePlan):
return is_with_query_block(node.source_plan)
if isinstance(node, SelectSnowflakePlan):
return is_with_query_block(node.snowflake_plan)

return False

g = graphviz.Graph(format="png")

curr_level = [root]
Expand Down
15 changes: 7 additions & 8 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import logging
import os
import re
import tempfile
from unittest.mock import patch

Expand Down Expand Up @@ -219,8 +218,7 @@ def test_breakdown_at_with_query_node(session):
queries = final_df.queries
assert len(queries["queries"]) == 2
assert queries["queries"][0].startswith("CREATE SCOPED TEMPORARY TABLE")
# SNOW-1734385: Remove it when the issue is fixed
assert "WITH SNOWPARK_TEMP_CTE_" in queries["queries"][0]
assert "WITH SNOWPARK_TEMP_CTE_" not in queries["queries"][0]
assert len(queries["post_actions"]) == 1


Expand Down Expand Up @@ -773,12 +771,13 @@ def test_large_query_breakdown_with_nested_cte(session):
queries = final_df.queries
assert len(queries["queries"]) == 2
assert len(queries["post_actions"]) == 1
match = re.search(r"SNOWPARK_TEMP_CTE_[\w]+", queries["queries"][0])
assert match is not None
cte_name_for_first_partition = match.group()

# assert that the first query contains the base temp table name
assert temp_table in queries["queries"][0]

# assert that query for upper cte node is re-written and does not
# contain the cte name for the first partition
assert cte_name_for_first_partition not in queries["queries"][1]
# contain query for the base temp table
assert temp_table not in queries["queries"][1]

check_result_with_and_without_breakdown(session, final_df)

Expand Down
Loading