diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 3dfe045c84..17eca6e5bf 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -11,7 +11,7 @@ import sys from collections import Counter from decimal import Decimal -from functools import cached_property +from functools import cached_property, reduce from logging import getLogger from types import ModuleType from typing import ( @@ -54,6 +54,9 @@ create_join_type, ) from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark._internal.analyzer.binary_expression import ( + And, +) from snowflake.snowpark._internal.analyzer.expression import ( Attribute, Expression, @@ -689,7 +692,14 @@ def __init__( self._statement_params = None self.is_cached: bool = is_cached #: Whether the dataframe is cached. + # Internal state variables used to construct flattened GROUP BY clauses in the correct order + # in SCOS compatibility mode. + # See comments on `_build_post_agg_df` for details. self._ops_after_agg = None + self._agg_base_plan = None + self._agg_base_select_statement = None + self._pending_havings = [] + self._pending_order_bys = [] # Whether all columns are VARIANT data type, # which support querying nested fields via dot notations @@ -2114,28 +2124,25 @@ def filter( stmt = _ast_stmt # In snowpark_connect_compatible mode, we need to handle - # the filtering for dataframe after aggregation without nesting using HAVING + # the filtering for dataframe after aggregation without nesting using HAVING. + # We defer the HAVING expression and rebuild the plan from the + # aggregate base so that SQL clauses are emitted in the correct order + # (HAVING -> ORDER BY -> LIMIT) regardless of the user's call order. + # If there is a LIMIT earlier in the expression tree, then we must produce a new + # sub-query from this filter to ensure correctness. if ( context._is_snowpark_connect_compatible_mode and self._ops_after_agg is not None - and "filter" not in self._ops_after_agg + and "limit" not in self._ops_after_agg ): - having_plan = Filter(filter_col_expr, self._plan, is_having=True) - if self._select_statement: - df = self._with_plan( - self._session._analyzer.create_select_statement( - from_=self._session._analyzer.create_select_snowflake_plan( - having_plan, analyzer=self._session._analyzer - ), - analyzer=self._session._analyzer, - ), - _ast_stmt=stmt, - ) - else: - df = self._with_plan(having_plan, _ast_stmt=stmt) - df._ops_after_agg = self._ops_after_agg.copy() - df._ops_after_agg.add("filter") - return df + new_ops = self._ops_after_agg.copy() + new_ops.add("filter") + return self._build_post_agg_df( + ops_after_agg=new_ops, + pending_havings=self._pending_havings + [filter_col_expr], + pending_order_bys=self._pending_order_bys, + _ast_stmt=stmt, + ) else: if self._select_statement: return self._with_plan( @@ -2330,28 +2337,25 @@ def sort( ) # In snowpark_connect_compatible mode, we need to handle - # the sorting for dataframe after aggregation without nesting + # the sorting for dataframe after aggregation without nesting. + # We defer the ORDER BY expressions and rebuild the plan from + # the aggregate base in correct SQL clause order. + # If there is a LIMIT earlier in the expression tree, then we must produce a new + # sub-query from this filter to ensure correctness. if ( context._is_snowpark_connect_compatible_mode and self._ops_after_agg is not None - and "sort" not in self._ops_after_agg + and "limit" not in self._ops_after_agg ): - sort_plan = Sort(sort_exprs, self._plan, is_order_by_append=True) - if self._select_statement: - df = self._with_plan( - self._session._analyzer.create_select_statement( - from_=self._session._analyzer.create_select_snowflake_plan( - sort_plan, analyzer=self._session._analyzer - ), - analyzer=self._session._analyzer, - ), - _ast_stmt=stmt, - ) - else: - df = self._with_plan(sort_plan, _ast_stmt=stmt) - df._ops_after_agg = self._ops_after_agg.copy() - df._ops_after_agg.add("sort") - return df + new_ops = self._ops_after_agg.copy() + new_ops.add("sort") + return self._build_post_agg_df( + ops_after_agg=new_ops, + pending_havings=self._pending_havings, + # New ordering clauses must be placed before previously-declared ones + pending_order_bys=sort_exprs + self._pending_order_bys, + _ast_stmt=stmt, + ) else: df = ( self._with_plan(self._select_statement.sort(sort_exprs)) @@ -3057,30 +3061,21 @@ def limit( stmt = None # In snowpark_connect_compatible mode, we need to handle - # the limit for dataframe after aggregation without nesting + # the limit for dataframe after aggregation without nesting. if ( context._is_snowpark_connect_compatible_mode and self._ops_after_agg is not None and "limit" not in self._ops_after_agg ): - limit_plan = Limit( - Literal(n), Literal(offset), self._plan, is_limit_append=True + new_ops = self._ops_after_agg.copy() + new_ops.add("limit") + return self._build_post_agg_df( + ops_after_agg=new_ops, + pending_havings=self._pending_havings, + pending_order_bys=self._pending_order_bys, + limit_parameters=(n, offset), + _ast_stmt=stmt, ) - if self._select_statement: - df = self._with_plan( - self._session._analyzer.create_select_statement( - from_=self._session._analyzer.create_select_snowflake_plan( - limit_plan, analyzer=self._session._analyzer - ), - analyzer=self._session._analyzer, - ), - _ast_stmt=stmt, - ) - else: - df = self._with_plan(limit_plan, _ast_stmt=stmt) - df._ops_after_agg = self._ops_after_agg.copy() - df._ops_after_agg.add("limit") - return df else: if self._select_statement: return self._with_plan( @@ -6835,6 +6830,72 @@ def dtypes(self) -> List[Tuple[str, str]]: ] return dtypes + def _build_post_agg_df( + self, + ops_after_agg: set[str], + pending_havings: list[Expression], + pending_order_bys: list[Expression], + limit_parameters: Optional[tuple[int, int]] = None, + _ast_stmt=None, + ) -> "DataFrame": + """ + When constructing group by aggregation queries in SCOS compatibility mode, we must ensure that + filter (HAVING), sorting (ORDER BY), and LIMIT clauses are emitted in the correct order, regardless of + the order in which the user specified those operations. For example: + + df.groupBy("dept").agg( + count("*").alias("headcount"), + avg("salary").alias("avg_salary"), + ) + .orderBy(col("avg_salary").desc()) + .filter(col("headcount") > 1) + .limit(2) + + Even though `orderBy` is the first operation, we must re-order the `filter` to be first because + SQL syntax requires HAVING, ORDER BY, and LIMIT clauses to appear in that specific order. + We use `_agg_base_plan` and `_agg_base_select_statement` to re-construct SQL with this constraint. + + Note that LIMIT itself does not commute with ORDER BY and FILTER, so if another FILTER or + ORDER BY appears after a LIMIT, we must generate a new sub-query. This invariant is enforced + when chaining new filter/order by operations. + + This method should only be called in SCOS compatibility mode (context._is_snowpark_connect_compatible_mode). + """ + current = self._agg_base_plan + + if len(pending_havings) > 0: + current = Filter( + reduce( + lambda acc, expr: And(acc, expr), + pending_havings, + ), + current, + is_having=True, + ) + if len(pending_order_bys) > 0: + current = Sort(pending_order_bys, current, is_order_by_append=True) + if limit_parameters is not None: + n, offset = limit_parameters + current = Limit(Literal(n), Literal(offset), current, is_limit_append=True) + + if self._agg_base_select_statement is not None: + new_plan = self._session._analyzer.create_select_statement( + from_=self._session._analyzer.create_select_snowflake_plan( + current, analyzer=self._session._analyzer + ), + analyzer=self._session._analyzer, + ) + else: + new_plan = current + + df = self._with_plan(new_plan, _ast_stmt=_ast_stmt) + df._ops_after_agg = ops_after_agg + df._agg_base_plan = self._agg_base_plan + df._agg_base_select_statement = self._agg_base_select_statement + df._pending_havings = pending_havings + df._pending_order_bys = pending_order_bys + return df + def _with_plan(self, plan, _ast_stmt=None) -> "DataFrame": """ :param proto.Bind ast_stmt: The AST statement protobuf corresponding to this value. diff --git a/src/snowflake/snowpark/relational_grouped_dataframe.py b/src/snowflake/snowpark/relational_grouped_dataframe.py index 2f29726e4e..03a6201d3e 100644 --- a/src/snowflake/snowpark/relational_grouped_dataframe.py +++ b/src/snowflake/snowpark/relational_grouped_dataframe.py @@ -338,6 +338,8 @@ def agg( # if no grouping exprs, there is already a LIMIT 1 in the query # see aggregate_statement in analyzer_utils.py df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + df._agg_base_plan = df._plan + df._agg_base_select_statement = df._select_statement if _emit_ast: df._ast_id = stmt.uid @@ -531,6 +533,8 @@ def end_partition( # if no grouping exprs, there is already a LIMIT 1 in the query # see aggregate_statement in analyzer_utils.py df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + df._agg_base_plan = df._plan + df._agg_base_select_statement = df._select_statement if _emit_ast: stmt = working_dataframe._session._ast_batch.bind() @@ -766,6 +770,8 @@ def count(self, _emit_ast: bool = True, **kwargs) -> DataFrame: # if no grouping exprs, there is already a LIMIT 1 in the query # see aggregate_statement in analyzer_utils.py df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + df._agg_base_plan = df._plan + df._agg_base_select_statement = df._select_statement # TODO: count seems similar to mean, min, .... Can we unify implementation here? if _emit_ast: @@ -815,6 +821,8 @@ def _function( # if no grouping exprs, there is already a LIMIT 1 in the query # see aggregate_statement in analyzer_utils.py df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + df._agg_base_plan = df._plan + df._agg_base_select_statement = df._select_statement if _emit_ast: stmt = self._dataframe._session._ast_batch.bind() @@ -909,6 +917,8 @@ def ai_agg( # if no grouping exprs, there is already a LIMIT 1 in the query # see aggregate_statement in analyzer_utils.py df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + df._agg_base_plan = df._plan + df._agg_base_select_statement = df._select_statement if _emit_ast: stmt = self._dataframe._session._ast_batch.bind() diff --git a/tests/integ/test_df_aggregate.py b/tests/integ/test_df_aggregate.py index 9a8795fb49..6bd2e8ee8d 100644 --- a/tests/integ/test_df_aggregate.py +++ b/tests/integ/test_df_aggregate.py @@ -928,6 +928,295 @@ def test_filter_sort_limit_snowpark_connect_compatible(session, sql_simplifier_e assert query6.upper().count("SELECT") == 4 if sql_simplifier_enabled else 5 +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="HAVING, ORDER BY append, and limit append are not supported in local testing mode", +) +def test_group_by_agg_sort_filter_sanity(session): + """ + Tests that post-aggregation clauses (HAVING, ORDER BY, LIMIT) are emitted in valid SQL order + regardless of the DataFrame call order (see SNOW-3266495). + + After a GROUP BY, HAVING must appear before ORDER BY, which in turn must appear before LIMIT. + """ + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + df = session.createDataFrame( + [ + (1, "engineering", 80000), + (2, "engineering", 90000), + (3, "sales", 50000), + (4, "sales", 60000), + (5, "hr", 45000), + (6, "hr", 55000), + (7, "engineering", 85000), + ], + ["id", "dept", "salary"], + ) + agg_df = df.groupBy("dept").agg( + count("*").alias("headcount"), + avg("salary").alias("avg_salary"), + ) + # Checking against exact query text structure is less than ideal, but these tests need to + # verify the level of nesting of certain sub-queries, making it a necessary evil. + agg_query_base = """ + SELECT "DEPT", count(1) AS "HEADCOUNT", avg("SALARY") AS "AVG_SALARY" + FROM ( + SELECT "ID", "DEPT", "SALARY" FROM ( + SELECT $1 AS "ID", $2 AS "DEPT", $3 AS "SALARY" FROM VALUES + (1 :: INT, 'engineering' :: STRING, 80000 :: INT), + (2 :: INT, 'engineering' :: STRING, 90000 :: INT), + (3 :: INT, 'sales' :: STRING, 50000 :: INT), + (4 :: INT, 'sales' :: STRING, 60000 :: INT), + (5 :: INT, 'hr' :: STRING, 45000 :: INT), + (6 :: INT, 'hr' :: STRING, 55000 :: INT), + (7 :: INT, 'engineering' :: STRING, 85000 :: INT) + ) + ) + GROUP BY "DEPT" + """ + + def check_agg_sql(df, expected_sql): + assert Utils.normalize_sql(df.queries["queries"][0]) == Utils.normalize_sql( + expected_sql + ) + + base_expected_result = [ + Row("engineering", 3, 85000.0), + Row("sales", 2, 55000.0), + Row("hr", 2, 50000.0), + ] + + # sort -> filter: ORDER BY before HAVING in user code, but SQL must be HAVING before ORDER BY. + result1 = agg_df.orderBy(col("avg_salary").desc()).filter(col("headcount") > 1) + Utils.check_answer(result1, base_expected_result) + check_agg_sql( + result1, + f""" + {agg_query_base} + HAVING ("HEADCOUNT" > 1) + ORDER BY "AVG_SALARY" DESC NULLS LAST + """, + ) + + # filter -> sort: already in correct SQL clause order. + result2 = agg_df.filter(col("headcount") > 1).orderBy(col("avg_salary").desc()) + Utils.check_answer(result2, base_expected_result) + check_agg_sql( + result2, + f""" + {agg_query_base} + HAVING ("HEADCOUNT" > 1) + ORDER BY "AVG_SALARY" DESC NULLS LAST + """, + ) + + # sort -> filter -> limit (must swap filter with sort) + result3 = ( + agg_df.orderBy(col("avg_salary").desc()) + .filter(col("headcount") > 1) + .limit(2) + ) + Utils.check_answer( + result3, + [ + Row("engineering", 3, 85000.0), + Row("sales", 2, 55000.0), + ], + ) + check_agg_sql( + result3, + f""" + {agg_query_base} + HAVING ("HEADCOUNT" > 1) + ORDER BY "AVG_SALARY" DESC NULLS LAST + LIMIT 2 OFFSET 0 + """, + ) + + # A new select between sort and filter should break the + # _ops_after_agg chain, so the subsequent filter uses a regular + # WHERE via subquery rather than a flattened HAVING. + result4 = ( + agg_df.orderBy(col("avg_salary").desc()) + .select("dept", "headcount", "avg_salary") + .filter(col("headcount") > 1) + ) + Utils.check_answer(result4, base_expected_result) + check_agg_sql( + result4, + ( + f""" + SELECT "DEPT", "HEADCOUNT", "AVG_SALARY" + FROM ( + {agg_query_base} + ORDER BY "AVG_SALARY" DESC NULLS LAST + ) + WHERE ("HEADCOUNT" > 1) + """ + if session.sql_simplifier_enabled + else f""" + SELECT * FROM ( + SELECT "DEPT", "HEADCOUNT", "AVG_SALARY" + FROM ( + {agg_query_base} + ORDER BY "AVG_SALARY" DESC NULLS LAST + ) + ) + WHERE ("HEADCOUNT" > 1) + """ + ), + ) + + # Repeated sort: all clauses are placed in a single ORDER BY clause, in the reverse + # order of their declaration. Note that referencing the same column multiple times is valid. + result5 = ( + agg_df.orderBy(col("avg_salary").asc()) + .filter(col("headcount") > 1) + .orderBy(col("avg_salary").desc()) + .orderBy(col("headcount").asc()) + ) + Utils.check_answer( + result5, + [ + Row("sales", 2, 55000.0), + Row("hr", 2, 50000.0), + Row("engineering", 3, 85000.0), + ], + ) + check_agg_sql( + result5, + f""" + {agg_query_base} + HAVING ("HEADCOUNT" > 1) + ORDER BY "HEADCOUNT" ASC NULLS FIRST, + "AVG_SALARY" DESC NULLS LAST, + "AVG_SALARY" ASC NULLS FIRST + """, + ) + + # Repeated filter: each clause is ANDed together. + result6 = ( + agg_df.filter(col("headcount") > 1) + .orderBy(col("avg_salary").desc()) + .filter(col("avg_salary") > 50000) + ) + Utils.check_answer( + result6, + [ + Row("engineering", 3, 85000.0), + Row("sales", 2, 55000.0), + ], + ) + check_agg_sql( + result6, + f""" + {agg_query_base} + HAVING (("HEADCOUNT" > 1) AND ("AVG_SALARY" > 50000)) + ORDER BY "AVG_SALARY" DESC NULLS LAST + """, + ) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="HAVING, ORDER BY append, and limit append are not supported in local testing mode", +) +def test_group_by_agg_sort_filter_limit_ordering(session): + """ + Tests that aggregations involving HAVING and ORDER BY/LIMIT apply operations in the correct order. + + ORDER BY -> LIMIT -> FILTER and FILTER -> ORDER BY -> LIMIT do not generally commute. + """ + df = session.createDataFrame( + [ + (1, "engineering", 80000), + (2, "engineering", 90000), + (3, "sales", 50000), + (4, "sales", 60000), + (5, "hr", 45000), + (6, "hr", 55000), + (7, "engineering", 85000), + (8, "research", 90000), + (9, "research", 100000), + (10, "research", 140000), + (11, "AAA", 130000), + ], + ["id", "dept", "salary"], + ) + agg_df = df.groupBy("dept").agg( + count("*").alias("headcount"), + avg("salary").alias("avg_salary"), + ) + + # 1. ORDER BY -> LIMIT -> FILTER + # The ordering drops the AAA group before the filter occurs. + result1 = ( + agg_df.orderBy(col("avg_salary").asc()) + .limit(3) + .filter(col("avg_salary") > 54000) + ) + Utils.check_answer( + result1, + [ + Row("sales", 2, 55000.0), + Row("engineering", 3, 85000.0), + ], + ) + + # 2. FILTER -> ORDER BY -> LIMIT + # This is different from case (1), as filtering occurs before ordering/limiting. + result2 = ( + agg_df.filter(col("avg_salary") > 54000) + .orderBy(col("avg_salary").asc()) + .limit(3) + ) + Utils.check_answer( + result2, + [ + Row("sales", 2, 55000.0), + Row("engineering", 3, 85000.0), + Row("research", 3, 110000.0), + ], + ) + + # 3. FILTER -> ORDER BY -> LIMIT -> FILTER + # The ordering drops the RESEARCH group, so even though it should survive the final filter, it + # gets dropped from the final result. + result3 = ( + agg_df.filter(col("headcount") > 1) + .orderBy(col("avg_salary").asc()) + .limit(3) + .filter(col("avg_salary") > 54000) + ) + Utils.check_answer( + result3, + [ + Row("sales", 2, 55000.0), + Row("engineering", 3, 85000.0), + ], + ) + + # 4. FILTER -> ORDER BY -> LIMIT -> FILTER -> LIMIT + # The final limit is not necessarily deterministic, but does not commute with the prior LIMIT. + result4 = result2.limit(1) + Utils.check_answer(result4, [Row("sales", 2, 55000.0)]) + + # 5. ORDER BY -> LIMIT -> ORDER BY -> LIMIT + # The RESEARCH group is dropped by the first ORDER BY + LIMIT. + # The SALES and HR groups are dropped by the second ORDER BY + LIMIT. + result5 = ( + agg_df.orderBy(col("headcount").asc()) + .limit(4) + .orderBy(col("avg_salary").desc()) + .limit(2) + ) + Utils.check_answer( + result5, [Row("AAA", 1, 130000.0), Row("engineering", 3, 85000.0)] + ) + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="exclude_grouping_columns is not supported",