Skip to content

Commit 1c5febf

Browse files
author
Joseph Hamman
committed
[loaders refactor] initial commit
1 parent 3bc614a commit 1c5febf

File tree

6 files changed

+230
-11
lines changed

6 files changed

+230
-11
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9
77

88
[isort]
99
known_first_party=xbatcher
10-
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,xarray
10+
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray
1111
multi_line_output=3
1212
include_trailing_comma=True
1313
force_grid_wrap=0

xbatcher/generators.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import itertools
44
from collections import OrderedDict
5+
from collections.abc import Iterator
6+
from typing import Any, Dict, Hashable
57

68
import xarray as xr
79

@@ -99,12 +101,12 @@ class BatchGenerator:
99101

100102
def __init__(
101103
self,
102-
ds,
103-
input_dims,
104-
input_overlap={},
105-
batch_dims={},
106-
concat_input_dims=False,
107-
preload_batch=True,
104+
ds: xr.Dataset,
105+
input_dims: Dict[Hashable, int],
106+
input_overlap: Dict[Hashable, int] = {},
107+
batch_dims: Dict[Hashable, int] = {},
108+
concat_input_dims: bool = False,
109+
preload_batch: bool = True,
108110
):
109111

110112
self.ds = _as_xarray_dataset(ds)
@@ -115,7 +117,38 @@ def __init__(
115117
self.concat_input_dims = concat_input_dims
116118
self.preload_batch = preload_batch
117119

118-
def __iter__(self):
120+
self._batches: Dict[
121+
int, Any
122+
] = self._gen_batches() # dict cache for batches
123+
# in the future, we can make this a lru cache or similar thing (cachey?)
124+
125+
def __iter__(self) -> Iterator[xr.Dataset]:
126+
for batch in self._batches.values():
127+
yield batch
128+
129+
def __len__(self) -> int:
130+
return len(self._batches)
131+
132+
def __getitem__(self, idx: int) -> xr.Dataset:
133+
134+
if not isinstance(idx, int):
135+
raise NotImplementedError(
136+
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
137+
)
138+
139+
if idx < 0:
140+
idx = list(self._batches)[idx]
141+
142+
if idx in self._batches:
143+
return self._batches[idx]
144+
else:
145+
raise IndexError('list index out of range')
146+
147+
def _gen_batches(self) -> dict:
148+
# in the future, we will want to do the batch generation lazily
149+
# going the eager route for now is allowing me to fill out the loader api
150+
# but it is likely to perform poorly.
151+
batches = []
119152
for ds_batch in self._iterate_batch_dims(self.ds):
120153
if self.preload_batch:
121154
ds_batch.load()
@@ -132,13 +165,15 @@ def __iter__(self):
132165
new_input_dims = [
133166
dim + new_dim_suffix for dim in self.input_dims
134167
]
135-
yield _maybe_stack_batch_dims(dsc, new_input_dims)
168+
batches.append(_maybe_stack_batch_dims(dsc, new_input_dims))
136169
else:
137170
for ds_input in input_generator:
138-
yield _maybe_stack_batch_dims(
139-
ds_input, list(self.input_dims)
171+
batches.append(
172+
_maybe_stack_batch_dims(ds_input, list(self.input_dims))
140173
)
141174

175+
return dict(zip(range(len(batches)), batches))
176+
142177
def _iterate_batch_dims(self, ds):
143178
return _iterate_through_dataset(ds, self.batch_dims)
144179

xbatcher/loaders/__init__.py

Whitespace-only changes.

xbatcher/loaders/torch.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from typing import Any, Callable, Optional, Tuple
2+
3+
import torch
4+
5+
# Notes:
6+
# This module includes two PyTorch datasets.
7+
# - The MapDataset provides an indexable interface
8+
# - The IterableDataset provides a simple iterable interface
9+
# Both can be provided as arguments to the the Torch DataLoader
10+
# Assumptions made:
11+
# - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators ina dataset)
12+
# TODOs:
13+
# - sort out xarray -> numpy pattern. Currently there is a hardcoded variable name for x/y
14+
# - need to test with additional dataset parameters (e.g. transforms)
15+
16+
17+
class MapDataset(torch.utils.data.Dataset):
18+
def __init__(
19+
self,
20+
X_generator,
21+
y_generator,
22+
transform: Optional[Callable] = None,
23+
target_transform: Optional[Callable] = None,
24+
) -> None:
25+
'''
26+
PyTorch Dataset adapter for Xbatcher
27+
28+
Parameters
29+
----------
30+
X_generator : xbatcher.BatchGenerator
31+
y_generator : xbatcher.BatchGenerator
32+
transform : callable, optional
33+
A function/transform that takes in an array and returns a transformed version.
34+
target_transform : callable, optional
35+
A function/transform that takes in the target and transforms it.
36+
'''
37+
self.X_generator = X_generator
38+
self.y_generator = y_generator
39+
self.transform = transform
40+
self.target_transform = target_transform
41+
42+
def __len__(self) -> int:
43+
return len(self.X_generator)
44+
45+
def __getitem__(self, idx) -> Tuple[Any, Any]:
46+
if torch.is_tensor(idx):
47+
idx = idx.tolist()
48+
assert len(idx) == 1
49+
50+
# TODO: figure out the dataset -> array workflow
51+
# currently hardcoding a variable name
52+
X_batch = self.X_generator[idx]['x'].data
53+
y_batch = self.y_generator[idx]['y'].data
54+
55+
if self.transform:
56+
X_batch = self.transform(X_batch)
57+
58+
if self.target_transform:
59+
y_batch = self.target_transform(y_batch)
60+
print('x_batch.shape', X_batch.shape)
61+
return X_batch, y_batch
62+
63+
64+
class IterableDataset(torch.utils.data.IterableDataset):
65+
def __init__(
66+
self,
67+
X_generator,
68+
y_generator,
69+
) -> None:
70+
'''
71+
PyTorch Dataset adapter for Xbatcher
72+
73+
Parameters
74+
----------
75+
X_generator : xbatcher.BatchGenerator
76+
y_generator : xbatcher.BatchGenerator
77+
'''
78+
79+
self.X_generator = X_generator
80+
self.y_generator = y_generator
81+
82+
def __iter__(self):
83+
for xb, yb in zip(self.X_generator, self.y_generator):
84+
yield (xb['x'].data, yb['y'].data)

xbatcher/tests/test_generators.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,28 @@ def sample_ds_1d():
1818
return ds
1919

2020

21+
@pytest.mark.parametrize('bsize', [5, 6])
22+
def test_batcher_lenth(sample_ds_1d, bsize):
23+
bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize})
24+
assert len(bg) == sample_ds_1d.dims['x'] // bsize
25+
26+
27+
def test_batcher_getitem(sample_ds_1d):
28+
bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10})
29+
30+
# first batch
31+
assert bg[0].dims['x'] == 10
32+
# last batch
33+
assert bg[-1].dims['x'] == 10
34+
# raises IndexError for out of range index
35+
with pytest.raises(IndexError, match=r'list index out of range'):
36+
bg[9999999]
37+
38+
# raises NotImplementedError for iterable index
39+
with pytest.raises(NotImplementedError):
40+
bg[[1, 2, 3]]
41+
42+
2143
# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
2244
# Should we enforce that each batch size always has to be the same
2345
@pytest.mark.parametrize('bsize', [5, 10])

