48
48
)
49
49
from .cache import memoize
50
50
from .lib import ArrayLayer , dask_array_type , sparse_array_type
51
+ from .options import OPTIONS
51
52
from .xrutils import (
52
53
_contains_cftime_datetimes ,
53
54
_to_pytimedelta ,
111
112
# _simple_combine.
112
113
DUMMY_AXIS = - 2
113
114
115
+
114
116
logger = logging .getLogger ("flox" )
115
117
116
118
@@ -215,8 +217,11 @@ def identity(x: T) -> T:
215
217
return x
216
218
217
219
218
- def _issorted (arr : np .ndarray ) -> bool :
219
- return bool ((arr [:- 1 ] <= arr [1 :]).all ())
220
+ def _issorted (arr : np .ndarray , ascending = True ) -> bool :
221
+ if ascending :
222
+ return bool ((arr [:- 1 ] <= arr [1 :]).all ())
223
+ else :
224
+ return bool ((arr [:- 1 ] >= arr [1 :]).all ())
220
225
221
226
222
227
def _is_arg_reduction (func : T_Agg ) -> bool :
@@ -299,7 +304,7 @@ def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray:
299
304
def _get_optimal_chunks_for_groups (chunks , labels ):
300
305
chunkidx = np .cumsum (chunks ) - 1
301
306
# what are the groups at chunk boundaries
302
- labels_at_chunk_bounds = _unique (labels [chunkidx ])
307
+ labels_at_chunk_bounds = pd . unique (labels [chunkidx ])
303
308
# what's the last index of all groups
304
309
last_indexes = npg .aggregate_numpy .aggregate (labels , np .arange (len (labels )), func = "last" )
305
310
# what's the last index of groups at the chunk boundaries.
@@ -317,6 +322,8 @@ def _get_optimal_chunks_for_groups(chunks, labels):
317
322
Δl = abs (c - l )
318
323
if c == 0 or newchunkidx [- 1 ] > l :
319
324
continue
325
+ f = f .item () # noqa
326
+ l = l .item () # noqa
320
327
if Δf < Δl and f > newchunkidx [- 1 ]:
321
328
newchunkidx .append (f )
322
329
else :
@@ -708,7 +715,9 @@ def rechunk_for_cohorts(
708
715
return array .rechunk ({axis : newchunks })
709
716
710
717
711
- def rechunk_for_blockwise (array : DaskArray , axis : T_Axis , labels : np .ndarray ) -> DaskArray :
718
+ def rechunk_for_blockwise (
719
+ array : DaskArray , axis : T_Axis , labels : np .ndarray , * , force : bool = True
720
+ ) -> tuple [T_MethodOpt , DaskArray ]:
712
721
"""
713
722
Rechunks array so that group boundaries line up with chunk boundaries, allowing
714
723
embarrassingly parallel group reductions.
@@ -731,14 +740,47 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
731
740
DaskArray
732
741
Rechunked array
733
742
"""
734
- # TODO: this should be unnecessary?
735
- labels = factorize_ ((labels ,), axes = ())[0 ]
743
+
736
744
chunks = array .chunks [axis ]
737
- newchunks = _get_optimal_chunks_for_groups (chunks , labels )
745
+ if len (chunks ) == 1 :
746
+ return "blockwise" , array
747
+
748
+ # import dask
749
+ # from dask.utils import parse_bytes
750
+ # factor = parse_bytes(dask.config.get("array.chunk-size")) / (
751
+ # math.prod(array.chunksize) * array.dtype.itemsize
752
+ # )
753
+ # if factor > BLOCKWISE_DEFAULT_ARRAY_CHUNK_SIZE_FACTOR:
754
+ # new_constant_chunks = math.ceil(factor) * max(chunks)
755
+ # q, r = divmod(array.shape[axis], new_constant_chunks)
756
+ # new_input_chunks = (new_constant_chunks,) * q + (r,)
757
+ # else:
758
+ new_input_chunks = chunks
759
+
760
+ # FIXME: this should be unnecessary?
761
+ labels = factorize_ ((labels ,), axes = ())[0 ]
762
+ newchunks = _get_optimal_chunks_for_groups (new_input_chunks , labels )
738
763
if newchunks == chunks :
739
- return array
764
+ return "blockwise" , array
765
+
766
+ Δn = abs (len (newchunks ) - len (new_input_chunks ))
767
+ if pass_num_chunks_threshold := (
768
+ Δn / len (new_input_chunks ) < OPTIONS ["rechunk_blockwise_num_chunks_threshold" ]
769
+ ):
770
+ logger .debug ("blockwise rechunk passes num chunks threshold" )
771
+ if pass_chunk_size_threshold := (
772
+ # we just pick the max because number of chunks may have changed.
773
+ (abs (max (newchunks ) - max (new_input_chunks )) / max (new_input_chunks ))
774
+ < OPTIONS ["rechunk_blockwise_chunk_size_threshold" ]
775
+ ):
776
+ logger .debug ("blockwise rechunk passes chunk size change threshold" )
777
+
778
+ if force or (pass_num_chunks_threshold and pass_chunk_size_threshold ):
779
+ logger .debug ("Rechunking to enable blockwise." )
780
+ return "blockwise" , array .rechunk ({axis : newchunks })
740
781
else :
741
- return array .rechunk ({axis : newchunks })
782
+ logger .debug ("Didn't meet thresholds to do automatic rechunking for blockwise reductions." )
783
+ return None , array
742
784
743
785
744
786
def reindex_numpy (array , from_ : pd .Index , to : pd .Index , fill_value , dtype , axis : int ):
@@ -2704,6 +2746,11 @@ def groupby_reduce(
2704
2746
has_dask = is_duck_dask_array (array ) or is_duck_dask_array (by_ )
2705
2747
has_cubed = is_duck_cubed_array (array ) or is_duck_cubed_array (by_ )
2706
2748
2749
+ if method is None and is_duck_dask_array (array ) and not any_by_dask and by_ .ndim == 1 and _issorted (by_ ):
2750
+ # Let's try rechunking for sorted 1D by.
2751
+ (single_axis ,) = axis_
2752
+ method , array = rechunk_for_blockwise (array , single_axis , by_ , force = False )
2753
+
2707
2754
is_first_last = _is_first_last_reduction (func )
2708
2755
if is_first_last :
2709
2756
if has_dask and nax != 1 :
@@ -2891,7 +2938,7 @@ def groupby_reduce(
2891
2938
2892
2939
# if preferred method is already blockwise, no need to rechunk
2893
2940
if preferred_method != "blockwise" and method == "blockwise" and by_ .ndim == 1 :
2894
- array = rechunk_for_blockwise (array , axis = - 1 , labels = by_ )
2941
+ _ , array = rechunk_for_blockwise (array , axis = - 1 , labels = by_ )
2895
2942
2896
2943
result , groups = partial_agg (
2897
2944
array = array ,
0 commit comments