Skip to content

Commit 6104bf3

Browse files
author
Joseph Hamman
committed
add torch accessor
1 parent 04480ba commit 6104bf3

File tree

4 files changed

+44
-4
lines changed

4 files changed

+44
-4
lines changed

xbatcher/accessors.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,21 @@ def generator(self, *args, **kwargs):
2424
Keyword arguments to pass to the `BatchGenerator` constructor.
2525
'''
2626
return BatchGenerator(self._obj, *args, **kwargs)
27+
28+
29+
@xr.register_dataarray_accessor('torch')
30+
class TorchAccessor:
31+
def __init__(self, xarray_obj):
32+
self._obj = xarray_obj
33+
34+
def to_tensor(self):
35+
"""Convert this DataArray to a torch.Tensor"""
36+
import torch
37+
38+
return torch.tensor(self._obj.data)
39+
40+
def to_named_tensor(self):
41+
"""Convert this DataArray to a torch.Tensor with named dimensions"""
42+
import torch
43+
44+
return torch.tensor(self._obj.data, names=self._obj.dims)

xbatcher/loaders/torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def __getitem__(self, idx) -> Tuple[Any, Any]:
4949

5050
# TODO: figure out the dataset -> array workflow
5151
# currently hardcoding a variable name
52-
X_batch = self.X_generator[idx]['x'].data
53-
y_batch = self.y_generator[idx]['y'].data
52+
X_batch = self.X_generator[idx]['x'].torch.to_tensor()
53+
y_batch = self.y_generator[idx]['y'].torch.to_tensor()
5454

5555
if self.transform:
5656
X_batch = self.transform(X_batch)
@@ -80,4 +80,4 @@ def __init__(
8080

8181
def __iter__(self):
8282
for xb, yb in zip(self.X_generator, self.y_generator):
83-
yield (xb['x'].data, yb['y'].data)
83+
yield (xb['x'].torch.to_tensor(), yb['y'].torch.to_tensor())

xbatcher/tests/test_accessors.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,25 @@ def test_batch_accessor_da(sample_ds_3d):
3838
assert isinstance(bg_acc, BatchGenerator)
3939
for batch_class, batch_acc in zip(bg_class, bg_acc):
4040
assert batch_class.equals(batch_acc)
41+
42+
43+
def test_torch_to_tensor(sample_ds_3d):
44+
torch = pytest.importorskip('torch')
45+
46+
da = sample_ds_3d['foo']
47+
t = da.torch.to_tensor()
48+
assert isinstance(t, torch.Tensor)
49+
assert t.names == (None, None, None)
50+
assert t.shape == da.shape
51+
np.testing.assert_array_equal(t, da.values)
52+
53+
54+
def test_torch_to_named_tensor(sample_ds_3d):
55+
torch = pytest.importorskip('torch')
56+
57+
da = sample_ds_3d['foo']
58+
t = da.torch.to_named_tensor()
59+
assert isinstance(t, torch.Tensor)
60+
assert t.names == da.dims
61+
assert t.shape == da.shape
62+
np.testing.assert_array_equal(t, da.values)

xbatcher/tests/test_torch_loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_map_dataset(ds_xy):
3737
# test __getitem__
3838
x_batch, y_batch = dataset[0]
3939
assert len(x_batch) == len(y_batch)
40-
assert isinstance(x_batch, np.ndarray)
40+
assert isinstance(x_batch, torch.Tensor)
4141

4242
# test __len__
4343
assert len(dataset) == len(x_gen)

0 commit comments

Comments
 (0)