1+ # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ # ==============================================================================
15+ # pylint: disable=missing-docstring
16+ """Train a simple model with MultiHeadAttention layer on MNIST dataset
17+ and prune it.
18+ """
19+ import tensorflow as tf
20+
21+ from tensorflow_model_optimization .python .core .keras import test_utils as keras_test_utils
22+ from tensorflow_model_optimization .python .core .sparsity .keras import prune
23+ from tensorflow_model_optimization .python .core .sparsity .keras import pruning_callbacks
24+ from tensorflow_model_optimization .python .core .sparsity .keras import pruning_schedule
25+ from tensorflow_model_optimization .python .core .sparsity .keras import pruning_utils
26+ from tensorflow_model_optimization .python .core .sparsity .keras import pruning_wrapper
27+
28+ tf .random .set_seed (42 )
29+
30+ ConstantSparsity = pruning_schedule .ConstantSparsity
31+
32+ # Load MNIST dataset
33+ mnist = tf .keras .datasets .mnist
34+ (train_images , train_labels ), (test_images , test_labels ) = mnist .load_data ()
35+
36+ # Normalize the input image so that each pixel value is between 0 to 1.
37+ train_images = train_images / 255.0
38+ test_images = test_images / 255.0
39+
40+ # define model
41+ input = tf .keras .layers .Input (shape = (28 , 28 ))
42+ x = tf .keras .layers .MultiHeadAttention (num_heads = 2 , key_dim = 16 , name = "mha" )(
43+ query = input , value = input
44+ )
45+ x = tf .keras .layers .Flatten ()(x )
46+ out = tf .keras .layers .Dense (10 )(x )
47+ model = tf .keras .Model (inputs = input , outputs = out )
48+
49+ # Train the digit classification model
50+ model .compile (
51+ optimizer = "adam" ,
52+ loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True ),
53+ metrics = ["accuracy" ],
54+ )
55+
56+ model .fit (
57+ train_images , train_labels , epochs = 10 , validation_split = 0.1 ,
58+ )
59+
60+ score = model .evaluate (test_images , test_labels , verbose = 0 )
61+ print ('Model test loss:' , score [0 ])
62+ print ('Model test accuracy:' , score [1 ])
63+
64+ # Define parameters for pruning
65+
66+ batch_size = 128
67+ epochs = 3
68+ validation_split = 0.1 # 10% of training set will be used for validation set.
69+
70+ callbacks = [
71+ pruning_callbacks .UpdatePruningStep (),
72+ pruning_callbacks .PruningSummaries (log_dir = '/tmp/logs' )
73+ ]
74+
75+ pruning_params = {
76+ 'pruning_schedule' : ConstantSparsity (0.75 , begin_step = 2000 , frequency = 100 )
77+ }
78+
79+ model_for_pruning = prune .prune_low_magnitude (model , ** pruning_params )
80+
81+ # `prune_low_magnitude` requires a recompile.
82+ model_for_pruning .compile (
83+ optimizer = "adam" ,
84+ loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True ),
85+ metrics = ["accuracy" ],
86+ )
87+
88+ model_for_pruning .fit (
89+ train_images ,
90+ train_labels ,
91+ batch_size = batch_size ,
92+ epochs = epochs ,
93+ callbacks = callbacks ,
94+ validation_split = validation_split ,
95+ )
96+
97+ score = model_for_pruning .evaluate (test_images , test_labels , verbose = 0 )
98+ print ('Pruned model test loss:' , score [0 ])
99+ print ('Pruned model test accuracy:' , score [1 ])
0 commit comments