Skip to content

Commit 6990f99

Browse files
feat: add builder
1 parent b8a3bf2 commit 6990f99

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

tensorflow_datasets/datasets/pneumoniamnist/pneumoniamnist_dataset_builder.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""pneumoniamnist dataset."""
22

33
import tensorflow_datasets.public_api as tfds
4-
4+
import numpy as np
55

66
class Builder(tfds.core.GeneratorBasedBuilder):
77
"""DatasetBuilder for pneumoniamnist dataset."""
@@ -24,19 +24,27 @@ def _info(self) -> tfds.core.DatasetInfo:
2424

2525
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
2626
"""Returns SplitGenerators."""
27-
# TODO(pneumoniamnist): Downloads the data and defines the splits
28-
path = dl_manager.download_and_extract('https://todo-data-url')
2927

30-
# TODO(pneumoniamnist): Returns the Dict[split names, Iterator[Key, Example]]
28+
path = dl_manager.download('https://zenodo.org/records/10519652/files/pneumoniamnist.npz?download=1')
29+
30+
raw_data = np.load(path, allow_pickle=True)
31+
train_images = np.expand_dims(raw_data.f.train_images, axis=-1)
32+
val_images = np.expand_dims(raw_data.f.val_images, axis=-1)
33+
test_images = np.expand_dims(raw_data.f.test_images, axis=-1)
34+
train_labels = np.squeeze(raw_data.f.train_labels)
35+
val_labels = np.squeeze(raw_data.f.val_labels)
36+
test_labels = np.squeeze(raw_data.f.test_labels)
37+
3138
return {
32-
'train': self._generate_examples(path / 'train_imgs'),
39+
'train': self._generate_examples(train_images, train_labels),
40+
'val': self._generate_examples(val_images, val_labels),
41+
'test': self._generate_examples(test_images, test_labels),
3342
}
3443

35-
def _generate_examples(self, path):
44+
def _generate_examples(self, images, labels):
3645
"""Yields examples."""
37-
# TODO(pneumoniamnist): Yields (key, example) tuples from the dataset
38-
for f in path.glob('*.jpeg'):
39-
yield 'key', {
40-
'image': f,
41-
'label': 'yes',
46+
for idx, (image, label) in enumerate(zip(images, labels)):
47+
yield idx, {
48+
'image': image,
49+
'label': int(np.squeeze(label)),
4250
}

0 commit comments

Comments
 (0)