From 0d608752dd9f375c3a5beed5353ae98101be222d Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Mon, 23 Mar 2026 12:47:34 -0700 Subject: [PATCH 1/3] defer having/sort/limit construction --- src/snowflake/snowpark/dataframe.py | 153 ++++++++++++------ .../snowpark/relational_grouped_dataframe.py | 10 ++ tests/integ/test_df_aggregate.py | 146 +++++++++++++++++ 3 files changed, 257 insertions(+), 52 deletions(-) diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 3dfe045c84..eff5b33a4a 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -689,7 +689,15 @@ def __init__( self._statement_params = None self.is_cached: bool = is_cached #: Whether the dataframe is cached. + # Internal state variables used to construct flattened GROUP BY clauses in the correct order + # in SCOS compatibility mode. + # See comments on `_build_post_agg_df` for details. self._ops_after_agg = None + self._agg_base_plan = None + self._agg_base_select_statement = None + self._pending_having = None + self._pending_order_by = None + self._pending_limit = None # Whether all columns are VARIANT data type, # which support querying nested fields via dot notations @@ -2114,28 +2122,24 @@ def filter( stmt = _ast_stmt # In snowpark_connect_compatible mode, we need to handle - # the filtering for dataframe after aggregation without nesting using HAVING + # the filtering for dataframe after aggregation without nesting using HAVING. + # We defer the HAVING expression and rebuild the plan from the + # aggregate base so that SQL clauses are emitted in the correct order + # (HAVING -> ORDER BY -> LIMIT) regardless of the user's call order. if ( context._is_snowpark_connect_compatible_mode and self._ops_after_agg is not None and "filter" not in self._ops_after_agg ): - having_plan = Filter(filter_col_expr, self._plan, is_having=True) - if self._select_statement: - df = self._with_plan( - self._session._analyzer.create_select_statement( - from_=self._session._analyzer.create_select_snowflake_plan( - having_plan, analyzer=self._session._analyzer - ), - analyzer=self._session._analyzer, - ), - _ast_stmt=stmt, - ) - else: - df = self._with_plan(having_plan, _ast_stmt=stmt) - df._ops_after_agg = self._ops_after_agg.copy() - df._ops_after_agg.add("filter") - return df + new_ops = self._ops_after_agg.copy() + new_ops.add("filter") + return self._build_post_agg_df( + ops_after_agg=new_ops, + pending_having=filter_col_expr, + pending_order_by=self._pending_order_by, + pending_limit=self._pending_limit, + _ast_stmt=stmt, + ) else: if self._select_statement: return self._with_plan( @@ -2330,28 +2334,23 @@ def sort( ) # In snowpark_connect_compatible mode, we need to handle - # the sorting for dataframe after aggregation without nesting + # the sorting for dataframe after aggregation without nesting. + # We defer the ORDER BY expressions and rebuild the plan from + # the aggregate base in correct SQL clause order. if ( context._is_snowpark_connect_compatible_mode and self._ops_after_agg is not None and "sort" not in self._ops_after_agg ): - sort_plan = Sort(sort_exprs, self._plan, is_order_by_append=True) - if self._select_statement: - df = self._with_plan( - self._session._analyzer.create_select_statement( - from_=self._session._analyzer.create_select_snowflake_plan( - sort_plan, analyzer=self._session._analyzer - ), - analyzer=self._session._analyzer, - ), - _ast_stmt=stmt, - ) - else: - df = self._with_plan(sort_plan, _ast_stmt=stmt) - df._ops_after_agg = self._ops_after_agg.copy() - df._ops_after_agg.add("sort") - return df + new_ops = self._ops_after_agg.copy() + new_ops.add("sort") + return self._build_post_agg_df( + ops_after_agg=new_ops, + pending_having=self._pending_having, + pending_order_by=sort_exprs, + pending_limit=self._pending_limit, + _ast_stmt=stmt, + ) else: df = ( self._with_plan(self._select_statement.sort(sort_exprs)) @@ -3057,30 +3056,23 @@ def limit( stmt = None # In snowpark_connect_compatible mode, we need to handle - # the limit for dataframe after aggregation without nesting + # the limit for dataframe after aggregation without nesting. + # We defer the LIMIT values and rebuild the plan from + # the aggregate base in correct SQL clause order. if ( context._is_snowpark_connect_compatible_mode and self._ops_after_agg is not None and "limit" not in self._ops_after_agg ): - limit_plan = Limit( - Literal(n), Literal(offset), self._plan, is_limit_append=True + new_ops = self._ops_after_agg.copy() + new_ops.add("limit") + return self._build_post_agg_df( + ops_after_agg=new_ops, + pending_having=self._pending_having, + pending_order_by=self._pending_order_by, + pending_limit=(n, offset), + _ast_stmt=stmt, ) - if self._select_statement: - df = self._with_plan( - self._session._analyzer.create_select_statement( - from_=self._session._analyzer.create_select_snowflake_plan( - limit_plan, analyzer=self._session._analyzer - ), - analyzer=self._session._analyzer, - ), - _ast_stmt=stmt, - ) - else: - df = self._with_plan(limit_plan, _ast_stmt=stmt) - df._ops_after_agg = self._ops_after_agg.copy() - df._ops_after_agg.add("limit") - return df else: if self._select_statement: return self._with_plan( @@ -6835,6 +6827,63 @@ def dtypes(self) -> List[Tuple[str, str]]: ] return dtypes + def _build_post_agg_df( + self, + ops_after_agg, + pending_having, + pending_order_by, + pending_limit, + _ast_stmt=None, + ) -> "DataFrame": + """ + When constructing group by aggregation queries in SCOS compatibility mode, we must ensure that + filter (HAVING), sorting (ORDER BY), and LIMIT clauses are emitted in the correct order, regardless of + the order in which the user specified those operations. For example: + + df.groupBy("dept").agg( + count("*").alias("headcount"), + avg("salary").alias("avg_salary"), + ) + .orderBy(col("avg_salary").desc()) + .limit(2) + .filter(col("headcount") > 1) + + Even though `orderBy` is the first operation, we must re-order the `filter` to be first because + SQL syntax requires HAVING, ORDER BY, and LIMIT clauses to appear in that specific order. + + We use `_agg_base_plan` and `_agg_base_select_statement` to re-construct SQL with this constraint. + + This method should only be called in SCOS compatibility mode (context._is_snowpark_connect_compatible_mode). + """ + current = self._agg_base_plan + + if pending_having is not None: + current = Filter(pending_having, current, is_having=True) + if pending_order_by is not None: + current = Sort(pending_order_by, current, is_order_by_append=True) + if pending_limit is not None: + n, offset = pending_limit + current = Limit(Literal(n), Literal(offset), current, is_limit_append=True) + + if self._agg_base_select_statement is not None: + new_plan = self._session._analyzer.create_select_statement( + from_=self._session._analyzer.create_select_snowflake_plan( + current, analyzer=self._session._analyzer + ), + analyzer=self._session._analyzer, + ) + else: + new_plan = current + + df = self._with_plan(new_plan, _ast_stmt=_ast_stmt) + df._ops_after_agg = ops_after_agg + df._agg_base_plan = self._agg_base_plan + df._agg_base_select_statement = self._agg_base_select_statement + df._pending_having = pending_having + df._pending_order_by = pending_order_by + df._pending_limit = pending_limit + return df + def _with_plan(self, plan, _ast_stmt=None) -> "DataFrame": """ :param proto.Bind ast_stmt: The AST statement protobuf corresponding to this value. diff --git a/src/snowflake/snowpark/relational_grouped_dataframe.py b/src/snowflake/snowpark/relational_grouped_dataframe.py index 2f29726e4e..03a6201d3e 100644 --- a/src/snowflake/snowpark/relational_grouped_dataframe.py +++ b/src/snowflake/snowpark/relational_grouped_dataframe.py @@ -338,6 +338,8 @@ def agg( # if no grouping exprs, there is already a LIMIT 1 in the query # see aggregate_statement in analyzer_utils.py df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + df._agg_base_plan = df._plan + df._agg_base_select_statement = df._select_statement if _emit_ast: df._ast_id = stmt.uid @@ -531,6 +533,8 @@ def end_partition( # if no grouping exprs, there is already a LIMIT 1 in the query # see aggregate_statement in analyzer_utils.py df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + df._agg_base_plan = df._plan + df._agg_base_select_statement = df._select_statement if _emit_ast: stmt = working_dataframe._session._ast_batch.bind() @@ -766,6 +770,8 @@ def count(self, _emit_ast: bool = True, **kwargs) -> DataFrame: # if no grouping exprs, there is already a LIMIT 1 in the query # see aggregate_statement in analyzer_utils.py df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + df._agg_base_plan = df._plan + df._agg_base_select_statement = df._select_statement # TODO: count seems similar to mean, min, .... Can we unify implementation here? if _emit_ast: @@ -815,6 +821,8 @@ def _function( # if no grouping exprs, there is already a LIMIT 1 in the query # see aggregate_statement in analyzer_utils.py df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + df._agg_base_plan = df._plan + df._agg_base_select_statement = df._select_statement if _emit_ast: stmt = self._dataframe._session._ast_batch.bind() @@ -909,6 +917,8 @@ def ai_agg( # if no grouping exprs, there is already a LIMIT 1 in the query # see aggregate_statement in analyzer_utils.py df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + df._agg_base_plan = df._plan + df._agg_base_select_statement = df._select_statement if _emit_ast: stmt = self._dataframe._session._ast_batch.bind() diff --git a/tests/integ/test_df_aggregate.py b/tests/integ/test_df_aggregate.py index 9a8795fb49..8df22275d8 100644 --- a/tests/integ/test_df_aggregate.py +++ b/tests/integ/test_df_aggregate.py @@ -928,6 +928,152 @@ def test_filter_sort_limit_snowpark_connect_compatible(session, sql_simplifier_e assert query6.upper().count("SELECT") == 4 if sql_simplifier_enabled else 5 +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="exclude_grouping_columns is not supported", +) +def test_group_by_agg_sort_filter(session): + """ + Tests that post-aggregation clauses (HAVING, ORDER BY, LIMIT) are emitted in valid SQL order + regardless of the DataFrame call order (see SNOW-3266495). + + After a GROUP BY, HAVING must appear before ORDER BY, which in turn must appear before LIMIT. + """ + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + df = session.createDataFrame( + [ + (1, "engineering", 80000), + (2, "engineering", 90000), + (3, "sales", 50000), + (4, "sales", 60000), + (5, "hr", 45000), + (6, "hr", 55000), + (7, "engineering", 85000), + ], + ["id", "dept", "salary"], + ) + agg_df = df.groupBy("dept").agg( + count("*").alias("headcount"), + avg("salary").alias("avg_salary"), + ) + + expected_all = [ + Row("engineering", 3, 85000.0), + Row("sales", 2, 55000.0), + Row("hr", 2, 50000.0), + ] + + # sort -> filter: ORDER BY before HAVING in user code, but + # SQL must be HAVING before ORDER BY. + result1 = agg_df.orderBy(col("avg_salary").desc()).filter(col("headcount") > 1) + Utils.check_answer(result1, expected_all) + + # filter -> sort: already in correct SQL clause order. + result2 = agg_df.filter(col("headcount") > 1).orderBy(col("avg_salary").desc()) + Utils.check_answer(result2, expected_all) + + # sort -> filter -> limit (must swap filter with sort) + result3 = ( + agg_df.orderBy(col("avg_salary").desc()) + .filter(col("headcount") > 1) + .limit(2) + ) + Utils.check_answer( + result3, + [ + Row("engineering", 3, 85000.0), + Row("sales", 2, 55000.0), + ], + ) + + # sort -> limit -> filter: all three clauses in wrong order + # (must move filter to first operation) + result4 = ( + agg_df.orderBy(col("avg_salary").desc()) + .limit(2) + .filter(col("headcount") > 1) + ) + Utils.check_answer( + result4, + [ + Row("engineering", 3, 85000.0), + Row("sales", 2, 55000.0), + ], + ) + + # limit -> sort: LIMIT before ORDER BY in user code. + result5 = agg_df.limit(2).orderBy(col("avg_salary").desc()) + assert result5.count() == 2 + + # limit -> filter: LIMIT before HAVING in user code. + result6 = agg_df.limit(2).filter(col("headcount") > 1) + assert result6.count() <= 2 + + # filter -> limit -> sort (must swap sort and limit) + result7 = ( + agg_df.filter(col("headcount") > 1) + .limit(2) + .orderBy(col("avg_salary").desc()) + ) + Utils.check_answer( + result7, + [ + Row("engineering", 3, 85000.0), + Row("sales", 2, 55000.0), + ], + ) + + # select/distinct between sort and filter should break the + # _ops_after_agg chain, so the subsequent filter uses a regular + # WHERE via subquery rather than a flattened HAVING. + result8 = ( + agg_df.orderBy(col("avg_salary").desc()) + .select("dept", "headcount", "avg_salary") + .filter(col("headcount") > 1) + ) + Utils.check_answer(result8, expected_all) + + result9 = ( + agg_df.orderBy(col("avg_salary").desc()) + .distinct() + .filter(col("headcount") > 1) + ) + Utils.check_answer(result9, expected_all) + + # Repeated sort: first sort uses the _ops_after_agg append path, + # second sort falls back to the regular (subquery) path. + result10 = ( + agg_df.filter(col("headcount") > 1) + .orderBy(col("avg_salary").desc()) + .orderBy(col("headcount").asc()) + ) + Utils.check_answer( + result10, + [ + Row("sales", 2, 55000.0), + Row("hr", 2, 50000.0), + Row("engineering", 3, 85000.0), + ], + ) + + # Repeated filter: first filter uses the HAVING append path, + # second filter falls back to the regular (subquery WHERE) path. + result11 = ( + agg_df.orderBy(col("avg_salary").desc()) + .filter(col("headcount") > 1) + .filter(col("avg_salary") > 50000) + ) + Utils.check_answer( + result11, + [ + Row("engineering", 3, 85000.0), + Row("sales", 2, 55000.0), + ], + ) + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="exclude_grouping_columns is not supported", From 8cbcd4e760c2b851d11061d4d21bfd9f74c94a81 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Fri, 27 Mar 2026 13:37:56 -0700 Subject: [PATCH 2/3] update nesting/limit breaking behavior --- src/snowflake/snowpark/dataframe.py | 67 ++++--- tests/integ/test_df_aggregate.py | 269 +++++++++++++++++++++------- 2 files changed, 245 insertions(+), 91 deletions(-) diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index eff5b33a4a..741cfe916d 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -11,7 +11,7 @@ import sys from collections import Counter from decimal import Decimal -from functools import cached_property +from functools import cached_property, reduce from logging import getLogger from types import ModuleType from typing import ( @@ -54,6 +54,9 @@ create_join_type, ) from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark._internal.analyzer.binary_expression import ( + And, +) from snowflake.snowpark._internal.analyzer.expression import ( Attribute, Expression, @@ -695,9 +698,8 @@ def __init__( self._ops_after_agg = None self._agg_base_plan = None self._agg_base_select_statement = None - self._pending_having = None - self._pending_order_by = None - self._pending_limit = None + self._pending_havings = [] + self._pending_order_bys = [] # Whether all columns are VARIANT data type, # which support querying nested fields via dot notations @@ -2126,18 +2128,19 @@ def filter( # We defer the HAVING expression and rebuild the plan from the # aggregate base so that SQL clauses are emitted in the correct order # (HAVING -> ORDER BY -> LIMIT) regardless of the user's call order. + # If there is a LIMIT earlier in the expression tree, then we must produce a new + # sub-query from this filter to ensure correctness. if ( context._is_snowpark_connect_compatible_mode and self._ops_after_agg is not None - and "filter" not in self._ops_after_agg + and "limit" not in self._ops_after_agg ): new_ops = self._ops_after_agg.copy() new_ops.add("filter") return self._build_post_agg_df( ops_after_agg=new_ops, - pending_having=filter_col_expr, - pending_order_by=self._pending_order_by, - pending_limit=self._pending_limit, + pending_havings=self._pending_havings + [filter_col_expr], + pending_order_bys=self._pending_order_bys, _ast_stmt=stmt, ) else: @@ -2337,18 +2340,20 @@ def sort( # the sorting for dataframe after aggregation without nesting. # We defer the ORDER BY expressions and rebuild the plan from # the aggregate base in correct SQL clause order. + # If there is a LIMIT earlier in the expression tree, then we must produce a new + # sub-query from this filter to ensure correctness. if ( context._is_snowpark_connect_compatible_mode and self._ops_after_agg is not None - and "sort" not in self._ops_after_agg + and "limit" not in self._ops_after_agg ): new_ops = self._ops_after_agg.copy() new_ops.add("sort") return self._build_post_agg_df( ops_after_agg=new_ops, - pending_having=self._pending_having, - pending_order_by=sort_exprs, - pending_limit=self._pending_limit, + pending_havings=self._pending_havings, + # New ordering clauses must be placed before previously-declared ones + pending_order_bys=sort_exprs + self._pending_order_bys, _ast_stmt=stmt, ) else: @@ -3068,9 +3073,9 @@ def limit( new_ops.add("limit") return self._build_post_agg_df( ops_after_agg=new_ops, - pending_having=self._pending_having, - pending_order_by=self._pending_order_by, - pending_limit=(n, offset), + pending_havings=self._pending_havings, + pending_order_bys=self._pending_order_bys, + limit_parameters=(n, offset), _ast_stmt=stmt, ) else: @@ -6829,10 +6834,10 @@ def dtypes(self) -> List[Tuple[str, str]]: def _build_post_agg_df( self, - ops_after_agg, - pending_having, - pending_order_by, - pending_limit, + ops_after_agg: set[str], + pending_havings: list[Expression], + pending_order_bys: list[Expression], + limit_parameters: Optional[tuple[int, int]] = None, _ast_stmt=None, ) -> "DataFrame": """ @@ -6857,12 +6862,19 @@ def _build_post_agg_df( """ current = self._agg_base_plan - if pending_having is not None: - current = Filter(pending_having, current, is_having=True) - if pending_order_by is not None: - current = Sort(pending_order_by, current, is_order_by_append=True) - if pending_limit is not None: - n, offset = pending_limit + if len(pending_havings) > 0: + current = Filter( + reduce( + lambda acc, expr: And(acc, expr), + pending_havings, + ), + current, + is_having=True, + ) + if len(pending_order_bys) > 0: + current = Sort(pending_order_bys, current, is_order_by_append=True) + if limit_parameters is not None: + n, offset = limit_parameters current = Limit(Literal(n), Literal(offset), current, is_limit_append=True) if self._agg_base_select_statement is not None: @@ -6879,9 +6891,8 @@ def _build_post_agg_df( df._ops_after_agg = ops_after_agg df._agg_base_plan = self._agg_base_plan df._agg_base_select_statement = self._agg_base_select_statement - df._pending_having = pending_having - df._pending_order_by = pending_order_by - df._pending_limit = pending_limit + df._pending_havings = pending_havings + df._pending_order_bys = pending_order_bys return df def _with_plan(self, plan, _ast_stmt=None) -> "DataFrame": diff --git a/tests/integ/test_df_aggregate.py b/tests/integ/test_df_aggregate.py index 8df22275d8..6bd2e8ee8d 100644 --- a/tests/integ/test_df_aggregate.py +++ b/tests/integ/test_df_aggregate.py @@ -930,9 +930,9 @@ def test_filter_sort_limit_snowpark_connect_compatible(session, sql_simplifier_e @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", - reason="exclude_grouping_columns is not supported", + reason="HAVING, ORDER BY append, and limit append are not supported in local testing mode", ) -def test_group_by_agg_sort_filter(session): +def test_group_by_agg_sort_filter_sanity(session): """ Tests that post-aggregation clauses (HAVING, ORDER BY, LIMIT) are emitted in valid SQL order regardless of the DataFrame call order (see SNOW-3266495). @@ -958,21 +958,59 @@ def test_group_by_agg_sort_filter(session): count("*").alias("headcount"), avg("salary").alias("avg_salary"), ) + # Checking against exact query text structure is less than ideal, but these tests need to + # verify the level of nesting of certain sub-queries, making it a necessary evil. + agg_query_base = """ + SELECT "DEPT", count(1) AS "HEADCOUNT", avg("SALARY") AS "AVG_SALARY" + FROM ( + SELECT "ID", "DEPT", "SALARY" FROM ( + SELECT $1 AS "ID", $2 AS "DEPT", $3 AS "SALARY" FROM VALUES + (1 :: INT, 'engineering' :: STRING, 80000 :: INT), + (2 :: INT, 'engineering' :: STRING, 90000 :: INT), + (3 :: INT, 'sales' :: STRING, 50000 :: INT), + (4 :: INT, 'sales' :: STRING, 60000 :: INT), + (5 :: INT, 'hr' :: STRING, 45000 :: INT), + (6 :: INT, 'hr' :: STRING, 55000 :: INT), + (7 :: INT, 'engineering' :: STRING, 85000 :: INT) + ) + ) + GROUP BY "DEPT" + """ + + def check_agg_sql(df, expected_sql): + assert Utils.normalize_sql(df.queries["queries"][0]) == Utils.normalize_sql( + expected_sql + ) - expected_all = [ + base_expected_result = [ Row("engineering", 3, 85000.0), Row("sales", 2, 55000.0), Row("hr", 2, 50000.0), ] - # sort -> filter: ORDER BY before HAVING in user code, but - # SQL must be HAVING before ORDER BY. + # sort -> filter: ORDER BY before HAVING in user code, but SQL must be HAVING before ORDER BY. result1 = agg_df.orderBy(col("avg_salary").desc()).filter(col("headcount") > 1) - Utils.check_answer(result1, expected_all) + Utils.check_answer(result1, base_expected_result) + check_agg_sql( + result1, + f""" + {agg_query_base} + HAVING ("HEADCOUNT" > 1) + ORDER BY "AVG_SALARY" DESC NULLS LAST + """, + ) # filter -> sort: already in correct SQL clause order. result2 = agg_df.filter(col("headcount") > 1).orderBy(col("avg_salary").desc()) - Utils.check_answer(result2, expected_all) + Utils.check_answer(result2, base_expected_result) + check_agg_sql( + result2, + f""" + {agg_query_base} + HAVING ("HEADCOUNT" > 1) + ORDER BY "AVG_SALARY" DESC NULLS LAST + """, + ) # sort -> filter -> limit (must swap filter with sort) result3 = ( @@ -987,91 +1025,196 @@ def test_group_by_agg_sort_filter(session): Row("sales", 2, 55000.0), ], ) - - # sort -> limit -> filter: all three clauses in wrong order - # (must move filter to first operation) - result4 = ( - agg_df.orderBy(col("avg_salary").desc()) - .limit(2) - .filter(col("headcount") > 1) - ) - Utils.check_answer( - result4, - [ - Row("engineering", 3, 85000.0), - Row("sales", 2, 55000.0), - ], - ) - - # limit -> sort: LIMIT before ORDER BY in user code. - result5 = agg_df.limit(2).orderBy(col("avg_salary").desc()) - assert result5.count() == 2 - - # limit -> filter: LIMIT before HAVING in user code. - result6 = agg_df.limit(2).filter(col("headcount") > 1) - assert result6.count() <= 2 - - # filter -> limit -> sort (must swap sort and limit) - result7 = ( - agg_df.filter(col("headcount") > 1) - .limit(2) - .orderBy(col("avg_salary").desc()) - ) - Utils.check_answer( - result7, - [ - Row("engineering", 3, 85000.0), - Row("sales", 2, 55000.0), - ], + check_agg_sql( + result3, + f""" + {agg_query_base} + HAVING ("HEADCOUNT" > 1) + ORDER BY "AVG_SALARY" DESC NULLS LAST + LIMIT 2 OFFSET 0 + """, ) - # select/distinct between sort and filter should break the + # A new select between sort and filter should break the # _ops_after_agg chain, so the subsequent filter uses a regular # WHERE via subquery rather than a flattened HAVING. - result8 = ( + result4 = ( agg_df.orderBy(col("avg_salary").desc()) .select("dept", "headcount", "avg_salary") .filter(col("headcount") > 1) ) - Utils.check_answer(result8, expected_all) - - result9 = ( - agg_df.orderBy(col("avg_salary").desc()) - .distinct() - .filter(col("headcount") > 1) + Utils.check_answer(result4, base_expected_result) + check_agg_sql( + result4, + ( + f""" + SELECT "DEPT", "HEADCOUNT", "AVG_SALARY" + FROM ( + {agg_query_base} + ORDER BY "AVG_SALARY" DESC NULLS LAST + ) + WHERE ("HEADCOUNT" > 1) + """ + if session.sql_simplifier_enabled + else f""" + SELECT * FROM ( + SELECT "DEPT", "HEADCOUNT", "AVG_SALARY" + FROM ( + {agg_query_base} + ORDER BY "AVG_SALARY" DESC NULLS LAST + ) + ) + WHERE ("HEADCOUNT" > 1) + """ + ), ) - Utils.check_answer(result9, expected_all) - # Repeated sort: first sort uses the _ops_after_agg append path, - # second sort falls back to the regular (subquery) path. - result10 = ( - agg_df.filter(col("headcount") > 1) + # Repeated sort: all clauses are placed in a single ORDER BY clause, in the reverse + # order of their declaration. Note that referencing the same column multiple times is valid. + result5 = ( + agg_df.orderBy(col("avg_salary").asc()) + .filter(col("headcount") > 1) .orderBy(col("avg_salary").desc()) .orderBy(col("headcount").asc()) ) Utils.check_answer( - result10, + result5, [ Row("sales", 2, 55000.0), Row("hr", 2, 50000.0), Row("engineering", 3, 85000.0), ], ) + check_agg_sql( + result5, + f""" + {agg_query_base} + HAVING ("HEADCOUNT" > 1) + ORDER BY "HEADCOUNT" ASC NULLS FIRST, + "AVG_SALARY" DESC NULLS LAST, + "AVG_SALARY" ASC NULLS FIRST + """, + ) - # Repeated filter: first filter uses the HAVING append path, - # second filter falls back to the regular (subquery WHERE) path. - result11 = ( - agg_df.orderBy(col("avg_salary").desc()) - .filter(col("headcount") > 1) + # Repeated filter: each clause is ANDed together. + result6 = ( + agg_df.filter(col("headcount") > 1) + .orderBy(col("avg_salary").desc()) .filter(col("avg_salary") > 50000) ) Utils.check_answer( - result11, + result6, [ Row("engineering", 3, 85000.0), Row("sales", 2, 55000.0), ], ) + check_agg_sql( + result6, + f""" + {agg_query_base} + HAVING (("HEADCOUNT" > 1) AND ("AVG_SALARY" > 50000)) + ORDER BY "AVG_SALARY" DESC NULLS LAST + """, + ) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="HAVING, ORDER BY append, and limit append are not supported in local testing mode", +) +def test_group_by_agg_sort_filter_limit_ordering(session): + """ + Tests that aggregations involving HAVING and ORDER BY/LIMIT apply operations in the correct order. + + ORDER BY -> LIMIT -> FILTER and FILTER -> ORDER BY -> LIMIT do not generally commute. + """ + df = session.createDataFrame( + [ + (1, "engineering", 80000), + (2, "engineering", 90000), + (3, "sales", 50000), + (4, "sales", 60000), + (5, "hr", 45000), + (6, "hr", 55000), + (7, "engineering", 85000), + (8, "research", 90000), + (9, "research", 100000), + (10, "research", 140000), + (11, "AAA", 130000), + ], + ["id", "dept", "salary"], + ) + agg_df = df.groupBy("dept").agg( + count("*").alias("headcount"), + avg("salary").alias("avg_salary"), + ) + + # 1. ORDER BY -> LIMIT -> FILTER + # The ordering drops the AAA group before the filter occurs. + result1 = ( + agg_df.orderBy(col("avg_salary").asc()) + .limit(3) + .filter(col("avg_salary") > 54000) + ) + Utils.check_answer( + result1, + [ + Row("sales", 2, 55000.0), + Row("engineering", 3, 85000.0), + ], + ) + + # 2. FILTER -> ORDER BY -> LIMIT + # This is different from case (1), as filtering occurs before ordering/limiting. + result2 = ( + agg_df.filter(col("avg_salary") > 54000) + .orderBy(col("avg_salary").asc()) + .limit(3) + ) + Utils.check_answer( + result2, + [ + Row("sales", 2, 55000.0), + Row("engineering", 3, 85000.0), + Row("research", 3, 110000.0), + ], + ) + + # 3. FILTER -> ORDER BY -> LIMIT -> FILTER + # The ordering drops the RESEARCH group, so even though it should survive the final filter, it + # gets dropped from the final result. + result3 = ( + agg_df.filter(col("headcount") > 1) + .orderBy(col("avg_salary").asc()) + .limit(3) + .filter(col("avg_salary") > 54000) + ) + Utils.check_answer( + result3, + [ + Row("sales", 2, 55000.0), + Row("engineering", 3, 85000.0), + ], + ) + + # 4. FILTER -> ORDER BY -> LIMIT -> FILTER -> LIMIT + # The final limit is not necessarily deterministic, but does not commute with the prior LIMIT. + result4 = result2.limit(1) + Utils.check_answer(result4, [Row("sales", 2, 55000.0)]) + + # 5. ORDER BY -> LIMIT -> ORDER BY -> LIMIT + # The RESEARCH group is dropped by the first ORDER BY + LIMIT. + # The SALES and HR groups are dropped by the second ORDER BY + LIMIT. + result5 = ( + agg_df.orderBy(col("headcount").asc()) + .limit(4) + .orderBy(col("avg_salary").desc()) + .limit(2) + ) + Utils.check_answer( + result5, [Row("AAA", 1, 130000.0), Row("engineering", 3, 85000.0)] + ) @pytest.mark.skipif( From 42588729bab0c7e74d18e852e8d54c3390d3c897 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Fri, 27 Mar 2026 13:53:51 -0700 Subject: [PATCH 3/3] update comments --- src/snowflake/snowpark/dataframe.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 741cfe916d..17eca6e5bf 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -3062,8 +3062,6 @@ def limit( # In snowpark_connect_compatible mode, we need to handle # the limit for dataframe after aggregation without nesting. - # We defer the LIMIT values and rebuild the plan from - # the aggregate base in correct SQL clause order. if ( context._is_snowpark_connect_compatible_mode and self._ops_after_agg is not None @@ -6850,14 +6848,17 @@ def _build_post_agg_df( avg("salary").alias("avg_salary"), ) .orderBy(col("avg_salary").desc()) - .limit(2) .filter(col("headcount") > 1) + .limit(2) Even though `orderBy` is the first operation, we must re-order the `filter` to be first because SQL syntax requires HAVING, ORDER BY, and LIMIT clauses to appear in that specific order. - We use `_agg_base_plan` and `_agg_base_select_statement` to re-construct SQL with this constraint. + Note that LIMIT itself does not commute with ORDER BY and FILTER, so if another FILTER or + ORDER BY appears after a LIMIT, we must generate a new sub-query. This invariant is enforced + when chaining new filter/order by operations. + This method should only be called in SCOS compatibility mode (context._is_snowpark_connect_compatible_mode). """ current = self._agg_base_plan