|
74 | 74 | "id": "FbORZA_bQx1G"
|
75 | 75 | },
|
76 | 76 | "source": [
|
77 |
| - "Structural pruning weights from your model to make it sparse in specific pattern can accelerate model inference time with appropriate HW supports. \n", |
| 77 | + "Structural pruning weights from your model to make it sparse in specific pattern can accelerate model inference time with appropriate HW supports.\n", |
78 | 78 | "\n",
|
79 | 79 | "This tutorial shows you how to:\n",
|
80 | 80 | "* Define and train a model on the mnist dataset with a specific structural sparsity\n",
|
|
459 | 459 | "outputs": [],
|
460 | 460 | "source": [
|
461 | 461 | "# Load tflite file with the created pruned model\n",
|
462 |
| - "interpreter = tf.lite.Interpreter(model_path=tflite_file)\n", |
| 462 | + "interpreter = tf.lite.Interpreter(model_path=tflite_file, experimental_preserve_all_tensors=True)\n", |
463 | 463 | "interpreter.allocate_tensors()\n",
|
464 | 464 | "\n",
|
465 | 465 | "details = interpreter.get_tensor_details()\n",
|
|
630 | 630 | "outputs": [],
|
631 | 631 | "source": [
|
632 | 632 | "# Get weights of the convolutional layer that has been pruned with 2 by 4 sparsity.\n",
|
633 |
| - "tensor_name = 'structural_pruning/Conv2D'\n", |
634 |
| - "detail = [x for x in details if tensor_name in x[\"name\"]]\n", |
635 |
| - "tensor_data = interpreter.tensor(detail[1][\"index\"])()\n", |
| 633 | + "op_details = interpreter._get_ops_details()\n", |
| 634 | + "op_name = 'CONV_2D'\n", |
| 635 | + "op_detail = [x for x in op_details if op_name in x[\"op_name\"]]\n", |
| 636 | + "tensor_data = interpreter.tensor(op_detail[1][\"inputs\"][1])()\n", |
636 | 637 | "print(f\"Shape of the weight tensor is {tensor_data.shape}\")"
|
637 | 638 | ]
|
638 | 639 | },
|
|
0 commit comments