Skip to content

Commit a79fe9f

Browse files
authored
SNOW-1865904 fix query gen when nested cte node is partitioned (#2816)
1 parent 10c612e commit a79fe9f

File tree

4 files changed

+47
-11
lines changed

4 files changed

+47
-11
lines changed

src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -565,19 +565,14 @@ def _replace_child_and_update_ancestors(
565565
temp_table_selectable.post_actions = [drop_table_query]
566566

567567
parents = self._parent_map[child]
568-
updated_nodes = set()
569568
for parent in parents:
570569
replace_child(parent, child, temp_table_selectable, self._query_generator)
571570

572571
nodes_to_reset = list(parents)
573572
while nodes_to_reset:
574573
node = nodes_to_reset.pop()
575-
if node in updated_nodes:
576-
# Skip if the node is already updated.
577-
continue
578574

579575
update_resolvable_node(node, self._query_generator)
580-
updated_nodes.add(node)
581576

582577
parents = self._parent_map[node]
583578
nodes_to_reset.extend(parents)

src/snowflake/snowpark/_internal/compiler/plan_compiler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
185185
error_type=type(e).__name__,
186186
error_message=str(e),
187187
)
188-
pass
189188

190189
return self.replace_temp_obj_placeholders(queries)
191190

src/snowflake/snowpark/_internal/compiler/query_generator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,11 @@ def do_resolve_with_resolved_children(
227227

228228
elif isinstance(logical_plan, WithQueryBlock):
229229
resolved_child = resolved_children[logical_plan.children[0]]
230-
# record the CTE definition of the current block
231-
if logical_plan.name not in self.resolved_with_query_block:
232-
self.resolved_with_query_block[
233-
logical_plan.name
234-
] = resolved_child.queries[-1]
230+
# record the CTE definition of the current block or update the query when
231+
# the child is re-resolved during optimization stage.
232+
self.resolved_with_query_block[logical_plan.name] = resolved_child.queries[
233+
-1
234+
]
235235

236236
resolved_plan = self.plan_builder.with_query_block(
237237
logical_plan,

tests/integ/test_large_query_breakdown.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import logging
77
import os
8+
import re
89
import tempfile
910
from unittest.mock import patch
1011

@@ -734,6 +735,47 @@ def test_optimization_skipped_with_exceptions(
734735
assert kwargs["error_type"] == error_type.__name__
735736

736737

738+
def test_large_query_breakdown_with_nested_cte(session):
739+
session.cte_optimization_enabled = True
740+
set_bounds(session, 15, 20)
741+
742+
temp_table = Utils.random_table_name()
743+
session.create_dataframe([(1, 2), (3, 4)], ["A", "B"]).write.save_as_table(
744+
temp_table, table_type="temp"
745+
)
746+
base_select = session.table(temp_table)
747+
for i in range(2):
748+
base_select = base_select.with_column("A", col("A") + lit(i))
749+
750+
base_df = base_select.union_all(base_select)
751+
752+
df1 = base_df.with_column("A", col("A") + 1)
753+
df2 = base_df.with_column("B", col("B") + 1)
754+
for i in range(2):
755+
df1 = df1.with_column("A", col("A") + i)
756+
757+
df1 = df1.group_by("A").agg(sum_distinct(col("B")).alias("B"))
758+
df2 = df2.group_by("B").agg(sum_distinct(col("A")).alias("A"))
759+
mid_final_df = df1.union_all(df2)
760+
761+
mid1 = mid_final_df.filter(col("A") > 10)
762+
mid2 = mid_final_df.filter(col("B") > 3)
763+
final_df = mid1.union_all(mid2)
764+
765+
with SqlCounter(query_count=1, describe_count=0):
766+
queries = final_df.queries
767+
assert len(queries["queries"]) == 2
768+
assert len(queries["post_actions"]) == 1
769+
match = re.search(r"SNOWPARK_TEMP_CTE_[\w]+", queries["queries"][0])
770+
assert match is not None
771+
cte_name_for_first_partition = match.group()
772+
# assert that query for upper cte node is re-written and does not
773+
# contain the cte name for the first partition
774+
assert cte_name_for_first_partition not in queries["queries"][1]
775+
776+
check_result_with_and_without_breakdown(session, final_df)
777+
778+
737779
def test_complexity_bounds_affect_num_partitions(session, large_query_df):
738780
"""Test complexity bounds affect number of partitions.
739781
Also test that when partitions are added, drop table queries are added.

0 commit comments

Comments
 (0)