Skip to content

Commit cb2cd85

Browse files
TFAccessor and _as_xarray_dataarray update (#107)
1 parent ce6fbfd commit cb2cd85

File tree

2 files changed

+73
-16
lines changed

2 files changed

+73
-16
lines changed

xbatcher/accessors.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
1+
from typing import Union
2+
13
import xarray as xr
24

35
from .generators import BatchGenerator
46

57

8+
def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray:
9+
"""
10+
Convert xarray.Dataset to xarray.DataArray if needed, so that it can
11+
be converted into a Tensor object.
12+
"""
13+
if isinstance(xr_obj, xr.Dataset):
14+
xr_obj = xr_obj.to_array().squeeze(dim="variable")
15+
16+
return xr_obj
17+
18+
619
@xr.register_dataarray_accessor("batch")
720
@xr.register_dataset_accessor("batch")
821
class BatchAccessor:
@@ -26,31 +39,32 @@ def generator(self, *args, **kwargs):
2639
return BatchGenerator(self._obj, *args, **kwargs)
2740

2841

42+
@xr.register_dataarray_accessor("tf")
43+
@xr.register_dataset_accessor("tf")
44+
class TFAccessor:
45+
def __init__(self, xarray_obj):
46+
self._obj = xarray_obj
47+
48+
def to_tensor(self):
49+
"""Convert this DataArray to a tensorflow.Tensor"""
50+
import tensorflow as tf
51+
52+
dataarray = _as_xarray_dataarray(xr_obj=self._obj)
53+
54+
return tf.convert_to_tensor(dataarray.data)
55+
56+
2957
@xr.register_dataarray_accessor("torch")
3058
@xr.register_dataset_accessor("torch")
3159
class TorchAccessor:
3260
def __init__(self, xarray_obj):
3361
self._obj = xarray_obj
3462

35-
def _as_xarray_dataarray(self, xr_obj):
36-
"""
37-
Convert xarray.Dataset to xarray.DataArray if needed, so that it can
38-
be converted into a torch.Tensor object.
39-
"""
40-
try:
41-
# Convert xr.Dataset to xr.DataArray
42-
dataarray = xr_obj.to_array().squeeze(dim="variable")
43-
except AttributeError: # 'DataArray' object has no attribute 'to_array'
44-
# If object is already an xr.DataArray
45-
dataarray = xr_obj
46-
47-
return dataarray
48-
4963
def to_tensor(self):
5064
"""Convert this DataArray to a torch.Tensor"""
5165
import torch
5266

53-
dataarray = self._as_xarray_dataarray(xr_obj=self._obj)
67+
dataarray = _as_xarray_dataarray(xr_obj=self._obj)
5468

5569
return torch.tensor(data=dataarray.data)
5670

@@ -62,6 +76,6 @@ def to_named_tensor(self):
6276
"""
6377
import torch
6478

65-
dataarray = self._as_xarray_dataarray(xr_obj=self._obj)
79+
dataarray = _as_xarray_dataarray(xr_obj=self._obj)
6680

6781
return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes))

xbatcher/tests/test_accessors.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,30 @@ def sample_ds_3d():
2222
return ds
2323

2424

25+
@pytest.fixture(scope="module")
26+
def sample_dataArray():
27+
return xr.DataArray(np.zeros((2, 4), dtype="i4"), dims=("x", "y"), name="foo")
28+
29+
30+
@pytest.fixture(scope="module")
31+
def sample_Dataset():
32+
return xr.Dataset(
33+
{
34+
"x": xr.DataArray(np.arange(10), dims="x"),
35+
"foo": xr.DataArray(np.ones(10, dtype="float"), dims="x"),
36+
}
37+
)
38+
39+
40+
def test_as_xarray_dataarray(sample_dataArray, sample_Dataset):
41+
assert isinstance(
42+
xbatcher.accessors._as_xarray_dataarray(sample_dataArray), xr.DataArray
43+
)
44+
assert isinstance(
45+
xbatcher.accessors._as_xarray_dataarray(sample_Dataset), xr.DataArray
46+
)
47+
48+
2549
def test_batch_accessor_ds(sample_ds_3d):
2650
bg_class = BatchGenerator(sample_ds_3d, input_dims={"x": 5})
2751
bg_acc = sample_ds_3d.batch.generator(input_dims={"x": 5})
@@ -40,6 +64,25 @@ def test_batch_accessor_da(sample_ds_3d):
4064
assert batch_class.equals(batch_acc)
4165

4266

67+
@pytest.mark.parametrize(
68+
"foo_var",
69+
[
70+
"foo", # xr.DataArray
71+
["foo"], # xr.Dataset
72+
],
73+
)
74+
def test_tf_to_tensor(sample_ds_3d, foo_var):
75+
tf = pytest.importorskip("tensorflow")
76+
77+
foo = sample_ds_3d[foo_var]
78+
t = foo.tf.to_tensor()
79+
assert isinstance(t, tf.Tensor)
80+
assert t.shape == tuple(foo.sizes.values())
81+
82+
foo_array = foo.to_array().squeeze() if hasattr(foo, "to_array") else foo
83+
np.testing.assert_array_equal(t, foo_array.values)
84+
85+
4386
@pytest.mark.parametrize(
4487
"foo_var",
4588
[

0 commit comments

Comments
 (0)