Skip to content

Commit 208bb04

Browse files
Merge pull request #2240 from MedAymenF:mlp-core-fix
PiperOrigin-RevId: 554991816
2 parents 9287030 + 1057e8b commit 208bb04

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

site/en/guide/core/mlp_core.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
"\n",
9595
"When these perceptrons are stacked, they form structures called dense layers which can then be connected to build a neural network. A dense layer's equation is similar to that of a perceptron's but uses a weight matrix and a bias vector instead: \n",
9696
"\n",
97-
"$$Y = \\mathrm{W}⋅\\mathrm{X} + \\vec{b}$$\n",
97+
"$$Z = \\mathrm{W}⋅\\mathrm{X} + \\vec{b}$$\n",
9898
"\n",
9999
"where\n",
100100
"\n",
@@ -234,7 +234,7 @@
234234
},
235235
"outputs": [],
236236
"source": [
237-
"sns.countplot(y_viz.numpy());\n",
237+
"sns.countplot(x=y_viz.numpy());\n",
238238
"plt.xlabel('Digits')\n",
239239
"plt.title(\"MNIST Digit Distribution\");"
240240
]
@@ -386,8 +386,8 @@
386386
" if not self.built:\n",
387387
" # Infer the input dimension based on first call\n",
388388
" self.in_dim = x.shape[1]\n",
389-
" # Initialize the weights and biases using Xavier scheme\n",
390-
" self.w = tf.Variable(xavier_init(shape=(self.in_dim, self.out_dim)))\n",
389+
" # Initialize the weights and biases\n",
390+
" self.w = tf.Variable(self.weight_init(shape=(self.in_dim, self.out_dim)))\n",
391391
" self.b = tf.Variable(tf.zeros(shape=(self.out_dim,)))\n",
392392
" self.built = True\n",
393393
" # Compute the forward pass\n",
@@ -875,9 +875,9 @@
875875
" label_ind = (y_test == label)\n",
876876
" # extract predictions for specific true label\n",
877877
" pred_label = test_classes[label_ind]\n",
878-
" label_filled = tf.cast(tf.fill(pred_label.shape[0], label), tf.int64)\n",
878+
" labels = y_test[label_ind]\n",
879879
" # compute class-wise accuracy\n",
880-
" label_accs[accuracy_score(pred_label, label_filled).numpy()] = label\n",
880+
" label_accs[accuracy_score(pred_label, labels).numpy()] = label\n",
881881
"for key in sorted(label_accs):\n",
882882
" print(f\"Digit {label_accs[key]}: {key:.3f}\")"
883883
]
@@ -906,7 +906,7 @@
906906
" plt.figure(figsize=(10,10))\n",
907907
" confusion = sk_metrics.confusion_matrix(test_labels.numpy(), \n",
908908
" test_classes.numpy())\n",
909-
" confusion_normalized = confusion / confusion.sum(axis=1)\n",
909+
" confusion_normalized = confusion / confusion.sum(axis=1, keepdims=True)\n",
910910
" axis_labels = range(10)\n",
911911
" ax = sns.heatmap(\n",
912912
" confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,\n",

0 commit comments

Comments
 (0)