Skip to content

Commit 16988dc

Browse files
sfc-gh-dyadavsfc-gh-jrosesfc-gh-aalam
authored
Fix allowing multiple aggregates in Pivot fix (#3171)
Co-authored-by: Jamison Rose <[email protected]> Co-authored-by: Afroz Alam <[email protected]>
1 parent 0b08ec5 commit 16988dc

File tree

6 files changed

+137
-31
lines changed

6 files changed

+137
-31
lines changed

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@
5050
BinaryExpression,
5151
)
5252
from snowflake.snowpark._internal.analyzer.binary_plan_node import (
53+
FullOuter,
5354
Join,
5455
SetOperation,
55-
Union as UnionPlan,
56+
UsingJoin,
5657
)
5758
from snowflake.snowpark._internal.analyzer.datatype_mapper import (
5859
numeric_to_sql_without_cast,
@@ -1167,7 +1168,10 @@ def do_resolve_with_resolved_children(
11671168
pivot_values = None
11681169

11691170
plan = None
1171+
11701172
for agg_expr in logical_plan.aggregates:
1173+
# We only allow pivot on more than one aggregates when it on a groupby clause
1174+
join_columns: List[str] | None = None
11711175
if (
11721176
len(logical_plan.grouping_columns) != 0
11731177
and agg_expr.children is not None
@@ -1186,6 +1190,10 @@ def do_resolve_with_resolved_children(
11861190
], # aggregate column is first child in logical_plan.aggregates
11871191
logical_plan.pivot_column,
11881192
]
1193+
join_columns = [
1194+
self.analyze(expression, df_aliased_col_name_to_real_col_name)
1195+
for expression in logical_plan.grouping_columns
1196+
]
11891197
child = self.plan_builder.project(
11901198
[
11911199
self.analyze(col, df_aliased_col_name_to_real_col_name)
@@ -1202,9 +1210,7 @@ def do_resolve_with_resolved_children(
12021210
logical_plan.pivot_column, df_aliased_col_name_to_real_col_name
12031211
),
12041212
pivot_values,
1205-
self.analyze(
1206-
logical_plan.aggregates[0], df_aliased_col_name_to_real_col_name
1207-
),
1213+
self.analyze(agg_expr, df_aliased_col_name_to_real_col_name),
12081214
self.analyze(
12091215
logical_plan.default_on_null,
12101216
df_aliased_col_name_to_real_col_name,
@@ -1213,6 +1219,8 @@ def do_resolve_with_resolved_children(
12131219
else None,
12141220
child,
12151221
logical_plan,
1222+
len(logical_plan.aggregates)
1223+
> 1, # we need to alias the names with agg function when we have more than one agg functions on the pivot
12161224
)
12171225

12181226
# If this is a dynamic pivot, then we can't use child.schema_query which is used in the schema_query
@@ -1225,15 +1233,35 @@ def do_resolve_with_resolved_children(
12251233
# table as it may not exist at later point in time when dataframe.schema is called.
12261234
pivot_plan.schema_query = pivot_plan.queries[-1].sql
12271235

1228-
# union multiple aggregations
1229-
# https://docs.snowflake.com/en/sql-reference/constructs/pivot#dynamic-pivot-with-multiple-aggregations-using-union
1236+
# using join here to have the output similar to what spark have
1237+
# both the aggregations are happening over the same set of columns and pivot values
1238+
# we will receive left and right both pivot table with same set of groupby columns and columns corresponding to pivot values
1239+
# to differentiate between columns corresponding to pivot values for aggregation function they will have name suffixed by agg fun
1240+
# join would keep the group by column same and append the columns corresponding to pivot values for multiple agg functions
1241+
# output would look similar to below for a statement like
1242+
# df.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum("year", "salary").show()
1243+
# +-------+---------------+---------------+-------------------+-------------------+
1244+
# | name|Sales_sum(year)|Sales_sum(year)|Marketing_sum(year)|Marketing_sum(year)|
1245+
# +-------+---------------+---------------+-------------------+-------------------+
1246+
# | Scott| NULL| NULL| NULL| NULL|
1247+
# | James| 4039| 4039| NULL| NULL|
1248+
# | Jen| NULL| NULL| NULL| NULL|
1249+
# |Michael| 2020| 2020| NULL| NULL|
1250+
12301251
if plan is None:
12311252
plan = pivot_plan
1232-
else:
1233-
union_plan = UnionPlan(plan, pivot_plan, is_all=False)
1234-
plan = self.plan_builder.set_operator(
1235-
plan, pivot_plan, union_plan.sql, union_plan
1253+
elif join_columns is not None:
1254+
plan = self.plan_builder.join(
1255+
plan,
1256+
pivot_plan,
1257+
UsingJoin(FullOuter(), join_columns),
1258+
"",
1259+
"",
1260+
logical_plan,
1261+
self.session.conf.get("use_constant_subquery_alias", False),
12361262
)
1263+
# we have a check in relational_grouped_dataframe.py which will prevent a case where there are more than one aggregate
1264+
# without having a grouping condition which is essential to create join_columns
12371265

12381266
assert plan is not None
12391267
return plan

src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
is_sql_select_statement,
4747
quote_name,
4848
random_name_for_temp_object,
49+
unwrap_single_quote,
4950
)
5051
from snowflake.snowpark.row import Row
5152
from snowflake.snowpark.types import DataType
@@ -63,6 +64,7 @@
6364
LEFT_BRACKET = "["
6465
RIGHT_BRACKET = "]"
6566
AS = " AS "
67+
EXCLUDE = " EXCLUDE "
6668
AND = " AND "
6769
OR = " OR "
6870
NOT = " NOT "
@@ -1248,7 +1250,9 @@ def pivot_statement(
12481250
aggregate: str,
12491251
default_on_null: Optional[str],
12501252
child: str,
1253+
should_alias_column_with_agg: bool,
12511254
) -> str:
1255+
select_str = STAR
12521256
if isinstance(pivot_values, str):
12531257
# The subexpression in this case already includes parenthesis.
12541258
values_str = pivot_values
@@ -1258,10 +1262,24 @@ def pivot_statement(
12581262
+ (ANY if pivot_values is None else COMMA.join(pivot_values))
12591263
+ RIGHT_PARENTHESIS
12601264
)
1265+
if pivot_values is not None and should_alias_column_with_agg:
1266+
quoted_names = [quote_name(value) for value in pivot_values]
1267+
# unwrap_single_quote on the value to match the output closer to what spark generates
1268+
aliased_names = [
1269+
quote_name(f"{unwrap_single_quote(value)}_{aggregate}")
1270+
for value in pivot_values
1271+
]
1272+
aliased_string = [
1273+
f"{quoted_name}{AS}{aliased_name}"
1274+
for aliased_name, quoted_name in zip(aliased_names, quoted_names)
1275+
]
1276+
exclude_str = COMMA.join(quoted_names)
1277+
aliased_str = COMMA.join(aliased_string)
1278+
select_str = f"{STAR}{EXCLUDE}{LEFT_PARENTHESIS}{exclude_str}{RIGHT_PARENTHESIS}, {aliased_str}"
12611279

12621280
return (
12631281
SELECT
1264-
+ STAR
1282+
+ select_str
12651283
+ FROM
12661284
+ LEFT_PARENTHESIS
12671285
+ child

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1105,10 +1105,16 @@ def pivot(
11051105
default_on_null: Optional[str],
11061106
child: SnowflakePlan,
11071107
source_plan: Optional[LogicalPlan],
1108+
should_alias_column_with_agg: bool,
11081109
) -> SnowflakePlan:
11091110
return self.build(
11101111
lambda x: pivot_statement(
1111-
pivot_column, pivot_values, aggregate, default_on_null, x
1112+
pivot_column,
1113+
pivot_values,
1114+
aggregate,
1115+
default_on_null,
1116+
x,
1117+
should_alias_column_with_agg,
11121118
),
11131119
child,
11141120
source_plan,

src/snowflake/snowpark/_internal/error_message.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ def DF_CROSS_TAB_COUNT_TOO_LARGE(
137137
error_code="1107",
138138
)
139139

140+
@staticmethod
141+
def DF_PIVOT_ONLY_SUPPORT_ONE_AGG_EXPR() -> SnowparkDataframeException:
142+
return SnowparkDataframeException(
143+
"You can apply only one aggregate expression to a RelationalGroupedDataFrame "
144+
"returned by the pivot() method unless the pivot is applied with a groupby clause.",
145+
error_code="1109",
146+
)
147+
140148
@staticmethod
141149
def DF_DATAFRAME_IS_NOT_QUALIFIED_FOR_SCALAR_QUERY(
142150
count: int, columns: str

src/snowflake/snowpark/relational_grouped_dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
66
import inspect
77

8+
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
89
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
910
import snowflake.snowpark.context as context
1011
from snowflake.connector.options import pandas
@@ -225,6 +226,8 @@ def _to_df(
225226
self._dataframe._select_statement or self._dataframe._plan,
226227
)
227228
elif isinstance(self._group_type, _PivotType):
229+
if len(agg_exprs) != 1 and len(unaliased_grouping) == 0:
230+
raise SnowparkClientExceptionMessages.DF_PIVOT_ONLY_SUPPORT_ONE_AGG_EXPR()
228231
group_plan = Pivot(
229232
unaliased_grouping,
230233
self._group_type.pivot_col,

tests/integ/scala/test_dataframe_aggregate_suite.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from decimal import Decimal
77
from math import sqrt
8+
import re
89
from typing import NamedTuple
910

1011
import pytest
@@ -15,6 +16,7 @@
1516
)
1617
from snowflake.snowpark.column import Column
1718
from snowflake.snowpark.exceptions import (
19+
SnowparkDataframeException,
1820
SnowparkSQLException,
1921
)
2022
from snowflake.snowpark.functions import (
@@ -419,39 +421,80 @@ class MonthlySales(NamedTuple):
419421
reason="Multiple aggregations are not supported in local testing mode",
420422
)
421423
def test_pivot_multiple_aggs(session):
422-
# 1) SUM and AVG
423-
Utils.check_answer(
424+
with pytest.raises(
425+
SnowparkDataframeException,
426+
match=re.escape(
427+
"You can apply only one aggregate expression to a RelationalGroupedDataFrame returned by the pivot() method unless the pivot is applied with a groupby clause."
428+
),
429+
):
430+
TestData.monthly_sales(session).pivot(
431+
"month", ["JAN", "FEB", "MAR", "APR"]
432+
).agg([sum(col("amount")), avg(col("amount"))]).sort(col("empid"))
433+
434+
df = (
424435
TestData.monthly_sales(session)
436+
.groupBy(col("empid"))
425437
.pivot("month", ["JAN", "FEB", "MAR", "APR"])
426438
.agg([sum(col("amount")), avg(col("amount"))])
427-
.sort(col("empid")),
439+
.sort(col("empid"))
440+
)
441+
442+
assert [f.name for f in df.schema.fields] == [
443+
"EMPID",
444+
'"JAN_sum(""AMOUNT"")"',
445+
'"FEB_sum(""AMOUNT"")"',
446+
'"MAR_sum(""AMOUNT"")"',
447+
'"APR_sum(""AMOUNT"")"',
448+
'"JAN_avg(""AMOUNT"")"',
449+
'"FEB_avg(""AMOUNT"")"',
450+
'"MAR_avg(""AMOUNT"")"',
451+
'"APR_avg(""AMOUNT"")"',
452+
]
453+
454+
Utils.check_answer(
455+
df,
428456
[
429-
Row(1, 10400, 8000, 11000, 18000),
430-
Row(2, 39500, 90700, 12000, 5300),
457+
Row(1, 10400, 8000, 11000, 18000, 5200.0, 4000.0, 5500.0, 9000.0),
458+
Row(
459+
2,
460+
39500,
461+
90700,
462+
12000,
463+
5300,
464+
19750.0,
465+
45350.0,
466+
6000.0,
467+
2650.0,
468+
),
431469
],
432470
)
433471

434-
# 2) MIN and MAX
435-
Utils.check_answer(
472+
df = (
436473
TestData.monthly_sales(session)
474+
.groupBy(col("empid"))
437475
.pivot("month", ["JAN", "FEB", "MAR", "APR"])
438476
.agg([min(col("amount")), max(col("amount"))])
439-
.sort(col("empid")),
440-
[
441-
Row(1, 400, 3000, 5000, 8000),
442-
Row(2, 4500, 200, 2500, 800),
443-
],
477+
.sort(col("empid"))
444478
)
445479

446-
# 3) AVG and COUNT_DISTINCT
480+
assert [f.name for f in df.schema.fields] == [
481+
"EMPID",
482+
'"JAN_min(""AMOUNT"")"',
483+
'"FEB_min(""AMOUNT"")"',
484+
'"MAR_min(""AMOUNT"")"',
485+
'"APR_min(""AMOUNT"")"',
486+
'"JAN_max(""AMOUNT"")"',
487+
'"FEB_max(""AMOUNT"")"',
488+
'"MAR_max(""AMOUNT"")"',
489+
'"APR_max(""AMOUNT"")"',
490+
]
491+
492+
# 2) MIN and MAX
447493
Utils.check_answer(
448-
TestData.monthly_sales(session)
449-
.pivot("month", ["JAN", "FEB", "MAR", "APR"])
450-
.agg([avg(col("amount")), count_distinct(col("amount"))])
451-
.sort(col("empid")),
494+
df,
452495
[
453-
Row(1, 5200, 4000, 5500, 9000),
454-
Row(2, 19750, 45350, 6000, 2650),
496+
Row(1, 400, 3000, 5000, 8000, 10000, 5000, 6000, 10000),
497+
Row(2, 4500, 200, 2500, 800, 35000, 90500, 9500, 4500),
455498
],
456499
)
457500

0 commit comments

Comments
 (0)