@@ -60,11 +60,11 @@ def test_batch_3d_1d_input(sample_ds_3d, bsize):
60
60
assert isinstance (ds_batch , xr .Dataset )
61
61
assert ds_batch .dims ['x' ] == bsize
62
62
# time and y should be collapsed into batch dimension
63
- assert ds_batch .dims ['batch ' ] == sample_ds_3d .dims ['y' ] * sample_ds_3d .dims ['time' ]
63
+ assert ds_batch .dims ['sample ' ] == sample_ds_3d .dims ['y' ] * sample_ds_3d .dims ['time' ]
64
64
expected_slice = slice (bsize * n , bsize * (n + 1 ))
65
65
ds_batch_expected = (sample_ds_3d .isel (x = expected_slice )
66
- .stack (batch = ['y ' , 'time ' ])
67
- .transpose ('batch ' , 'x' ))
66
+ .stack (sample = ['time ' , 'y ' ])
67
+ .transpose ('sample ' , 'x' ))
68
68
print (ds_batch )
69
69
print (ds_batch_expected )
70
70
assert ds_batch .equals (ds_batch_expected )
@@ -93,6 +93,6 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
93
93
assert isinstance (ds_batch , xr .Dataset )
94
94
assert ds_batch .dims ['x_input' ] == xbsize
95
95
assert ds_batch .dims ['y_input' ] == bsize
96
- assert ds_batch .dims ['batch ' ] == ((sample_ds_3d .dims ['x' ]// xbsize ) *
96
+ assert ds_batch .dims ['sample ' ] == ((sample_ds_3d .dims ['x' ]// xbsize ) *
97
97
(sample_ds_3d .dims ['y' ]// bsize ) *
98
98
sample_ds_3d .dims ['time' ])
0 commit comments