Skip to content

Commit e38d886

Browse files
abatterytensorflower-gardener
authored andcommitted
Fix structural pruning sparsity notebook
- Enabled tensor preservation option explicitly when creating TFLite interpreter. - The convolution weight search is now done via the operator lookup. PiperOrigin-RevId: 595845109
1 parent a57f59d commit e38d886

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_sparsity_2_by_4.ipynb

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"id": "FbORZA_bQx1G"
7575
},
7676
"source": [
77-
"Structural pruning weights from your model to make it sparse in specific pattern can accelerate model inference time with appropriate HW supports. \n",
77+
"Structural pruning weights from your model to make it sparse in specific pattern can accelerate model inference time with appropriate HW supports.\n",
7878
"\n",
7979
"This tutorial shows you how to:\n",
8080
"* Define and train a model on the mnist dataset with a specific structural sparsity\n",
@@ -459,7 +459,7 @@
459459
"outputs": [],
460460
"source": [
461461
"# Load tflite file with the created pruned model\n",
462-
"interpreter = tf.lite.Interpreter(model_path=tflite_file)\n",
462+
"interpreter = tf.lite.Interpreter(model_path=tflite_file, experimental_preserve_all_tensors=True)\n",
463463
"interpreter.allocate_tensors()\n",
464464
"\n",
465465
"details = interpreter.get_tensor_details()\n",
@@ -630,9 +630,10 @@
630630
"outputs": [],
631631
"source": [
632632
"# Get weights of the convolutional layer that has been pruned with 2 by 4 sparsity.\n",
633-
"tensor_name = 'structural_pruning/Conv2D'\n",
634-
"detail = [x for x in details if tensor_name in x[\"name\"]]\n",
635-
"tensor_data = interpreter.tensor(detail[1][\"index\"])()\n",
633+
"op_details = interpreter._get_ops_details()\n",
634+
"op_name = 'CONV_2D'\n",
635+
"op_detail = [x for x in op_details if op_name in x[\"op_name\"]]\n",
636+
"tensor_data = interpreter.tensor(op_detail[1][\"inputs\"][1])()\n",
636637
"print(f\"Shape of the weight tensor is {tensor_data.shape}\")"
637638
]
638639
},

0 commit comments

Comments
 (0)