18
18
import os
19
19
import xml .etree .ElementTree
20
20
21
+ from etils .epath import Path
21
22
from tensorflow_datasets .core .utils .lazy_imports_utils import tensorflow as tf
22
23
import tensorflow_datasets .public_api as tfds
23
24
64
65
]
65
66
_SPECIES_CLASSES = ["Cat" , "Dog" ]
66
67
67
- # List of samples with corrupt image files
68
- _SKIP_SAMPLES = [
68
+ # List of samples with corrupt image files (mostly wrong format -> we are fixing these during dataset creation)
69
+ _CORRUPT_SAMPLES = [
69
70
"beagle_116" ,
70
71
"chihuahua_121" ,
71
72
"Abyssinian_5" ,
80
81
"Egyptian_Mau_191"
81
82
]
82
83
84
+ _EMPTY_BBOX = tfds .features .BBox (0. , 0. , 0. , 0. )
85
+
86
+
83
87
def _get_head_bbox (annon_filepath ):
84
88
"""Read head bbox from annotation XML file."""
85
- with tf . io . gfile . GFile (annon_filepath , "r" ) as f :
89
+ with Path (annon_filepath ). open ( "r" ) as f :
86
90
root = xml .etree .ElementTree .parse (f ).getroot ()
87
91
88
- # Disable pytype to avoid attribute-error due to find returning
89
- # Optional[Element]
90
- # pytype: disable=attribute-error
91
- size = root .find ("size" )
92
+ size = root .find ("size" ) # pytype: disable=annotation-type-mismatch
92
93
width = float (size .find ("width" ).text )
93
94
height = float (size .find ("height" ).text )
94
95
@@ -106,7 +107,10 @@ def _get_head_bbox(annon_filepath):
106
107
class Builder (tfds .core .GeneratorBasedBuilder ):
107
108
"""Oxford-IIIT pet dataset."""
108
109
109
- VERSION = tfds .core .Version ("3.2.0" )
110
+ VERSION = tfds .core .Version ("4.0.0" )
111
+ RELEASE_NOTES = {
112
+ '4.0.0' : 'Add head bounding boxes. Fix corrupt iamges. Update dataset URL.'
113
+ }
110
114
111
115
def _info (self ):
112
116
return self .dataset_info_from_configs (
@@ -118,7 +122,7 @@ def _info(self):
118
122
"segmentation_mask" : tfds .features .Image (
119
123
shape = (None , None , 1 ), use_colormap = True
120
124
),
121
- "head " : tfds .features .BBoxFeature ()
125
+ "head_bbox " : tfds .features .BBoxFeature ()
122
126
}),
123
127
supervised_keys = ("image" , "label" ),
124
128
homepage = "http://www.robots.ox.ac.uk/~vgg/data/pets/" ,
@@ -162,13 +166,23 @@ def _split_generators(self, dl_manager):
162
166
def _generate_examples (
163
167
self , images_dir_path , annotations_dir_path , images_list_file
164
168
):
165
- with tf . io . gfile . GFile (images_list_file , "r" ) as images_list :
169
+ with Path (images_list_file ). open ( "r" ) as images_list :
166
170
for line in images_list :
167
171
image_name , label , species , _ = line .strip ().split (" " )
168
172
169
- # skip corrupt samples
170
- if image_name in _SKIP_SAMPLES :
171
- continue
173
+ image_path = os .path .join (images_dir_path , image_name + ".jpg" )
174
+
175
+ if image_name in _CORRUPT_SAMPLES :
176
+ # some images caused 'Corrupt JPEG data...' messages during training or any other iteration
177
+ # recoding them once fixes the issue (discussion: https://github.com/tensorflow/datasets/issues/2188)
178
+ with Path (image_path ).open ("rb" ) as image_file :
179
+ img_data = image_file .read ()
180
+ img_tensor = tf .image .decode_image (img_data )
181
+ if tf .shape (img_tensor )[- 1 ] == 4 : # some files have an alpha channel -> remove
182
+ img_tensor = img_tensor [:, :, :- 1 ]
183
+ img_recoded = tf .io .encode_jpeg (img_tensor )
184
+ with Path (image_path ).open ("wb" ) as image_file :
185
+ image_file .write (img_recoded .numpy ())
172
186
173
187
trimaps_dir_path = os .path .join (annotations_dir_path , "trimaps" )
174
188
xmls_dir_path = os .path .join (annotations_dir_path , "xmls" )
@@ -181,16 +195,16 @@ def _generate_examples(
181
195
182
196
try :
183
197
head_bbox = _get_head_bbox (os .path .join (xmls_dir_path , xml_name ))
184
- except tf . errors . NotFoundError :
198
+ except FileNotFoundError as e :
185
199
# test samples do not have an annotation file
186
- head_bbox = tfds . features . BBox ( 0. , 0. , 0. , 0. )
200
+ head_bbox = _EMPTY_BBOX
187
201
188
202
record = {
189
203
"image" : os .path .join (images_dir_path , image_name ),
190
204
"label" : int (label ),
191
205
"species" : species ,
192
206
"file_name" : image_name ,
193
207
"segmentation_mask" : os .path .join (trimaps_dir_path , trimap_name ),
194
- "head " : head_bbox
208
+ "head_bbox " : head_bbox
195
209
}
196
210
yield image_name , record
0 commit comments