xbatcher/tests/test_torch_loaders.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import numpy as np
2+
import pytest
3+
import xarray as xr
4+
5+
torch = pytest.importorskip('torch')
6+
7+
from xbatcher import BatchGenerator
8+
from xbatcher.loaders.torch import IterableDataset, MapDataset
9+
10+
11+
@pytest.fixture(scope='module')
12+
def ds_xy():
13+
n_samples = 100
14+
n_features = 5
15+
ds = xr.Dataset(
16+
{
17+
'x': (
18+
['sample', 'feature'],
19+
np.random.random((n_samples, n_features)),
20+
),
21+
'y': (['sample'], np.random.random(n_samples)),
22+
},
23+
)
24+
return ds
25+
26+
27+
def test_map_dataset(ds_xy):
28+
29+
x = ds_xy['x']
30+
y = ds_xy['y']
31+
32+
x_gen = BatchGenerator(x, {'sample': 10})
33+
y_gen = BatchGenerator(y, {'sample': 10})
34+
35+
dataset = MapDataset(x_gen, y_gen)
36+
37+
# test __getitem__
38+
x_batch, y_batch = dataset[0]
39+
assert len(x_batch) == len(y_batch)
40+
assert isinstance(x_batch, np.ndarray)
41+
42+
# test __len__
43+
assert len(dataset) == len(x_gen)
44+
45+
# test integration with torch DataLoader
46+
loader = torch.utils.data.DataLoader(dataset)
47+
48+
for x_batch, y_batch in loader:
49+
assert len(x_batch) == len(y_batch)
50+
assert isinstance(x_batch, torch.Tensor)
51+
52+
# TODO: why does pytorch add an extra dimension (length 1) to x_batch
53+
assert x_gen[-1]['x'].shape == x_batch.shape[1:]
54+
# TODO: also need to revisit the variable extraction bits here
55+
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])
56+
57+
58+
def test_iterable_dataset(ds_xy):
59+
60+
x = ds_xy['x']
61+
y = ds_xy['y']
62+
63+
x_gen = BatchGenerator(x, {'sample': 10})
64+
y_gen = BatchGenerator(y, {'sample': 10})
65+
66+
dataset = IterableDataset(x_gen, y_gen)
67+
68+
# test integration with torch DataLoader
69+
loader = torch.utils.data.DataLoader(dataset)
70+
71+
for x_batch, y_batch in loader:
72+
assert len(x_batch) == len(y_batch)
73+
assert isinstance(x_batch, torch.Tensor)
74+
75+
# TODO: why does pytorch add an extra dimension (length 1) to x_batch
76+
assert x_gen[-1]['x'].shape == x_batch.shape[1:]
77+
# TODO: also need to revisit the variable extraction bits here
78+
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])

0 commit comments

Comments
 (0)