Skip to content

Commit 8da9fb2

Browse files
markmcdcopybara-github
authored andcommitted
Use consistent dtypes for augmentation layers.
We were calling the `data_augmentation()` layer interactively with `uint8` images, and then during training using `float32` images. This change ensures we use consistent dtypes in both instances. PiperOrigin-RevId: 430239809
1 parent f292e1c commit 8da9fb2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

site/en/tutorials/images/data_augmentation.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@
318318
"outputs": [],
319319
"source": [
320320
"# Add the image to a batch.\n",
321-
"image = tf.expand_dims(image, 0)"
321+
"image = tf.cast(tf.expand_dims(image, 0), tf.float32)"
322322
]
323323
},
324324
{

0 commit comments

Comments
 (0)