@@ -115,60 +115,6 @@ def generic_aggregate(
115
115
return result
116
116
117
117
118
- def _normalize_dtype (dtype : DTypeLike , array_dtype : np .dtype , fill_value = None ) -> np .dtype :
119
- if dtype is None :
120
- dtype = array_dtype
121
- if dtype is np .floating :
122
- # mean, std, var always result in floating
123
- # but we preserve the array's dtype if it is floating
124
- if array_dtype .kind in "fcmM" :
125
- dtype = array_dtype
126
- else :
127
- dtype = np .dtype ("float64" )
128
- elif not isinstance (dtype , np .dtype ):
129
- dtype = np .dtype (dtype )
130
- if fill_value not in [None , dtypes .INF , dtypes .NINF , dtypes .NA ]:
131
- dtype = np .result_type (dtype , fill_value )
132
- return dtype
133
-
134
-
135
- def _maybe_promote_int (dtype ) -> np .dtype :
136
- # https://numpy.org/doc/stable/reference/generated/numpy.prod.html
137
- # The dtype of a is used by default unless a has an integer dtype of less precision
138
- # than the default platform integer.
139
- if not isinstance (dtype , np .dtype ):
140
- dtype = np .dtype (dtype )
141
- if dtype .kind == "i" :
142
- dtype = np .result_type (dtype , np .intp )
143
- elif dtype .kind == "u" :
144
- dtype = np .result_type (dtype , np .uintp )
145
- return dtype
146
-
147
-
148
- def _get_fill_value (dtype , fill_value ):
149
- """Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
150
- if fill_value in [None , dtypes .NA ] and dtype .kind in "US" :
151
- return ""
152
- if fill_value == dtypes .INF or fill_value is None :
153
- return dtypes .get_pos_infinity (dtype , max_for_int = True )
154
- if fill_value == dtypes .NINF :
155
- return dtypes .get_neg_infinity (dtype , min_for_int = True )
156
- if fill_value == dtypes .NA :
157
- if np .issubdtype (dtype , np .floating ) or np .issubdtype (dtype , np .complexfloating ):
158
- return np .nan
159
- # This is madness, but npg checks that fill_value is compatible
160
- # with array dtype even if the fill_value is never used.
161
- elif np .issubdtype (dtype , np .integer ):
162
- return dtypes .get_neg_infinity (dtype , min_for_int = True )
163
- elif np .issubdtype (dtype , np .timedelta64 ):
164
- return np .timedelta64 ("NaT" )
165
- elif np .issubdtype (dtype , np .datetime64 ):
166
- return np .datetime64 ("NaT" )
167
- else :
168
- return None
169
- return fill_value
170
-
171
-
172
118
def _atleast_1d (inp , min_length : int = 1 ):
173
119
if xrutils .is_scalar (inp ):
174
120
inp = (inp ,) * min_length
@@ -646,7 +592,7 @@ def last(self) -> AlignedArrays:
646
592
# TODO: automate?
647
593
engine = "flox" ,
648
594
dtype = self .array .dtype ,
649
- fill_value = _get_fill_value (self .array .dtype , dtypes .NA ),
595
+ fill_value = dtypes . _get_fill_value (self .array .dtype , dtypes .NA ),
650
596
expected_groups = None ,
651
597
)
652
598
return AlignedArrays (array = reduced ["intermediates" ][0 ], group_idx = reduced ["groups" ])
@@ -829,7 +775,9 @@ def _initialize_aggregation(
829
775
np .dtype (dtype ) if dtype is not None and not isinstance (dtype , np .dtype ) else dtype
830
776
)
831
777
832
- final_dtype = _normalize_dtype (dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value )
778
+ final_dtype = dtypes ._normalize_dtype (
779
+ dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value
780
+ )
833
781
if agg .name not in [
834
782
"first" ,
835
783
"last" ,
@@ -841,14 +789,14 @@ def _initialize_aggregation(
841
789
"nanmax" ,
842
790
"topk" ,
843
791
]:
844
- final_dtype = _maybe_promote_int (final_dtype )
792
+ final_dtype = dtypes . _maybe_promote_int (final_dtype )
845
793
agg .dtype = {
846
794
"user" : dtype , # Save to automatically choose an engine
847
795
"final" : final_dtype ,
848
796
"numpy" : (final_dtype ,),
849
797
"intermediate" : tuple (
850
798
(
851
- _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
799
+ dtypes . _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
852
800
if int_dtype is None
853
801
else np .dtype (int_dtype )
854
802
)
@@ -863,10 +811,10 @@ def _initialize_aggregation(
863
811
# Replace sentinel fill values according to dtype
864
812
agg .fill_value ["user" ] = fill_value
865
813
agg .fill_value ["intermediate" ] = tuple (
866
- _get_fill_value (dt , fv )
814
+ dtypes . _get_fill_value (dt , fv )
867
815
for dt , fv in zip (agg .dtype ["intermediate" ], agg .fill_value ["intermediate" ])
868
816
)
869
- agg .fill_value [func ] = _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
817
+ agg .fill_value [func ] = dtypes . _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
870
818
871
819
fv = fill_value if fill_value is not None else agg .fill_value [agg .name ]
872
820
if _is_arg_reduction (agg ):
0 commit comments