1
+ from typing import Union
2
+
1
3
import xarray as xr
2
4
3
5
from .generators import BatchGenerator
4
6
5
7
8
+ def _as_xarray_dataarray (xr_obj : Union [xr .Dataset , xr .DataArray ]) -> xr .DataArray :
9
+ """
10
+ Convert xarray.Dataset to xarray.DataArray if needed, so that it can
11
+ be converted into a Tensor object.
12
+ """
13
+ if isinstance (xr_obj , xr .Dataset ):
14
+ xr_obj = xr_obj .to_array ().squeeze (dim = "variable" )
15
+
16
+ return xr_obj
17
+
18
+
6
19
@xr .register_dataarray_accessor ("batch" )
7
20
@xr .register_dataset_accessor ("batch" )
8
21
class BatchAccessor :
@@ -26,31 +39,32 @@ def generator(self, *args, **kwargs):
26
39
return BatchGenerator (self ._obj , * args , ** kwargs )
27
40
28
41
42
+ @xr .register_dataarray_accessor ("tf" )
43
+ @xr .register_dataset_accessor ("tf" )
44
+ class TFAccessor :
45
+ def __init__ (self , xarray_obj ):
46
+ self ._obj = xarray_obj
47
+
48
+ def to_tensor (self ):
49
+ """Convert this DataArray to a tensorflow.Tensor"""
50
+ import tensorflow as tf
51
+
52
+ dataarray = _as_xarray_dataarray (xr_obj = self ._obj )
53
+
54
+ return tf .convert_to_tensor (dataarray .data )
55
+
56
+
29
57
@xr .register_dataarray_accessor ("torch" )
30
58
@xr .register_dataset_accessor ("torch" )
31
59
class TorchAccessor :
32
60
def __init__ (self , xarray_obj ):
33
61
self ._obj = xarray_obj
34
62
35
- def _as_xarray_dataarray (self , xr_obj ):
36
- """
37
- Convert xarray.Dataset to xarray.DataArray if needed, so that it can
38
- be converted into a torch.Tensor object.
39
- """
40
- try :
41
- # Convert xr.Dataset to xr.DataArray
42
- dataarray = xr_obj .to_array ().squeeze (dim = "variable" )
43
- except AttributeError : # 'DataArray' object has no attribute 'to_array'
44
- # If object is already an xr.DataArray
45
- dataarray = xr_obj
46
-
47
- return dataarray
48
-
49
63
def to_tensor (self ):
50
64
"""Convert this DataArray to a torch.Tensor"""
51
65
import torch
52
66
53
- dataarray = self . _as_xarray_dataarray (xr_obj = self ._obj )
67
+ dataarray = _as_xarray_dataarray (xr_obj = self ._obj )
54
68
55
69
return torch .tensor (data = dataarray .data )
56
70
@@ -62,6 +76,6 @@ def to_named_tensor(self):
62
76
"""
63
77
import torch
64
78
65
- dataarray = self . _as_xarray_dataarray (xr_obj = self ._obj )
79
+ dataarray = _as_xarray_dataarray (xr_obj = self ._obj )
66
80
67
81
return torch .tensor (data = dataarray .data , names = tuple (dataarray .sizes ))
0 commit comments