Skip to content

Commit 714b624

Browse files
weiji14Max Jones
andauthored
Use .sizes instead of .dims for xr.Dataset/xr.DataArray compatibility (#71)
* Use .sizes instead of .dims for xr.Dataset/xr.DataArray compatibility Removes need for using `_as_xarray_dataset` so xr.DataArray inputs are preserved as xr.DataArray objects on the returned output. * Remove unused _as_xarray_dataset function * Fix conversion to torch named tensors by changing frozen dict to tuple * Fix KeyError for Pytorch tests * Fix AttributeError on keras.py * Mention that unit tests are only for DataArray and not Dataset * Remove concat_dim from CustomTFDataset to support xr.DataArray only * Test torch.Tensor shapes to be more precise than just then batch size Co-authored-by: Max Jones <[email protected]>
1 parent 0ded974 commit 714b624

File tree

7 files changed

+45
-64
lines changed

7 files changed

+45
-64
lines changed

xbatcher/accessors.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
@xr.register_dataset_accessor('batch')
88
class BatchAccessor:
99
def __init__(self, xarray_obj):
10-
'''
10+
"""
1111
Batch accessor returning a BatchGenerator object via the `generator method`
12-
'''
12+
"""
1313
self._obj = xarray_obj
1414

1515
def generator(self, *args, **kwargs):
16-
'''
16+
"""
1717
Return a BatchGenerator via the batch accessor
1818
1919
Parameters
@@ -22,7 +22,7 @@ def generator(self, *args, **kwargs):
2222
Positional arguments to pass to the `BatchGenerator` constructor.
2323
**kwargs : dict
2424
Keyword arguments to pass to the `BatchGenerator` constructor.
25-
'''
25+
"""
2626
return BatchGenerator(self._obj, *args, **kwargs)
2727

2828

@@ -38,7 +38,11 @@ def to_tensor(self):
3838
return torch.tensor(self._obj.data)
3939

4040
def to_named_tensor(self):
41-
"""Convert this DataArray to a torch.Tensor with named dimensions"""
41+
"""
42+
Convert this DataArray to a torch.Tensor with named dimensions.
43+
44+
See https://pytorch.org/docs/stable/named_tensor.html
45+
"""
4246
import torch
4347

44-
return torch.tensor(self._obj.data, names=self._obj.dims)
48+
return torch.tensor(self._obj.data, names=tuple(self._obj.sizes))

xbatcher/generators.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,6 @@
77
import xarray as xr
88

99

10-
def _as_xarray_dataset(ds):
11-
# maybe coerce to xarray dataset
12-
if isinstance(ds, xr.Dataset):
13-
return ds
14-
else:
15-
return ds.to_dataset()
16-
17-
1810
def _slices(dimsize, size, overlap=0):
1911
# return a list of slices to chop up a single dimension
2012
if overlap >= size:
@@ -34,7 +26,7 @@ def _slices(dimsize, size, overlap=0):
3426
def _iterate_through_dataset(ds, dims, overlap={}):
3527
dim_slices = []
3628
for dim in dims:
37-
dimsize = ds.dims[dim]
29+
dimsize = ds.sizes[dim]
3830
size = dims[dim]
3931
olap = overlap.get(dim, 0)
4032
if size > dimsize:
@@ -66,7 +58,7 @@ def _drop_input_dims(ds, input_dims, suffix='_input'):
6658

6759

6860
def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'):
69-
batch_dims = [d for d in ds.dims if d not in input_dims]
61+
batch_dims = [d for d in ds.sizes if d not in input_dims]
7062
if len(batch_dims) < 2:
7163
return ds
7264
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
@@ -121,7 +113,7 @@ def __init__(
121113
preload_batch: bool = True,
122114
):
123115

124-
self.ds = _as_xarray_dataset(ds)
116+
self.ds = ds
125117
# should be a dict
126118
self.input_dims = OrderedDict(input_dims)
127119
self.input_overlap = input_overlap

xbatcher/loaders/keras.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Any, Callable, Optional, Tuple
22

33
import tensorflow as tf
4-
import xarray as xr
54

65
# Notes:
76
# This module includes one Keras dataset, which can be provided to model.fit().
@@ -18,9 +17,8 @@ def __init__(
1817
*,
1918
transform: Optional[Callable] = None,
2019
target_transform: Optional[Callable] = None,
21-
dim: str = 'new_dim',
2220
) -> None:
23-
'''
21+
"""
2422
Keras Dataset adapter for Xbatcher
2523
2624
Parameters
@@ -31,38 +29,18 @@ def __init__(
3129
A function/transform that takes in an array and returns a transformed version.
3230
target_transform : callable, optional
3331
A function/transform that takes in the target and transforms it.
34-
dim : str, 'new_dim'
35-
Name of dim to pass to :func:`xarray.concat` as the dimension
36-
to concatenate all variables along.
37-
'''
32+
"""
3833
self.X_generator = X_generator
3934
self.y_generator = y_generator
4035
self.transform = transform
4136
self.target_transform = target_transform
42-
self.concat_dim = dim
4337

4438
def __len__(self) -> int:
4539
return len(self.X_generator)
4640

4741
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
48-
X_batch = tf.convert_to_tensor(
49-
xr.concat(
50-
(
51-
self.X_generator[idx][key]
52-
for key in list(self.X_generator[idx].keys())
53-
),
54-
self.concat_dim,
55-
).data
56-
)
57-
y_batch = tf.convert_to_tensor(
58-
xr.concat(
59-
(
60-
self.y_generator[idx][key]
61-
for key in list(self.y_generator[idx].keys())
62-
),
63-
self.concat_dim,
64-
).data
65-
)
42+
X_batch = tf.convert_to_tensor(self.X_generator[idx].data)
43+
y_batch = tf.convert_to_tensor(self.y_generator[idx].data)
6644

6745
# TODO: Should the transformations be applied before tensor conversion?
6846
if self.transform:

xbatcher/loaders/torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def __getitem__(self, idx) -> Tuple[Any, Any]:
5454

5555
# TODO: figure out the dataset -> array workflow
5656
# currently hardcoding a variable name
57-
X_batch = self.X_generator[idx]['x'].torch.to_tensor()
58-
y_batch = self.y_generator[idx]['y'].torch.to_tensor()
57+
X_batch = self.X_generator[idx].torch.to_tensor()
58+
y_batch = self.y_generator[idx].torch.to_tensor()
5959

6060
if self.transform:
6161
X_batch = self.transform(X_batch)
@@ -85,4 +85,4 @@ def __init__(
8585

8686
def __iter__(self):
8787
for xb, yb in zip(self.X_generator, self.y_generator):
88-
yield (xb['x'].torch.to_tensor(), yb['y'].torch.to_tensor())
88+
yield (xb.torch.to_tensor(), yb.torch.to_tensor())

xbatcher/tests/test_generators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def sample_ds_3d():
3434
return ds
3535

3636

37-
def test_constructor_coerces_to_dataset():
37+
def test_constructor_dataarray():
3838
da = xr.DataArray(np.random.rand(10), dims='x', name='foo')
3939
bg = BatchGenerator(da, input_dims={'x': 2})
40-
assert isinstance(bg.ds, xr.Dataset)
41-
assert bg.ds.equals(da.to_dataset())
40+
assert isinstance(bg.ds, xr.DataArray)
41+
assert bg.ds.equals(da)
4242

4343

4444
@pytest.mark.parametrize('bsize', [5, 6])

xbatcher/tests/test_keras_loaders.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def ds_xy():
2424
return ds
2525

2626

27-
def test_custom_dataset(ds_xy):
27+
def test_custom_dataarray(ds_xy):
2828

2929
x = ds_xy['x']
3030
y = ds_xy['y']
@@ -36,15 +36,16 @@ def test_custom_dataset(ds_xy):
3636

3737
# test __getitem__
3838
x_batch, y_batch = dataset[0]
39-
assert len(x_batch) == len(y_batch)
39+
assert x_batch.shape == (10, 5)
40+
assert y_batch.shape == (10,)
4041
assert tf.is_tensor(x_batch)
4142
assert tf.is_tensor(y_batch)
4243

4344
# test __len__
4445
assert len(dataset) == len(x_gen)
4546

4647

47-
def test_custom_dataset_with_transform(ds_xy):
48+
def test_custom_dataarray_with_transform(ds_xy):
4849

4950
x = ds_xy['x']
5051
y = ds_xy['y']
@@ -62,7 +63,8 @@ def y_transform(batch):
6263
x_gen, y_gen, transform=x_transform, target_transform=y_transform
6364
)
6465
x_batch, y_batch = dataset[0]
65-
assert len(x_batch) == len(y_batch)
66+
assert x_batch.shape == (10, 5)
67+
assert y_batch.shape == (10,)
6668
assert tf.is_tensor(x_batch)
6769
assert tf.is_tensor(y_batch)
6870
assert tf.experimental.numpy.all(x_batch == 1)

xbatcher/tests/test_torch_loaders.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ def test_map_dataset(ds_xy):
3636

3737
# test __getitem__
3838
x_batch, y_batch = dataset[0]
39-
assert len(x_batch) == len(y_batch)
39+
assert x_batch.shape == (10, 5)
40+
assert y_batch.shape == (10,)
4041
assert isinstance(x_batch, torch.Tensor)
4142

4243
idx = torch.tensor([0])
4344
x_batch, y_batch = dataset[idx]
44-
assert len(x_batch) == len(y_batch)
45+
assert x_batch.shape == (10, 5)
46+
assert y_batch.shape == (10,)
4547
assert isinstance(x_batch, torch.Tensor)
4648

4749
with pytest.raises(NotImplementedError):
@@ -55,13 +57,14 @@ def test_map_dataset(ds_xy):
5557
loader = torch.utils.data.DataLoader(dataset)
5658

5759
for x_batch, y_batch in loader:
58-
assert len(x_batch) == len(y_batch)
60+
assert x_batch.shape == (1, 10, 5)
61+
assert y_batch.shape == (1, 10)
5962
assert isinstance(x_batch, torch.Tensor)
6063

6164
# TODO: why does pytorch add an extra dimension (length 1) to x_batch
62-
assert x_gen[-1]['x'].shape == x_batch.shape[1:]
63-
# TODO: also need to revisit the variable extraction bits here
64-
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])
65+
assert x_gen[-1].shape == x_batch.shape[1:]
66+
# TODO: add test for xarray.Dataset
67+
assert np.array_equal(x_gen[-1], x_batch[0, :, :])
6568

