From 42de55a440ace26bb155ea96808d887c5214bcfd Mon Sep 17 00:00:00 2001 From: arbennett Date: Mon, 13 Mar 2023 19:43:50 -0400 Subject: [PATCH 1/3] Allow partial batches --- xbatcher/generators.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 58799ba..f43d214 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -59,6 +59,7 @@ def __init__( batch_dims: Optional[Dict[Hashable, int]] = None, concat_input_bins: bool = True, preload_batch: bool = True, + return_partial: bool = False, ): if input_overlap is None: input_overlap = {} @@ -69,6 +70,7 @@ def __init__( self.batch_dims = dict(batch_dims) self.concat_input_dims = concat_input_bins self.preload_batch = preload_batch + self.return_partial = return_partial # Store helpful information based on arguments self._duplicate_batch_dims: Dict[Hashable, int] = { dim: length @@ -121,6 +123,7 @@ def _gen_patch_selectors( ds, dims=self._all_sliced_dims, overlap=self.input_overlap, + return_partial=self.return_partial ) return all_slices @@ -262,7 +265,7 @@ def _get_batch_in_range_per_batch(self, batch_multi_index): return batch_in_range_per_patch -def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[slice]: +def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0, return_partial: bool = False) -> List[slice]: # return a list of slices to chop up a single dimension if overlap >= slice_size: raise ValueError( @@ -275,6 +278,8 @@ def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[sli end = start + slice_size if end <= dim_size: slices.append(slice(start, end)) + elif return_partial: + slices.append(slice(start, dim_size)) return slices @@ -283,6 +288,7 @@ def _iterate_through_dimensions( *, dims: Dict[Hashable, int], overlap: Dict[Hashable, int] = {}, + return_partial: bool = False, ) -> Iterator[Dict[Hashable, slice]]: dim_slices = [] for dim in dims: @@ -297,7 +303,7 @@ def _iterate_through_dimensions( f"for {dim}" ) dim_slices.append( - _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=slice_overlap) + _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=slice_overlap, return_partial=return_partial) ) for slices in itertools.product(*dim_slices): selector = dict(zip(dims, slices)) @@ -364,6 +370,9 @@ class BatchGenerator: preload_batch : bool, optional If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. + return_partial: bool, optional + If ``True``, produce batches from edges when dims are not evenly divisible + by the input dim shapes Yields ------ @@ -379,6 +388,7 @@ def __init__( batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, + return_partial: bool = False, ): self.ds = ds self._batch_selectors: BatchSchema = BatchSchema( @@ -388,6 +398,7 @@ def __init__( batch_dims=batch_dims, concat_input_bins=concat_input_dims, preload_batch=preload_batch, + return_partial=return_partial, ) @property From 89f4927fa0f350d6dd5ac644e566040285e6000c Mon Sep 17 00:00:00 2001 From: arbennett Date: Tue, 9 Jul 2024 15:01:56 -0700 Subject: [PATCH 2/3] Add test for --- xbatcher/testing.py | 16 +++++++++++----- xbatcher/tests/test_generators.py | 9 +++++++-- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/xbatcher/testing.py b/xbatcher/testing.py index 72953fd..668b292 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -211,9 +211,10 @@ def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int: s : int Number of batches expected given ``input_dims`` and ``input_overlap``. """ + final_batch_counts = 0.5 if generator._batch_selectors.return_partial else 0 nbatches_from_input_dims = np.prod( [ - generator.ds.sizes[dim] // length + int(generator.ds.sizes[dim] / length + final_batch_counts) for dim, length in generator.input_dims.items() if generator.input_overlap.get(dim) is None and generator.batch_dims.get(dim) is None @@ -222,8 +223,8 @@ def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int: if generator.input_overlap: nbatches_from_input_overlap = np.prod( [ - (generator.ds.sizes[dim] - overlap) - // (generator.input_dims[dim] - overlap) + int((generator.ds.sizes[dim] - overlap) + / (generator.input_dims[dim] - overlap) + final_batch_counts) for dim, overlap in generator.input_overlap.items() ] ) @@ -242,17 +243,22 @@ def validate_generator_length(generator: BatchGenerator) -> None: generator : xbatcher.BatchGenerator The batch generator object. """ + non_input_batch_dims = _get_non_input_batch_dims(generator) duplicate_batch_dims = _get_duplicate_batch_dims(generator) + + # Add 0.5 if the generator is returning partial batches to account for + # the final batch that will be smaller than the rest. + final_batch_counts = 0.5 if generator._batch_selectors.return_partial else 0 nbatches_from_unique_batch_dims = np.prod( [ - generator.ds.sizes[dim] // length + int(generator.ds.sizes[dim] / length + final_batch_counts) for dim, length in non_input_batch_dims.items() ] ) nbatches_from_duplicate_batch_dims = np.prod( [ - generator.ds.sizes[dim] // length + int(generator.ds.sizes[dim] / length + final_batch_counts) for dim, length in duplicate_batch_dims.items() ] ) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 248dd03..b7cf59a 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -58,11 +58,16 @@ def test_constructor_dataarray(): @pytest.mark.parametrize("input_size", [5, 6]) -def test_generator_length(sample_ds_1d, input_size): +@pytest.mark.parametrize("return_partial", [True, False]) +def test_generator_length(sample_ds_1d, input_size, return_partial): """ " Test the length of the batch generator. """ - bg = BatchGenerator(sample_ds_1d, input_dims={"x": input_size}) + bg = BatchGenerator( + sample_ds_1d, + input_dims={"x": input_size}, + return_partial=return_partial + ) validate_generator_length(bg) From 74b77d24c0614bb654da1119a2c1f555e5342db8 Mon Sep 17 00:00:00 2001 From: arbennett Date: Tue, 9 Jul 2024 15:24:12 -0700 Subject: [PATCH 3/3] Update formatting --- xbatcher/testing.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/xbatcher/testing.py b/xbatcher/testing.py index 668b292..cd56287 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -211,6 +211,8 @@ def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int: s : int Number of batches expected given ``input_dims`` and ``input_overlap``. """ + # Add 0.5 if the generator is returning partial batches to account for + # the final batch that will be smaller than the rest. final_batch_counts = 0.5 if generator._batch_selectors.return_partial else 0 nbatches_from_input_dims = np.prod( [ @@ -223,8 +225,11 @@ def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int: if generator.input_overlap: nbatches_from_input_overlap = np.prod( [ - int((generator.ds.sizes[dim] - overlap) - / (generator.input_dims[dim] - overlap) + final_batch_counts) + int( + (generator.ds.sizes[dim] - overlap) + / (generator.input_dims[dim] - overlap) + + final_batch_counts + ) for dim, overlap in generator.input_overlap.items() ] )