Skip to content

Commit d7aca89

Browse files
committed
defer having/sort/limit construction
1 parent f1d02bf commit d7aca89

File tree

3 files changed

+257
-52
lines changed

3 files changed

+257
-52
lines changed

src/snowflake/snowpark/dataframe.py

Lines changed: 101 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,15 @@ def __init__(
689689

690690
self._statement_params = None
691691
self.is_cached: bool = is_cached #: Whether the dataframe is cached.
692+
# Internal state variables used to construct flattened GROUP BY clauses in the correct order
693+
# in SCOS compatibility mode.
694+
# See comments on `_build_post_agg_df` for details.
692695
self._ops_after_agg = None
696+
self._agg_base_plan = None
697+
self._agg_base_select_statement = None
698+
self._pending_having = None
699+
self._pending_order_by = None
700+
self._pending_limit = None
693701

694702
# Whether all columns are VARIANT data type,
695703
# which support querying nested fields via dot notations
@@ -2114,28 +2122,24 @@ def filter(
21142122
stmt = _ast_stmt
21152123

21162124
# In snowpark_connect_compatible mode, we need to handle
2117-
# the filtering for dataframe after aggregation without nesting using HAVING
2125+
# the filtering for dataframe after aggregation without nesting using HAVING.
2126+
# We defer the HAVING expression and rebuild the plan from the
2127+
# aggregate base so that SQL clauses are emitted in the correct order
2128+
# (HAVING -> ORDER BY -> LIMIT) regardless of the user's call order.
21182129
if (
21192130
context._is_snowpark_connect_compatible_mode
21202131
and self._ops_after_agg is not None
21212132
and "filter" not in self._ops_after_agg
21222133
):
2123-
having_plan = Filter(filter_col_expr, self._plan, is_having=True)
2124-
if self._select_statement:
2125-
df = self._with_plan(
2126-
self._session._analyzer.create_select_statement(
2127-
from_=self._session._analyzer.create_select_snowflake_plan(
2128-
having_plan, analyzer=self._session._analyzer
2129-
),
2130-
analyzer=self._session._analyzer,
2131-
),
2132-
_ast_stmt=stmt,
2133-
)
2134-
else:
2135-
df = self._with_plan(having_plan, _ast_stmt=stmt)
2136-
df._ops_after_agg = self._ops_after_agg.copy()
2137-
df._ops_after_agg.add("filter")
2138-
return df
2134+
new_ops = self._ops_after_agg.copy()
2135+
new_ops.add("filter")
2136+
return self._build_post_agg_df(
2137+
ops_after_agg=new_ops,
2138+
pending_having=filter_col_expr,
2139+
pending_order_by=self._pending_order_by,
2140+
pending_limit=self._pending_limit,
2141+
_ast_stmt=stmt,
2142+
)
21392143
else:
21402144
if self._select_statement:
21412145
return self._with_plan(
@@ -2330,28 +2334,23 @@ def sort(
23302334
)
23312335

23322336
# In snowpark_connect_compatible mode, we need to handle
2333-
# the sorting for dataframe after aggregation without nesting
2337+
# the sorting for dataframe after aggregation without nesting.
2338+
# We defer the ORDER BY expressions and rebuild the plan from
2339+
# the aggregate base in correct SQL clause order.
23342340
if (
23352341
context._is_snowpark_connect_compatible_mode
23362342
and self._ops_after_agg is not None
23372343
and "sort" not in self._ops_after_agg
23382344
):
2339-
sort_plan = Sort(sort_exprs, self._plan, is_order_by_append=True)
2340-
if self._select_statement:
2341-
df = self._with_plan(
2342-
self._session._analyzer.create_select_statement(
2343-
from_=self._session._analyzer.create_select_snowflake_plan(
2344-
sort_plan, analyzer=self._session._analyzer
2345-
),
2346-
analyzer=self._session._analyzer,
2347-
),
2348-
_ast_stmt=stmt,
2349-
)
2350-
else:
2351-
df = self._with_plan(sort_plan, _ast_stmt=stmt)
2352-
df._ops_after_agg = self._ops_after_agg.copy()
2353-
df._ops_after_agg.add("sort")
2354-
return df
2345+
new_ops = self._ops_after_agg.copy()
2346+
new_ops.add("sort")
2347+
return self._build_post_agg_df(
2348+
ops_after_agg=new_ops,
2349+
pending_having=self._pending_having,
2350+
pending_order_by=sort_exprs,
2351+
pending_limit=self._pending_limit,
2352+
_ast_stmt=stmt,
2353+
)
23552354
else:
23562355
df = (
23572356
self._with_plan(self._select_statement.sort(sort_exprs))
@@ -3057,30 +3056,23 @@ def limit(
30573056
stmt = None
30583057

30593058
# In snowpark_connect_compatible mode, we need to handle
3060-
# the limit for dataframe after aggregation without nesting
3059+
# the limit for dataframe after aggregation without nesting.
3060+
# We defer the LIMIT values and rebuild the plan from
3061+
# the aggregate base in correct SQL clause order.
30613062
if (
30623063
context._is_snowpark_connect_compatible_mode
30633064
and self._ops_after_agg is not None
30643065
and "limit" not in self._ops_after_agg
30653066
):
3066-
limit_plan = Limit(
3067-
Literal(n), Literal(offset), self._plan, is_limit_append=True
3067+
new_ops = self._ops_after_agg.copy()
3068+
new_ops.add("limit")
3069+
return self._build_post_agg_df(
3070+
ops_after_agg=new_ops,
3071+
pending_having=self._pending_having,
3072+
pending_order_by=self._pending_order_by,
3073+
pending_limit=(n, offset),
3074+
_ast_stmt=stmt,
30683075
)
3069-
if self._select_statement:
3070-
df = self._with_plan(
3071-
self._session._analyzer.create_select_statement(
3072-
from_=self._session._analyzer.create_select_snowflake_plan(
3073-
limit_plan, analyzer=self._session._analyzer
3074-
),
3075-
analyzer=self._session._analyzer,
3076-
),
3077-
_ast_stmt=stmt,
3078-
)
3079-
else:
3080-
df = self._with_plan(limit_plan, _ast_stmt=stmt)
3081-
df._ops_after_agg = self._ops_after_agg.copy()
3082-
df._ops_after_agg.add("limit")
3083-
return df
30843076
else:
30853077
if self._select_statement:
30863078
return self._with_plan(
@@ -6835,6 +6827,63 @@ def dtypes(self) -> List[Tuple[str, str]]:
68356827
]
68366828
return dtypes
68376829

6830+
def _build_post_agg_df(
6831+
self,
6832+
ops_after_agg,
6833+
pending_having,
6834+
pending_order_by,
6835+
pending_limit,
6836+
_ast_stmt=None,
6837+
) -> "DataFrame":
6838+
"""
6839+
When constructing group by aggregation queries in SCOS compatibility mode, we must ensure that
6840+
filter (HAVING), sorting (ORDER BY), and LIMIT clauses are emitted in the correct order, regardless of
6841+
the order in which the user specified those operations. For example:
6842+
6843+
df.groupBy("dept").agg(
6844+
count("*").alias("headcount"),
6845+
avg("salary").alias("avg_salary"),
6846+
)
6847+
.orderBy(col("avg_salary").desc())
6848+
.limit(2)
6849+
.filter(col("headcount") > 1)
6850+
6851+
Even though `orderBy` is the first operation, we must re-order the `filter` to be first because
6852+
SQL syntax requires HAVING, ORDER BY, and LIMIT clauses to appear in that specific order.
6853+
6854+
We use `_agg_base_plan` and `_agg_base_select_statement` to re-construct SQL with this constraint.
6855+
6856+
This method should only be called in SCOS compatibility mode (context._is_snowpark_connect_compatible_mode).
6857+
"""
6858+
current = self._agg_base_plan
6859+
6860+
if pending_having is not None:
6861+
current = Filter(pending_having, current, is_having=True)
6862+
if pending_order_by is not None:
6863+
current = Sort(pending_order_by, current, is_order_by_append=True)
6864+
if pending_limit is not None:
6865+
n, offset = pending_limit
6866+
current = Limit(Literal(n), Literal(offset), current, is_limit_append=True)
6867+
6868+
if self._agg_base_select_statement is not None:
6869+
new_plan = self._session._analyzer.create_select_statement(
6870+
from_=self._session._analyzer.create_select_snowflake_plan(
6871+
current, analyzer=self._session._analyzer
6872+
),
6873+
analyzer=self._session._analyzer,
6874+
)
6875+
else:
6876+
new_plan = current
6877+
6878+
df = self._with_plan(new_plan, _ast_stmt=_ast_stmt)
6879+
df._ops_after_agg = ops_after_agg
6880+
df._agg_base_plan = self._agg_base_plan
6881+
df._agg_base_select_statement = self._agg_base_select_statement
6882+
df._pending_having = pending_having
6883+
df._pending_order_by = pending_order_by
6884+
df._pending_limit = pending_limit
6885+
return df
6886+
68386887
def _with_plan(self, plan, _ast_stmt=None) -> "DataFrame":
68396888
"""
68406889
:param proto.Bind ast_stmt: The AST statement protobuf corresponding to this value.

src/snowflake/snowpark/relational_grouped_dataframe.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ def agg(
338338
# if no grouping exprs, there is already a LIMIT 1 in the query
339339
# see aggregate_statement in analyzer_utils.py
340340
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
341+
df._agg_base_plan = df._plan
342+
df._agg_base_select_statement = df._select_statement
341343

342344
if _emit_ast:
343345
df._ast_id = stmt.uid
@@ -531,6 +533,8 @@ def end_partition(
531533
# if no grouping exprs, there is already a LIMIT 1 in the query
532534
# see aggregate_statement in analyzer_utils.py
533535
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
536+
df._agg_base_plan = df._plan
537+
df._agg_base_select_statement = df._select_statement
534538

535539
if _emit_ast:
536540
stmt = working_dataframe._session._ast_batch.bind()
@@ -766,6 +770,8 @@ def count(self, _emit_ast: bool = True, **kwargs) -> DataFrame:
766770
# if no grouping exprs, there is already a LIMIT 1 in the query
767771
# see aggregate_statement in analyzer_utils.py
768772
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
773+
df._agg_base_plan = df._plan
774+
df._agg_base_select_statement = df._select_statement
769775

770776
# TODO: count seems similar to mean, min, .... Can we unify implementation here?
771777
if _emit_ast:
@@ -815,6 +821,8 @@ def _function(
815821
# if no grouping exprs, there is already a LIMIT 1 in the query
816822
# see aggregate_statement in analyzer_utils.py
817823
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
824+
df._agg_base_plan = df._plan
825+
df._agg_base_select_statement = df._select_statement
818826

819827
if _emit_ast:
820828
stmt = self._dataframe._session._ast_batch.bind()
@@ -909,6 +917,8 @@ def ai_agg(
909917
# if no grouping exprs, there is already a LIMIT 1 in the query
910918
# see aggregate_statement in analyzer_utils.py
911919
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
920+
df._agg_base_plan = df._plan
921+
df._agg_base_select_statement = df._select_statement
912922

913923
if _emit_ast:
914924
stmt = self._dataframe._session._ast_batch.bind()

0 commit comments

Comments
 (0)