6669

6770
def test_map_dataset_with_transform(ds_xy):
@@ -82,7 +85,8 @@ def y_transform(batch):
8285
x_gen, y_gen, transform=x_transform, target_transform=y_transform
8386
)
8487
x_batch, y_batch = dataset[0]
85-
assert len(x_batch) == len(y_batch)
88+
assert x_batch.shape == (10, 5)
89+
assert y_batch.shape == (10,)
8690
assert isinstance(x_batch, torch.Tensor)
8791
assert (x_batch == 1).all()
8892
assert (y_batch == -1).all()
@@ -102,10 +106,11 @@ def test_iterable_dataset(ds_xy):
102106
loader = torch.utils.data.DataLoader(dataset)
103107

104108
for x_batch, y_batch in loader:
105-
assert len(x_batch) == len(y_batch)
109+
assert x_batch.shape == (1, 10, 5)
110+
assert y_batch.shape == (1, 10)
106111
assert isinstance(x_batch, torch.Tensor)
107112

108113
# TODO: why does pytorch add an extra dimension (length 1) to x_batch
109-
assert x_gen[-1]['x'].shape == x_batch.shape[1:]
110-
# TODO: also need to revisit the variable extraction bits here
111-
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])
114+
assert x_gen[-1].shape == x_batch.shape[1:]
115+
# TODO: add test for xarray.Dataset
116+
assert np.array_equal(x_gen[-1], x_batch[0, :, :])

0 commit comments

Comments
 (0)