Skip to content

Commit f61e698

Browse files
SNOW-1844466: Support more aggregation functions in pivot methods. (#2915)
Add support for aggregations ``"count"``, ``"median"``, ``np.median``, ``"skew"``, ``"std"``, ``np.std`` ``"var"``, and ``np.var`` in ``pd.pivot_table()``, ``DataFrame.pivot_table()``, and ``pd.crosstab()``. Snowflake PIVOT now supports all those aggregations. This commit also expands pivot and crosstab tests to include some aggregation functions we do not yet support due to Snowflake's PIVOT limitations. Fixes SNOW-1844466 --------- Signed-off-by: sfc-gh-mvashishtha <mahesh.vashishtha@snowflake.com>
1 parent fe13793 commit f61e698

File tree

12 files changed

+228
-123
lines changed

12 files changed

+228
-123
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@
112112
- Added support for `DataFrame.pop` and `Series.pop`.
113113
- Added support for `first` and `last` in `DataFrameGroupBy.agg` and `SeriesGroupBy.agg`.
114114
- Added support for `Index.drop_duplicates`.
115+
- Added support for aggregations `"count"`, `"median"`, `np.median`,
116+
`"skew"`, `"std"`, `np.std` `"var"`, and `np.var` in
117+
`pd.pivot_table()`, `DataFrame.pivot_table()`, and `pd.crosstab()`.
115118

116119
#### Bug Fixes
117120

docs/source/modin/supported/agg_supp.rst

Lines changed: 46 additions & 54 deletions
Large diffs are not rendered by default.

docs/source/modin/supported/dataframe_supported.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,11 @@ Methods
305305
| | | | any ``argfunc`` is not "count", "mean", "min", |
306306
| | | | "max", or "sum". N if ``index`` is None, |
307307
| | | | ``margins`` is True and ``aggfunc`` is "count" |
308-
| | | | or "mean" or a dictionary. N if ``index`` is None |
309-
| | | | and ``aggfunc`` is a dictionary containing |
310-
| | | | lists of aggfuncs to apply. |
308+
| | | | or "mean" or a dictionary. ``N`` if ``index`` is |
309+
| | | | None and ``aggfunc`` is a dictionary containing |
310+
| | | | lists of aggfuncs to apply. ``N`` if ``aggfunc`` is|
311+
| | | | an `unsupported aggregation |
312+
| | | | function <agg_supp.html>`_ for pivot. |
311313
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
312314
| ``plot`` | D | | Performed locally on the client |
313315
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+

docs/source/modin/supported/general_supported.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ Data manipulations
1818
| ``concat`` | P | ``levels`` is not supported, | |
1919
| | | ``copy`` is ignored | |
2020
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
21-
| ``crosstab`` | P | | ``N`` if ``aggfunc`` is not one of |
22-
| | | | "count", "mean", "min", "max", or "sum", or |
21+
| ``crosstab`` | P | | ``N`` if ``aggfunc`` is not a `supported |
22+
| | | | aggregation function <agg_supp.html>`_, |
2323
| | | | margins is True, normalize is "all" or True, |
2424
| | | | and values is passed. |
2525
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
@@ -50,8 +50,8 @@ Data manipulations
5050
| ``pivot`` | P | | See ``pivot_table`` |
5151
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
5252
| ``pivot_table`` | P | ``observed``, ``margins``, | ``N`` if ``index``, ``columns``, or ``values`` is |
53-
| | | ``sort`` | not str; or MultiIndex; or any ``argfunc`` is not |
54-
| | | | "count", "mean", "min", "max", or "sum" |
53+
| | | ``sort`` | not str; or MultiIndex; or any ``aggfunc`` is not a|
54+
| | | | `supported aggregation function <agg_supp.html>`_ |
5555
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
5656
| ``qcut`` | P | | ``N`` if ``labels!=False`` or ``retbins=True``. |
5757
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+

src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ class _SnowparkPandasAggregation(NamedTuple):
265265
# sum would be True.
266266
preserves_snowpark_pandas_types: bool
267267

268+
# Whether Snowflake PIVOT supports this aggregation on axis 0. It seems
269+
# that Snowflake PIVOT supports any aggregation expressed as as single
270+
# function call applied to a single column, e.g. MAX(A), BOOLOR_AND(A)
271+
supported_in_pivot: bool
272+
268273
# This callable takes a single Snowpark column as input and aggregates the
269274
# column on axis=0. If None, Snowpark pandas does not support this
270275
# aggregation on axis=0.
@@ -305,6 +310,12 @@ class SnowflakeAggFunc(NamedTuple):
305310
# sum would be True.
306311
preserves_snowpark_pandas_types: bool
307312

313+
# Whether Snowflake PIVOT supports this aggregation on axis 0. It seems
314+
# that Snowflake PIVOT supports any aggregation expressed as as single
315+
# function call applied to a single column, e.g. MAX(A), BOOLOR_AND(A).
316+
# This field only makes sense for axis 0 aggregation.
317+
supported_in_pivot: bool
318+
308319

309320
class AggFuncWithLabel(NamedTuple):
310321
"""
@@ -523,6 +534,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
523534
axis_0_aggregation=count,
524535
axis_1_aggregation_skipna=_columns_count,
525536
preserves_snowpark_pandas_types=False,
537+
supported_in_pivot=True,
526538
),
527539
**_create_pandas_to_snowpark_pandas_aggregation_map(
528540
(len, "size"),
@@ -532,47 +544,53 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
532544
axis_1_aggregation_keepna=_columns_count_keep_nulls,
533545
axis_1_aggregation_skipna=_columns_count_keep_nulls,
534546
preserves_snowpark_pandas_types=False,
547+
supported_in_pivot=False,
535548
),
536549
),
537550
"first": _SnowparkPandasAggregation(
538551
axis_0_aggregation=_column_first_value,
539552
axis_1_aggregation_keepna=lambda *cols: cols[0],
540553
axis_1_aggregation_skipna=lambda *cols: coalesce(*cols),
541554
preserves_snowpark_pandas_types=True,
555+
supported_in_pivot=False,
542556
),
543557
"last": _SnowparkPandasAggregation(
544558
axis_0_aggregation=_column_last_value,
545559
axis_1_aggregation_keepna=lambda *cols: cols[-1],
546560
axis_1_aggregation_skipna=lambda *cols: coalesce(*(cols[::-1])),
547561
preserves_snowpark_pandas_types=True,
562+
supported_in_pivot=False,
548563
),
549564
**_create_pandas_to_snowpark_pandas_aggregation_map(
550565
("mean", np.mean),
551566
_SnowparkPandasAggregation(
552567
axis_0_aggregation=mean,
553568
preserves_snowpark_pandas_types=True,
569+
supported_in_pivot=True,
554570
),
555571
),
556572
**_create_pandas_to_snowpark_pandas_aggregation_map(
557-
("min", np.min),
573+
("min", np.min, min),
558574
_SnowparkPandasAggregation(
559575
axis_0_aggregation=min_,
560576
axis_1_aggregation_keepna=least,
561577
axis_1_aggregation_skipna=_columns_coalescing_min,
562578
preserves_snowpark_pandas_types=True,
579+
supported_in_pivot=True,
563580
),
564581
),
565582
**_create_pandas_to_snowpark_pandas_aggregation_map(
566-
("max", np.max),
583+
("max", np.max, max),
567584
_SnowparkPandasAggregation(
568585
axis_0_aggregation=max_,
569586
axis_1_aggregation_keepna=greatest,
570587
axis_1_aggregation_skipna=_columns_coalescing_max,
571588
preserves_snowpark_pandas_types=True,
589+
supported_in_pivot=True,
572590
),
573591
),
574592
**_create_pandas_to_snowpark_pandas_aggregation_map(
575-
("sum", np.sum),
593+
("sum", np.sum, sum),
576594
_SnowparkPandasAggregation(
577595
axis_0_aggregation=sum_,
578596
# IMPORTANT: count and sum use python builtin sum to invoke
@@ -581,13 +599,15 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
581599
axis_1_aggregation_keepna=lambda *cols: sum(cols),
582600
axis_1_aggregation_skipna=_columns_coalescing_sum,
583601
preserves_snowpark_pandas_types=True,
602+
supported_in_pivot=True,
584603
),
585604
),
586605
**_create_pandas_to_snowpark_pandas_aggregation_map(
587606
("median", np.median),
588607
_SnowparkPandasAggregation(
589608
axis_0_aggregation=median,
590609
preserves_snowpark_pandas_types=True,
610+
supported_in_pivot=True,
591611
),
592612
),
593613
"idxmax": _SnowparkPandasAggregation(
@@ -597,6 +617,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
597617
axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
598618
axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
599619
preserves_snowpark_pandas_types=False,
620+
supported_in_pivot=False,
600621
),
601622
"idxmin": _SnowparkPandasAggregation(
602623
axis_0_aggregation=functools.partial(
@@ -605,30 +626,35 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
605626
axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
606627
axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
607628
preserves_snowpark_pandas_types=False,
629+
supported_in_pivot=False,
608630
),
609631
"skew": _SnowparkPandasAggregation(
610632
axis_0_aggregation=skew,
611633
preserves_snowpark_pandas_types=True,
634+
supported_in_pivot=True,
612635
),
613636
"all": _SnowparkPandasAggregation(
614637
# all() for a column with no non-null values is NULL in Snowflake, but True in pandas.
615638
axis_0_aggregation=lambda c: coalesce(
616639
builtin("booland_agg")(col(c)), pandas_lit(True)
617640
),
618641
preserves_snowpark_pandas_types=False,
642+
supported_in_pivot=False,
619643
),
620644
"any": _SnowparkPandasAggregation(
621645
# any() for a column with no non-null values is NULL in Snowflake, but False in pandas.
622646
axis_0_aggregation=lambda c: coalesce(
623647
builtin("boolor_agg")(col(c)), pandas_lit(False)
624648
),
625649
preserves_snowpark_pandas_types=False,
650+
supported_in_pivot=False,
626651
),
627652
**_create_pandas_to_snowpark_pandas_aggregation_map(
628653
("std", np.std),
629654
_SnowparkPandasAggregation(
630655
axis_0_aggregation=stddev,
631656
preserves_snowpark_pandas_types=True,
657+
supported_in_pivot=True,
632658
),
633659
),
634660
**_create_pandas_to_snowpark_pandas_aggregation_map(
@@ -638,19 +664,23 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
638664
# variance units are the square of the input column units, so
639665
# variance does not preserve types.
640666
preserves_snowpark_pandas_types=False,
667+
supported_in_pivot=True,
641668
),
642669
),
643670
"array_agg": _SnowparkPandasAggregation(
644671
axis_0_aggregation=array_agg,
645672
preserves_snowpark_pandas_types=False,
673+
supported_in_pivot=False,
646674
),
647675
"quantile": _SnowparkPandasAggregation(
648676
axis_0_aggregation=column_quantile,
649677
preserves_snowpark_pandas_types=True,
678+
supported_in_pivot=False,
650679
),
651680
"nunique": _SnowparkPandasAggregation(
652681
axis_0_aggregation=count_distinct,
653682
preserves_snowpark_pandas_types=False,
683+
supported_in_pivot=False,
654684
),
655685
}
656686
)
@@ -762,6 +792,7 @@ def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn:
762792
return SnowflakeAggFunc(
763793
snowpark_aggregation=snowpark_aggregation,
764794
preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types,
795+
supported_in_pivot=snowpark_pandas_aggregation.supported_in_pivot,
765796
)
766797

767798

@@ -800,6 +831,7 @@ def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn:
800831
return SnowflakeAggFunc(
801832
snowpark_aggregation,
802833
preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types,
834+
supported_in_pivot=snowpark_pandas_aggregation.supported_in_pivot,
803835
)
804836

805837

src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
extract_pandas_label_from_snowflake_quoted_identifier,
4646
from_pandas_label,
4747
get_distinct_rows,
48-
is_supported_snowflake_pivot_agg_func,
4948
pandas_lit,
5049
random_name_for_temp_object,
5150
to_pandas_label,
@@ -522,9 +521,7 @@ def single_pivot_helper(
522521
data_column_pandas_labels: new data column pandas labels for this pivot result
523522
"""
524523
snowflake_agg_func = get_snowflake_agg_func(pandas_aggr_func_name, {}, axis=0)
525-
if snowflake_agg_func is None or not is_supported_snowflake_pivot_agg_func(
526-
snowflake_agg_func.snowpark_aggregation
527-
):
524+
if snowflake_agg_func is None or not snowflake_agg_func.supported_in_pivot:
528525
# TODO: (SNOW-853334) Add support for any non-supported snowflake pivot aggregations
529526
raise ErrorMessage.not_implemented(
530527
f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(original_aggfunc, agg_kwargs={})} with the given arguments."

src/snowflake/snowpark/modin/plugin/_internal/utils.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import traceback
99
from collections.abc import Hashable, Iterable, Sequence
1010
from enum import Enum
11-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
11+
from typing import TYPE_CHECKING, Any, Optional, Union
1212

1313
import modin.pandas as pd
1414
import numpy as np
@@ -42,14 +42,9 @@
4242
from snowflake.snowpark.exceptions import SnowparkSQLException
4343
from snowflake.snowpark.functions import (
4444
col,
45-
count,
4645
equal_nan,
4746
floor,
4847
iff,
49-
max as max_,
50-
mean,
51-
min as min_,
52-
sum as sum_,
5348
to_char,
5449
to_timestamp_ntz,
5550
to_timestamp_tz,
@@ -1196,29 +1191,6 @@ def is_snowpark_pandas_dataframe_or_series_type(obj: Any) -> bool:
11961191
return isinstance(obj, (pd.DataFrame, pd.Series))
11971192

11981193

1199-
# TODO: (SNOW-853334) Support other agg functions (any, all, prod, median, skew, kurt, sem, var, std, mad, etc)
1200-
snowflake_pivot_agg_func_supported = [
1201-
count,
1202-
mean,
1203-
min_,
1204-
max_,
1205-
sum_,
1206-
]
1207-
1208-
1209-
def is_supported_snowflake_pivot_agg_func(agg_func: Callable) -> bool:
1210-
"""
1211-
Check if the aggregation function is supported with snowflake pivot. Current supported
1212-
aggregation functions are the functions that can be mapped to snowflake builtin function.
1213-
1214-
Args:
1215-
agg_func: str or Callable. the aggregation function to check
1216-
Returns:
1217-
Whether it is valid to implement with snowflake or not.
1218-
"""
1219-
return agg_func in snowflake_pivot_agg_func_supported
1220-
1221-
12221194
def convert_snowflake_string_constant_to_python_string(identifier: str) -> str:
12231195
"""
12241196
Convert a snowflake string constant to a python constant, this removes surrounding single quotes

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9094,12 +9094,6 @@ def pivot_table(
90949094
if not sort:
90959095
raise NotImplementedError("Not implemented not sorted")
90969096

9097-
# TODO: (SNOW-853334) Support callable agg functions
9098-
if aggfunc and callable(aggfunc):
9099-
raise NotImplementedError(
9100-
f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(aggfunc, agg_kwargs={})} with the given arguments."
9101-
)
9102-
91039097
if columns is not None and isinstance(columns, Hashable):
91049098
columns = [columns]
91059099

0 commit comments

Comments
 (0)