Skip to content

Commit 3fdb62f

Browse files
Merge pull request #1558 from rabinadk1:patch-1
PiperOrigin-RevId: 311608883
2 parents 4c10d93 + 29ec013 commit 3fdb62f

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

site/en/tutorials/generative/adversarial_fgsm.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@
155155
"# Helper function to preprocess the image so that it can be inputted in MobileNetV2\n",
156156
"def preprocess(image):\n",
157157
" image = tf.cast(image, tf.float32)\n",
158-
" image = image/255\n",
159158
" image = tf.image.resize(image, (224, 224))\n",
159+
" image = tf.keras.applications.mobilenet_v2.preprocess_input(image)\n",
160160
" image = image[None, ...]\n",
161161
" return image\n",
162162
"\n",
@@ -215,7 +215,7 @@
215215
"outputs": [],
216216
"source": [
217217
"plt.figure()\n",
218-
"plt.imshow(image[0])\n",
218+
"plt.imshow(image[0]*0.5+0.5) # To change [-1, 1] to [0,1]\n",
219219
"_, image_class, class_confidence = get_imagenet_label(image_probs)\n",
220220
"plt.title('{} : {:.2f}% Confidence'.format(image_class, class_confidence*100))\n",
221221
"plt.show()"
@@ -285,7 +285,7 @@
285285
"label = tf.reshape(label, (1, image_probs.shape[-1]))\n",
286286
"\n",
287287
"perturbations = create_adversarial_pattern(image, label)\n",
288-
"plt.imshow(perturbations[0])"
288+
"plt.imshow(perturbations[0]*0.5+0.5); # To change [-1, 1] to [0,1]"
289289
]
290290
},
291291
{
@@ -311,7 +311,7 @@
311311
"def display_images(image, description):\n",
312312
" _, label, confidence = get_imagenet_label(pretrained_model.predict(image))\n",
313313
" plt.figure()\n",
314-
" plt.imshow(image[0])\n",
314+
" plt.imshow(image[0]*0.5+0.5)\n",
315315
" plt.title('{} \\n {} : {:.2f}% Confidence'.format(description,\n",
316316
" label, confidence*100))\n",
317317
" plt.show()"
@@ -333,7 +333,7 @@
333333
"\n",
334334
"for i, eps in enumerate(epsilons):\n",
335335
" adv_x = image + eps*perturbations\n",
336-
" adv_x = tf.clip_by_value(adv_x, 0, 1)\n",
336+
" adv_x = tf.clip_by_value(adv_x, -1, 1)\n",
337337
" display_images(adv_x, descriptions[i])"
338338
]
339339
},

0 commit comments

Comments
 (0)