1+ import functools
12from typing import Any , Callable , Hashable , Iterable , Optional , Tuple , Union
23
3- import dask .array as da
44import numpy as np
55from xarray import Dataset
66
7+ import sgkit .distarray as da
78from sgkit import variables
89from sgkit .model import get_contigs , num_contigs
910from sgkit .utils import conditional_merge_datasets , create_dataset
@@ -510,8 +511,15 @@ def window_statistic(
510511 and len (window_stops ) == 1
511512 and window_stops == np .array ([values .shape [0 ]])
512513 ):
514+ out = da .map_blocks (
515+ functools .partial (statistic , ** kwargs ),
516+ values ,
517+ dtype = dtype ,
518+ chunks = values .chunks [1 :],
519+ drop_axis = 0 ,
520+ )
513521 # call expand_dims to add back window dimension (size 1)
514- return da .expand_dims (statistic ( values , ** kwargs ) , axis = 0 )
522+ return da .expand_dims (out , axis = 0 )
515523
516524 window_lengths = window_stops - window_starts
517525 depth = np .max (window_lengths ) # type: ignore[no-untyped-call]
@@ -536,10 +544,10 @@ def window_statistic(
536544
537545 chunk_offsets = _sizes_to_start_offsets (windows_per_chunk )
538546
539- def blockwise_moving_stat (x : ArrayLike , block_info : Any = None ) -> ArrayLike :
540- if block_info is None or len ( block_info ) == 0 :
547+ def blockwise_moving_stat (x : ArrayLike , block_id : Any = None ) -> ArrayLike :
548+ if block_id is None :
541549 return np .array ([])
542- chunk_number = block_info [ 0 ][ "chunk-location" ] [0 ]
550+ chunk_number = block_id [0 ]
543551 chunk_offset_start = chunk_offsets [chunk_number ]
544552 chunk_offset_stop = chunk_offsets [chunk_number + 1 ]
545553 chunk_window_starts = rel_window_starts [chunk_offset_start :chunk_offset_stop ]
@@ -559,8 +567,9 @@ def blockwise_moving_stat(x: ArrayLike, block_info: Any = None) -> ArrayLike:
559567 depth = {0 : depth }
560568 # new chunks are same except in first axis
561569 new_chunks = tuple ([tuple (windows_per_chunk )] + list (desired_chunks [1 :])) # type: ignore
562- return values .map_overlap (
570+ return da .map_overlap (
563571 blockwise_moving_stat ,
572+ values ,
564573 dtype = dtype ,
565574 chunks = new_chunks ,
566575 depth = depth ,
0 commit comments