Skip to content

Commit e728bd2

Browse files
authored
Merge pull request #1 from rabernat/rename-batch-dim
rename batch dim
2 parents 3e41939 + 4bc5368 commit e728bd2

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

xbatcher/generators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ def _drop_input_dims(ds, input_dims, suffix='_input'):
4848
return out
4949

5050

51-
def _maybe_stack_batch_dims(ds, input_dims):
51+
def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'):
5252
batch_dims = list(set(ds.dims) - set(input_dims))
5353
if len(batch_dims) < 2:
5454
return ds
55-
ds_stack = ds.stack(batch=batch_dims)
55+
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
5656
# ensure correct order
57-
dim_order = ('batch',) + tuple(input_dims)
57+
dim_order = (stacked_dim_name,) + tuple(input_dims)
5858
return ds_stack.transpose(*dim_order)
5959

6060

xbatcher/tests/test_generators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ def test_batch_3d_1d_input(sample_ds_3d, bsize):
6060
assert isinstance(ds_batch, xr.Dataset)
6161
assert ds_batch.dims['x'] == bsize
6262
# time and y should be collapsed into batch dimension
63-
assert ds_batch.dims['batch'] == sample_ds_3d.dims['y'] * sample_ds_3d.dims['time']
63+
assert ds_batch.dims['sample'] == sample_ds_3d.dims['y'] * sample_ds_3d.dims['time']
6464
expected_slice = slice(bsize*n, bsize*(n+1))
6565
ds_batch_expected = (sample_ds_3d.isel(x=expected_slice)
66-
.stack(batch=['y', 'time'])
67-
.transpose('batch', 'x'))
66+
.stack(sample=['time', 'y'])
67+
.transpose('sample', 'x'))
6868
print(ds_batch)
6969
print(ds_batch_expected)
7070
assert ds_batch.equals(ds_batch_expected)
@@ -93,6 +93,6 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
9393
assert isinstance(ds_batch, xr.Dataset)
9494
assert ds_batch.dims['x_input'] == xbsize
9595
assert ds_batch.dims['y_input'] == bsize
96-
assert ds_batch.dims['batch'] == ((sample_ds_3d.dims['x']//xbsize) *
96+
assert ds_batch.dims['sample'] == ((sample_ds_3d.dims['x']//xbsize) *
9797
(sample_ds_3d.dims['y']//bsize) *
9898
sample_ds_3d.dims['time'])

0 commit comments

Comments
 (0)