diff --git a/src/python/tensorflow_cloud/core/tests/testdata/keras_tuner_cifar_example.py b/src/python/tensorflow_cloud/core/tests/testdata/keras_tuner_cifar_example.py index 60f3a9c6..1c2d4773 100644 --- a/src/python/tensorflow_cloud/core/tests/testdata/keras_tuner_cifar_example.py +++ b/src/python/tensorflow_cloud/core/tests/testdata/keras_tuner_cifar_example.py @@ -102,9 +102,9 @@ def scale(image, label): return image, label -train_dataset = train_dataset.map(scale).cache() +train_dataset = train_dataset.map(scale,num_parallel_calls=tf.data.AUTOTUNE).cache() train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE) -test_dataset = test_dataset.map(scale).batch(BATCH_SIZE) +test_dataset = test_dataset.map(scale,num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE) tuner.search( train_dataset,