|
334 | 334 | },
|
335 | 335 | "outputs": [],
|
336 | 336 | "source": [
|
| 337 | + "#@title\n", |
337 | 338 | "fig = plt.figure(figsize=(12, 5))\n",
|
338 | 339 | "ax0 = fig.add_subplot(121)\n",
|
339 | 340 | "ax0.plot(x, f(x), marker='o')\n",
|
|
695 | 696 | "cell_type": "code",
|
696 | 697 | "execution_count": null,
|
697 | 698 | "metadata": {
|
698 |
| - "id": "mCH8sAf3TTJ2" |
| 699 | + "id": "FQWwcI0Wr0AX" |
699 | 700 | },
|
700 | 701 | "outputs": [],
|
701 | 702 | "source": [
|
702 | 703 | "pred = model(interpolated_images)\n",
|
703 |
| - "pred_proba = tf.nn.softmax(pred, axis=-1)[:, 555]\n", |
704 |
| - "\n", |
| 704 | + "pred_proba = tf.nn.softmax(pred, axis=-1)[:, 555]" |
| 705 | + ] |
| 706 | + }, |
| 707 | + { |
| 708 | + "cell_type": "code", |
| 709 | + "execution_count": null, |
| 710 | + "metadata": { |
| 711 | + "id": "mCH8sAf3TTJ2" |
| 712 | + }, |
| 713 | + "outputs": [], |
| 714 | + "source": [ |
| 715 | + "#@title\n", |
705 | 716 | "plt.figure(figsize=(10, 4))\n",
|
706 | 717 | "ax1 = plt.subplot(1, 2, 1)\n",
|
707 | 718 | "ax1.plot(alphas, pred_proba)\n",
|
|
873 | 884 | },
|
874 | 885 | "outputs": [],
|
875 | 886 | "source": [
|
876 |
| - "@tf.function\n", |
877 | 887 | "def integrated_gradients(baseline,\n",
|
878 | 888 | " image,\n",
|
879 | 889 | " target_class_idx,\n",
|
880 | 890 | " m_steps=50,\n",
|
881 | 891 | " batch_size=32):\n",
|
882 |
| - " # 1. Generate alphas.\n", |
| 892 | + " # Generate alphas.\n", |
883 | 893 | " alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1)\n",
|
884 | 894 | "\n",
|
885 |
| - " # Initialize TensorArray outside loop to collect gradients. \n", |
886 |
| - " gradient_batches = tf.TensorArray(tf.float32, size=m_steps+1)\n", |
| 895 | + " # Collect gradients. \n", |
| 896 | + " gradient_batches = []\n", |
887 | 897 | " \n",
|
888 | 898 | " # Iterate alphas range and batch computation for speed, memory efficiency, and scaling to larger m_steps.\n",
|
889 | 899 | " for alpha in tf.range(0, len(alphas), batch_size):\n",
|
890 | 900 | " from_ = alpha\n",
|
891 | 901 | " to = tf.minimum(from_ + batch_size, len(alphas))\n",
|
892 | 902 | " alpha_batch = alphas[from_:to]\n",
|
893 | 903 | "\n",
|
894 |
| - " # 2. Generate interpolated inputs between baseline and input.\n", |
895 |
| - " interpolated_path_input_batch = interpolate_images(baseline=baseline,\n", |
896 |
| - " image=image,\n", |
897 |
| - " alphas=alpha_batch)\n", |
898 |
| - "\n", |
899 |
| - " # 3. Compute gradients between model outputs and interpolated inputs.\n", |
900 |
| - " gradient_batch = compute_gradients(images=interpolated_path_input_batch,\n", |
901 |
| - " target_class_idx=target_class_idx)\n", |
902 |
| - " \n", |
903 |
| - " # Write batch indices and gradients to extend TensorArray.\n", |
904 |
| - " gradient_batches = gradient_batches.scatter(tf.range(from_, to), gradient_batch) \n", |
905 |
| - " \n", |
| 904 | + " gradient_batch = one_batch(baseline, image, alpha_batch, target_class_idx)\n", |
| 905 | + " gradient_batches.append(gradient_batch)\n", |
| 906 | + " \n", |
906 | 907 | " # Stack path gradients together row-wise into single tensor.\n",
|
907 |
| - " total_gradients = gradient_batches.stack()\n", |
| 908 | + " total_gradients = tf.stack(gradient_batch)\n", |
908 | 909 | "\n",
|
909 |
| - " # 4. Integral approximation through averaging gradients.\n", |
| 910 | + " # Integral approximation through averaging gradients.\n", |
910 | 911 | " avg_gradients = integral_approximation(gradients=total_gradients)\n",
|
911 | 912 | "\n",
|
912 |
| - " # 5. Scale integrated gradients with respect to input.\n", |
| 913 | + " # Scale integrated gradients with respect to input.\n", |
913 | 914 | " integrated_gradients = (image - baseline) * avg_gradients\n",
|
914 | 915 | "\n",
|
915 | 916 | " return integrated_gradients"
|
916 | 917 | ]
|
917 | 918 | },
|
| 919 | + { |
| 920 | + "cell_type": "code", |
| 921 | + "execution_count": null, |
| 922 | + "metadata": { |
| 923 | + "id": "dszwB_Sp0CX0" |
| 924 | + }, |
| 925 | + "outputs": [], |
| 926 | + "source": [ |
| 927 | + "@tf.function\n", |
| 928 | + "def one_batch(baseline, image, alpha_batch, target_class_idx):\n", |
| 929 | + " # Generate interpolated inputs between baseline and input.\n", |
| 930 | + " interpolated_path_input_batch = interpolate_images(baseline=baseline,\n", |
| 931 | + " image=image,\n", |
| 932 | + " alphas=alpha_batch)\n", |
| 933 | + "\n", |
| 934 | + " # Compute gradients between model outputs and interpolated inputs.\n", |
| 935 | + " gradient_batch = compute_gradients(images=interpolated_path_input_batch,\n", |
| 936 | + " target_class_idx=target_class_idx)\n", |
| 937 | + " return gradient_batch" |
| 938 | + ] |
| 939 | + }, |
918 | 940 | {
|
919 | 941 | "cell_type": "code",
|
920 | 942 | "execution_count": null,
|
|
984 | 1006 | },
|
985 | 1007 | "outputs": [],
|
986 | 1008 | "source": [
|
| 1009 | + "#@title\n", |
987 | 1010 | "def plot_img_attributions(baseline,\n",
|
988 | 1011 | " image,\n",
|
989 | 1012 | " target_class_idx,\n",
|
|
1128 | 1151 | "colab": {
|
1129 | 1152 | "collapsed_sections": [],
|
1130 | 1153 | "name": "integrated_gradients.ipynb",
|
| 1154 | + "private_outputs": true, |
1131 | 1155 | "toc_visible": true
|
1132 | 1156 | },
|
1133 | 1157 | "kernelspec": {
|
|
0 commit comments