Skip to content

Commit 0ded974

Browse files
author
Max Jones
authored
Prototype keras data loader (#73)
1 parent 213264c commit 0ded974

File tree

6 files changed

+148
-2
lines changed

6 files changed

+148
-2
lines changed

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pytest
2+
tensorflow
23
torch
34
coverage
45
pytest-cov

doc/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ Dataloaders
3030

3131
.. autoclass:: xbatcher.loaders.torch.IterableDataset
3232
:members:
33+
34+
.. autoclass:: xbatcher.loaders.keras.CustomTFDataset
35+
:members:

doc/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def setup(app):
7070
app.connect('autodoc-skip-member', skip)
7171

7272

73-
autodoc_mock_imports = ['torch']
73+
autodoc_mock_imports = ['torch', 'tensorflow']
7474

7575
# link to github issues
7676
extlinks = {

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9
77

88
[isort]
99
known_first_party=xbatcher
10-
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray
10+
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,tensorflow,torch,xarray
1111
multi_line_output=3
1212
include_trailing_comma=True
1313
force_grid_wrap=0

xbatcher/loaders/keras.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Any, Callable, Optional, Tuple
2+
3+
import tensorflow as tf
4+
import xarray as xr
5+
6+
# Notes:
7+
# This module includes one Keras dataset, which can be provided to model.fit().
8+
# - The CustomTFDataset provides an indexable interface
9+
# Assumptions made:
10+
# - The dataset takes pre-configured X/y xbatcher generators (may not always want two generators in a dataset)
11+
12+
13+
class CustomTFDataset(tf.keras.utils.Sequence):
14+
def __init__(
15+
self,
16+
X_generator,
17+
y_generator,
18+
*,
19+
transform: Optional[Callable] = None,
20+
target_transform: Optional[Callable] = None,
21+
dim: str = 'new_dim',
22+
) -> None:
23+
'''
24+
Keras Dataset adapter for Xbatcher
25+
26+
Parameters
27+
----------
28+
X_generator : xbatcher.BatchGenerator
29+
y_generator : xbatcher.BatchGenerator
30+
transform : callable, optional
31+
A function/transform that takes in an array and returns a transformed version.
32+
target_transform : callable, optional
33+
A function/transform that takes in the target and transforms it.
34+
dim : str, 'new_dim'
35+
Name of dim to pass to :func:`xarray.concat` as the dimension
36+
to concatenate all variables along.
37+
'''
38+
self.X_generator = X_generator
39+
self.y_generator = y_generator
40+
self.transform = transform
41+
self.target_transform = target_transform
42+
self.concat_dim = dim
43+
44+
def __len__(self) -> int:
45+
return len(self.X_generator)
46+
47+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
48+
X_batch = tf.convert_to_tensor(
49+
xr.concat(
50+
(
51+
self.X_generator[idx][key]
52+
for key in list(self.X_generator[idx].keys())
53+
),
54+
self.concat_dim,
55+
).data
56+
)
57+
y_batch = tf.convert_to_tensor(
58+
xr.concat(
59+
(
60+
self.y_generator[idx][key]
61+
for key in list(self.y_generator[idx].keys())
62+
),
63+
self.concat_dim,
64+
).data
65+
)
66+
67+
# TODO: Should the transformations be applied before tensor conversion?
68+
if self.transform:
69+
X_batch = self.transform(X_batch)
70+
71+
if self.target_transform:
72+
y_batch = self.target_transform(y_batch)
73+
return X_batch, y_batch

xbatcher/tests/test_keras_loaders.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import numpy as np
2+
import pytest
3+
import xarray as xr
4+
5+
tf = pytest.importorskip('tensorflow')
6+
7+
from xbatcher import BatchGenerator
8+
from xbatcher.loaders.keras import CustomTFDataset
9+
10+
11+
@pytest.fixture(scope='module')
12+
def ds_xy():
13+
n_samples = 100
14+
n_features = 5
15+
ds = xr.Dataset(
16+
{
17+
'x': (
18+
['sample', 'feature'],
19+
np.random.random((n_samples, n_features)),
20+
),
21+
'y': (['sample'], np.random.random(n_samples)),
22+
},
23+
)
24+
return ds
25+
26+
27+
def test_custom_dataset(ds_xy):
28+
29+
x = ds_xy['x']
30+
y = ds_xy['y']
31+
32+
x_gen = BatchGenerator(x, {'sample': 10})
33+
y_gen = BatchGenerator(y, {'sample': 10})
34+
35+
dataset = CustomTFDataset(x_gen, y_gen)
36+
37+
# test __getitem__
38+
x_batch, y_batch = dataset[0]
39+
assert len(x_batch) == len(y_batch)
40+
assert tf.is_tensor(x_batch)
41+
assert tf.is_tensor(y_batch)
42+
43+
# test __len__
44+
assert len(dataset) == len(x_gen)
45+
46+
47+
def test_custom_dataset_with_transform(ds_xy):
48+
49+
x = ds_xy['x']
50+
y = ds_xy['y']
51+
52+
x_gen = BatchGenerator(x, {'sample': 10})
53+
y_gen = BatchGenerator(y, {'sample': 10})
54+
55+
def x_transform(batch):
56+
return batch * 0 + 1
57+
58+
def y_transform(batch):
59+
return batch * 0 - 1
60+
61+
dataset = CustomTFDataset(
62+
x_gen, y_gen, transform=x_transform, target_transform=y_transform
63+
)
64+
x_batch, y_batch = dataset[0]
65+
assert len(x_batch) == len(y_batch)
66+
assert tf.is_tensor(x_batch)
67+
assert tf.is_tensor(y_batch)
68+
assert tf.experimental.numpy.all(x_batch == 1)
69+
assert tf.experimental.numpy.all(y_batch == -1)

0 commit comments

Comments
 (0)