@@ -11,6 +11,52 @@ def _as_xarray_dataset(ds):
11
11
else :
12
12
return ds .to_dataset ()
13
13
14
+ def _slices (dimsize , size , overlap = 0 ):
15
+ # return a list of slices to chop up a single dimension
16
+ slices = []
17
+ stride = size - overlap
18
+ assert stride > 0
19
+ assert stride < dimsize
20
+ for start in range (0 , dimsize , stride ):
21
+ end = start + size
22
+ if end <= dimsize :
23
+ slices .append (slice (start , end ))
24
+ return slices
25
+
26
+
27
+ def _iterate_through_dataset (ds , dims , overlap = {}):
28
+ dim_slices = []
29
+ for dim in dims :
30
+ dimsize = ds .dims [dim ]
31
+ size = dims [dim ]
32
+ olap = overlap .get (dim , 0 )
33
+ dim_slices .append (_slices (dimsize , size , olap ))
34
+
35
+ for slices in itertools .product (* dim_slices ):
36
+ selector = {key : slice for key , slice in zip (dims , slices )}
37
+ yield ds .isel (** selector )
38
+
39
+
40
+ def _drop_input_dims (ds , input_dims , suffix = '_input' ):
41
+ # remove input_dims coordinates from datasets, rename the dimensions
42
+ # then put intput_dims back in as coordinates
43
+ out = ds .copy ()
44
+ out = (out .drop (input_dims )
45
+ .rename ({dim : dim + suffix for dim in input_dims }))
46
+ for dim in input_dims :
47
+ out .coords [dim ] = dim + suffix , ds [dim ].values
48
+ return out
49
+
50
+
51
+ def _maybe_stack_batch_dims (ds , input_dims ):
52
+ batch_dims = list (set (ds .dims ) - set (input_dims ))
53
+ if len (batch_dims ) < 2 :
54
+ return ds
55
+ ds_stack = ds .stack (batch = batch_dims )
56
+ # ensure correct order
57
+ dim_order = ('batch' ,) + tuple (input_dims )
58
+ return ds_stack .transpose (* dim_order )
59
+
14
60
15
61
class BatchGenerator :
16
62
"""Create generator for iterating through xarray datarrays / datasets in
@@ -20,44 +66,59 @@ class BatchGenerator:
20
66
----------
21
67
ds : ``xarray.Dataset`` or ``xarray.DataArray``
22
68
The data to iterate over
23
- batch_sizes : dict
24
- A dictionary specifying the size of the batch in each dimension,
25
- e.g. ``{'time': 100, 'latitude': 30}``
26
- overlap : dict, optional
69
+ input_dims : dict
70
+ A dictionary specifying the size of the inputs in each dimension,
71
+ e.g. ``{'lat': 30, 'lon': 30}``
72
+ These are the dimensions the ML library will see. All other dimensions
73
+ will be stacked into one dimension called ``batch``.
74
+ input_overlap : dict, optional
27
75
A dictionary specifying the overlap along each dimension
76
+ e.g. ``{'lat': 3, 'lon': 3}``
77
+ batch_dims : dict, optional
78
+ A dictionary specifying the size of the batch along each dimension
79
+ e.g. ``{'time': 10}``. These will always be interated over.
80
+ concat_input_dims : bool, optional
81
+ If ``True``, the dimension chunks specified in ``input_dims`` will be
82
+ concatenated and stacked into the batch dimension. If ``False``, they
83
+ will be iterated over.
84
+ preload_batch : bool, optional
85
+ If ``True``, each batch will be loaded into memory before reshaping /
86
+ processing, triggering any dask arrays to be computed.
28
87
29
88
Yields
30
89
------
31
90
ds_slice : ``xarray.Dataset`` or ``xarray.DataArray``
32
- Slices of the array matching the given batch size specification
91
+ Slices of the array matching the given batch size specification.
33
92
"""
34
93
35
- def __init__ (self , ds , batch_sizes , overlap = {}):
94
+ def __init__ (self , ds , input_dims , input_overlap = {}, batch_dims = {},
95
+ concat_input_dims = False , preload_batch = True ):
96
+
36
97
self .ds = _as_xarray_dataset (ds )
37
98
# should be a dict
38
- self .batch_sizes = OrderedDict (batch_sizes )
39
- self .batch_dims = list (self .batch_sizes )
40
- # make overlap is defined for each batch size defined
41
- self .overlap = {k : overlap .get (k , 0 ) for k in self .batch_dims }
99
+ self .input_dims = OrderedDict (input_dims )
100
+ self .input_overlap = input_overlap
101
+ self .batch_dims = OrderedDict (batch_dims )
102
+ self .concat_input_dims = concat_input_dims
103
+ self .preload_batch = preload_batch
42
104
43
105
44
106
def __iter__ (self ):
45
- for slices in itertools .product (* [self ._iterate_dim (dim )
46
- for dim in self .batch_dims ]):
47
- selector = {key : slice for key , slice in zip (self .batch_dims , slices )}
48
- yield self .ds .isel (** selector )
49
-
50
-
51
- def _iterate_dim (self , dim ):
52
- dimsize = self .ds .dims [dim ]
53
- size = self .batch_sizes [dim ]
54
- overlap = self .overlap [dim ]
55
- stride = size - overlap
56
- assert stride > 0
57
- assert stride < dimsize
58
- for start in range (0 , dimsize , stride ):
59
- end = start + size
60
- if end <= dimsize :
61
- yield slice (start , end )
107
+ for ds_batch in self ._iterate_batch_dims (self .ds ):
108
+ if self .preload_batch :
109
+ ds_batch .load ()
110
+ input_generator = self ._iterate_input_dims (ds_batch )
111
+ if self .concat_input_dims :
112
+ all_dsets = [_drop_input_dims (ds_input , list (self .input_dims ))
113
+ for ds_input in input_generator ]
114
+ dsc = xr .concat (all_batches , dim = 'input_batch' )
115
+ yield _maybe_stack_batch_dims (dsc , list (self .input_dims ))
62
116
else :
63
- return
117
+ for ds_input in input_generator :
118
+ yield _maybe_stack_batch_dims (ds_input , list (self .input_dims ))
119
+
120
+ def _iterate_batch_dims (self , ds ):
121
+ return _iterate_through_dataset (ds , self .batch_dims )
122
+
123
+ def _iterate_input_dims (self , ds ):
124
+ return _iterate_through_dataset (ds , self .input_dims , self .input_overlap )
0 commit comments