|
74 | 74 | "id": "FbORZA_bQx1G" |
75 | 75 | }, |
76 | 76 | "source": [ |
77 | | - "Structural pruning weights from your model to make it sparse in specific pattern can accelerate model inference time with appropriate HW supports. \n", |
| 77 | + "Structural pruning weights from your model to make it sparse in specific pattern can accelerate model inference time with appropriate HW supports.\n", |
78 | 78 | "\n", |
79 | 79 | "This tutorial shows you how to:\n", |
80 | 80 | "* Define and train a model on the mnist dataset with a specific structural sparsity\n", |
|
459 | 459 | "outputs": [], |
460 | 460 | "source": [ |
461 | 461 | "# Load tflite file with the created pruned model\n", |
462 | | - "interpreter = tf.lite.Interpreter(model_path=tflite_file)\n", |
| 462 | + "interpreter = tf.lite.Interpreter(model_path=tflite_file, experimental_preserve_all_tensors=True)\n", |
463 | 463 | "interpreter.allocate_tensors()\n", |
464 | 464 | "\n", |
465 | 465 | "details = interpreter.get_tensor_details()\n", |
|
630 | 630 | "outputs": [], |
631 | 631 | "source": [ |
632 | 632 | "# Get weights of the convolutional layer that has been pruned with 2 by 4 sparsity.\n", |
633 | | - "tensor_name = 'structural_pruning/Conv2D'\n", |
634 | | - "detail = [x for x in details if tensor_name in x[\"name\"]]\n", |
635 | | - "tensor_data = interpreter.tensor(detail[1][\"index\"])()\n", |
| 633 | + "op_details = interpreter._get_ops_details()\n", |
| 634 | + "op_name = 'CONV_2D'\n", |
| 635 | + "op_detail = [x for x in op_details if op_name in x[\"op_name\"]]\n", |
| 636 | + "tensor_data = interpreter.tensor(op_detail[1][\"inputs\"][1])()\n", |
636 | 637 | "print(f\"Shape of the weight tensor is {tensor_data.shape}\")" |
637 | 638 | ] |
638 | 639 | }, |
|
0 commit comments