3
3
import copy
4
4
import warnings
5
5
from functools import partial
6
+ from typing import TYPE_CHECKING , Any , Callable , TypedDict
6
7
7
8
import numpy as np
8
9
import numpy_groupies as npg
10
+ from numpy .typing import DTypeLike
9
11
10
12
from . import aggregate_flox , aggregate_npg , xrdtypes as dtypes , xrutils
11
13
14
+ if TYPE_CHECKING :
15
+ FuncTuple = tuple [Callable | str , ...]
16
+
12
17
13
18
def _is_arg_reduction (func : str | Aggregation ) -> bool :
14
19
if isinstance (func , str ) and func in ["argmin" , "argmax" , "nanargmax" , "nanargmin" ]:
@@ -18,6 +23,17 @@ def _is_arg_reduction(func: str | Aggregation) -> bool:
18
23
return False
19
24
20
25
26
+ class AggDtypeInit (TypedDict ):
27
+ final : DTypeLike | None
28
+ intermediate : tuple [DTypeLike , ...]
29
+
30
+
31
+ class AggDtype (TypedDict ):
32
+ final : np .dtype
33
+ numpy : tuple [np .dtype | type [np .intp ], ...]
34
+ intermediate : tuple [np .dtype | type [np .intp ], ...]
35
+
36
+
21
37
def generic_aggregate (
22
38
group_idx ,
23
39
array ,
@@ -57,7 +73,7 @@ def generic_aggregate(
57
73
return result
58
74
59
75
60
- def _normalize_dtype (dtype , array_dtype , fill_value = None ):
76
+ def _normalize_dtype (dtype : DTypeLike , array_dtype : np . dtype , fill_value = None ) -> np . dtype :
61
77
if dtype is None :
62
78
dtype = array_dtype
63
79
if dtype is np .floating :
@@ -103,16 +119,16 @@ def __init__(
103
119
self ,
104
120
name ,
105
121
* ,
106
- numpy = None ,
107
- chunk ,
108
- combine ,
109
- preprocess = None ,
110
- aggregate = None ,
111
- finalize = None ,
122
+ numpy : str | FuncTuple | None = None ,
123
+ chunk : str | FuncTuple | None ,
124
+ combine : str | FuncTuple | None ,
125
+ preprocess : Callable | None = None ,
126
+ aggregate : Callable | None = None ,
127
+ finalize : Callable | None = None ,
112
128
fill_value = None ,
113
129
final_fill_value = dtypes .NA ,
114
130
dtypes = None ,
115
- final_dtype = None ,
131
+ final_dtype : DTypeLike | None = None ,
116
132
reduction_type = "reduce" ,
117
133
):
118
134
"""
@@ -162,15 +178,15 @@ def __init__(
162
178
self .preprocess = preprocess
163
179
# Use "chunk_reduce" or "chunk_argreduce"
164
180
self .reduction_type = reduction_type
165
- self .numpy = (numpy ,) if numpy else (self .name ,)
181
+ self .numpy : FuncTuple = (numpy ,) if numpy else (self .name ,)
166
182
# initialize blockwise reduction
167
- self .chunk = _atleast_1d (chunk )
183
+ self .chunk : FuncTuple = _atleast_1d (chunk )
168
184
# how to aggregate results after first round of reduction
169
- self .combine = _atleast_1d (combine )
185
+ self .combine : FuncTuple = _atleast_1d (combine )
170
186
# final aggregation
171
- self .aggregate = aggregate if aggregate else self .combine [0 ]
187
+ self .aggregate : Callable | str = aggregate if aggregate else self .combine [0 ]
172
188
# finalize results (see mean)
173
- self .finalize = finalize if finalize else lambda x : x
189
+ self .finalize : Callable | None = finalize
174
190
175
191
self .fill_value = {}
176
192
# This is used for the final reindexing
@@ -180,13 +196,15 @@ def __init__(
180
196
# They should make sense when aggregated together with results from other blocks
181
197
self .fill_value ["intermediate" ] = self ._normalize_dtype_fill_value (fill_value , "fill_value" )
182
198
183
- self .dtype = {}
184
- self .dtype [name ] = final_dtype
185
- self .dtype ["intermediate" ] = self ._normalize_dtype_fill_value (dtypes , "dtype" )
199
+ self .dtype_init : AggDtypeInit = {
200
+ "final" : final_dtype ,
201
+ "intermediate" : self ._normalize_dtype_fill_value (dtypes , "dtype" ),
202
+ }
203
+ self .dtype : AggDtype = None # type: ignore
186
204
187
205
# The following are set by _initialize_aggregation
188
- self .finalize_kwargs = {}
189
- self .min_count = None
206
+ self .finalize_kwargs : dict [ Any , Any ] = {}
207
+ self .min_count : int | None = None
190
208
191
209
def _normalize_dtype_fill_value (self , value , name ):
192
210
value = _atleast_1d (value )
@@ -211,15 +229,15 @@ def __dask_tokenize__(self):
211
229
self .dtype ,
212
230
)
213
231
214
- def __repr__ (self ):
232
+ def __repr__ (self ) -> str :
215
233
return "\n " .join (
216
234
(
217
- f"{ self .name } , fill: { np . unique ( self .fill_value .values ()) } , dtype: { self .dtype } " ,
218
- f"chunk: { self .chunk } " ,
219
- f"combine: { self .combine } " ,
220
- f"aggregate: { self .aggregate } " ,
221
- f"finalize: { self .finalize } " ,
222
- f"min_count: { self .min_count } " ,
235
+ f"{ self .name !r } , fill: { self .fill_value .values ()!r } , dtype: { self .dtype } " ,
236
+ f"chunk: { self .chunk !r } " ,
237
+ f"combine: { self .combine !r } " ,
238
+ f"aggregate: { self .aggregate !r } " ,
239
+ f"finalize: { self .finalize !r } " ,
240
+ f"min_count: { self .min_count !r } " ,
223
241
)
224
242
)
225
243
@@ -484,7 +502,7 @@ def _initialize_aggregation(
484
502
array_dtype ,
485
503
fill_value ,
486
504
min_count : int | None ,
487
- finalize_kwargs ,
505
+ finalize_kwargs : dict [ Any , Any ] | None ,
488
506
) -> Aggregation :
489
507
if not isinstance (func , Aggregation ):
490
508
try :
@@ -502,24 +520,30 @@ def _initialize_aggregation(
502
520
503
521
# np.dtype(None) == np.dtype("float64")!!!
504
522
# so check for not None
505
- if dtype is not None and not isinstance (dtype , np .dtype ):
506
- dtype = np .dtype (dtype )
523
+ dtype_ : np .dtype | None = (
524
+ np .dtype (dtype ) if dtype is not None and not isinstance (dtype , np .dtype ) else dtype
525
+ )
507
526
508
- agg .dtype [func ] = _normalize_dtype (dtype or agg .dtype [func ], array_dtype , fill_value )
509
- agg .dtype ["numpy" ] = (agg .dtype [func ],)
510
- agg .dtype ["intermediate" ] = [
511
- _normalize_dtype (int_dtype , np .result_type (array_dtype , agg .dtype [func ]), int_fv )
512
- if int_dtype is None
513
- else int_dtype
514
- for int_dtype , int_fv in zip (agg .dtype ["intermediate" ], agg .fill_value ["intermediate" ])
515
- ]
527
+ final_dtype = _normalize_dtype (dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value )
528
+ agg .dtype = {
529
+ "final" : final_dtype ,
530
+ "numpy" : (final_dtype ,),
531
+ "intermediate" : tuple (
532
+ _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
533
+ if int_dtype is None
534
+ else np .dtype (int_dtype )
535
+ for int_dtype , int_fv in zip (
536
+ agg .dtype_init ["intermediate" ], agg .fill_value ["intermediate" ]
537
+ )
538
+ ),
539
+ }
516
540
517
541
# Replace sentinel fill values according to dtype
518
542
agg .fill_value ["intermediate" ] = tuple (
519
543
_get_fill_value (dt , fv )
520
544
for dt , fv in zip (agg .dtype ["intermediate" ], agg .fill_value ["intermediate" ])
521
545
)
522
- agg .fill_value [func ] = _get_fill_value (agg .dtype [func ], agg .fill_value [func ])
546
+ agg .fill_value [func ] = _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
523
547
524
548
fv = fill_value if fill_value is not None else agg .fill_value [agg .name ]
525
549
if _is_arg_reduction (agg ):
0 commit comments