Skip to content

Commit 786f63a

Browse files
MarkDaoustcopybara-github
authored andcommitted
Move the loop out of the tf.function, so this doesn't OOM.
+ collapse code for plotting functions PiperOrigin-RevId: 420295363
1 parent 6914e83 commit 786f63a

File tree

1 file changed

+46
-22
lines changed

1 file changed

+46
-22
lines changed

site/en/tutorials/interpretability/integrated_gradients.ipynb

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@
334334
},
335335
"outputs": [],
336336
"source": [
337+
"#@title\n",
337338
"fig = plt.figure(figsize=(12, 5))\n",
338339
"ax0 = fig.add_subplot(121)\n",
339340
"ax0.plot(x, f(x), marker='o')\n",
@@ -695,13 +696,23 @@
695696
"cell_type": "code",
696697
"execution_count": null,
697698
"metadata": {
698-
"id": "mCH8sAf3TTJ2"
699+
"id": "FQWwcI0Wr0AX"
699700
},
700701
"outputs": [],
701702
"source": [
702703
"pred = model(interpolated_images)\n",
703-
"pred_proba = tf.nn.softmax(pred, axis=-1)[:, 555]\n",
704-
"\n",
704+
"pred_proba = tf.nn.softmax(pred, axis=-1)[:, 555]"
705+
]
706+
},
707+
{
708+
"cell_type": "code",
709+
"execution_count": null,
710+
"metadata": {
711+
"id": "mCH8sAf3TTJ2"
712+
},
713+
"outputs": [],
714+
"source": [
715+
"#@title\n",
705716
"plt.figure(figsize=(10, 4))\n",
706717
"ax1 = plt.subplot(1, 2, 1)\n",
707718
"ax1.plot(alphas, pred_proba)\n",
@@ -873,48 +884,59 @@
873884
},
874885
"outputs": [],
875886
"source": [
876-
"@tf.function\n",
877887
"def integrated_gradients(baseline,\n",
878888
" image,\n",
879889
" target_class_idx,\n",
880890
" m_steps=50,\n",
881891
" batch_size=32):\n",
882-
" # 1. Generate alphas.\n",
892+
" # Generate alphas.\n",
883893
" alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1)\n",
884894
"\n",
885-
" # Initialize TensorArray outside loop to collect gradients. \n",
886-
" gradient_batches = tf.TensorArray(tf.float32, size=m_steps+1)\n",
895+
" # Collect gradients. \n",
896+
" gradient_batches = []\n",
887897
" \n",
888898
" # Iterate alphas range and batch computation for speed, memory efficiency, and scaling to larger m_steps.\n",
889899
" for alpha in tf.range(0, len(alphas), batch_size):\n",
890900
" from_ = alpha\n",
891901
" to = tf.minimum(from_ + batch_size, len(alphas))\n",
892902
" alpha_batch = alphas[from_:to]\n",
893903
"\n",
894-
" # 2. Generate interpolated inputs between baseline and input.\n",
895-
" interpolated_path_input_batch = interpolate_images(baseline=baseline,\n",
896-
" image=image,\n",
897-
" alphas=alpha_batch)\n",
898-
"\n",
899-
" # 3. Compute gradients between model outputs and interpolated inputs.\n",
900-
" gradient_batch = compute_gradients(images=interpolated_path_input_batch,\n",
901-
" target_class_idx=target_class_idx)\n",
902-
" \n",
903-
" # Write batch indices and gradients to extend TensorArray.\n",
904-
" gradient_batches = gradient_batches.scatter(tf.range(from_, to), gradient_batch) \n",
905-
" \n",
904+
" gradient_batch = one_batch(baseline, image, alpha_batch, target_class_idx)\n",
905+
" gradient_batches.append(gradient_batch)\n",
906+
" \n",
906907
" # Stack path gradients together row-wise into single tensor.\n",
907-
" total_gradients = gradient_batches.stack()\n",
908+
" total_gradients = tf.stack(gradient_batch)\n",
908909
"\n",
909-
" # 4. Integral approximation through averaging gradients.\n",
910+
" # Integral approximation through averaging gradients.\n",
910911
" avg_gradients = integral_approximation(gradients=total_gradients)\n",
911912
"\n",
912-
" # 5. Scale integrated gradients with respect to input.\n",
913+
" # Scale integrated gradients with respect to input.\n",
913914
" integrated_gradients = (image - baseline) * avg_gradients\n",
914915
"\n",
915916
" return integrated_gradients"
916917
]
917918
},
919+
{
920+
"cell_type": "code",
921+
"execution_count": null,
922+
"metadata": {
923+
"id": "dszwB_Sp0CX0"
924+
},
925+
"outputs": [],
926+
"source": [
927+
"@tf.function\n",
928+
"def one_batch(baseline, image, alpha_batch, target_class_idx):\n",
929+
" # Generate interpolated inputs between baseline and input.\n",
930+
" interpolated_path_input_batch = interpolate_images(baseline=baseline,\n",
931+
" image=image,\n",
932+
" alphas=alpha_batch)\n",
933+
"\n",
934+
" # Compute gradients between model outputs and interpolated inputs.\n",
935+
" gradient_batch = compute_gradients(images=interpolated_path_input_batch,\n",
936+
" target_class_idx=target_class_idx)\n",
937+
" return gradient_batch"
938+
]
939+
},
918940
{
919941
"cell_type": "code",
920942
"execution_count": null,
@@ -984,6 +1006,7 @@
9841006
},
9851007
"outputs": [],
9861008
"source": [
1009+
"#@title\n",
9871010
"def plot_img_attributions(baseline,\n",
9881011
" image,\n",
9891012
" target_class_idx,\n",
@@ -1128,6 +1151,7 @@
11281151
"colab": {
11291152
"collapsed_sections": [],
11301153
"name": "integrated_gradients.ipynb",
1154+
"private_outputs": true,
11311155
"toc_visible": true
11321156
},
11331157
"kernelspec": {

0 commit comments

Comments
 (0)