Skip to content

Commit 05f55cd

Browse files
author
The TensorFlow Datasets Authors
committed
Improve uint64 handling in tfds.
1) For encoding (raw), bitcast to int64 before writing. This is idempotent for values < kint64max. 2) For decoding, read as int64, bitcast to uint64. All other (uint) dtypes are handled by tf.io.decode_raw. This change could also be achieved by making tf.io.decode_raw support uint64, but the parts around tf.Example.int64_values of values over kint64max would be necessary. PiperOrigin-RevId: 634029422
1 parent a1b575c commit 05f55cd

File tree

4 files changed

+46
-2
lines changed

4 files changed

+46
-2
lines changed

tensorflow_datasets/core/example_serializer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ def _item_to_tf_feature(
187187
# Convert boolean to integer (tf.train.Example does not support bool)
188188
if v.dtype == np.bool_:
189189
v = v.astype(int)
190-
190+
if v.dtype == np.uint64:
191+
# We cannot store uint64 in tf.Example, so we bitcast to int64.
192+
v = v.view(np.int64)
191193
vals = v.flat # Convert v into a 1-d array (without extra copy)
192194
if dtype_utils.is_integer(v.dtype):
193195
return tf_feature_pb2.Feature(

tensorflow_datasets/core/example_serializer_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,12 @@ def test_add_ragged_fields_single_level_sequence(self, dtype):
262262
)
263263
self.assertEqual(out[1], tensor_info)
264264

265+
def test_uint_to_tf_feature_overflow(self):
266+
tensor_info = feature_lib.TensorInfo(shape=(), dtype=np.uint64)
267+
bigint = np.array((1 << 63) + 10, dtype=np.uint64)
268+
# Does not raise value error.
269+
example_serializer._item_to_tf_feature(bigint, tensor_info)
270+
265271
@parameterized.parameters((np.int64), (tf.int64))
266272
def test_item_to_tf_feature_incorrect_shape(self, dtype):
267273
# Test shape check in _item_to_tf_feature raises ValueError.

tensorflow_datasets/core/features/tensor_feature.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,15 +268,27 @@ def _get_value_and_shape(self, example_data):
268268
else:
269269
value = example_data
270270
shape = np_utils.to_np_shape(self._shape)
271+
if (
272+
self._dtype == np.uint64
273+
and not self._encoded_to_bytes
274+
and isinstance(value, np.ndarray)
275+
):
276+
# We can only store int64 inside tf.Example, so if we had a uint64, we
277+
# bitcasted it to int64 at encoding time. Thus, when decoding, we need to
278+
# bitcast it asback to uint64.
279+
value = value.view(np.uint64)
271280
return value, shape
272281

273282
def decode_example(self, tfexample_data):
274283
"""See base class for details."""
275284
value, shape = self._get_value_and_shape(tfexample_data)
285+
decode_dtype = self.tf_dtype if self.tf_dtype != tf.uint64 else tf.int64
276286
if self._encoded_to_bytes:
277287
if self._encoding == Encoding.ZLIB:
278288
value = tf.io.decode_compressed(value, compression_type='ZLIB')
279-
value = tf.io.decode_raw(value, self.tf_dtype)
289+
value = tf.io.decode_raw(value, decode_dtype)
290+
if self.dtype == tf.uint64:
291+
value = tf.bitcast(value, tf.uint64)
280292
value = tf.reshape(value, shape)
281293

282294
return value

tensorflow_datasets/core/features/tensor_feature_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,30 @@ def test_shape_static(
101101
},
102102
)
103103

104+
@parameterized.parameters([
105+
features_lib.Encoding.BYTES,
106+
features_lib.Encoding.ZLIB,
107+
])
108+
def test_uint64_encoded_roundtrip(self, encoding: features_lib.Encoding):
109+
bigint = np.array((1 << 63) + 10, dtype=np.uint64)
110+
feature = features_lib.Tensor(shape=(), dtype=np.uint64, encoding=encoding)
111+
self.assertEqual(
112+
feature.decode_example(feature.encode_example(bigint)),
113+
bigint,
114+
)
115+
self.assertEqual(
116+
feature.decode_example_np(feature.encode_example(bigint)),
117+
bigint,
118+
)
119+
120+
def test_uint64_roundtrip(self):
121+
feature = features_lib.Tensor(shape=(), dtype=np.uint64)
122+
bigint = np.array((1 << 63) + 10, dtype=np.uint64)
123+
# since we are using tf.Example int64 to hold this result, we start with
124+
# the manually encoded (bitcasted) version of the value.
125+
self.assertEqual(feature.decode_example(bigint.view(np.int64)), bigint)
126+
self.assertEqual(feature.decode_example_np(bigint.view(np.int64)), bigint)
127+
104128
@parameterized.parameters([
105129
(np.int32, features_lib.Encoding.NONE),
106130
(tf.int32, features_lib.Encoding.NONE),

0 commit comments

Comments
 (0)