|
67 | 67 | "source": [
|
68 | 68 | "Welcome to the guide on the structural pruning M by N.\n",
|
69 | 69 | "\n",
|
70 |
| - "Before reading this tutorial it is recommended to get familiar with the concept of pruning and APIs for unstructured pruning:\n", |
| 70 | + "Before reading this tutorial it is recommended to get familiar with the concept of pruning and APIs for random pruning:\n", |
71 | 71 | "* General overview of the pruning technique for the model optimization, see the [overview](https://www.tensorflow.org/model_optimization/guide/pruning).\n",
|
72 | 72 | "* Usage of API's on a single end-to-end example, see the [pruning example](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras).\n",
|
73 | 73 | "\n",
|
|
80 | 80 | "id": "FbORZA_bQx1G"
|
81 | 81 | }
|
82 | 82 | },
|
| 83 | + { |
| 84 | + "cell_type": "markdown", |
| 85 | + "source": [ |
| 86 | + "## Structural pruning M by N" |
| 87 | + ], |
| 88 | + "metadata": {} |
| 89 | + }, |
| 90 | + { |
| 91 | + "cell_type": "markdown", |
| 92 | + "source": [ |
| 93 | + "Structural pruning zeroes out model weights at the beginning of the training\n", |
| 94 | + "process according to the following pattern: M weights are set to zero in the\n", |
| 95 | + "block of N weights. It is important to notice that this pattern affects only the last dimension of the weight tensor for the model that is converted by TensorFlow Lite. For example, `Conv2D` layer weights in TensorFlow Lite have the structure [channel_out, height, width, channel_in] and `Dense` layer weights have the structure [channel_out, channel_in]. The sparsity pattern is applied to the weights in the last dimension: channel_in.\n", |
| 96 | + "Special hardware can benefit from this type of sparsity in the model and inference time can have a speedup up to 2x. Because this pattern lock in sparsity is more restrictive, the accuracy achieved after fine-tuning is worse than with the magnitude-based pruning.\n", |
| 97 | + "It is important to indicate that the pattern is valid only for the model that is converted to tflite.\n", |
| 98 | + "If the model is quantized, then the accuracy could be improved using [collaborative optimization technique](https://blog.tensorflow.org/2021/10/Collaborative-Optimizations.html): Sparsity preserving quantization aware training." |
| 99 | + ], |
| 100 | + "metadata": {} |
| 101 | + }, |
83 | 102 | {
|
84 | 103 | "cell_type": "markdown",
|
85 | 104 | "source": [
|
|
117 | 136 | "execution_count": null,
|
118 | 137 | "source": [
|
119 | 138 | "import tensorflow as tf\n",
|
| 139 | + "from tensorflow import keras\n", |
120 | 140 | "\n",
|
121 | 141 | "import tensorflow_model_optimization as tfmot\n",
|
122 |
| - "prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude\n", |
123 |
| - "\n", |
124 |
| - "from tensorflow import keras" |
| 142 | + "prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude" |
125 | 143 | ],
|
126 | 144 | "outputs": [],
|
127 | 145 | "metadata": {}
|
|
154 | 172 | "cell_type": "markdown",
|
155 | 173 | "source": [
|
156 | 174 | "Define parameters for pruning and specify the type of structural pruning that will be used: (2, 4).\n",
|
157 |
| - "It means that in a block of four elements, two with the lowest magnitude will be set to zero.\n", |
| 175 | + "It means that in a block of four elements, at least two with the lowest magnitude will be set to zero.\n", |
158 | 176 | "\n",
|
159 | 177 | "We don't set `pruning_schedule` parameter. By default, the pruning mask is defined at the first step and it is not updated during the training."
|
160 | 178 | ],
|
|
174 | 192 | {
|
175 | 193 | "cell_type": "markdown",
|
176 | 194 | "source": [
|
177 |
| - "Define parameters for unstructured pruning with the same target sparsity: 50%." |
| 195 | + "Define parameters for random pruning with the target sparsity: 50%." |
178 | 196 | ],
|
179 | 197 | "metadata": {}
|
180 | 198 | },
|
181 | 199 | {
|
182 | 200 | "cell_type": "code",
|
183 | 201 | "execution_count": null,
|
184 | 202 | "source": [
|
185 |
| - "pruning_params_unstructured = {\n", |
| 203 | + "pruning_params_sparsity_0_5 = {\n", |
186 | 204 | " 'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.5,\n",
|
187 | 205 | " begin_step=0,\n",
|
188 | 206 | " frequency=100)\n",
|
|
198 | 216 | "\n",
|
199 | 217 | "In the example below, we prune only some of the layers. We prune `Conv2D` layer with the biggest number of parameters and an internal `Dense` layer.\n",
|
200 | 218 | "\n",
|
201 |
| - "It is important to notice that even if we marked the first `Conv2D` layer to be structural pruned, it is not structurally pruned, because the number of input channels is 1. Therefore, we prune the first `Conv2D` layer with the unstructured pruning.\n" |
| 219 | + "It is important to notice that even if we marked the first `Conv2D` layer to be structural pruned, it is not structurally pruned. We pruned in the channel directory, so we need to have at least 2 (or m channels) input channels. In this case, the number of input channels is 1. Therefore, we prune the first `Conv2D` layer with the random pruning.\n" |
202 | 220 | ],
|
203 | 221 | "metadata": {}
|
204 | 222 | },
|
|
211 | 229 | " keras.layers.Conv2D(\n",
|
212 | 230 | " 32, 5, padding='same', activation='relu',\n",
|
213 | 231 | " input_shape=(28, 28, 1),\n",
|
214 |
| - " name=\"unstructured_pruning\"),\n", |
215 |
| - " **pruning_params_unstructured),\n", |
| 232 | + " name=\"pruning_sparsity_0_5\"),\n", |
| 233 | + " **pruning_params_sparsity_0_5),\n", |
216 | 234 | " keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),\n",
|
217 | 235 | " prune_low_magnitude(\n",
|
218 | 236 | " keras.layers.Conv2D(\n",
|
|
264 | 282 | " callbacks=tfmot.sparsity.keras.UpdatePruningStep(),\n",
|
265 | 283 | " validation_split=0.1)\n",
|
266 | 284 | "\n",
|
267 |
| - "_, model_for_pruning_accuracy = model.evaluate(test_images, test_labels, verbose=0)\n", |
268 |
| - "print('Pruned test accuracy:', model_for_pruning_accuracy)" |
| 285 | + "_, pruned_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)\n", |
| 286 | + "print('Pruned test accuracy:', pruned_model_accuracy)" |
269 | 287 | ],
|
270 | 288 | "outputs": [],
|
271 | 289 | "metadata": {}
|
|
313 | 331 | {
|
314 | 332 | "cell_type": "markdown",
|
315 | 333 | "source": [
|
316 |
| - "## Visualize and check weights." |
| 334 | + "## Visualize and check weights" |
317 | 335 | ],
|
318 | 336 | "metadata": {}
|
319 | 337 | },
|
320 | 338 | {
|
321 | 339 | "cell_type": "markdown",
|
322 | 340 | "source": [
|
323 |
| - "Now let visualize the weights structure in the `Dense` layer pruned with 2/4 sparsity. At first, we need to extract these weights from the tflite file." |
| 341 | + "Now let's visualize the weights structure in the `Dense` layer pruned with 2 by 4 sparsity. At first, we need to extract these weights from the tflite file." |
324 | 342 | ],
|
325 | 343 | "metadata": {}
|
326 | 344 | },
|
|
347 | 365 | {
|
348 | 366 | "cell_type": "markdown",
|
349 | 367 | "source": [
|
350 |
| - "To check that we selected the layer that has been pruned, let us check the shape of the weight tensor." |
| 368 | + "To verify that we selected the correct layer that has been pruned, let us print the shape of the weight tensor." |
351 | 369 | ],
|
352 | 370 | "metadata": {}
|
353 | 371 | },
|
|
376 | 394 | "import matplotlib.pyplot as plt\n",
|
377 | 395 | "import numpy as np\n",
|
378 | 396 | "\n",
|
| 397 | + "# The value 24 is chosen for convenience.\n", |
379 | 398 | "width = height = 24\n",
|
380 | 399 | "\n",
|
381 | 400 | "subset_values_to_display = tensor_data[0:height, 0:width]\n",
|
|
450 | 469 | "cell_type": "code",
|
451 | 470 | "execution_count": null,
|
452 | 471 | "source": [
|
453 |
| - "# Let us get weights of the convolutional layer that has been pruned with 2/4 sparsity.\n", |
| 472 | + "# Let us get weights of the convolutional layer that has been pruned with 2 by 4 sparsity.\n", |
454 | 473 | "tensor_name = 'structural_pruning/Conv2D'\n",
|
455 | 474 | "detail = [x for x in details if tensor_name in x[\"name\"]]\n",
|
456 | 475 | "tensor_data = interpreter.tensor(detail[1][\"index\"])()\n",
|
|
491 | 510 | {
|
492 | 511 | "cell_type": "markdown",
|
493 | 512 | "source": [
|
494 |
| - "Let's see how unstructured weights look. We extract them and display a subset of the weight tensor." |
| 513 | + "Let's see how randomly pruned weights look. We extract them and display a subset of the weight tensor." |
495 | 514 | ],
|
496 | 515 | "metadata": {}
|
497 | 516 | },
|
498 | 517 | {
|
499 | 518 | "cell_type": "code",
|
500 | 519 | "execution_count": null,
|
501 | 520 | "source": [
|
502 |
| - "# Let us get weights of the convolutional layer that has been pruned with unstructured pruning.\n", |
503 |
| - "tensor_name = 'unstructured_pruning/Conv2D'\n", |
| 521 | + "# Let us get weights of the convolutional layer that has been pruned with random pruning.\n", |
| 522 | + "tensor_name = 'pruning_sparsity_0_5/Conv2D'\n", |
504 | 523 | "detail = [x for x in details if tensor_name in x[\"name\"]]\n",
|
505 | 524 | "tensor_data = interpreter.tensor(detail[0][\"index\"])()\n",
|
506 | 525 | "print(f\"Shape of the weight tensor is {tensor_data.shape}\")"
|
|
533 | 552 | {
|
534 | 553 | "cell_type": "markdown",
|
535 | 554 | "source": [
|
536 |
| - "There is a python script included in the TensorFlow Model Optimization Toolkit that could be used to check whether which layers in the model from the given flite file have the structurally pruned weights: [`check_sparsity_m_by_n.py`](https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/sparsity/keras/tools/check_sparsity_m_by_n.py)." |
| 555 | + "There is a python script included in the TensorFlow Model Optimization Toolkit that could be used to check whether which layers in the model from the given flite file have the structurally pruned weights: [`check_sparsity_m_by_n.py`](https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/sparsity/keras/tools/check_sparsity_m_by_n.py). The usage of this tool for the case of 2 by 4 is shown below:" |
| 556 | + ], |
| 557 | + "metadata": {} |
| 558 | + }, |
| 559 | + { |
| 560 | + "cell_type": "markdown", |
| 561 | + "source": [ |
| 562 | + "`python ./tensorflow_model_optimization/python/core/sparsity/keras/tools/check_sparsity_m_by_n.py --model_tflite=pruned_model.tflite --m_by_n=2,4`" |
537 | 563 | ],
|
538 | 564 | "metadata": {}
|
539 | 565 | }
|
|
0 commit comments