Skip to content

Commit a30989b

Browse files
authored
Simple audio tutorial: macOS fix, typo, do not mix label variables
1 parent 881afa3 commit a30989b

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

site/en/tutorials/audio/simple_audio.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@
178178
"outputs": [],
179179
"source": [
180180
"commands = np.array(tf.io.gfile.listdir(str(data_dir)))\n",
181-
"commands = commands[commands != 'README.md']\n",
181+
"commands = commands[(commands != 'README.md') & (commands != '.DS_Store')]\n",
182182
"print('Commands:', commands)"
183183
]
184184
},
@@ -266,7 +266,7 @@
266266
},
267267
"source": [
268268
"The `utils.audio_dataset_from_directory` function only returns up to two splits. It's a good idea to keep a test set separate from your validation set.\n",
269-
"Ideally you'd keep it in a separate directory, but in this case you can use `Dataset.shard` to split the validation set into two halves. Note that iterating over **any** shard will load **all** the data, and only keep it's fraction. "
269+
"Ideally you'd keep it in a separate directory, but in this case you can use `Dataset.shard` to split the validation set into two halves. Note that iterating over **any** shard will load **all** the data, and only keep its fraction. "
270270
]
271271
},
272272
{
@@ -547,7 +547,7 @@
547547
" c = i % cols\n",
548548
" ax = axes[r][c]\n",
549549
" plot_spectrogram(example_spectrograms[i].numpy(), ax)\n",
550-
" ax.set_title(commands[example_spect_labels[i].numpy()])\n",
550+
" ax.set_title(label_names[example_spect_labels[i].numpy()])\n",
551551
"\n",
552552
"plt.show()"
553553
]
@@ -609,7 +609,7 @@
609609
"source": [
610610
"input_shape = example_spectrograms.shape[1:]\n",
611611
"print('Input shape:', input_shape)\n",
612-
"num_labels = len(commands)\n",
612+
"num_labels = len(label_names)\n",
613613
"\n",
614614
"# Instantiate the `tf.keras.layers.Normalization` layer.\n",
615615
"norm_layer = layers.Normalization()\n",
@@ -797,8 +797,8 @@
797797
"confusion_mtx = tf.math.confusion_matrix(y_true, y_pred)\n",
798798
"plt.figure(figsize=(10, 8))\n",
799799
"sns.heatmap(confusion_mtx,\n",
800-
" xticklabels=commands,\n",
801-
" yticklabels=commands,\n",
800+
" xticklabels=label_names,\n",
801+
" yticklabels=label_names,\n",
802802
" annot=True, fmt='g')\n",
803803
"plt.xlabel('Prediction')\n",
804804
"plt.ylabel('Label')\n",

0 commit comments

Comments
 (0)