Skip to content

Commit 1322eff

Browse files
author
cmdupuis3
committed
Wrap batch dim generation in a flag
1 parent 0679395 commit 1322eff

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

xbatcher/generators.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,13 @@ def _drop_input_dims(ds, input_dims, suffix='_input'):
5353
out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs
5454
return out
5555

56-
def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'):
56+
def _maybe_stack_batch_dims(ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample'):
5757
batch_dims = [d for d in ds.dims if d not in input_dims]
5858
if len(batch_dims) < 2:
59-
return ds.expand_dims(stacked_dim_name, 0)
59+
if(squeeze_batch_dim):
60+
return ds
61+
else:
62+
return ds.expand_dims(stacked_dim_name, 0)
6063
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
6164
# ensure correct order
6265
dim_order = (stacked_dim_name,) + tuple(input_dims)
@@ -89,6 +92,10 @@ class BatchGenerator:
8992
preload_batch : bool, optional
9093
If ``True``, each batch will be loaded into memory before reshaping /
9194
processing, triggering any dask arrays to be computed.
95+
squeeze_batch_dim : bool, optional
96+
If ``False", each batch's dataset will have a "batch" dimension of size 1
97+
prepended to the array. This functionality is useful for interoperability
98+
with Keras / Tensorflow.
9299
93100
Yields
94101
------
@@ -104,6 +111,7 @@ def __init__(
104111
batch_dims={},
105112
concat_input_dims=False,
106113
preload_batch=True,
114+
squeeze_batch_dim=True
107115
):
108116

109117
self.ds = _as_xarray_dataset(ds)
@@ -113,6 +121,7 @@ def __init__(
113121
self.batch_dims = OrderedDict(batch_dims)
114122
self.concat_input_dims = concat_input_dims
115123
self.preload_batch = preload_batch
124+
self.squeeze_batch_dim = squeeze_batch_dim
116125

117126
def __iter__(self):
118127
for ds_batch in self._iterate_batch_dims(self.ds):
@@ -131,11 +140,11 @@ def __iter__(self):
131140
new_input_dims = [
132141
dim + new_dim_suffix for dim in self.input_dims
133142
]
134-
yield _maybe_stack_batch_dims(dsc, new_input_dims)
143+
yield _maybe_stack_batch_dims(dsc, new_input_dims, self.squeeze_batch_dim)
135144
else:
136145
for ds_input in input_generator:
137146
yield _maybe_stack_batch_dims(
138-
ds_input, list(self.input_dims)
147+
ds_input, list(self.input_dims), self.squeeze_batch_dim
139148
)
140149

141150
def _iterate_batch_dims(self, ds):

0 commit comments

Comments
 (0)