Skip to content

Commit bb5d40c

Browse files
fix: label arrays
1 parent 15bd08e commit bb5d40c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tensorflow_datasets/datasets/pneumoniamnist/pneumoniamnist_dataset_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
3535
train_images = np.expand_dims(raw_data.f.train_images, axis=-1)
3636
val_images = np.expand_dims(raw_data.f.val_images, axis=-1)
3737
test_images = np.expand_dims(raw_data.f.test_images, axis=-1)
38-
train_labels = np.squeeze(raw_data.f.train_labels)
39-
val_labels = np.squeeze(raw_data.f.val_labels)
40-
test_labels = np.squeeze(raw_data.f.test_labels)
38+
train_labels = raw_data.f.train_labels.flatten()
39+
val_labels = raw_data.f.val_labels.flatten()
40+
test_labels = raw_data.f.test_labels.flatten()
4141

4242
return {
4343
'train': self._generate_examples(train_images, train_labels),

0 commit comments

Comments
 (0)