Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 116 additions & 55 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
from collections import Counter
from decimal import Decimal
from functools import cached_property
from functools import cached_property, reduce
from logging import getLogger
from types import ModuleType
from typing import (
Expand Down Expand Up @@ -54,6 +54,9 @@
create_join_type,
)
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
from snowflake.snowpark._internal.analyzer.binary_expression import (
And,
)
from snowflake.snowpark._internal.analyzer.expression import (
Attribute,
Expression,
Expand Down Expand Up @@ -689,7 +692,14 @@ def __init__(

self._statement_params = None
self.is_cached: bool = is_cached #: Whether the dataframe is cached.
# Internal state variables used to construct flattened GROUP BY clauses in the correct order
# in SCOS compatibility mode.
# See comments on `_build_post_agg_df` for details.
self._ops_after_agg = None
self._agg_base_plan = None
self._agg_base_select_statement = None
self._pending_havings = []
self._pending_order_bys = []

# Whether all columns are VARIANT data type,
# which support querying nested fields via dot notations
Expand Down Expand Up @@ -2114,28 +2124,25 @@ def filter(
stmt = _ast_stmt

# In snowpark_connect_compatible mode, we need to handle
# the filtering for dataframe after aggregation without nesting using HAVING
# the filtering for dataframe after aggregation without nesting using HAVING.
# We defer the HAVING expression and rebuild the plan from the
# aggregate base so that SQL clauses are emitted in the correct order
# (HAVING -> ORDER BY -> LIMIT) regardless of the user's call order.
# If there is a LIMIT earlier in the expression tree, then we must produce a new
# sub-query from this filter to ensure correctness.
if (
context._is_snowpark_connect_compatible_mode
and self._ops_after_agg is not None
and "filter" not in self._ops_after_agg
and "limit" not in self._ops_after_agg
):
having_plan = Filter(filter_col_expr, self._plan, is_having=True)
if self._select_statement:
df = self._with_plan(
self._session._analyzer.create_select_statement(
from_=self._session._analyzer.create_select_snowflake_plan(
having_plan, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
),
_ast_stmt=stmt,
)
else:
df = self._with_plan(having_plan, _ast_stmt=stmt)
df._ops_after_agg = self._ops_after_agg.copy()
df._ops_after_agg.add("filter")
return df
new_ops = self._ops_after_agg.copy()
new_ops.add("filter")
return self._build_post_agg_df(
ops_after_agg=new_ops,
pending_havings=self._pending_havings + [filter_col_expr],
pending_order_bys=self._pending_order_bys,
_ast_stmt=stmt,
)
else:
if self._select_statement:
return self._with_plan(
Expand Down Expand Up @@ -2330,28 +2337,25 @@ def sort(
)

# In snowpark_connect_compatible mode, we need to handle
# the sorting for dataframe after aggregation without nesting
# the sorting for dataframe after aggregation without nesting.
# We defer the ORDER BY expressions and rebuild the plan from
# the aggregate base in correct SQL clause order.
# If there is a LIMIT earlier in the expression tree, then we must produce a new
# sub-query from this filter to ensure correctness.
if (
context._is_snowpark_connect_compatible_mode
and self._ops_after_agg is not None
and "sort" not in self._ops_after_agg
and "limit" not in self._ops_after_agg
):
sort_plan = Sort(sort_exprs, self._plan, is_order_by_append=True)
if self._select_statement:
df = self._with_plan(
self._session._analyzer.create_select_statement(
from_=self._session._analyzer.create_select_snowflake_plan(
sort_plan, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
),
_ast_stmt=stmt,
)
else:
df = self._with_plan(sort_plan, _ast_stmt=stmt)
df._ops_after_agg = self._ops_after_agg.copy()
df._ops_after_agg.add("sort")
return df
new_ops = self._ops_after_agg.copy()
new_ops.add("sort")
return self._build_post_agg_df(
ops_after_agg=new_ops,
pending_havings=self._pending_havings,
# New ordering clauses must be placed before previously-declared ones
pending_order_bys=sort_exprs + self._pending_order_bys,
_ast_stmt=stmt,
)
else:
df = (
self._with_plan(self._select_statement.sort(sort_exprs))
Expand Down Expand Up @@ -3057,30 +3061,21 @@ def limit(
stmt = None

# In snowpark_connect_compatible mode, we need to handle
# the limit for dataframe after aggregation without nesting
# the limit for dataframe after aggregation without nesting.
if (
context._is_snowpark_connect_compatible_mode
and self._ops_after_agg is not None
and "limit" not in self._ops_after_agg
):
limit_plan = Limit(
Literal(n), Literal(offset), self._plan, is_limit_append=True
new_ops = self._ops_after_agg.copy()
new_ops.add("limit")
return self._build_post_agg_df(
ops_after_agg=new_ops,
pending_havings=self._pending_havings,
pending_order_bys=self._pending_order_bys,
limit_parameters=(n, offset),
_ast_stmt=stmt,
)
if self._select_statement:
df = self._with_plan(
self._session._analyzer.create_select_statement(
from_=self._session._analyzer.create_select_snowflake_plan(
limit_plan, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
),
_ast_stmt=stmt,
)
else:
df = self._with_plan(limit_plan, _ast_stmt=stmt)
df._ops_after_agg = self._ops_after_agg.copy()
df._ops_after_agg.add("limit")
return df
else:
if self._select_statement:
return self._with_plan(
Expand Down Expand Up @@ -6835,6 +6830,72 @@ def dtypes(self) -> List[Tuple[str, str]]:
]
return dtypes

def _build_post_agg_df(
self,
ops_after_agg: set[str],
pending_havings: list[Expression],
pending_order_bys: list[Expression],
limit_parameters: Optional[tuple[int, int]] = None,
_ast_stmt=None,
) -> "DataFrame":
"""
When constructing group by aggregation queries in SCOS compatibility mode, we must ensure that
filter (HAVING), sorting (ORDER BY), and LIMIT clauses are emitted in the correct order, regardless of
the order in which the user specified those operations. For example:

df.groupBy("dept").agg(
count("*").alias("headcount"),
avg("salary").alias("avg_salary"),
)
.orderBy(col("avg_salary").desc())
.filter(col("headcount") > 1)
.limit(2)

Even though `orderBy` is the first operation, we must re-order the `filter` to be first because
SQL syntax requires HAVING, ORDER BY, and LIMIT clauses to appear in that specific order.
We use `_agg_base_plan` and `_agg_base_select_statement` to re-construct SQL with this constraint.

Note that LIMIT itself does not commute with ORDER BY and FILTER, so if another FILTER or
ORDER BY appears after a LIMIT, we must generate a new sub-query. This invariant is enforced
when chaining new filter/order by operations.

This method should only be called in SCOS compatibility mode (context._is_snowpark_connect_compatible_mode).
"""
current = self._agg_base_plan

if len(pending_havings) > 0:
current = Filter(
reduce(
lambda acc, expr: And(acc, expr),
pending_havings,
),
current,
is_having=True,
)
if len(pending_order_bys) > 0:
current = Sort(pending_order_bys, current, is_order_by_append=True)
if limit_parameters is not None:
n, offset = limit_parameters
current = Limit(Literal(n), Literal(offset), current, is_limit_append=True)

if self._agg_base_select_statement is not None:
new_plan = self._session._analyzer.create_select_statement(
from_=self._session._analyzer.create_select_snowflake_plan(
current, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
)
else:
new_plan = current

df = self._with_plan(new_plan, _ast_stmt=_ast_stmt)
df._ops_after_agg = ops_after_agg
df._agg_base_plan = self._agg_base_plan
df._agg_base_select_statement = self._agg_base_select_statement
df._pending_havings = pending_havings
df._pending_order_bys = pending_order_bys
return df

def _with_plan(self, plan, _ast_stmt=None) -> "DataFrame":
"""
:param proto.Bind ast_stmt: The AST statement protobuf corresponding to this value.
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/relational_grouped_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ def agg(
# if no grouping exprs, there is already a LIMIT 1 in the query
# see aggregate_statement in analyzer_utils.py
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
df._agg_base_plan = df._plan
df._agg_base_select_statement = df._select_statement

if _emit_ast:
df._ast_id = stmt.uid
Expand Down Expand Up @@ -531,6 +533,8 @@ def end_partition(
# if no grouping exprs, there is already a LIMIT 1 in the query
# see aggregate_statement in analyzer_utils.py
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
df._agg_base_plan = df._plan
df._agg_base_select_statement = df._select_statement

if _emit_ast:
stmt = working_dataframe._session._ast_batch.bind()
Expand Down Expand Up @@ -766,6 +770,8 @@ def count(self, _emit_ast: bool = True, **kwargs) -> DataFrame:
# if no grouping exprs, there is already a LIMIT 1 in the query
# see aggregate_statement in analyzer_utils.py
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
df._agg_base_plan = df._plan
df._agg_base_select_statement = df._select_statement

# TODO: count seems similar to mean, min, .... Can we unify implementation here?
if _emit_ast:
Expand Down Expand Up @@ -815,6 +821,8 @@ def _function(
# if no grouping exprs, there is already a LIMIT 1 in the query
# see aggregate_statement in analyzer_utils.py
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
df._agg_base_plan = df._plan
df._agg_base_select_statement = df._select_statement

if _emit_ast:
stmt = self._dataframe._session._ast_batch.bind()
Expand Down Expand Up @@ -909,6 +917,8 @@ def ai_agg(
# if no grouping exprs, there is already a LIMIT 1 in the query
# see aggregate_statement in analyzer_utils.py
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
df._agg_base_plan = df._plan
df._agg_base_select_statement = df._select_statement

if _emit_ast:
stmt = self._dataframe._session._ast_batch.bind()
Expand Down
Loading
Loading