Skip to content

Commit 1ed8c72

Browse files
committed
Normalize the file path before writing for the cats_vs_dogs.py dataset
1 parent b68aa45 commit 1ed8c72

File tree

1 file changed

+82
-80
lines changed

1 file changed

+82
-80
lines changed

tensorflow_datasets/image_classification/cats_vs_dogs.py

Lines changed: 82 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Cats vs Dogs dataset."""
1717

1818
import io
19+
import os
1920
import re
2021
import zipfile
2122

@@ -42,90 +43,91 @@
4243
)
4344
_NUM_CORRUPT_IMAGES = 1738
4445
_DESCRIPTION = (
45-
"A large set of images of cats and dogs. "
46-
"There are %d corrupted images that are dropped." % _NUM_CORRUPT_IMAGES
46+
"A large set of images of cats and dogs. "
47+
"There are %d corrupted images that are dropped." % _NUM_CORRUPT_IMAGES
4748
)
4849

4950
_NAME_RE = re.compile(r"^PetImages[\\/](Cat|Dog)[\\/]\d+\.jpg$")
5051

5152

5253
class CatsVsDogs(tfds.core.GeneratorBasedBuilder):
53-
"""Cats vs Dogs."""
54-
55-
VERSION = tfds.core.Version("4.0.1")
56-
RELEASE_NOTES = {
57-
"4.0.0": "New split API (https://tensorflow.org/datasets/splits)",
58-
"4.0.1": (
59-
"Recoding images in generator to fix corrupt JPEG data warnings"
60-
" (https://github.com/tensorflow/datasets/issues/2188)"
61-
),
62-
}
63-
64-
def _info(self):
65-
return tfds.core.DatasetInfo(
66-
builder=self,
67-
description=_DESCRIPTION,
68-
features=tfds.features.FeaturesDict({
69-
"image": tfds.features.Image(),
70-
"image/filename": tfds.features.Text(), # eg 'PetImages/Dog/0.jpg'
71-
"label": tfds.features.ClassLabel(names=["cat", "dog"]),
72-
}),
73-
supervised_keys=("image", "label"),
74-
homepage=(
75-
"https://www.microsoft.com/en-us/download/details.aspx?id=54765"
54+
"""Cats vs Dogs."""
55+
56+
VERSION = tfds.core.Version("4.0.1")
57+
RELEASE_NOTES = {
58+
"4.0.0": "New split API (https://tensorflow.org/datasets/splits)",
59+
"4.0.1": (
60+
"Recoding images in generator to fix corrupt JPEG data warnings"
61+
" (https://github.com/tensorflow/datasets/issues/2188)"
7662
),
77-
citation=_CITATION,
78-
)
79-
80-
def _split_generators(self, dl_manager):
81-
path = dl_manager.download(_URL)
82-
83-
# There is no predefined train/val/test split for this dataset.
84-
return [
85-
tfds.core.SplitGenerator(
86-
name=tfds.Split.TRAIN,
87-
gen_kwargs={
88-
"archive": dl_manager.iter_archive(path),
89-
},
90-
),
91-
]
92-
93-
def _generate_examples(self, archive):
94-
"""Generate Cats vs Dogs images and labels given a directory path."""
95-
num_skipped = 0
96-
for fname, fobj in archive:
97-
res = _NAME_RE.match(fname)
98-
if not res: # README file, ...
99-
continue
100-
label = res.group(1).lower()
101-
if tf.compat.as_bytes("JFIF") not in fobj.peek(10):
102-
num_skipped += 1
103-
continue
104-
105-
# Some images caused 'Corrupt JPEG data...' messages during training or
106-
# any other iteration recoding them once fixes the issue (discussion:
107-
# https://github.com/tensorflow/datasets/issues/2188).
108-
# Those messages are now displayed when generating the dataset instead.
109-
img_data = fobj.read()
110-
img_tensor = tf.image.decode_image(img_data)
111-
img_recoded = tf.io.encode_jpeg(img_tensor)
112-
113-
# Converting the recoded image back into a zip file container.
114-
buffer = io.BytesIO()
115-
with zipfile.ZipFile(buffer, "w") as new_zip:
116-
new_zip.writestr(fname, img_recoded.numpy())
117-
new_fobj = zipfile.ZipFile(buffer).open(fname)
118-
119-
record = {
120-
"image": new_fobj,
121-
"image/filename": fname,
122-
"label": label,
123-
}
124-
yield fname, record
125-
126-
if num_skipped != _NUM_CORRUPT_IMAGES:
127-
raise ValueError(
128-
"Expected %d corrupt images, but found %d"
129-
% (_NUM_CORRUPT_IMAGES, num_skipped)
130-
)
131-
logging.warning("%d images were corrupted and were skipped", num_skipped)
63+
}
64+
65+
def _info(self):
66+
return tfds.core.DatasetInfo(
67+
builder=self,
68+
description=_DESCRIPTION,
69+
features=tfds.features.FeaturesDict({
70+
"image": tfds.features.Image(),
71+
"image/filename": tfds.features.Text(), # eg 'PetImages/Dog/0.jpg'
72+
"label": tfds.features.ClassLabel(names=["cat", "dog"]),
73+
}),
74+
supervised_keys=("image", "label"),
75+
homepage=(
76+
"https://www.microsoft.com/en-us/download/details.aspx?id=54765"
77+
),
78+
citation=_CITATION,
79+
)
80+
81+
def _split_generators(self, dl_manager):
82+
path = dl_manager.download(_URL)
83+
84+
# There is no predefined train/val/test split for this dataset.
85+
return [
86+
tfds.core.SplitGenerator(
87+
name=tfds.Split.TRAIN,
88+
gen_kwargs={
89+
"archive": dl_manager.iter_archive(path),
90+
},
91+
),
92+
]
93+
94+
def _generate_examples(self, archive):
95+
"""Generate Cats vs Dogs images and labels given a directory path."""
96+
num_skipped = 0
97+
for fname, fobj in archive:
98+
norm_fname = os.path.normpath(fname)
99+
res = _NAME_RE.match(norm_fname)
100+
if not res: # README file, ...
101+
continue
102+
label = res.group(1).lower()
103+
if tf.compat.as_bytes("JFIF") not in fobj.peek(10):
104+
num_skipped += 1
105+
continue
106+
107+
# Some images caused 'Corrupt JPEG data...' messages during training or
108+
# any other iteration recoding them once fixes the issue (discussion:
109+
# https://github.com/tensorflow/datasets/issues/2188).
110+
# Those messages are now displayed when generating the dataset instead.
111+
img_data = fobj.read()
112+
img_tensor = tf.image.decode_image(img_data)
113+
img_recoded = tf.io.encode_jpeg(img_tensor)
114+
115+
# Converting the recoded image back into a zip file container.
116+
buffer = io.BytesIO()
117+
with zipfile.ZipFile(buffer, "w") as new_zip:
118+
new_zip.writestr(norm_fname, img_recoded.numpy())
119+
new_fobj = zipfile.ZipFile(buffer).open(norm_fname)
120+
121+
record = {
122+
"image": new_fobj,
123+
"image/filename": norm_fname,
124+
"label": label,
125+
}
126+
yield norm_fname, record
127+
128+
if num_skipped != _NUM_CORRUPT_IMAGES:
129+
raise ValueError(
130+
"Expected %d corrupt images, but found %d"
131+
% (_NUM_CORRUPT_IMAGES, num_skipped)
132+
)
133+
logging.warning("%d images were corrupted and were skipped", num_skipped)

0 commit comments

Comments
 (0)