Skip to content

Commit 830c8b6

Browse files
authored
Fix limit() after sort() in aggregation query (#3596)
1 parent 041e624 commit 830c8b6

File tree

5 files changed

+141
-18
lines changed

5 files changed

+141
-18
lines changed

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,9 +1166,10 @@ def do_resolve_with_resolved_children(
11661166
)
11671167

11681168
if isinstance(logical_plan, Limit):
1169-
on_top_of_order_by = isinstance(
1170-
logical_plan.child, SnowflakePlan
1171-
) and isinstance(logical_plan.child.source_plan, Sort)
1169+
on_top_of_order_by = logical_plan.is_limit_append or (
1170+
isinstance(logical_plan.child, SnowflakePlan)
1171+
and isinstance(logical_plan.child.source_plan, Sort)
1172+
)
11721173
return self.plan_builder.limit(
11731174
self.to_sql_try_avoid_cast(
11741175
logical_plan.limit_expr, df_aliased_col_name_to_real_col_name

src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,18 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
288288

289289
class Limit(LogicalPlan):
290290
def __init__(
291-
self, limit_expr: Expression, offset_expr: Expression, child: LogicalPlan
291+
self,
292+
limit_expr: Expression,
293+
offset_expr: Expression,
294+
child: LogicalPlan,
295+
is_limit_append: bool = False,
292296
) -> None:
293297
super().__init__()
294298
self.limit_expr = limit_expr
295299
self.offset_expr = offset_expr
296300
self.child = child
297301
self.children.append(child)
302+
self.is_limit_append = is_limit_append
298303

299304
@property
300305
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:

src/snowflake/snowpark/dataframe.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def __init__(
617617

618618
self._statement_params = None
619619
self.is_cached: bool = is_cached #: Whether the dataframe is cached.
620-
self._is_grouped_by_and_aggregated = False
620+
self._ops_after_agg = None
621621

622622
# Whether all columns are VARIANT data type,
623623
# which support querying nested fields via dot notations
@@ -1970,7 +1970,8 @@ def filter(
19701970
# the filtering for dataframe after aggregation without nesting using HAVING
19711971
if (
19721972
context._is_snowpark_connect_compatible_mode
1973-
and self._is_grouped_by_and_aggregated
1973+
and self._ops_after_agg is not None
1974+
and "filter" not in self._ops_after_agg
19741975
):
19751976
having_plan = Filter(filter_col_expr, self._plan, is_having=True)
19761977
if self._select_statement:
@@ -1985,7 +1986,8 @@ def filter(
19851986
)
19861987
else:
19871988
df = self._with_plan(having_plan, _ast_stmt=stmt)
1988-
df._is_grouped_by_and_aggregated = True
1989+
df._ops_after_agg = self._ops_after_agg.copy()
1990+
df._ops_after_agg.add("filter")
19891991
return df
19901992
else:
19911993
if self._select_statement:
@@ -2134,7 +2136,8 @@ def sort(
21342136
# the sorting for dataframe after aggregation without nesting
21352137
if (
21362138
context._is_snowpark_connect_compatible_mode
2137-
and self._is_grouped_by_and_aggregated
2139+
and self._ops_after_agg is not None
2140+
and "sort" not in self._ops_after_agg
21382141
):
21392142
sort_plan = Sort(sort_exprs, self._plan, is_order_by_append=True)
21402143
if self._select_statement:
@@ -2149,7 +2152,8 @@ def sort(
21492152
)
21502153
else:
21512154
df = self._with_plan(sort_plan, _ast_stmt=stmt)
2152-
df._is_grouped_by_and_aggregated = True
2155+
df._ops_after_agg = self._ops_after_agg.copy()
2156+
df._ops_after_agg.add("sort")
21532157
return df
21542158
else:
21552159
df = (
@@ -2855,13 +2859,39 @@ def limit(
28552859
else:
28562860
stmt = None
28572861

2858-
if self._select_statement:
2862+
# In snowpark_connect_compatible mode, we need to handle
2863+
# the limit for dataframe after aggregation without nesting
2864+
if (
2865+
context._is_snowpark_connect_compatible_mode
2866+
and self._ops_after_agg is not None
2867+
and "limit" not in self._ops_after_agg
2868+
):
2869+
limit_plan = Limit(
2870+
Literal(n), Literal(offset), self._plan, is_limit_append=True
2871+
)
2872+
if self._select_statement:
2873+
df = self._with_plan(
2874+
self._session._analyzer.create_select_statement(
2875+
from_=self._session._analyzer.create_select_snowflake_plan(
2876+
limit_plan, analyzer=self._session._analyzer
2877+
),
2878+
analyzer=self._session._analyzer,
2879+
),
2880+
_ast_stmt=stmt,
2881+
)
2882+
else:
2883+
df = self._with_plan(limit_plan, _ast_stmt=stmt)
2884+
df._ops_after_agg = self._ops_after_agg.copy()
2885+
df._ops_after_agg.add("limit")
2886+
return df
2887+
else:
2888+
if self._select_statement:
2889+
return self._with_plan(
2890+
self._select_statement.limit(n, offset=offset), _ast_stmt=stmt
2891+
)
28592892
return self._with_plan(
2860-
self._select_statement.limit(n, offset=offset), _ast_stmt=stmt
2893+
Limit(Literal(n), Literal(offset), self._plan), _ast_stmt=stmt
28612894
)
2862-
return self._with_plan(
2863-
Limit(Literal(n), Literal(offset), self._plan), _ast_stmt=stmt
2864-
)
28652895

28662896
@df_api_usage
28672897
@publicapi

src/snowflake/snowpark/relational_grouped_dataframe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def agg(
324324
agg_exprs.append(_str_to_expr(e[1], _emit_ast)(col_expr))
325325

326326
df = self._to_df(agg_exprs, _emit_ast=False)
327-
df._is_grouped_by_and_aggregated = True
327+
df._ops_after_agg = set()
328328

329329
if _emit_ast:
330330
df._ast_id = stmt.uid
@@ -515,7 +515,7 @@ def end_partition(
515515
),
516516
_emit_ast=False,
517517
)
518-
df._is_grouped_by_and_aggregated = True
518+
df._ops_after_agg = set()
519519

520520
if _emit_ast:
521521
stmt = working_dataframe._session._ast_batch.bind()
@@ -694,7 +694,7 @@ def count(self, _emit_ast: bool = True) -> DataFrame:
694694
],
695695
_emit_ast=False,
696696
)
697-
df._is_grouped_by_and_aggregated = True
697+
df._ops_after_agg = set()
698698

699699
# TODO: count seems similar to mean, min, .... Can we unify implementation here?
700700
if _emit_ast:
@@ -730,7 +730,7 @@ def _function(
730730
)._expression
731731
agg_exprs.append(expr)
732732
df = self._to_df(agg_exprs)
733-
df._is_grouped_by_and_aggregated = True
733+
df._ops_after_agg = set()
734734

735735
if _emit_ast:
736736
stmt = self._dataframe._session._ast_batch.bind()

tests/integ/test_df_aggregate.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,3 +841,90 @@ def test_agg_filter_and_sort_with_grouping_snowpark_connect_compatible(session):
841841
assert results6[0][2] == 1 # gc=1 for NULL course
842842
finally:
843843
context._is_snowpark_connect_compatible_mode = original_value
844+
845+
846+
@pytest.mark.skipif(
847+
"config.getoption('local_testing_mode', default=False)",
848+
reason="HAVING, ORDER BY append, and limit append are not supported in local testing mode",
849+
)
850+
def test_filter_sort_limit_snowpark_connect_compatible(session):
851+
original_value = context._is_snowpark_connect_compatible_mode
852+
853+
try:
854+
context._is_snowpark_connect_compatible_mode = True
855+
df = session.create_dataframe(
856+
[(1, 2, 3), (3, 2, 1), (3, 2, 1)], ["a", "b", "c"]
857+
)
858+
859+
# Basic aggregation with filter, sort, limit - should be in same level
860+
agg_df = df.group_by("a").agg(
861+
sum_("b").alias("sum_b"), count("c").alias("count_c")
862+
)
863+
result_df1 = agg_df.filter(col("sum_b") > 1).sort("a").limit(10)
864+
865+
# Check the result
866+
Utils.check_answer(result_df1, [Row(1, 2, 1), Row(3, 4, 2)])
867+
868+
# Check that filter, sort, and limit are in the same query level (single SELECT)
869+
query1 = result_df1.queries["queries"][-1]
870+
# Count SELECT statements - should be 3 for operations in same level
871+
assert query1.upper().count("SELECT") == 3
872+
assert "ORDER BY" in query1.upper()
873+
assert "LIMIT" in query1.upper()
874+
assert "HAVING" in query1.upper()
875+
876+
# Duplicate sort operations - second sort should be in next level
877+
result_df2 = agg_df.sort("a").sort("sum_b")
878+
879+
# Check the result
880+
Utils.check_answer(result_df2, [Row(1, 2, 1), Row(3, 4, 2)])
881+
882+
# Check that the second sort creates a new query level
883+
query2 = result_df2.queries["queries"][-1]
884+
# Should have 4 SELECT statements for nested query
885+
assert query2.upper().count("SELECT") == 4
886+
887+
# filter.sort().limit().sort() - last sort should be in next level
888+
result_df3 = (
889+
agg_df.filter(col("count_c") >= 1)
890+
.sort("a")
891+
.limit(10)
892+
.sort("sum_b", ascending=False)
893+
)
894+
895+
# Check the result
896+
Utils.check_answer(result_df3, [Row(3, 4, 2), Row(1, 2, 1)])
897+
898+
# Check query structure - should have nested SELECT due to sort after limit
899+
query3 = result_df3.queries["queries"][-1]
900+
assert query3.upper().count("SELECT") == 4
901+
902+
# limit().limit() - second limit should create new level
903+
result_df5 = agg_df.limit(10).limit(1)
904+
905+
# Check the result (should return only first row)
906+
assert result_df5.count() == 1
907+
908+
# Check query structure - nested due to second limit
909+
query5 = result_df5.queries["queries"][-1]
910+
assert query5.upper().count("SELECT") == 4
911+
912+
# Complex chain - filter().sort().limit().filter().sort()
913+
result_df6 = (
914+
agg_df.filter(col("sum_b") >= 2)
915+
.sort("a")
916+
.limit(10)
917+
.filter(col("count_c") > 1)
918+
.sort("sum_b", ascending=False)
919+
)
920+
921+
# Check the result
922+
Utils.check_answer(result_df6, [Row(3, 4, 2)])
923+
924+
# Check query structure - should have multiple levels due to operations after limit
925+
query6 = result_df6.queries["queries"][-1]
926+
# Should have 4 SELECT statements
927+
assert query6.upper().count("SELECT") == 4
928+
929+
finally:
930+
context._is_snowpark_connect_compatible_mode = original_value

0 commit comments

Comments
 (0)