Skip to content

Commit 2a5afac

Browse files
Merge branch 'refs/heads/main' into feature/aherrera/SNOW-2432059-StringAndBinary-part1
2 parents 81652f7 + 5ce80df commit 2a5afac

File tree

13 files changed

+38
-20
lines changed

13 files changed

+38
-20
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787

8888
#### New Features
8989
- Added support for the `dtypes` parameter of `pd.get_dummies`
90+
- Added support for `nunique` in `df.pivot_table`, `df.agg` and other places where aggregate functions can be used.
9091

9192
#### Improvements
9293

docs/source/modin/supported/agg_supp.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ methods ``pd.pivot_table``, ``DataFrame.pivot_table``, and ``pd.crosstab``.
3838
| ``median`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | ``Y`` |
3939
| | ``N`` for ``axis=1``. | | | | |
4040
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+
41+
| ``nunique`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | ``Y`` |
42+
| | ``N`` for ``axis=1``. | | | | |
43+
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+
4144
| ``size`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | ``N`` |
4245
| | ``N`` for ``axis=1``. | | | | |
4346
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+

src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,11 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
578578
preserves_snowpark_pandas_types=False,
579579
supported_in_pivot=True,
580580
),
581+
"nunique": _SnowparkPandasAggregation(
582+
axis_0_aggregation=count_distinct,
583+
preserves_snowpark_pandas_types=False,
584+
supported_in_pivot=True,
585+
),
581586
**_create_pandas_to_snowpark_pandas_aggregation_map(
582587
(len, "size"),
583588
_SnowparkPandasAggregation(
@@ -719,11 +724,6 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
719724
preserves_snowpark_pandas_types=True,
720725
supported_in_pivot=False,
721726
),
722-
"nunique": _SnowparkPandasAggregation(
723-
axis_0_aggregation=count_distinct,
724-
preserves_snowpark_pandas_types=False,
725-
supported_in_pivot=False,
726-
),
727727
}
728728
)
729729

src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
min as min_,
2323
object_construct,
2424
sum as sum_,
25+
count_distinct,
2526
)
2627
from snowflake.snowpark.modin.plugin._internal.aggregation_utils import (
2728
get_pandas_aggr_func_name,
@@ -768,7 +769,7 @@ def prepare_pivot_aggregation_for_handling_missing_and_null_values(
768769
bar | 0.0 | Nan | 0.0 | Nan
769770
foo | 1.0 | 1.0 | 0.0 | 1.0
770771
771-
To match pandas behavior, we do an upfront group-by aggregation for count and sum to get the correct
772+
To match pandas behavior, we do an upfront group-by aggregation for count, nunique and sum to get the correct
772773
values for all null values via snowflake query:
773774
774775
select a, b, coalesce(sum(C), 0) as sum_c, count(C) as cnt_c from df_small_data group by a, b;
@@ -792,16 +793,21 @@ def prepare_pivot_aggregation_for_handling_missing_and_null_values(
792793
Snowpark dataframe that has done an pre-pivot aggregation needed for matching pandas pivot behavior as
793794
described earlier.
794795
"""
795-
if snowpark_aggr_func in [sum_, count]:
796-
agg_expr = (
797-
coalesce(sum_(aggr_snowflake_quoted_identifier), pandas_lit(0)).as_(
796+
if snowpark_aggr_func in [sum_, count, count_distinct]:
797+
if snowpark_aggr_func == sum_:
798+
agg_expr = coalesce(
799+
sum_(aggr_snowflake_quoted_identifier), pandas_lit(0)
800+
).as_(aggr_snowflake_quoted_identifier)
801+
elif snowpark_aggr_func == count:
802+
agg_expr = count(aggr_snowflake_quoted_identifier).as_(
798803
aggr_snowflake_quoted_identifier
799804
)
800-
if snowpark_aggr_func == sum_
801-
else count(aggr_snowflake_quoted_identifier).as_(
805+
elif snowpark_aggr_func == count_distinct:
806+
agg_expr = count_distinct(aggr_snowflake_quoted_identifier).as_(
802807
aggr_snowflake_quoted_identifier
803808
)
804-
)
809+
else:
810+
raise NotImplementedError("Aggregate function not supported for pivot")
805811
pre_pivot_ordered_dataframe = pivot_ordered_dataframe.group_by(
806812
grouping_snowflake_quoted_identifiers, agg_expr
807813
)

tests/integ/modin/frame/test_aggregate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def native_df_multiindex() -> native_pd.DataFrame:
6565
3,
6666
),
6767
(lambda df: df.aggregate({"A": ["count", "max"], "B": [max, "min"]}), 2),
68+
(lambda df: df.aggregate("nunique"), 0),
6869
(
6970
lambda df: df.aggregate(
7071
x=pd.NamedAgg("A", "max"), y=("B", "min"), c=("A", "count")

tests/integ/modin/groupby/test_groupby_basic_agg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ def test_groupby_agg_with_int_dtypes(int_to_decimal_float_agg_method) -> None:
413413
np.min,
414414
min,
415415
sum,
416+
"nunique",
416417
np.std,
417418
"var",
418419
{"col2": "sum"},

tests/integ/modin/groupby/test_groupby_series.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def test_groupby_series_count_with_nan():
4949
np.median,
5050
np.std,
5151
"var",
52+
"nunique",
5253
[np.var],
5354
["sum", np.std],
5455
["sum", np.median, sum],

tests/integ/modin/pivot/test_pivot_table_dropna.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_pivot_table_single_all_aggfuncs_dropna_and_null_data(
107107
df_data_with_nulls_2,
108108
values,
109109
):
110-
expected_join_count = 10 if len(values) > 1 else 5
110+
expected_join_count = 12 if len(values) > 1 else 6
111111
with SqlCounter(query_count=1, join_count=expected_join_count):
112112
pivot_table_test_helper(
113113
df_data_with_nulls_2,
@@ -116,7 +116,7 @@ def test_pivot_table_single_all_aggfuncs_dropna_and_null_data(
116116
"columns": ["C"],
117117
"values": values,
118118
"dropna": False,
119-
"aggfunc": ["count", "sum", "min", "max", "mean"],
119+
"aggfunc": ["count", "sum", "min", "max", "mean", "nunique"],
120120
},
121121
)
122122

tests/integ/modin/pivot/test_pivot_table_margins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_pivot_table_multiple_columns_values_with_margins(
134134
),
135135
],
136136
)
137-
@sql_count_checker(query_count=1, join_count=5, union_count=1)
137+
@sql_count_checker(query_count=1, join_count=6, union_count=1)
138138
def test_pivot_table_multiple_pivot_values_null_data_with_margins(
139139
df_data_with_nulls, index, fill_value
140140
):
@@ -144,7 +144,7 @@ def test_pivot_table_multiple_pivot_values_null_data_with_margins(
144144
"index": index,
145145
"columns": "C",
146146
"values": "F",
147-
"aggfunc": ["count", "sum", "mean"],
147+
"aggfunc": ["count", "sum", "mean", "nunique"],
148148
"dropna": False,
149149
"fill_value": fill_value,
150150
"margins": True,

tests/integ/modin/pivot/test_pivot_table_multiple.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ def test_pivot_table_no_index_single_column_multiple_values(df_data):
3838
)
3939

4040

41-
@sql_count_checker(query_count=1, union_count=1, join_count=2)
41+
@sql_count_checker(query_count=1, union_count=1, join_count=4)
4242
def test_pivot_table_no_index_single_column_multiple_values_multiple_aggr_func(df_data):
4343
pivot_table_test_helper(
4444
df_data,
4545
{
4646
"columns": "B",
4747
"values": ["D", "E"],
48-
"aggfunc": ["mean", "max"],
48+
"aggfunc": ["mean", "max", "nunique"],
4949
},
5050
)
5151

@@ -119,7 +119,7 @@ def test_pivot_table_single_index_multiple_column_single_value(
119119
)
120120

121121

122-
@pytest.mark.parametrize("aggfunc", ["count", "sum", "min", "max", "mean"])
122+
@pytest.mark.parametrize("aggfunc", ["count", "sum", "min", "max", "mean", "nunique"])
123123
@pytest.mark.parametrize("values", ["D", ["D"]])
124124
@sql_count_checker(query_count=1)
125125
def test_pivot_table_no_index_multiple_column_single_value(df_data, aggfunc, values):

0 commit comments

Comments
 (0)