-
Notifications
You must be signed in to change notification settings - Fork 146
Expand file tree
/
Copy pathaggregation_utils.py
More file actions
1715 lines (1523 loc) · 68.4 KB
/
aggregation_utils.py
File metadata and controls
1715 lines (1523 loc) · 68.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
#
# This file contains utils functions used by aggregation functions.
#
import functools
from collections import defaultdict
from collections.abc import Hashable, Iterable
from functools import partial
from inspect import getmembers
from types import BuiltinFunctionType, MappingProxyType
from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Union
import numpy as np
from pandas._typing import AggFuncType, AggFuncTypeBase
from pandas.core.dtypes.common import (
is_dict_like,
is_list_like,
is_named_tuple,
is_numeric_dtype,
is_scalar,
)
from snowflake.snowpark._internal.type_utils import ColumnOrName
from snowflake.snowpark.column import CaseExpr, Column as SnowparkColumn
from snowflake.snowpark.functions import (
Column,
array_agg,
array_construct,
array_construct_compact,
array_contains,
array_flatten,
array_max,
array_min,
array_position,
builtin,
cast,
coalesce,
col,
count,
count_distinct,
get,
greatest,
iff,
is_null,
least,
listagg,
lit,
max as max_,
mean,
median,
min as min_,
parse_json,
skew,
stddev,
stddev_pop,
sum as sum_,
trunc,
var_pop,
variance,
when,
)
from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame
from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import (
OrderedDataFrame,
OrderingColumn,
)
from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
TimedeltaType,
)
from snowflake.snowpark.modin.plugin._internal.utils import (
from_pandas_label,
pandas_lit,
to_pandas_label,
)
from snowflake.snowpark.modin.plugin._typing import PandasLabelToSnowflakeIdentifierPair
from snowflake.snowpark.types import (
BooleanType,
DataType,
DoubleType,
IntegerType,
StringType,
)
AGG_NAME_COL_LABEL = "AGG_FUNC_NAME"
_NUMPY_FUNCTION_TO_NAME = {
function: name for name, function in getmembers(np) if callable(function)
}
def _array_agg_keepna(
column_to_aggregate: ColumnOrName, ordering_columns: Iterable[OrderingColumn]
) -> Column:
"""
Aggregate a column, including nulls, into an array by the given ordering columns.
"""
# array_agg drops nulls, but we can use the solution [1] to work around
# that by turning each element `v` into the array `[v]`...
# except that we can't use array_construct(NULL) and instead have to use
# parse_json(lit("null")) per [2].
# [1] https://stackoverflow.com/a/77422662
# [2] https://github.com/snowflakedb/snowflake-connector-python/issues/1388#issuecomment-1371091831
# HOWEVER it appears that this workaround only works for integer values.
# See details in SNOW-1859090.
return array_flatten(
array_agg(
array_construct(
iff(
is_null(column_to_aggregate),
parse_json(lit("null")),
Column(column_to_aggregate),
)
)
).within_group(
[ordering_column.snowpark_column for ordering_column in ordering_columns]
)
)
def column_quantile(
column: SnowparkColumn,
interpolation: Literal["linear", "lower", "higher", "midpoint", "nearest"],
q: float,
) -> SnowparkColumn:
assert interpolation in (
"linear",
"nearest",
), f"unsupported interpolation method '{interpolation}'"
# PERCENTILE_CONT interpolates between the nearest values if needed, while
# PERCENTILE_DISC finds the nearest value
agg_method = "percentile_cont" if interpolation == "linear" else "percentile_disc"
# PERCENTILE_* returns DECIMAL; we cast to DOUBLE
# example sql: SELECT CAST(PERCENTILE_COUNT(0.25) WITHIN GROUP(ORDER BY a) AS DOUBLE) AS a FROM table
return builtin(agg_method)(pandas_lit(q)).within_group(column).cast(DoubleType())
def _columns_coalescing_idxmax_idxmin_helper(
*cols: SnowparkColumn,
axis: Literal[0, 1],
func: Literal["idxmax", "idxmin"],
keepna: bool,
pandas_column_labels: list,
is_groupby: bool = False,
) -> SnowparkColumn:
"""
Computes the index corresponding to the func for each row if axis=1 or column if axis=0.
If all values in a row/column are NaN, then the result will be NaN.
Parameters
----------
*cols: SnowparkColumn
A tuple of Snowpark Columns.
axis: {0, 1}
The axis to apply the func on.
func: {"idxmax", "idxmin"}
The function to apply.
keepna: bool
Whether to skip NaN Values.
pandas_column_labels: list
pandas index/column names.
Returns
-------
Callable
"""
if axis == 0:
extremum = max_(*cols) if func == "idxmax" else min_(*cols)
# TODO SNOW-1316602: Support MultiIndex for DataFrame, Series, and DataFrameGroupBy cases.
if len(pandas_column_labels) > 1:
# The index is a MultiIndex, current logic does not support this.
raise NotImplementedError(
f"{func} is not yet supported when the index is a MultiIndex."
)
# TODO SNOW-1270521: max_by and min_by are not guaranteed to break tiebreaks deterministically
extremum_position = (
get(
builtin("max_by")(
Column(pandas_column_labels[0]),
Column(*cols),
1,
),
0,
)
if func == "idxmax"
else get(
builtin("min_by")(
Column(pandas_column_labels[0]),
Column(*cols),
1,
),
0,
)
)
if is_groupby and keepna:
# When performing groupby, if a group has any NaN values in its column, the idxmax/idxmin of that column
# will always be NaN. Therefore, we need to check whether there are any NaN values in each group.
return iff(
builtin("count_if")(Column(*cols).is_null()) > 0,
pandas_lit(None),
extremum_position,
)
else:
# if extremum is null, i.e. there are no columns or all columns are
# null, mark extremum_position as null, because our final expression has
# to evaluate to null.
return builtin("nvl2")(extremum, extremum_position, lit(None))
else:
column_array = array_construct(*cols)
# extremum is null if there are no columns or all columns are null.
# otherwise, extremum contains the extremal column, i.e. the max column for
# idxmax and the min column for idxmin.
extremum = (array_max if func == "idxmax" else array_min)(column_array)
# extremum_position is the position of the first column with a value equal
# to extremum.
extremum_position = array_position(extremum, column_array)
if keepna:
# if any of the columns is null, mark extremum_position as null,
# because our final expression has to evaluate to null. That's how we
# "keep NA."
extremum_position = iff(
array_contains(lit(None), column_array), lit(None), extremum_position
)
else:
# if extremum is null, i.e. there are no columns or all columns are
# null, mark extremum_position as null, because our final expression has
# to evalute to null.
extremum_position = builtin("nvl2")(extremum, extremum_position, lit(None))
# If extremum_position is null, return null.
return builtin("nvl2")(
extremum_position,
# otherwise, we create an array of all the column names using pandas_column_labels
# and get the element of that array that is at extremum_position.
get(
array_construct(*(lit(c) for c in pandas_column_labels)),
cast(extremum_position, "int"),
),
lit(None),
)
class _SnowparkPandasAggregation(NamedTuple):
"""
A representation of a Snowpark pandas aggregation.
This structure gives us a common representation for an aggregation that may
have multiple aliases, like "sum" and np.sum.
"""
# This field tells whether if types of all the inputs of the function are
# the same instance of SnowparkPandasType, the type of the result is the
# same instance of SnowparkPandasType. Note that this definition applies
# whether the aggregation is on axis=0 or axis=1. For example, the sum of
# a single timedelta column on axis 0 is another timedelta column.
# Equivalently, the sum of two timedelta columns along axis 1 is also
# another timedelta column. Therefore, preserves_snowpark_pandas_types for
# sum would be True.
preserves_snowpark_pandas_types: bool
# Whether Snowflake PIVOT supports this aggregation on axis 0. It seems
# that Snowflake PIVOT supports any aggregation expressed as as single
# function call applied to a single column, e.g. MAX(A), BOOLOR_AND(A)
supported_in_pivot: bool
# This callable takes a single Snowpark column as input and aggregates the
# column on axis=0. If None, Snowpark pandas does not support this
# aggregation on axis=0.
axis_0_aggregation: Optional[Callable] = None
# This callable takes one or more Snowpark columns as input and
# the columns on axis=1 with skipna=True, i.e. not including nulls in the
# aggregation. If None, Snowpark pandas does not support this aggregation
# on axis=1 with skipna=True.
axis_1_aggregation_skipna: Optional[Callable] = None
# This callable takes one or more Snowpark columns as input and
# the columns on axis=1 with skipna=False, i.e. including nulls in the
# aggregation. If None, Snowpark pandas does not support this aggregation
# on axis=1 with skipna=False.
axis_1_aggregation_keepna: Optional[Callable] = None
class SnowflakeAggFunc(NamedTuple):
"""
A Snowflake aggregation, including information about how the aggregation acts on SnowparkPandasType.
"""
# The aggregation function in Snowpark.
# For aggregation on axis=0, this field should take a single Snowpark
# column and return the aggregated column.
# For aggregation on axis=1, this field should take an arbitrary number
# of Snowpark columns and return the aggregated column.
snowpark_aggregation: Callable
# This field tells whether if types of all the inputs of the function are
# the same instance of SnowparkPandasType, the type of the result is the
# same instance of SnowparkPandasType. Note that this definition applies
# whether the aggregation is on axis=0 or axis=1. For example, the sum of
# a single timedelta column on axis 0 is another timedelta column.
# Equivalently, the sum of two timedelta columns along axis 1 is also
# another timedelta column. Therefore, preserves_snowpark_pandas_types for
# sum would be True.
preserves_snowpark_pandas_types: bool
# Whether Snowflake PIVOT supports this aggregation on axis 0. It seems
# that Snowflake PIVOT supports any aggregation expressed as as single
# function call applied to a single column, e.g. MAX(A), BOOLOR_AND(A).
# This field only makes sense for axis 0 aggregation.
supported_in_pivot: bool
class AggFuncWithLabel(
NamedTuple(
"AggFuncWithLabel", [("func", AggFuncTypeBase), ("pandas_label", Hashable)]
)
):
"""
This class is used to process NamedAgg's internally, and represents an AggFunc that
also includes a label to be used on the column that it generates.
"""
# Temporary workaround for a modin bug:
# https://github.com/modin-project/modin/issues/7594
# The query compiler caster may call this constructor with a single list as argument,
# trying to match the behavior of vanilla `tuple`.
# Go back to directly using a NamedTuple after this is fixed.
# # The aggregate function
# func: AggFuncTypeBase
# # The label to provide the new column produced by `func`.
# pandas_label: Hashable
def __new__(
cls,
func: Union[AggFuncTypeBase, list[Any]],
pandas_label: Optional[Hashable] = None,
) -> "AggFuncWithLabel":
if isinstance(func, list) and pandas_label is None:
assert (
len(func) == 2
), "AggFuncWithLabel was constructed with too many arguments in list"
return super().__new__(cls, *func)
else:
if pandas_label is None:
raise TypeError(
"AggFuncWithLabel.__new__() missing 1 required positional argument: 'pandas_label'"
)
return super().__new__(cls, func, pandas_label)
class AggFuncInfo(NamedTuple):
"""
Information needed to distinguish between dummy and normal aggregate functions.
"""
# The aggregate function
func: AggFuncTypeBase
# If true, the aggregate function is applied to "NULL" rather than a column
is_dummy_agg: bool
# If specified, the pandas label to provide the new column generated by this aggregate
# function. Used in conjunction with pd.NamedAgg.
post_agg_pandas_label: Optional[Hashable] = None
class AggregationSupportResult(NamedTuple):
"""
Information needed to return the first unsupported aggregate function if any.
"""
# Whether the function is supported for aggregation in snowflake.
is_valid: bool
# The unsupported function used for aggregation.
unsupported_function: str
# The kwargs for the unsupported function.
unsupported_kwargs: dict[str, Any]
def _columns_coalescing_min(*cols: SnowparkColumn) -> Callable:
"""
Computes the minimum value in each row, skipping NaN values. If all values in a row are NaN,
then the result will be NaN.
Example SQL:
SELECT ARRAY_MIN(ARRAY_CONSTRUCT_COMPACT(a, b, c)) AS min
FROM VALUES (10, 1, NULL), (NULL, NULL, NULL) AS t (a, b, c);
Result:
--------
| min |
--------
| 1 |
--------
| NULL |
--------
"""
return array_min(array_construct_compact(*cols))
def _columns_coalescing_max(*cols: SnowparkColumn) -> Callable:
"""
Computes the maximum value in each row, skipping NaN values. If all values in a row are NaN,
then the result will be NaN.
Example SQL:
SELECT ARRAY_MAX(ARRAY_CONSTRUCT_COMPACT(a, b, c)) AS max
FROM VALUES (10, 1, NULL), (NULL, NULL, NULL) AS t (a, b, c);
Result:
--------
| max |
--------
| 10 |
--------
| NULL |
--------
"""
return array_max(array_construct_compact(*cols))
def _columns_count(*cols: SnowparkColumn) -> Callable:
"""
Counts the number of non-NULL values in each row.
Example SQL:
SELECT NVL2(a, 1, 0) + NVL2(b, 1, 0) + NVL2(c, 1, 0) AS count
FROM VALUES (10, 1, NULL), (NULL, NULL, NULL) AS t (a, b, c);
Result:
---------
| count |
---------
| 2 |
---------
| 0 |
---------
"""
# IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark
# sum_, since Snowpark sum_ gets the sum of all rows within a single column.
# NVL2(col, x, y) returns x if col is NULL, and y otherwise.
return sum(builtin("nvl2")(col, pandas_lit(1), pandas_lit(0)) for col in cols)
def _columns_count_keep_nulls(*cols: SnowparkColumn) -> Callable:
"""
Counts the number of values (including NULL) in each row.
"""
# IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark
# sum_, since Snowpark sum_ gets the sum of all rows within a single column.
return sum(pandas_lit(1) for _ in cols)
def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable:
"""
Sums all non-NaN elements in each row. If all elements are NaN, returns 0.
Example SQL:
SELECT ZEROIFNULL(a) + ZEROIFNULL(b) + ZEROIFNULL(c) AS sum
FROM VALUES (10, 1, NULL), (NULL, NULL, NULL) AS t (a, b, c);
Result:
-------
| sum |
-------
| 11 |
-------
| 0 |
-------
"""
# IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark
# sum_, since Snowpark sum_ gets the sum of all rows within a single column.
return sum(builtin("zeroifnull")(col) for col in cols)
def _column_first_value(
column: SnowparkColumn,
row_position_snowflake_quoted_identifier: str,
ignore_nulls: bool,
) -> SnowparkColumn:
"""
Returns the first value (ordered by `row_position_snowflake_identifier`) over the specified group.
Parameters
----------
col: Snowpark Column
The Snowpark column to aggregate.
row_position_snowflake_quoted_identifier: str
The Snowflake quoted identifier of the column to order by.
ignore_nulls: bool
Whether or not to ignore nulls.
Returns
-------
The aggregated Snowpark Column.
"""
if ignore_nulls:
col_to_min_by = iff(
col(column).is_null(),
pandas_lit(None),
col(row_position_snowflake_quoted_identifier),
)
else:
col_to_min_by = col(row_position_snowflake_quoted_identifier)
return builtin("min_by")(col(column), col_to_min_by)
def _column_last_value(
column: SnowparkColumn,
row_position_snowflake_quoted_identifier: str,
ignore_nulls: bool,
) -> SnowparkColumn:
"""
Returns the last value (ordered by `row_position_snowflake_identifier`) over the specified group.
Parameters
----------
col: Snowpark Column
The Snowpark column to aggregate.
row_position_snowflake_quoted_identifier: str
The Snowflake quoted identifier of the column to order by.
ignore_nulls: bool
Whether or not to ignore nulls.
Returns
-------
The aggregated Snowpark Column.
"""
if ignore_nulls:
col_to_max_by = iff(
col(column).is_null(),
pandas_lit(None),
col(row_position_snowflake_quoted_identifier),
)
else:
col_to_max_by = col(row_position_snowflake_quoted_identifier)
return builtin("max_by")(col(column), col_to_max_by)
def _create_pandas_to_snowpark_pandas_aggregation_map(
pandas_functions: Iterable[AggFuncTypeBase],
snowpark_pandas_aggregation: _SnowparkPandasAggregation,
) -> MappingProxyType[AggFuncTypeBase, _SnowparkPandasAggregation]:
"""
Create a map from the given pandas functions to the given _SnowparkPandasAggregation.
Args;
pandas_functions: The pandas functions that map to the given aggregation.
snowpark_pandas_aggregation: The aggregation to map to
Returns:
The map.
"""
return MappingProxyType({k: snowpark_pandas_aggregation for k in pandas_functions})
# Map between the pandas input aggregation function (str or numpy function) and
# _SnowparkPandasAggregation representing information about applying the
# aggregation in Snowpark pandas.
_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: MappingProxyType[
AggFuncTypeBase, _SnowparkPandasAggregation
] = MappingProxyType(
{
"count": _SnowparkPandasAggregation(
axis_0_aggregation=count,
axis_1_aggregation_skipna=_columns_count,
preserves_snowpark_pandas_types=False,
supported_in_pivot=True,
),
"nunique": _SnowparkPandasAggregation(
axis_0_aggregation=count_distinct,
preserves_snowpark_pandas_types=False,
supported_in_pivot=True,
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
(len, "size"),
_SnowparkPandasAggregation(
# We must count the total number of rows regardless of if they're null.
axis_0_aggregation=lambda _: builtin("count_if")(pandas_lit(True)),
axis_1_aggregation_keepna=_columns_count_keep_nulls,
axis_1_aggregation_skipna=_columns_count_keep_nulls,
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
),
"first": _SnowparkPandasAggregation(
axis_0_aggregation=_column_first_value,
axis_1_aggregation_keepna=lambda *cols: cols[0],
axis_1_aggregation_skipna=lambda *cols: coalesce(*cols),
preserves_snowpark_pandas_types=True,
supported_in_pivot=False,
),
"last": _SnowparkPandasAggregation(
axis_0_aggregation=_column_last_value,
axis_1_aggregation_keepna=lambda *cols: cols[-1],
axis_1_aggregation_skipna=lambda *cols: coalesce(*(cols[::-1])),
preserves_snowpark_pandas_types=True,
supported_in_pivot=False,
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("mean", np.mean),
_SnowparkPandasAggregation(
axis_0_aggregation=mean,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("min", np.min, min),
_SnowparkPandasAggregation(
axis_0_aggregation=min_,
axis_1_aggregation_keepna=least,
axis_1_aggregation_skipna=_columns_coalescing_min,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("max", np.max, max),
_SnowparkPandasAggregation(
axis_0_aggregation=max_,
axis_1_aggregation_keepna=greatest,
axis_1_aggregation_skipna=_columns_coalescing_max,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("sum", np.sum, sum),
_SnowparkPandasAggregation(
axis_0_aggregation=sum_,
# IMPORTANT: count and sum use python builtin sum to invoke
# __add__ on each column rather than Snowpark sum_, since
# Snowpark sum_ gets the sum of all rows within a single column.
axis_1_aggregation_keepna=lambda *cols: sum(cols),
axis_1_aggregation_skipna=_columns_coalescing_sum,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("median", np.median),
_SnowparkPandasAggregation(
axis_0_aggregation=median,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
"idxmax": _SnowparkPandasAggregation(
axis_0_aggregation=functools.partial(
_columns_coalescing_idxmax_idxmin_helper, func="idxmax"
),
axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
"idxmin": _SnowparkPandasAggregation(
axis_0_aggregation=functools.partial(
_columns_coalescing_idxmax_idxmin_helper, func="idxmin"
),
axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
"skew": _SnowparkPandasAggregation(
axis_0_aggregation=skew,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
"all": _SnowparkPandasAggregation(
# all() for a column with no non-null values is NULL in Snowflake, but True in pandas.
axis_0_aggregation=lambda c: coalesce(
builtin("booland_agg")(col(c)), pandas_lit(True)
),
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
"any": _SnowparkPandasAggregation(
# any() for a column with no non-null values is NULL in Snowflake, but False in pandas.
axis_0_aggregation=lambda c: coalesce(
builtin("boolor_agg")(col(c)), pandas_lit(False)
),
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("std", np.std),
_SnowparkPandasAggregation(
axis_0_aggregation=stddev,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("var", np.var),
_SnowparkPandasAggregation(
axis_0_aggregation=variance,
# variance units are the square of the input column units, so
# variance does not preserve types.
preserves_snowpark_pandas_types=False,
supported_in_pivot=True,
),
),
"array_agg": _SnowparkPandasAggregation(
axis_0_aggregation=array_agg,
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
"quantile": _SnowparkPandasAggregation(
axis_0_aggregation=column_quantile,
preserves_snowpark_pandas_types=True,
supported_in_pivot=False,
),
}
)
class AggregateColumnOpParameters(NamedTuple):
"""
Parameters/Information needed to apply aggregation on a Snowpark column correctly.
"""
# Snowflake quoted identifier for the column to apply aggregation on
snowflake_quoted_identifier: ColumnOrName
# The Snowpark data type for the column to apply aggregation on
data_type: DataType
# pandas label for the new column produced after aggregation
agg_pandas_label: Optional[Hashable]
# Snowflake quoted identifier for the new Snowpark column produced after aggregation
agg_snowflake_quoted_identifier: str
# the snowflake aggregation function to apply on the column
snowflake_agg_func: SnowflakeAggFunc
# the columns specifying the order of rows in the column. This is only
# relevant for aggregations that depend on row order, e.g. summing a string
# column.
ordering_columns: Iterable[OrderingColumn]
def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool:
return agg_func in _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION
def get_snowflake_agg_func(
agg_func: AggFuncTypeBase,
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> Optional[SnowflakeAggFunc]:
"""
Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function.
If no corresponding snowflake aggregation function can be found, return None.
"""
if axis == 1:
return _generate_rowwise_aggregation_function(agg_func, agg_kwargs)
snowpark_pandas_aggregation = (
_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func)
)
if snowpark_pandas_aggregation is None:
# We don't have any implementation at all for this aggregation.
return None
snowpark_aggregation = snowpark_pandas_aggregation.axis_0_aggregation
if snowpark_aggregation is None:
# We don't have an implementation on axis=0 for this aggregation.
return None
# Rewrite some aggregations according to `agg_kwargs.`
if snowpark_aggregation == stddev or snowpark_aggregation == variance:
# for aggregation function std and var, we only support ddof = 0 or ddof = 1.
# when ddof is 1, std is mapped to stddev, var is mapped to variance
# when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop
# TODO (SNOW-892532): support std/var for ddof that is not 0 or 1
ddof = agg_kwargs.get("ddof", 1)
if ddof != 1 and ddof != 0:
return None
if ddof == 0:
snowpark_aggregation = (
stddev_pop if snowpark_aggregation == stddev else var_pop
)
elif snowpark_aggregation == column_quantile:
interpolation = agg_kwargs.get("interpolation", "linear")
q = agg_kwargs.get("q", 0.5)
if interpolation not in ("linear", "nearest"):
return None
if not is_scalar(q):
# SNOW-1062878 Because list-like q would return multiple rows, calling quantile
# through the aggregate frontend in this manner is unsupported.
return None
def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn:
return column_quantile(col, interpolation, q)
elif (
snowpark_aggregation == _column_first_value
or snowpark_aggregation == _column_last_value
):
if _is_df_agg:
# First and last are not supported for df.agg.
return None
ignore_nulls = agg_kwargs.get("skipna", True)
row_position_snowflake_quoted_identifier = agg_kwargs.get(
"_first_last_row_pos_col", None
)
snowpark_aggregation = functools.partial(
snowpark_aggregation,
ignore_nulls=ignore_nulls,
row_position_snowflake_quoted_identifier=row_position_snowflake_quoted_identifier,
)
assert (
snowpark_aggregation is not None
), "Internal error: Snowpark pandas should have identified a Snowpark aggregation."
return SnowflakeAggFunc(
snowpark_aggregation=snowpark_aggregation,
preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types,
supported_in_pivot=snowpark_pandas_aggregation.supported_in_pivot,
)
def _generate_rowwise_aggregation_function(
agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any]
) -> Optional[SnowflakeAggFunc]:
"""
Get a callable taking *arg columns to apply for an aggregation.
Unlike get_snowflake_agg_func, this function may return a wrapped composition of
Snowflake builtin functions depending on the values of the specified kwargs.
"""
snowpark_pandas_aggregation = (
_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func)
)
if snowpark_pandas_aggregation is None:
return None
snowpark_aggregation = (
snowpark_pandas_aggregation.axis_1_aggregation_skipna
if agg_kwargs.get("skipna", True)
else snowpark_pandas_aggregation.axis_1_aggregation_keepna
)
if snowpark_aggregation is None:
return None
min_count = agg_kwargs.get("min_count", 0)
if min_count > 0:
original_aggregation = snowpark_aggregation
# Create a case statement to check if the number of non-null values exceeds min_count
# when min_count > 0, if the number of not NULL values is < min_count, return NULL.
def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn:
return when(_columns_count(*cols) < min_count, pandas_lit(None)).otherwise(
original_aggregation(*cols)
)
return SnowflakeAggFunc(
snowpark_aggregation,
preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types,
supported_in_pivot=snowpark_pandas_aggregation.supported_in_pivot,
)
def _is_supported_snowflake_agg_func(
agg_func: AggFuncTypeBase,
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> AggregationSupportResult:
"""
check if the aggregation function is supported with snowflake. Current supported
aggregation functions are the functions that can be mapped to snowflake builtin function.
Args:
agg_func: str or Callable. the aggregation function to check
agg_kwargs: keyword argument passed for the aggregation function, such as ddof, min_count etc.
The value can be different for different aggregation functions.
Returns:
is_valid: bool. Whether the function is supported for aggregation in snowflake.
unsupported_function: str. The unsupported function used for aggregation.
unsupported_kwargs: dict. The kwargs for the unsupported function
"""
if isinstance(agg_func, tuple) and len(agg_func) == 2:
# For named aggregations, like `df.agg(new_col=("old_col", "sum"))`,
# take the aggregation part of the named aggregation.
agg_func = (
agg_func.func if isinstance(agg_func, AggFuncWithLabel) else agg_func[1]
)
if get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is None:
return AggregationSupportResult(
is_valid=False, unsupported_function=agg_func, unsupported_kwargs=agg_kwargs
)
return AggregationSupportResult(
is_valid=True, unsupported_function="", unsupported_kwargs={}
)
def _are_all_agg_funcs_supported_by_snowflake(
agg_funcs: list[AggFuncTypeBase],
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> AggregationSupportResult:
"""
Check if all aggregation functions in the given list are snowflake supported
aggregation functions.
Returns:
is_valid: bool. Whether it is valid to implement with snowflake or not.
unsupported_function: str. The unsupported function used for aggregation.
unsupported_kwargs: dict. The kwargs for the unsupported function.
"""
is_supported_bools: list[bool] = []
unsupported_list: list[str] = []
unsupported_kwargs_list: list[dict[str, Any]] = []
for func in agg_funcs:
(
is_supported,
unsupported_func,
unsupported_kwargs,
) = _is_supported_snowflake_agg_func(func, agg_kwargs, axis, _is_df_agg)
is_supported_bools.append(is_supported)
if not is_supported:
unsupported_list.append(unsupported_func)
unsupported_kwargs_list.append(unsupported_kwargs)
unsupported_func = unsupported_list[0] if len(unsupported_list) > 0 else ""
unsupported_kwargs = (
unsupported_kwargs_list[0] if len(unsupported_kwargs_list) > 0 else {}
)
is_valid = all(is_supported_bools)
return AggregationSupportResult(is_valid, unsupported_func, unsupported_kwargs)
def check_is_aggregation_supported_in_snowflake(
agg_func: AggFuncType,
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> AggregationSupportResult:
"""
check if distributed implementation with snowflake is available for the aggregation
based on the input arguments.
Args:
agg_func: the aggregation function to apply
agg_kwargs: keyword argument passed for the aggregation function, such as ddof, min_count etc.
The value can be different for different aggregation function.
_is_df_agg: whether or not this is being called by df.agg, since some functions are only supported
for groupby_agg.
Returns:
is_supported_func: bool. Whether it is valid to implement with snowflake or not.
unsupported_func: str. The unsupported function used for aggregation.
unsupported_kwargs: dict. The kwargs for the unsupported function.
"""
# validate agg_func, only snowflake builtin agg function or dict of snowflake builtin agg
# function can be implemented in distributed way.
# If there are multiple unsupported functions, the first unsupported function will be returned.
unsupported_func = ""
unsupported_kwargs: dict[str, Any] = {}
is_supported_func = True
if is_dict_like(agg_func):
for value in agg_func.values():
if is_list_like(value) and not is_named_tuple(value):
(
is_supported_func,
unsupported_func,
unsupported_kwargs,
) = _are_all_agg_funcs_supported_by_snowflake(
value, agg_kwargs, axis, _is_df_agg
)
else:
(
is_supported_func,
unsupported_func,
unsupported_kwargs,
) = _is_supported_snowflake_agg_func(
value, agg_kwargs, axis, _is_df_agg
)
if not is_supported_func:
return AggregationSupportResult(
is_supported_func, unsupported_func, unsupported_kwargs