Skip to content

Commit 1e6c0f5

Browse files
committed
refactor complete; no tests
1 parent 57c901f commit 1e6c0f5

File tree

2 files changed

+94
-33
lines changed

2 files changed

+94
-33
lines changed

xbatcher/generators.py

Lines changed: 89 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,52 @@ def _as_xarray_dataset(ds):
1111
else:
1212
return ds.to_dataset()
1313

14+
def _slices(dimsize, size, overlap=0):
15+
# return a list of slices to chop up a single dimension
16+
slices = []
17+
stride = size - overlap
18+
assert stride > 0
19+
assert stride < dimsize
20+
for start in range(0, dimsize, stride):
21+
end = start+size
22+
if end <= dimsize:
23+
slices.append(slice(start, end))
24+
return slices
25+
26+
27+
def _iterate_through_dataset(ds, dims, overlap={}):
28+
dim_slices = []
29+
for dim in dims:
30+
dimsize = ds.dims[dim]
31+
size = dims[dim]
32+
olap = overlap.get(dim, 0)
33+
dim_slices.append(_slices(dimsize, size, olap))
34+
35+
for slices in itertools.product(*dim_slices):
36+
selector = {key: slice for key, slice in zip(dims, slices)}
37+
yield ds.isel(**selector)
38+
39+
40+
def _drop_input_dims(ds, input_dims, suffix='_input'):
41+
# remove input_dims coordinates from datasets, rename the dimensions
42+
# then put intput_dims back in as coordinates
43+
out = ds.copy()
44+
out = (out.drop(input_dims)
45+
.rename({dim: dim + suffix for dim in input_dims}))
46+
for dim in input_dims:
47+
out.coords[dim] = dim + suffix, ds[dim].values
48+
return out
49+
50+
51+
def _maybe_stack_batch_dims(ds, input_dims):
52+
batch_dims = list(set(ds.dims) - set(input_dims))
53+
if len(batch_dims) < 2:
54+
return ds
55+
ds_stack = ds.stack(batch=batch_dims)
56+
# ensure correct order
57+
dim_order = ('batch',) + tuple(input_dims)
58+
return ds_stack.transpose(*dim_order)
59+
1460

1561
class BatchGenerator:
1662
"""Create generator for iterating through xarray datarrays / datasets in
@@ -20,44 +66,59 @@ class BatchGenerator:
2066
----------
2167
ds : ``xarray.Dataset`` or ``xarray.DataArray``
2268
The data to iterate over
23-
batch_sizes : dict
24-
A dictionary specifying the size of the batch in each dimension,
25-
e.g. ``{'time': 100, 'latitude': 30}``
26-
overlap : dict, optional
69+
input_dims : dict
70+
A dictionary specifying the size of the inputs in each dimension,
71+
e.g. ``{'lat': 30, 'lon': 30}``
72+
These are the dimensions the ML library will see. All other dimensions
73+
will be stacked into one dimension called ``batch``.
74+
input_overlap : dict, optional
2775
A dictionary specifying the overlap along each dimension
76+
e.g. ``{'lat': 3, 'lon': 3}``
77+
batch_dims : dict, optional
78+
A dictionary specifying the size of the batch along each dimension
79+
e.g. ``{'time': 10}``. These will always be interated over.
80+
concat_input_dims : bool, optional
81+
If ``True``, the dimension chunks specified in ``input_dims`` will be
82+
concatenated and stacked into the batch dimension. If ``False``, they
83+
will be iterated over.
84+
preload_batch : bool, optional
85+
If ``True``, each batch will be loaded into memory before reshaping /
86+
processing, triggering any dask arrays to be computed.
2887
2988
Yields
3089
------
3190
ds_slice : ``xarray.Dataset`` or ``xarray.DataArray``
32-
Slices of the array matching the given batch size specification
91+
Slices of the array matching the given batch size specification.
3392
"""
3493

35-
def __init__(self, ds, batch_sizes, overlap={}):
94+
def __init__(self, ds, input_dims, input_overlap={}, batch_dims={},
95+
concat_input_dims=False, preload_batch=True):
96+
3697
self.ds = _as_xarray_dataset(ds)
3798
# should be a dict
38-
self.batch_sizes = OrderedDict(batch_sizes)
39-
self.batch_dims = list(self.batch_sizes)
40-
# make overlap is defined for each batch size defined
41-
self.overlap = {k: overlap.get(k, 0) for k in self.batch_dims}
99+
self.input_dims = OrderedDict(input_dims)
100+
self.input_overlap = input_overlap
101+
self.batch_dims = OrderedDict(batch_dims)
102+
self.concat_input_dims = concat_input_dims
103+
self.preload_batch = preload_batch
42104

