|
13 | 13 | # limitations under the License.
|
14 | 14 | # ==============================================================================
|
15 | 15 | # pylint: disable=missing-docstring
|
16 |
| -"""Train a simple convnet on the MNIST dataset.""" |
| 16 | +"""Train a simple convnet on the MNIST dataset and cluster it. |
17 | 17 |
|
18 |
| -from __future__ import print_function |
| 18 | +This example is based on the sample that can be found here: |
| 19 | +https://www.tensorflow.org/model_optimization/guide/quantization/training_example |
| 20 | +""" |
19 | 21 |
|
20 |
| -import tempfile |
| 22 | +from __future__ import print_function |
21 | 23 |
|
22 | 24 | from absl import app as absl_app
|
23 | 25 | from absl import flags
|
|
27 | 29 | from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
|
28 | 30 |
|
29 | 31 | keras = tf.keras
|
30 |
| -l = keras.layers |
31 | 32 |
|
32 | 33 | FLAGS = flags.FLAGS
|
33 | 34 |
|
34 | 35 | batch_size = 128
|
35 |
| -num_classes = 10 |
36 | 36 | epochs = 12
|
37 | 37 | epochs_fine_tuning = 4
|
38 | 38 |
|
|
41 | 41 | 'Output directory to hold tensorboard events')
|
42 | 42 |
|
43 | 43 |
|
44 |
| -def build_sequential_model(input_shape): |
45 |
| - return tf.keras.Sequential([ |
46 |
| - l.Conv2D( |
47 |
| - 32, 5, padding='same', activation='relu', input_shape=input_shape), |
48 |
| - l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
49 |
| - l.BatchNormalization(), |
50 |
| - l.Conv2D(64, 5, padding='same', activation='relu'), |
51 |
| - l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
52 |
| - l.Flatten(), |
53 |
| - l.Dense(1024, activation='relu'), |
54 |
| - l.Dropout(0.4), |
55 |
| - l.Dense(num_classes, activation='softmax') |
56 |
| - ]) |
| 44 | +def load_mnist_dataset(): |
| 45 | + mnist = keras.datasets.mnist |
| 46 | + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() |
57 | 47 |
|
| 48 | + # Normalize the input image so that each pixel value is between 0 to 1. |
| 49 | + train_images = train_images / 255.0 |
| 50 | + test_images = test_images / 255.0 |
58 | 51 |
|
59 |
| -def build_functional_model(input_shape): |
60 |
| - inp = tf.keras.Input(shape=input_shape) |
61 |
| - x = l.Conv2D(32, 5, padding='same', activation='relu')(inp) |
62 |
| - x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) |
63 |
| - x = l.BatchNormalization()(x) |
64 |
| - x = l.Conv2D(64, 5, padding='same', activation='relu')(x) |
65 |
| - x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) |
66 |
| - x = l.Flatten()(x) |
67 |
| - x = l.Dense(1024, activation='relu')(x) |
68 |
| - x = l.Dropout(0.4)(x) |
69 |
| - out = l.Dense(num_classes, activation='softmax')(x) |
70 |
| - |
71 |
| - return tf.keras.models.Model([inp], [out]) |
72 |
| - |
73 |
| -def train_and_save(models, x_train, y_train, x_test, y_test): |
74 |
| - for model in models: |
75 |
| - model.compile( |
76 |
| - loss=tf.keras.losses.categorical_crossentropy, |
77 |
| - optimizer='adam', |
78 |
| - metrics=['accuracy']) |
79 |
| - |
80 |
| - # Print the model summary. |
81 |
| - model.summary() |
82 |
| - |
83 |
| - # Model needs to be clustered after initial training |
84 |
| - # and having achieved good accuracy |
85 |
| - model.fit( |
86 |
| - x_train, |
87 |
| - y_train, |
88 |
| - batch_size=batch_size, |
89 |
| - epochs=epochs, |
90 |
| - verbose=1, |
91 |
| - validation_data=(x_test, y_test)) |
92 |
| - score = model.evaluate(x_test, y_test, verbose=0) |
93 |
| - print('Test loss:', score[0]) |
94 |
| - print('Test accuracy:', score[1]) |
95 |
| - |
96 |
| - print('Clustering model') |
97 |
| - |
98 |
| - clustering_params = { |
99 |
| - 'number_of_clusters': 8, |
100 |
| - 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED |
101 |
| - } |
102 |
| - |
103 |
| - # Cluster model |
104 |
| - clustered_model = cluster.cluster_weights(model, **clustering_params) |
105 |
| - |
106 |
| - # Use smaller learning rate for fine-tuning |
107 |
| - # clustered model |
108 |
| - opt = tf.keras.optimizers.Adam(learning_rate=1e-5) |
109 |
| - |
110 |
| - clustered_model.compile( |
111 |
| - loss=tf.keras.losses.categorical_crossentropy, |
112 |
| - optimizer=opt, |
113 |
| - metrics=['accuracy']) |
| 52 | + return (train_images, train_labels), (test_images, test_labels) |
114 | 53 |
|
115 |
| - # Fine-tune model |
116 |
| - clustered_model.fit( |
117 |
| - x_train, |
118 |
| - y_train, |
119 |
| - batch_size=batch_size, |
120 |
| - epochs=epochs_fine_tuning, |
121 |
| - verbose=1, |
122 |
| - validation_data=(x_test, y_test)) |
123 | 54 |
|
124 |
| - score = clustered_model.evaluate(x_test, y_test, verbose=0) |
125 |
| - print('Clustered Model Test loss:', score[0]) |
126 |
| - print('Clustered Model Test accuracy:', score[1]) |
| 55 | +def build_sequential_model(): |
| 56 | + "Define the model architecture." |
127 | 57 |
|
128 |
| - #Ensure accuracy persists after stripping the model |
129 |
| - stripped_model = cluster.strip_clustering(clustered_model) |
| 58 | + return keras.Sequential([ |
| 59 | + keras.layers.InputLayer(input_shape=(28, 28)), |
| 60 | + keras.layers.Reshape(target_shape=(28, 28, 1)), |
| 61 | + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'), |
| 62 | + keras.layers.MaxPooling2D(pool_size=(2, 2)), |
| 63 | + keras.layers.Flatten(), |
| 64 | + keras.layers.Dense(10) |
| 65 | + ]) |
130 | 66 |
|
131 |
| - stripped_model.compile( |
132 |
| - loss=tf.keras.losses.categorical_crossentropy, |
| 67 | + |
| 68 | +def train_model(model, x_train, y_train, x_test, y_test): |
| 69 | + model.compile( |
| 70 | + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
| 71 | + optimizer='adam', |
| 72 | + metrics=['accuracy']) |
| 73 | + |
| 74 | + # Print the model summary. |
| 75 | + model.summary() |
| 76 | + |
| 77 | + # Model needs to be clustered after initial training |
| 78 | + # and having achieved good accuracy |
| 79 | + model.fit( |
| 80 | + x_train, |
| 81 | + y_train, |
| 82 | + batch_size=batch_size, |
| 83 | + epochs=epochs, |
| 84 | + verbose=1, |
| 85 | + validation_split=0.1) |
| 86 | + |
| 87 | + score = model.evaluate(x_test, y_test, verbose=0) |
| 88 | + print('Test loss:', score[0]) |
| 89 | + print('Test accuracy:', score[1]) |
| 90 | + |
| 91 | + return model |
| 92 | + |
| 93 | + |
| 94 | +def cluster_model(model, x_train, y_train, x_test, y_test): |
| 95 | + print('Clustering model') |
| 96 | + |
| 97 | + clustering_params = { |
| 98 | + 'number_of_clusters': 8, |
| 99 | + 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED |
| 100 | + } |
| 101 | + |
| 102 | + # Cluster model |
| 103 | + clustered_model = cluster.cluster_weights(model, **clustering_params) |
| 104 | + |
| 105 | + # Use smaller learning rate for fine-tuning |
| 106 | + # clustered model |
| 107 | + opt = tf.keras.optimizers.Adam(learning_rate=1e-5) |
| 108 | + |
| 109 | + clustered_model.compile( |
| 110 | + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
| 111 | + optimizer=opt, |
| 112 | + metrics=['accuracy']) |
| 113 | + |
| 114 | + # Fine-tune clustered model |
| 115 | + clustered_model.fit( |
| 116 | + x_train, |
| 117 | + y_train, |
| 118 | + batch_size=batch_size, |
| 119 | + epochs=epochs_fine_tuning, |
| 120 | + verbose=1, |
| 121 | + validation_split=0.1) |
| 122 | + |
| 123 | + score = clustered_model.evaluate(x_test, y_test, verbose=0) |
| 124 | + print('Clustered model test loss:', score[0]) |
| 125 | + print('Clustered model test accuracy:', score[1]) |
| 126 | + |
| 127 | + return clustered_model |
| 128 | + |
| 129 | + |
| 130 | +def test_clustered_model(clustered_model, x_test, y_test): |
| 131 | + # Ensure accuracy persists after serializing/deserializing the model |
| 132 | + clustered_model.save('clustered_model.h5') |
| 133 | + # To deserialize the clustered model, use the clustering scope |
| 134 | + with cluster.cluster_scope(): |
| 135 | + loaded_clustered_model = keras.models.load_model('clustered_model.h5') |
| 136 | + |
| 137 | + # Checking that the deserialized model's accuracy matches the clustered model |
| 138 | + score = loaded_clustered_model.evaluate(x_test, y_test, verbose=0) |
| 139 | + print('Deserialized model test loss:', score[0]) |
| 140 | + print('Deserialized model test accuracy:', score[1]) |
| 141 | + |
| 142 | + # Ensure accuracy persists after stripping the model |
| 143 | + stripped_model = cluster.strip_clustering(loaded_clustered_model) |
| 144 | + stripped_model.compile( |
| 145 | + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
133 | 146 | optimizer='adam',
|
134 | 147 | metrics=['accuracy'])
|
135 |
| - stripped_model.save('stripped_model.h5') |
136 | 148 |
|
137 |
| - # To acquire the stripped model, |
138 |
| - # deserialize with clustering scope |
139 |
| - with cluster.cluster_scope(): |
140 |
| - loaded_model = keras.models.load_model('stripped_model.h5') |
| 149 | + # Checking that the stripped model's accuracy matches the clustered model |
| 150 | + score = stripped_model.evaluate(x_test, y_test, verbose=0) |
| 151 | + print('Stripped model test loss:', score[0]) |
| 152 | + print('Stripped model test accuracy:', score[1]) |
141 | 153 |
|
142 |
| - # Checking that the stripped model's accuracy matches the clustered model |
143 |
| - score = loaded_model.evaluate(x_test, y_test, verbose=0) |
144 |
| - print('Stripped Model Test loss:', score[0]) |
145 |
| - print('Stripped Model Test accuracy:', score[1]) |
146 | 154 |
|
147 | 155 | def main(unused_argv):
|
148 | 156 | if FLAGS.enable_eager:
|
149 | 157 | print('Running in Eager mode.')
|
150 | 158 | tf.compat.v1.enable_eager_execution()
|
151 | 159 |
|
152 |
| - # input image dimensions |
153 |
| - img_rows, img_cols = 28, 28 |
154 |
| - |
155 | 160 | # the data, shuffled and split between train and test sets
|
156 |
| - (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() |
157 |
| - |
158 |
| - if tf.keras.backend.image_data_format() == 'channels_first': |
159 |
| - x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) |
160 |
| - x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) |
161 |
| - input_shape = (1, img_rows, img_cols) |
162 |
| - else: |
163 |
| - x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) |
164 |
| - x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) |
165 |
| - input_shape = (img_rows, img_cols, 1) |
166 |
| - |
167 |
| - x_train = x_train.astype('float32') |
168 |
| - x_test = x_test.astype('float32') |
169 |
| - x_train /= 255 |
170 |
| - x_test /= 255 |
| 161 | + (x_train, y_train), (x_test, y_test) = load_mnist_dataset() |
| 162 | + |
171 | 163 | print('x_train shape:', x_train.shape)
|
172 | 164 | print(x_train.shape[0], 'train samples')
|
173 | 165 | print(x_test.shape[0], 'test samples')
|
174 | 166 |
|
175 |
| - # convert class vectors to binary class matrices |
176 |
| - y_train = tf.keras.utils.to_categorical(y_train, num_classes) |
177 |
| - y_test = tf.keras.utils.to_categorical(y_test, num_classes) |
178 |
| - |
179 |
| - sequential_model = build_sequential_model(input_shape) |
180 |
| - functional_model = build_functional_model(input_shape) |
181 |
| - models = [sequential_model, functional_model] |
182 |
| - train_and_save(models, x_train, y_train, x_test, y_test) |
| 167 | + # Build model |
| 168 | + model = build_sequential_model() |
| 169 | + # Train model |
| 170 | + model = train_model(model, x_train, y_train, x_test, y_test) |
| 171 | + # Cluster and fine-tune model |
| 172 | + clustered_model = cluster_model(model, x_train, y_train, x_test, y_test) |
| 173 | + # Test clustered model (serialize/deserialize, strip clustering) |
| 174 | + test_clustered_model(clustered_model, x_test, y_test) |
183 | 175 |
|
184 | 176 |
|
185 | 177 | if __name__ == '__main__':
|
|
0 commit comments