Skip to content

Commit e5cd825

Browse files
committed
Extract labels directly from y_test.
1 parent 9ad2e9e commit e5cd825

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

site/en/guide/core/mlp_core.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -901,9 +901,9 @@
901901
" label_ind = (y_test == label)\n",
902902
" # extract predictions for specific true label\n",
903903
" pred_label = test_classes[label_ind]\n",
904-
" label_filled = tf.cast(tf.fill(pred_label.shape[0], label), tf.int64)\n",
904+
" labels = y_test[label_ind]\n",
905905
" # compute class-wise accuracy\n",
906-
" label_accs[accuracy_score(pred_label, label_filled).numpy()] = label\n",
906+
" label_accs[accuracy_score(pred_label, labels).numpy()] = label\n",
907907
"for key in sorted(label_accs):\n",
908908
" print(f\"Digit {label_accs[key]}: {key:.3f}\")"
909909
]

0 commit comments

Comments
 (0)