Skip to content

Commit 7b0bd95

Browse files
author
cmdupuis3
committed
Prepend batch dimension if it's not there (partial fix?)
1 parent d98ad21 commit 7b0bd95

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

xbatcher/generators.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _iterate_through_dataset(ds, dims, overlap={}):
3434
size = dims[dim]
3535
olap = overlap.get(dim, 0)
3636
dim_slices.append(_slices(dimsize, size, olap))
37-
37+
3838
for slices in itertools.product(*dim_slices):
3939
selector = {key: slice for key, slice in zip(dims, slices)}
4040
yield ds.isel(**selector)
@@ -53,11 +53,10 @@ def _drop_input_dims(ds, input_dims, suffix='_input'):
5353
out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs
5454
return out
5555

56-
5756
def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'):
5857
batch_dims = [d for d in ds.dims if d not in input_dims]
5958
if len(batch_dims) < 2:
60-
return ds
59+
return ds.expand_dims(stacked_dim_name, 0)
6160
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
6261
# ensure correct order
6362
dim_order = (stacked_dim_name,) + tuple(input_dims)

0 commit comments

Comments
 (0)