|
13 | 13 | # limitations under the License.
|
14 | 14 | # ==============================================================================
|
15 | 15 | # pylint: disable=missing-docstring
|
16 |
| -"""Train a simple convnet on the MNIST dataset.""" |
| 16 | +"""Train a simple convnet on the MNIST dataset and cluster it. |
17 | 17 |
|
18 |
| -from __future__ import print_function |
| 18 | +This example is based on the sample that can be found here: |
| 19 | +https://www.tensorflow.org/model_optimization/guide/quantization/training_example |
| 20 | +""" |
19 | 21 |
|
| 22 | +from __future__ import print_function |
20 | 23 | import datetime
|
21 | 24 | import os
|
22 | 25 |
|
|
29 | 32 | from tensorflow_model_optimization.python.core.clustering.keras import clustering_callbacks
|
30 | 33 |
|
31 | 34 | keras = tf.keras
|
32 |
| -l = keras.layers |
33 | 35 |
|
34 | 36 | FLAGS = flags.FLAGS
|
35 | 37 |
|
36 | 38 | batch_size = 128
|
37 |
| -num_classes = 10 |
38 | 39 | epochs = 12
|
39 | 40 | epochs_fine_tuning = 4
|
40 | 41 |
|
|
43 | 44 | 'Output directory to hold tensorboard events')
|
44 | 45 |
|
45 | 46 |
|
46 |
| -def build_sequential_model(input_shape): |
47 |
| - return tf.keras.Sequential([ |
48 |
| - l.Conv2D( |
49 |
| - 32, 5, padding='same', activation='relu', input_shape=input_shape), |
50 |
| - l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
51 |
| - l.BatchNormalization(), |
52 |
| - l.Conv2D(64, 5, padding='same', activation='relu'), |
53 |
| - l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
54 |
| - l.Flatten(), |
55 |
| - l.Dense(1024, activation='relu'), |
56 |
| - l.Dropout(0.4), |
57 |
| - l.Dense(num_classes, activation='softmax') |
58 |
| - ]) |
| 47 | +def load_mnist_dataset(): |
| 48 | + mnist = keras.datasets.mnist |
| 49 | + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() |
59 | 50 |
|
| 51 | + # Normalize the input image so that each pixel value is between 0 to 1. |
| 52 | + train_images = train_images / 255.0 |
| 53 | + test_images = test_images / 255.0 |
60 | 54 |
|
61 |
| -def build_functional_model(input_shape): |
62 |
| - inp = tf.keras.Input(shape=input_shape) |
63 |
| - x = l.Conv2D(32, 5, padding='same', activation='relu')(inp) |
64 |
| - x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) |
65 |
| - x = l.BatchNormalization()(x) |
66 |
| - x = l.Conv2D(64, 5, padding='same', activation='relu')(x) |
67 |
| - x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) |
68 |
| - x = l.Flatten()(x) |
69 |
| - x = l.Dense(1024, activation='relu')(x) |
70 |
| - x = l.Dropout(0.4)(x) |
71 |
| - out = l.Dense(num_classes, activation='softmax')(x) |
72 |
| - |
73 |
| - return tf.keras.models.Model([inp], [out]) |
74 |
| - |
75 |
| -def train_and_save(models, x_train, y_train, x_test, y_test): |
76 |
| - for model in models: |
77 |
| - model.compile( |
78 |
| - loss=tf.keras.losses.categorical_crossentropy, |
79 |
| - optimizer='adam', |
80 |
| - metrics=['accuracy']) |
81 |
| - |
82 |
| - # Print the model summary. |
83 |
| - model.summary() |
84 |
| - |
85 |
| - # Model needs to be clustered after initial training |
86 |
| - # and having achieved good accuracy |
87 |
| - model.fit( |
88 |
| - x_train, |
89 |
| - y_train, |
90 |
| - batch_size=batch_size, |
91 |
| - epochs=epochs, |
92 |
| - verbose=1, |
93 |
| - validation_data=(x_test, y_test)) |
94 |
| - score = model.evaluate(x_test, y_test, verbose=0) |
95 |
| - print('Test loss:', score[0]) |
96 |
| - print('Test accuracy:', score[1]) |
97 |
| - |
98 |
| - print('Clustering model') |
99 |
| - |
100 |
| - clustering_params = { |
101 |
| - 'number_of_clusters': 8, |
102 |
| - 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED |
103 |
| - } |
104 |
| - |
105 |
| - # Cluster model |
106 |
| - clustered_model = cluster.cluster_weights(model, **clustering_params) |
107 |
| - |
108 |
| - # Use smaller learning rate for fine-tuning |
109 |
| - # clustered model |
110 |
| - opt = tf.keras.optimizers.Adam(learning_rate=1e-5) |
111 |
| - |
112 |
| - clustered_model.compile( |
113 |
| - loss=tf.keras.losses.categorical_crossentropy, |
114 |
| - optimizer=opt, |
115 |
| - metrics=['accuracy']) |
| 55 | + return (train_images, train_labels), (test_images, test_labels) |
| 56 | + |
| 57 | +def build_sequential_model(): |
| 58 | + "Define the model architecture." |
116 | 59 |
|
117 |
| - # Add callback for tensorboard summaries |
118 |
| - log_dir = os.path.join( |
119 |
| - FLAGS.output_dir, |
120 |
| - datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering")) |
121 |
| - callbacks = [ |
122 |
| - clustering_callbacks.ClusteringSummaries( |
123 |
| - log_dir, |
124 |
| - cluster_update_freq='epoch', |
125 |
| - update_freq='batch', |
126 |
| - histogram_freq=1) |
127 |
| - ] |
128 |
| - |
129 |
| - # Fine-tune model |
130 |
| - clustered_model.fit( |
131 |
| - x_train, |
132 |
| - y_train, |
133 |
| - batch_size=batch_size, |
134 |
| - epochs=epochs_fine_tuning, |
135 |
| - verbose=1, |
136 |
| - callbacks=callbacks, |
137 |
| - validation_data=(x_test, y_test)) |
138 |
| - |
139 |
| - score = clustered_model.evaluate(x_test, y_test, verbose=0) |
140 |
| - print('Clustered Model Test loss:', score[0]) |
141 |
| - print('Clustered Model Test accuracy:', score[1]) |
142 |
| - |
143 |
| - #Ensure accuracy persists after stripping the model |
144 |
| - stripped_model = cluster.strip_clustering(clustered_model) |
145 |
| - |
146 |
| - stripped_model.compile( |
147 |
| - loss=tf.keras.losses.categorical_crossentropy, |
| 60 | + return keras.Sequential([ |
| 61 | + keras.layers.InputLayer(input_shape=(28, 28)), |
| 62 | + keras.layers.Reshape(target_shape=(28, 28, 1)), |
| 63 | + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'), |
| 64 | + keras.layers.MaxPooling2D(pool_size=(2, 2)), |
| 65 | + keras.layers.Flatten(), |
| 66 | + keras.layers.Dense(10) |
| 67 | + ]) |
| 68 | + |
| 69 | + |
| 70 | +def train_model(model, x_train, y_train, x_test, y_test): |
| 71 | + model.compile( |
| 72 | + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
| 73 | + optimizer='adam', |
| 74 | + metrics=['accuracy']) |
| 75 | + |
| 76 | + # Print the model summary. |
| 77 | + model.summary() |
| 78 | + |
| 79 | + # Model needs to be clustered after initial training |
| 80 | + # and having achieved good accuracy |
| 81 | + model.fit( |
| 82 | + x_train, |
| 83 | + y_train, |
| 84 | + batch_size=batch_size, |
| 85 | + epochs=epochs, |
| 86 | + verbose=1, |
| 87 | + validation_split=0.1) |
| 88 | + |
| 89 | + score = model.evaluate(x_test, y_test, verbose=0) |
| 90 | + print('Test loss:', score[0]) |
| 91 | + print('Test accuracy:', score[1]) |
| 92 | + |
| 93 | + return model |
| 94 | + |
| 95 | + |
| 96 | +def cluster_model(model, x_train, y_train, x_test, y_test): |
| 97 | + print('Clustering model') |
| 98 | + |
| 99 | + clustering_params = { |
| 100 | + 'number_of_clusters': 8, |
| 101 | + 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED |
| 102 | + } |
| 103 | + |
| 104 | + # Cluster model |
| 105 | + clustered_model = cluster.cluster_weights(model, **clustering_params) |
| 106 | + |
| 107 | + # Use smaller learning rate for fine-tuning |
| 108 | + # clustered model |
| 109 | + opt = tf.keras.optimizers.Adam(learning_rate=1e-5) |
| 110 | + |
| 111 | + clustered_model.compile( |
| 112 | + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
| 113 | + optimizer=opt, |
| 114 | + metrics=['accuracy']) |
| 115 | + |
| 116 | + # Add callback for tensorboard summaries |
| 117 | + log_dir = os.path.join( |
| 118 | + FLAGS.output_dir, |
| 119 | + datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering")) |
| 120 | + callbacks = [ |
| 121 | + clustering_callbacks.ClusteringSummaries( |
| 122 | + log_dir, |
| 123 | + cluster_update_freq='epoch', |
| 124 | + update_freq='batch', |
| 125 | + histogram_freq=1) |
| 126 | + ] |
| 127 | + |
| 128 | + # Fine-tune clustered model |
| 129 | + clustered_model.fit( |
| 130 | + x_train, |
| 131 | + y_train, |
| 132 | + batch_size=batch_size, |
| 133 | + epochs=epochs_fine_tuning, |
| 134 | + verbose=1, |
| 135 | + callbacks=callbacks, |
| 136 | + validation_split=0.1) |
| 137 | + |
| 138 | + score = clustered_model.evaluate(x_test, y_test, verbose=0) |
| 139 | + print('Clustered model test loss:', score[0]) |
| 140 | + print('Clustered model test accuracy:', score[1]) |
| 141 | + |
| 142 | + return clustered_model |
| 143 | + |
| 144 | + |
| 145 | +def test_clustered_model(clustered_model, x_test, y_test): |
| 146 | + # Ensure accuracy persists after serializing/deserializing the model |
| 147 | + clustered_model.save('clustered_model.h5') |
| 148 | + # To deserialize the clustered model, use the clustering scope |
| 149 | + with cluster.cluster_scope(): |
| 150 | + loaded_clustered_model = keras.models.load_model('clustered_model.h5') |
| 151 | + |
| 152 | + # Checking that the deserialized model's accuracy matches the clustered model |
| 153 | + score = loaded_clustered_model.evaluate(x_test, y_test, verbose=0) |
| 154 | + print('Deserialized model test loss:', score[0]) |
| 155 | + print('Deserialized model test accuracy:', score[1]) |
| 156 | + |
| 157 | + # Ensure accuracy persists after stripping the model |
| 158 | + stripped_model = cluster.strip_clustering(loaded_clustered_model) |
| 159 | + stripped_model.compile( |
| 160 | + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
148 | 161 | optimizer='adam',
|
149 | 162 | metrics=['accuracy'])
|
150 |
| - stripped_model.save('stripped_model.h5') |
151 | 163 |
|
152 |
| - # To acquire the stripped model, |
153 |
| - # deserialize with clustering scope |
154 |
| - with cluster.cluster_scope(): |
155 |
| - loaded_model = keras.models.load_model('stripped_model.h5') |
| 164 | + # Checking that the stripped model's accuracy matches the clustered model |
| 165 | + score = stripped_model.evaluate(x_test, y_test, verbose=0) |
| 166 | + print('Stripped model test loss:', score[0]) |
| 167 | + print('Stripped model test accuracy:', score[1]) |
156 | 168 |
|
157 |
| - # Checking that the stripped model's accuracy matches the clustered model |
158 |
| - score = loaded_model.evaluate(x_test, y_test, verbose=0) |
159 |
| - print('Stripped Model Test loss:', score[0]) |
160 |
| - print('Stripped Model Test accuracy:', score[1]) |
161 | 169 |
|
162 | 170 | def main(unused_argv):
|
163 | 171 | if FLAGS.enable_eager:
|
164 | 172 | print('Running in Eager mode.')
|
165 | 173 | tf.compat.v1.enable_eager_execution()
|
166 | 174 |
|
167 |
| - # input image dimensions |
168 |
| - img_rows, img_cols = 28, 28 |
169 |
| - |
170 | 175 | # the data, shuffled and split between train and test sets
|
171 |
| - (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() |
172 |
| - |
173 |
| - if tf.keras.backend.image_data_format() == 'channels_first': |
174 |
| - x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) |
175 |
| - x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) |
176 |
| - input_shape = (1, img_rows, img_cols) |
177 |
| - else: |
178 |
| - x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) |
179 |
| - x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) |
180 |
| - input_shape = (img_rows, img_cols, 1) |
181 |
| - |
182 |
| - x_train = x_train.astype('float32') |
183 |
| - x_test = x_test.astype('float32') |
184 |
| - x_train /= 255 |
185 |
| - x_test /= 255 |
| 176 | + (x_train, y_train), (x_test, y_test) = load_mnist_dataset() |
| 177 | + |
186 | 178 | print('x_train shape:', x_train.shape)
|
187 | 179 | print(x_train.shape[0], 'train samples')
|
188 | 180 | print(x_test.shape[0], 'test samples')
|
189 | 181 |
|
190 |
| - # convert class vectors to binary class matrices |
191 |
| - y_train = tf.keras.utils.to_categorical(y_train, num_classes) |
192 |
| - y_test = tf.keras.utils.to_categorical(y_test, num_classes) |
193 |
| - |
194 |
| - sequential_model = build_sequential_model(input_shape) |
195 |
| - functional_model = build_functional_model(input_shape) |
196 |
| - models = [sequential_model, functional_model] |
197 |
| - train_and_save(models, x_train, y_train, x_test, y_test) |
| 182 | + # Build model |
| 183 | + model = build_sequential_model() |
| 184 | + # Train model |
| 185 | + model = train_model(model, x_train, y_train, x_test, y_test) |
| 186 | + # Cluster and fine-tune model |
| 187 | + clustered_model = cluster_model(model, x_train, y_train, x_test, y_test) |
| 188 | + # Test clustered model (serialize/deserialize, strip clustering) |
| 189 | + test_clustered_model(clustered_model, x_test, y_test) |
198 | 190 |
|
199 | 191 |
|
200 | 192 | if __name__ == '__main__':
|
|
0 commit comments