Skip to content

Commit e885373

Browse files
committed
think I fixed it
1 parent 374dec0 commit e885373

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

xbatcher/generators.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@ def _drop_input_dims(ds, input_dims, suffix='_input'):
4141
# remove input_dims coordinates from datasets, rename the dimensions
4242
# then put intput_dims back in as coordinates
4343
out = ds.copy()
44-
out = (out.drop(input_dims)
45-
.rename({dim: dim + suffix for dim in input_dims}))
4644
for dim in input_dims:
47-
out.coords[dim] = dim + suffix, ds[dim].values
45+
newdim = dim + suffix
46+
out = out.rename({dim: newdim})
47+
# extra steps needed if there is a coordinate
48+
if newdim in out:
49+
out = out.drop(newdim)
50+
out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs
4851
return out
4952

5053

xbatcher/tests/test_generators.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_batch_1d_no_coordinate(sample_ds_1d, bsize):
5050

5151
@pytest.mark.parametrize("bsize", [5, 10])
5252
def test_batch_1d_concat_no_coordinate(sample_ds_1d, bsize):
53-
# fix for #3
53+
# test for #3
5454
ds_dropped = sample_ds_1d.drop('x')
5555
bg = BatchGenerator(ds_dropped, input_dims={'x': bsize},
5656
concat_input_dims=True)
@@ -60,7 +60,6 @@ def test_batch_1d_concat_no_coordinate(sample_ds_1d, bsize):
6060
assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x']//bsize
6161
assert 'x' not in ds_batch.coords
6262

63-
6463
@pytest.mark.parametrize("olap", [1, 4])
6564
def test_batch_1d_overlap(sample_ds_1d, olap):
6665
bsize = 10

0 commit comments

Comments
 (0)