|
178 | 178 | "outputs": [],
|
179 | 179 | "source": [
|
180 | 180 | "commands = np.array(tf.io.gfile.listdir(str(data_dir)))\n",
|
181 |
| - "commands = commands[commands != 'README.md']\n", |
| 181 | + "commands = commands[(commands != 'README.md') & (commands != '.DS_Store')]\n", |
182 | 182 | "print('Commands:', commands)"
|
183 | 183 | ]
|
184 | 184 | },
|
|
266 | 266 | },
|
267 | 267 | "source": [
|
268 | 268 | "The `utils.audio_dataset_from_directory` function only returns up to two splits. It's a good idea to keep a test set separate from your validation set.\n",
|
269 |
| - "Ideally you'd keep it in a separate directory, but in this case you can use `Dataset.shard` to split the validation set into two halves. Note that iterating over **any** shard will load **all** the data, and only keep it's fraction. " |
| 269 | + "Ideally you'd keep it in a separate directory, but in this case you can use `Dataset.shard` to split the validation set into two halves. Note that iterating over **any** shard will load **all** the data, and only keep its fraction. " |
270 | 270 | ]
|
271 | 271 | },
|
272 | 272 | {
|
|
547 | 547 | " c = i % cols\n",
|
548 | 548 | " ax = axes[r][c]\n",
|
549 | 549 | " plot_spectrogram(example_spectrograms[i].numpy(), ax)\n",
|
550 |
| - " ax.set_title(commands[example_spect_labels[i].numpy()])\n", |
| 550 | + " ax.set_title(label_names[example_spect_labels[i].numpy()])\n", |
551 | 551 | "\n",
|
552 | 552 | "plt.show()"
|
553 | 553 | ]
|
|
609 | 609 | "source": [
|
610 | 610 | "input_shape = example_spectrograms.shape[1:]\n",
|
611 | 611 | "print('Input shape:', input_shape)\n",
|
612 |
| - "num_labels = len(commands)\n", |
| 612 | + "num_labels = len(label_names)\n", |
613 | 613 | "\n",
|
614 | 614 | "# Instantiate the `tf.keras.layers.Normalization` layer.\n",
|
615 | 615 | "norm_layer = layers.Normalization()\n",
|
|
797 | 797 | "confusion_mtx = tf.math.confusion_matrix(y_true, y_pred)\n",
|
798 | 798 | "plt.figure(figsize=(10, 8))\n",
|
799 | 799 | "sns.heatmap(confusion_mtx,\n",
|
800 |
| - " xticklabels=commands,\n", |
801 |
| - " yticklabels=commands,\n", |
| 800 | + " xticklabels=label_names,\n", |
| 801 | + " yticklabels=label_names,\n", |
802 | 802 | " annot=True, fmt='g')\n",
|
803 | 803 | "plt.xlabel('Prediction')\n",
|
804 | 804 | "plt.ylabel('Label')\n",
|
|
0 commit comments