43105

44106
def __iter__(self):
45-
for slices in itertools.product(*[self._iterate_dim(dim)
46-
for dim in self.batch_dims]):
47-
selector = {key: slice for key, slice in zip(self.batch_dims, slices)}
48-
yield self.ds.isel(**selector)
49-
50-
51-
def _iterate_dim(self, dim):
52-
dimsize = self.ds.dims[dim]
53-
size = self.batch_sizes[dim]
54-
overlap = self.overlap[dim]
55-
stride = size - overlap
56-
assert stride > 0
57-
assert stride < dimsize
58-
for start in range(0, dimsize, stride):
59-
end = start+size
60-
if end <= dimsize:
61-
yield slice(start, end)
107+
for ds_batch in self._iterate_batch_dims(self.ds):
108+
if self.preload_batch:
109+
ds_batch.load()
110+
input_generator = self._iterate_input_dims(ds_batch)
111+
if self.concat_input_dims:
112+
all_dsets = [_drop_input_dims(ds_input, list(self.input_dims))
113+
for ds_input in input_generator]
114+
dsc = xr.concat(all_batches, dim='input_batch')
115+
yield _maybe_stack_batch_dims(dsc, list(self.input_dims))
62116
else:
63-
return
117+
for ds_input in input_generator:
118+
yield _maybe_stack_batch_dims(ds_input, list(self.input_dims))
119+
120+
def _iterate_batch_dims(self, ds):
121+
return _iterate_through_dataset(ds, self.batch_dims)
122+
123+
def _iterate_input_dims(self, ds):
124+
return _iterate_through_dataset(ds, self.input_dims, self.input_overlap)

xbatcher/tests/test_generators.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def sample_ds_1d():
1717
# Should we enforce that each batch size always has to be the same
1818
@pytest.mark.parametrize("bsize", [5, 10])
1919
def test_batch_1d(sample_ds_1d, bsize):
20-
bg = BatchGenerator(sample_ds_1d, batch_sizes={'x': bsize})
20+
bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize})
2121
for n, ds_batch in enumerate(bg):
2222
assert isinstance(ds_batch, xr.Dataset)
2323
# TODO: maybe relax this? see comment above
@@ -30,8 +30,8 @@ def test_batch_1d(sample_ds_1d, bsize):
3030
@pytest.mark.parametrize("olap", [1, 4])
3131
def test_batch_1d_overlap(sample_ds_1d, olap):
3232
bsize = 10
33-
bg = BatchGenerator(sample_ds_1d, batch_sizes={'x': bsize},
34-
overlap={'x': olap})
33+
bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize},
34+
input_overlap={'x': olap})
3535
stride = bsize-olap
3636
for n, ds_batch in enumerate(bg):
3737
assert isinstance(ds_batch, xr.Dataset)
@@ -55,7 +55,7 @@ def sample_ds_2d():
5555
def test_batch_2d(sample_ds_2d, bsize):
5656

5757
# first do the iteration over just one dimension
58-
bg = BatchGenerator(sample_ds_2d, batch_sizes={'x': bsize})
58+
bg = BatchGenerator(sample_ds_2d, input_dims={'x': bsize})
5959
for n, ds_batch in enumerate(bg):
6060
assert isinstance(ds_batch, xr.Dataset)
6161
assert ds_batch.dims['x'] == bsize
@@ -66,7 +66,7 @@ def test_batch_2d(sample_ds_2d, bsize):
6666

6767
# now iterate over both
6868
xbsize = 20
69-
bg = BatchGenerator(sample_ds_2d, batch_sizes={'y': bsize, 'x': xbsize})
69+
bg = BatchGenerator(sample_ds_2d, input_dims={'y': bsize, 'x': xbsize})
7070
for n, ds_batch in enumerate(bg):
7171
assert isinstance(ds_batch, xr.Dataset)
7272
assert ds_batch.dims['x'] == xbsize

0 commit comments

Comments
 (0)