@@ -94,6 +94,9 @@ class _ImageEncoder:
94
94
95
95
def __post_init__ (self ):
96
96
self .np_dtype = dtype_utils .cast_to_numpy (self .dtype )
97
+ # When encoding isn't defined, default to PNG.
98
+ if self .encoding_format is None :
99
+ self .encoding_format = 'png'
97
100
98
101
# TODO(tfds): Should deprecate the TFGraph runner in favor of simpler
99
102
# implementation
@@ -123,15 +126,12 @@ def _encode_image(self, np_image: np.ndarray) -> bytes:
123
126
"""Returns np_image encoded as jpeg or png."""
124
127
_validate_np_array (np_image , shape = self .shape , dtype = self .np_dtype )
125
128
126
- # When encoding isn't defined, default to PNG.
127
129
# Should we be more strict about explicitly define the encoding (raise
128
130
# error / warning instead) ?
129
131
# It has created subtle issues for imagenet_corrupted: images are read as
130
132
# JPEG images to apply some processing, but final image saved as PNG
131
133
# (default) rather than JPEG.
132
- return self ._runner .run (
133
- _ENCODE_FN [self .encoding_format or 'png' ](), np_image
134
- )
134
+ return self ._runner .run (_ENCODE_FN [self .encoding_format ](), np_image )
135
135
136
136
def _encode_pil_image (self , pil_image ) -> bytes :
137
137
"""Encode a PIL Image object to bytes.
@@ -144,7 +144,7 @@ def _encode_pil_image(self, pil_image) -> bytes:
144
144
"""
145
145
check_pil_import_or_raise_error ()
146
146
buffer = io .BytesIO ()
147
- pil_image .save (buffer , format = self .encoding_format or pil_image . format )
147
+ pil_image .save (buffer , format = self .encoding_format )
148
148
return buffer .getvalue ()
149
149
150
150
def decode_image (self , img : tf .Tensor ) -> tf .Tensor :
0 commit comments