Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from snowflake.snowpark._internal.compiler.utils import (
TreeNode,
is_active_transaction,
is_with_query_block,
replace_child,
update_resolvable_node,
)
Expand Down Expand Up @@ -271,7 +272,7 @@ def _try_to_breakdown_plan(self, root: TreeNode) -> List[LogicalPlan]:
)
break

partition = self._get_partitioned_plan(root, child)
partition = self._get_partitioned_plan(child)
plans.append(partition)
complexity_score = get_complexity_score(root)

Expand Down Expand Up @@ -356,7 +357,21 @@ def _find_node_to_breakdown(
current_node_validity_statistics,
)

def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePlan:
def _extract_child_from_with_query_block(
self, child: Optional[LogicalPlan]
) -> TreeNode:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to utils right below is_with_query_block function to group the functionality.

if isinstance(child, WithQueryBlock):
return child.children[0]
if isinstance(child, SnowflakePlan):
return self._extract_child_from_with_query_block(child.source_plan)
if isinstance(child, SelectSnowflakePlan):
return self._extract_child_from_with_query_block(child.snowflake_plan)

raise ValueError(
f"Invalid node type {type(child)} for partitioning."
) # pragma: no cover

def _get_partitioned_plan(self, child: TreeNode) -> SnowflakePlan:
"""This method takes cuts the child out from the root, creates a temp table plan for the
partitioned child and returns the plan. The steps involved are:

Expand All @@ -375,7 +390,9 @@ def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePla
[temp_table_name],
None,
SaveMode.ERROR_IF_EXISTS,
child,
self._extract_child_from_with_query_block(child)
if is_with_query_block(child)
else child,
table_type="temp",
creation_source=TableCreationSource.LARGE_QUERY_BREAKDOWN,
)
Expand Down
21 changes: 11 additions & 10 deletions src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,17 @@ def is_active_transaction(session):
return session._run_query("SELECT CURRENT_TRANSACTION()")[0][0] is not None


def is_with_query_block(node: Optional[LogicalPlan]) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the node needs to be optional? when it will be none?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will be none if node is a snowflakeplan and node.source_plan is None. But it can be handled within the if-else logic so I'll make it non-optional here

if isinstance(node, WithQueryBlock):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment here that this function is check whether a given logical node is a node or resolved SnowflakePlan or Selectable for WithQueryBlock

return True
if isinstance(node, SnowflakePlan):
return is_with_query_block(node.source_plan)
if isinstance(node, SelectSnowflakePlan):
return is_with_query_block(node.snowflake_plan)

return False


def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None:
"""A helper function to plot the query plan tree using graphviz useful for debugging.
It plots the plan if the environment variable ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING
Expand Down Expand Up @@ -456,16 +467,6 @@ def get_sql_text(node: LogicalPlan) -> str: # pragma: no cover

return f"{name=}\n{score=}, {ref_ctes=}, {sql_size=}\n{sql_preview=}"

def is_with_query_block(node: Optional[LogicalPlan]) -> bool: # pragma: no cover
if isinstance(node, WithQueryBlock):
return True
if isinstance(node, SnowflakePlan):
return is_with_query_block(node.source_plan)
if isinstance(node, SelectSnowflakePlan):
return is_with_query_block(node.snowflake_plan)

return False

g = graphviz.Graph(format="png")

curr_level = [root]
Expand Down
15 changes: 7 additions & 8 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import logging
import os
import re
import tempfile
from unittest.mock import patch

Expand Down Expand Up @@ -219,8 +218,7 @@ def test_breakdown_at_with_query_node(session):
queries = final_df.queries
assert len(queries["queries"]) == 2
assert queries["queries"][0].startswith("CREATE SCOPED TEMPORARY TABLE")
# SNOW-1734385: Remove it when the issue is fixed
assert "WITH SNOWPARK_TEMP_CTE_" in queries["queries"][0]
assert "WITH SNOWPARK_TEMP_CTE_" not in queries["queries"][0]
assert len(queries["post_actions"]) == 1


Expand Down Expand Up @@ -773,12 +771,13 @@ def test_large_query_breakdown_with_nested_cte(session):
queries = final_df.queries
assert len(queries["queries"]) == 2
assert len(queries["post_actions"]) == 1
match = re.search(r"SNOWPARK_TEMP_CTE_[\w]+", queries["queries"][0])
assert match is not None
cte_name_for_first_partition = match.group()

# assert that the first query contains the base temp table name
assert temp_table in queries["queries"][0]

# assert that query for upper cte node is re-written and does not
# contain the cte name for the first partition
assert cte_name_for_first_partition not in queries["queries"][1]
# contain query for the base temp table
assert temp_table not in queries["queries"][1]

check_result_with_and_without_breakdown(session, final_df)

Expand Down
Loading