Skip to content

Commit 6da287e

Browse files
authored
SNOW-2185699: Support filtering after grouping by and aggregation (#3547)
1 parent 65e7d4e commit 6da287e

File tree

8 files changed

+353
-33
lines changed

8 files changed

+353
-33
lines changed

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,7 @@ def do_resolve_with_resolved_children(
10301030
self.analyze(
10311031
logical_plan.condition, df_aliased_col_name_to_real_col_name
10321032
),
1033+
logical_plan.is_having,
10331034
resolved_children[logical_plan.child],
10341035
logical_plan,
10351036
)
@@ -1082,6 +1083,7 @@ def do_resolve_with_resolved_children(
10821083
self.analyze(x, df_aliased_col_name_to_real_col_name)
10831084
for x in logical_plan.order
10841085
],
1086+
logical_plan.is_order_by_append,
10851087
resolved_children[logical_plan.child],
10861088
logical_plan,
10871089
)

src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@
206206
UUID_COMMENT = "-- {}"
207207
MODEL = "MODEL"
208208
EXCLAMATION_MARK = "!"
209+
HAVING = " HAVING "
209210

210211
TEMPORARY_STRING_SET = frozenset(["temporary", "temp"])
211212

@@ -530,14 +531,17 @@ def project_statement(
530531

531532

532533
def filter_statement(
533-
condition: str, child: str, child_uuid: Optional[str] = None
534+
condition: str, is_having: bool, child: str, child_uuid: Optional[str] = None
534535
) -> str:
535-
return (
536-
project_statement([], child, child_uuid=child_uuid)
537-
+ NEW_LINE
538-
+ WHERE
539-
+ condition
540-
)
536+
if is_having:
537+
return child + NEW_LINE + HAVING + condition
538+
else:
539+
return (
540+
project_statement([], child, child_uuid=child_uuid)
541+
+ NEW_LINE
542+
+ WHERE
543+
+ condition
544+
)
541545

542546

543547
def sample_statement(
@@ -648,10 +652,17 @@ def aggregate_statement(
648652

649653

650654
def sort_statement(
651-
order: List[str], child: str, child_uuid: Optional[str] = None
655+
order: List[str],
656+
is_order_by_append: bool,
657+
child: str,
658+
child_uuid: Optional[str] = None,
652659
) -> str:
653660
return (
654-
project_statement([], child, child_uuid=child_uuid)
661+
(
662+
child
663+
if is_order_by_append
664+
else project_statement([], child, child_uuid=child_uuid)
665+
)
655666
+ NEW_LINE
656667
+ ORDER_BY
657668
+ NEW_LINE
@@ -736,7 +747,7 @@ def values_statement(output: List[Attribute], data: List[Row]) -> str:
736747

737748
def empty_values_statement(output: List[Attribute]) -> str:
738749
data = [Row(*[None] * len(output))]
739-
return filter_statement(UNSAT_FILTER, values_statement(output, data))
750+
return filter_statement(UNSAT_FILTER, False, values_statement(output, data))
740751

741752

742753
def set_operator_statement(left: str, right: str, operator: str) -> str:

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,12 +1078,14 @@ def aggregate(
10781078
def filter(
10791079
self,
10801080
condition: str,
1081+
is_having: bool,
10811082
child: SnowflakePlan,
10821083
source_plan: Optional[LogicalPlan],
10831084
) -> SnowflakePlan:
10841085
return self.build(
10851086
lambda x: filter_statement(
10861087
condition,
1088+
is_having,
10871089
x,
10881090
child_uuid=(
10891091
child.uuid
@@ -1135,12 +1137,14 @@ def sample_by(
11351137
def sort(
11361138
self,
11371139
order: List[str],
1140+
is_order_by_append: bool,
11381141
child: SnowflakePlan,
11391142
source_plan: Optional[LogicalPlan],
11401143
) -> SnowflakePlan:
11411144
return self.build(
11421145
lambda x: sort_statement(
11431146
order,
1147+
is_order_by_append,
11441148
x,
11451149
child_uuid=(
11461150
child.uuid

src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,15 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
8888

8989

9090
class Sort(UnaryNode):
91-
def __init__(self, order: List[SortOrder], child: LogicalPlan) -> None:
91+
def __init__(
92+
self,
93+
order: List[SortOrder],
94+
child: LogicalPlan,
95+
is_order_by_append: bool = False,
96+
) -> None:
9297
super().__init__(child)
9398
self.order = order
99+
self.is_order_by_append = is_order_by_append
94100

95101
@property
96102
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
@@ -242,13 +248,16 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
242248

243249

244250
class Filter(UnaryNode):
245-
def __init__(self, condition: Expression, child: LogicalPlan) -> None:
251+
def __init__(
252+
self, condition: Expression, child: LogicalPlan, is_having: bool = False
253+
) -> None:
246254
super().__init__(child)
247255
self.condition = condition
256+
self.is_having = is_having
248257

249258
@property
250259
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
251-
# child WHERE condition
260+
# child WHERE condition or HAVING condition
252261
return sum_node_complexities(
253262
{PlanNodeCategory.FILTER: 1},
254263
self.condition.cumulative_node_complexity,

src/snowflake/snowpark/dataframe.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,7 @@ def __init__(
616616

617617
self._statement_params = None
618618
self.is_cached: bool = is_cached #: Whether the dataframe is cached.
619+
self._is_grouped_by_and_aggregated = False
619620

620621
# Whether all columns are VARIANT data type,
621622
# which support querying nested fields via dot notations
@@ -1964,18 +1965,41 @@ def filter(
19641965
else:
19651966
stmt = _ast_stmt
19661967

1967-
if self._select_statement:
1968+
# In snowpark_connect_compatible mode, we need to handle
1969+
# the filtering for dataframe after aggregation without nesting using HAVING
1970+
if (
1971+
context._is_snowpark_connect_compatible_mode
1972+
and self._is_grouped_by_and_aggregated
1973+
):
1974+
having_plan = Filter(filter_col_expr, self._plan, is_having=True)
1975+
if self._select_statement:
1976+
df = self._with_plan(
1977+
self._session._analyzer.create_select_statement(
1978+
from_=self._session._analyzer.create_select_snowflake_plan(
1979+
having_plan, analyzer=self._session._analyzer
1980+
),
1981+
analyzer=self._session._analyzer,
1982+
),
1983+
_ast_stmt=stmt,
1984+
)
1985+
else:
1986+
df = self._with_plan(having_plan, _ast_stmt=stmt)
1987+
df._is_grouped_by_and_aggregated = True
1988+
return df
1989+
else:
1990+
if self._select_statement:
1991+
return self._with_plan(
1992+
self._select_statement.filter(filter_col_expr),
1993+
_ast_stmt=stmt,
1994+
)
19681995
return self._with_plan(
1969-
self._select_statement.filter(filter_col_expr),
1996+
Filter(
1997+
filter_col_expr,
1998+
self._plan,
1999+
is_having=False,
2000+
),
19702001
_ast_stmt=stmt,
19712002
)
1972-
return self._with_plan(
1973-
Filter(
1974-
filter_col_expr,
1975-
self._plan,
1976-
),
1977-
_ast_stmt=stmt,
1978-
)
19792003

19802004
@df_api_usage
19812005
@publicapi
@@ -2105,16 +2129,40 @@ def sort(
21052129
SortOrder(exprs[idx], orders[idx] if orders else Ascending())
21062130
)
21072131

2108-
df = (
2109-
self._with_plan(self._select_statement.sort(sort_exprs))
2110-
if self._select_statement
2111-
else self._with_plan(Sort(sort_exprs, self._plan))
2112-
)
2132+
# In snowpark_connect_compatible mode, we need to handle
2133+
# the sorting for dataframe after aggregation without nesting
2134+
if (
2135+
context._is_snowpark_connect_compatible_mode
2136+
and self._is_grouped_by_and_aggregated
2137+
):
2138+
sort_plan = Sort(sort_exprs, self._plan, is_order_by_append=True)
2139+
if self._select_statement:
2140+
df = self._with_plan(
2141+
self._session._analyzer.create_select_statement(
2142+
from_=self._session._analyzer.create_select_snowflake_plan(
2143+
sort_plan, analyzer=self._session._analyzer
2144+
),
2145+
analyzer=self._session._analyzer,
2146+
),
2147+
_ast_stmt=stmt,
2148+
)
2149+
else:
2150+
df = self._with_plan(sort_plan, _ast_stmt=stmt)
2151+
df._is_grouped_by_and_aggregated = True
2152+
return df
2153+
else:
2154+
df = (
2155+
self._with_plan(self._select_statement.sort(sort_exprs))
2156+
if self._select_statement
2157+
else self._with_plan(
2158+
Sort(sort_exprs, self._plan, is_order_by_append=False)
2159+
)
2160+
)
21132161

2114-
if _emit_ast:
2115-
df._ast_id = stmt.uid
2162+
if _emit_ast:
2163+
df._ast_id = stmt.uid
21162164

2117-
return df
2165+
return df
21182166

21192167
@experimental(version="1.5.0")
21202168
@publicapi

src/snowflake/snowpark/relational_grouped_dataframe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def agg(
324324
agg_exprs.append(_str_to_expr(e[1], _emit_ast)(col_expr))
325325

326326
df = self._to_df(agg_exprs, _emit_ast=False)
327+
df._is_grouped_by_and_aggregated = True
327328

328329
if _emit_ast:
329330
df._ast_id = stmt.uid
@@ -514,6 +515,7 @@ def end_partition(
514515
),
515516
_emit_ast=False,
516517
)
518+
df._is_grouped_by_and_aggregated = True
517519

518520
if _emit_ast:
519521
stmt = working_dataframe._session._ast_batch.bind()
@@ -692,6 +694,7 @@ def count(self, _emit_ast: bool = True) -> DataFrame:
692694
],
693695
_emit_ast=False,
694696
)
697+
df._is_grouped_by_and_aggregated = True
695698

696699
# TODO: count seems similar to mean, min, .... Can we unify implementation here?
697700
if _emit_ast:
@@ -727,6 +730,7 @@ def _function(
727730
)._expression
728731
agg_exprs.append(expr)
729732
df = self._to_df(agg_exprs)
733+
df._is_grouped_by_and_aggregated = True
730734

731735
if _emit_ast:
732736
stmt = self._dataframe._session._ast_batch.bind()

0 commit comments

Comments
 (0)