Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 101 additions & 52 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,15 @@ def __init__(

self._statement_params = None
self.is_cached: bool = is_cached #: Whether the dataframe is cached.
# Internal state variables used to construct flattened GROUP BY clauses in the correct order
# in SCOS compatibility mode.
# See comments on `_build_post_agg_df` for details.
self._ops_after_agg = None
self._agg_base_plan = None
self._agg_base_select_statement = None
self._pending_having = None
self._pending_order_by = None
self._pending_limit = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious how multiple filter would affect the plan generation

df1 = df.groupBy(...).agg(...)
df2 = df1.filter(...).limit().filter(...)
df3 = df1.filter(...).filter(...).limit()
df4 = df1.limit().filter(...).filter(...)
  1. do df2,3,4 output the same query?
  2. what's the behavior in spark and do we align with spark behavior after your code change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. It looks like in spark, filter is not commutative across a sequence of df.filter(...).orderBy(...).limit(...).filter(...) (the final call will see a deterministic subset of rows based on the prior order/limit). I'll need to do some more testing to see what this means for SQL generation, and whether the cases you mentioned have similar problems.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-aling I added some test cases covering this behavior, and checked the output against spark. The changes are:

  1. Operations that occur after a LIMIT now always produce a new sub-query, since FILTER -> LIMIT and LIMIT -> FILTER are not equivalent.
  2. Consecutive filter operations are now conjoined into a single HAVING clause. I don't think SQL has any short-circuiting evaluation behavior that imperative languages do, so I believe this should always be equivalent. The Spark explain plans I looked at did also combine filter clauses together into a single operator.
  3. Consecutive ordering operations are now combined into a single ORDER BY, with the last ordering clause appearing first in the SQL.

Most of these cases were previously broken in SCOS, as the only sequence of operations that would have produced valid SQL was df.groupby(...).agg(...).filter(...).orderBy(...).limit(...), where the operations appeared in the same order as that required by SQL.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, thanks for checking the behavior!
re1: does orderBy also produce a subquery like limit?

Copy link
Copy Markdown
Contributor Author

@sfc-gh-joshi sfc-gh-joshi Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ORDER BY does not create a new sub-query, since FILTER and ORDER BY are semantically commutative.


# Whether all columns are VARIANT data type,
# which support querying nested fields via dot notations
Expand Down Expand Up @@ -2114,28 +2122,24 @@ def filter(
stmt = _ast_stmt

# In snowpark_connect_compatible mode, we need to handle
# the filtering for dataframe after aggregation without nesting using HAVING
# the filtering for dataframe after aggregation without nesting using HAVING.
# We defer the HAVING expression and rebuild the plan from the
# aggregate base so that SQL clauses are emitted in the correct order
# (HAVING -> ORDER BY -> LIMIT) regardless of the user's call order.
if (
context._is_snowpark_connect_compatible_mode
and self._ops_after_agg is not None
and "filter" not in self._ops_after_agg
):
having_plan = Filter(filter_col_expr, self._plan, is_having=True)
if self._select_statement:
df = self._with_plan(
self._session._analyzer.create_select_statement(
from_=self._session._analyzer.create_select_snowflake_plan(
having_plan, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
),
_ast_stmt=stmt,
)
else:
df = self._with_plan(having_plan, _ast_stmt=stmt)
df._ops_after_agg = self._ops_after_agg.copy()
df._ops_after_agg.add("filter")
return df
new_ops = self._ops_after_agg.copy()
new_ops.add("filter")
return self._build_post_agg_df(
ops_after_agg=new_ops,
pending_having=filter_col_expr,
pending_order_by=self._pending_order_by,
pending_limit=self._pending_limit,
_ast_stmt=stmt,
)
else:
if self._select_statement:
return self._with_plan(
Expand Down Expand Up @@ -2330,28 +2334,23 @@ def sort(
)

# In snowpark_connect_compatible mode, we need to handle
# the sorting for dataframe after aggregation without nesting
# the sorting for dataframe after aggregation without nesting.
# We defer the ORDER BY expressions and rebuild the plan from
# the aggregate base in correct SQL clause order.
if (
context._is_snowpark_connect_compatible_mode
and self._ops_after_agg is not None
and "sort" not in self._ops_after_agg
):
sort_plan = Sort(sort_exprs, self._plan, is_order_by_append=True)
if self._select_statement:
df = self._with_plan(
self._session._analyzer.create_select_statement(
from_=self._session._analyzer.create_select_snowflake_plan(
sort_plan, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
),
_ast_stmt=stmt,
)
else:
df = self._with_plan(sort_plan, _ast_stmt=stmt)
df._ops_after_agg = self._ops_after_agg.copy()
df._ops_after_agg.add("sort")
return df
new_ops = self._ops_after_agg.copy()
new_ops.add("sort")
return self._build_post_agg_df(
ops_after_agg=new_ops,
pending_having=self._pending_having,
pending_order_by=sort_exprs,
pending_limit=self._pending_limit,
_ast_stmt=stmt,
)
else:
df = (
self._with_plan(self._select_statement.sort(sort_exprs))
Expand Down Expand Up @@ -3057,30 +3056,23 @@ def limit(
stmt = None

# In snowpark_connect_compatible mode, we need to handle
# the limit for dataframe after aggregation without nesting
# the limit for dataframe after aggregation without nesting.
# We defer the LIMIT values and rebuild the plan from
# the aggregate base in correct SQL clause order.
if (
context._is_snowpark_connect_compatible_mode
and self._ops_after_agg is not None
and "limit" not in self._ops_after_agg
):
limit_plan = Limit(
Literal(n), Literal(offset), self._plan, is_limit_append=True
new_ops = self._ops_after_agg.copy()
new_ops.add("limit")
return self._build_post_agg_df(
ops_after_agg=new_ops,
pending_having=self._pending_having,
pending_order_by=self._pending_order_by,
pending_limit=(n, offset),
_ast_stmt=stmt,
)
if self._select_statement:
df = self._with_plan(
self._session._analyzer.create_select_statement(
from_=self._session._analyzer.create_select_snowflake_plan(
limit_plan, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
),
_ast_stmt=stmt,
)
else:
df = self._with_plan(limit_plan, _ast_stmt=stmt)
df._ops_after_agg = self._ops_after_agg.copy()
df._ops_after_agg.add("limit")
return df
else:
if self._select_statement:
return self._with_plan(
Expand Down Expand Up @@ -6835,6 +6827,63 @@ def dtypes(self) -> List[Tuple[str, str]]:
]
return dtypes

def _build_post_agg_df(
self,
ops_after_agg,
pending_having,
pending_order_by,
pending_limit,
_ast_stmt=None,
) -> "DataFrame":
"""
When constructing group by aggregation queries in SCOS compatibility mode, we must ensure that
filter (HAVING), sorting (ORDER BY), and LIMIT clauses are emitted in the correct order, regardless of
the order in which the user specified those operations. For example:

df.groupBy("dept").agg(
count("*").alias("headcount"),
avg("salary").alias("avg_salary"),
)
.orderBy(col("avg_salary").desc())
.limit(2)
.filter(col("headcount") > 1)

Even though `orderBy` is the first operation, we must re-order the `filter` to be first because
SQL syntax requires HAVING, ORDER BY, and LIMIT clauses to appear in that specific order.

We use `_agg_base_plan` and `_agg_base_select_statement` to re-construct SQL with this constraint.

This method should only be called in SCOS compatibility mode (context._is_snowpark_connect_compatible_mode).
"""
current = self._agg_base_plan

if pending_having is not None:
current = Filter(pending_having, current, is_having=True)
if pending_order_by is not None:
current = Sort(pending_order_by, current, is_order_by_append=True)
if pending_limit is not None:
n, offset = pending_limit
current = Limit(Literal(n), Literal(offset), current, is_limit_append=True)

if self._agg_base_select_statement is not None:
new_plan = self._session._analyzer.create_select_statement(
from_=self._session._analyzer.create_select_snowflake_plan(
current, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
)
else:
new_plan = current

df = self._with_plan(new_plan, _ast_stmt=_ast_stmt)
df._ops_after_agg = ops_after_agg
df._agg_base_plan = self._agg_base_plan
df._agg_base_select_statement = self._agg_base_select_statement
df._pending_having = pending_having
df._pending_order_by = pending_order_by
df._pending_limit = pending_limit
return df

def _with_plan(self, plan, _ast_stmt=None) -> "DataFrame":
"""
:param proto.Bind ast_stmt: The AST statement protobuf corresponding to this value.
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/relational_grouped_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ def agg(
# if no grouping exprs, there is already a LIMIT 1 in the query
# see aggregate_statement in analyzer_utils.py
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
df._agg_base_plan = df._plan
df._agg_base_select_statement = df._select_statement

if _emit_ast:
df._ast_id = stmt.uid
Expand Down Expand Up @@ -531,6 +533,8 @@ def end_partition(
# if no grouping exprs, there is already a LIMIT 1 in the query
# see aggregate_statement in analyzer_utils.py
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
df._agg_base_plan = df._plan
df._agg_base_select_statement = df._select_statement

if _emit_ast:
stmt = working_dataframe._session._ast_batch.bind()
Expand Down Expand Up @@ -766,6 +770,8 @@ def count(self, _emit_ast: bool = True, **kwargs) -> DataFrame:
# if no grouping exprs, there is already a LIMIT 1 in the query
# see aggregate_statement in analyzer_utils.py
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
df._agg_base_plan = df._plan
df._agg_base_select_statement = df._select_statement

# TODO: count seems similar to mean, min, .... Can we unify implementation here?
if _emit_ast:
Expand Down Expand Up @@ -815,6 +821,8 @@ def _function(
# if no grouping exprs, there is already a LIMIT 1 in the query
# see aggregate_statement in analyzer_utils.py
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
df._agg_base_plan = df._plan
df._agg_base_select_statement = df._select_statement

if _emit_ast:
stmt = self._dataframe._session._ast_batch.bind()
Expand Down Expand Up @@ -909,6 +917,8 @@ def ai_agg(
# if no grouping exprs, there is already a LIMIT 1 in the query
# see aggregate_statement in analyzer_utils.py
df._ops_after_agg = set() if self._grouping_exprs else {"limit"}
df._agg_base_plan = df._plan
df._agg_base_select_statement = df._select_statement

if _emit_ast:
stmt = self._dataframe._session._ast_batch.bind()
Expand Down
Loading
Loading