Skip to content

Commit 1dafd57

Browse files
committed
Re-factoring of the clustering example.
1 parent 9926e78 commit 1dafd57

File tree

1 file changed

+130
-138
lines changed
  • tensorflow_model_optimization/python/examples/clustering/keras/mnist

1 file changed

+130
-138
lines changed

tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py

Lines changed: 130 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
# pylint: disable=missing-docstring
16-
"""Train a simple convnet on the MNIST dataset."""
16+
"""Train a simple convnet on the MNIST dataset and cluster it.
1717
18-
from __future__ import print_function
18+
This example is based on the sample that can be found here:
19+
https://www.tensorflow.org/model_optimization/guide/quantization/training_example
20+
"""
1921

22+
from __future__ import print_function
2023
import datetime
2124
import os
2225

@@ -29,12 +32,10 @@
2932
from tensorflow_model_optimization.python.core.clustering.keras import clustering_callbacks
3033

3134
keras = tf.keras
32-
l = keras.layers
3335

3436
FLAGS = flags.FLAGS
3537

3638
batch_size = 128
37-
num_classes = 10
3839
epochs = 12
3940
epochs_fine_tuning = 4
4041

@@ -43,158 +44,149 @@
4344
'Output directory to hold tensorboard events')
4445

4546

46-
def build_sequential_model(input_shape):
47-
return tf.keras.Sequential([
48-
l.Conv2D(
49-
32, 5, padding='same', activation='relu', input_shape=input_shape),
50-
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
51-
l.BatchNormalization(),
52-
l.Conv2D(64, 5, padding='same', activation='relu'),
53-
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
54-
l.Flatten(),
55-
l.Dense(1024, activation='relu'),
56-
l.Dropout(0.4),
57-
l.Dense(num_classes, activation='softmax')
58-
])
47+
def load_mnist_dataset():
48+
mnist = keras.datasets.mnist
49+
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
5950

51+
# Normalize the input image so that each pixel value is between 0 to 1.
52+
train_images = train_images / 255.0
53+
test_images = test_images / 255.0
6054

61-
def build_functional_model(input_shape):
62-
inp = tf.keras.Input(shape=input_shape)
63-
x = l.Conv2D(32, 5, padding='same', activation='relu')(inp)
64-
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
65-
x = l.BatchNormalization()(x)
66-
x = l.Conv2D(64, 5, padding='same', activation='relu')(x)
67-
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
68-
x = l.Flatten()(x)
69-
x = l.Dense(1024, activation='relu')(x)
70-
x = l.Dropout(0.4)(x)
71-
out = l.Dense(num_classes, activation='softmax')(x)
72-
73-
return tf.keras.models.Model([inp], [out])
74-
75-
def train_and_save(models, x_train, y_train, x_test, y_test):
76-
for model in models:
77-
model.compile(
78-
loss=tf.keras.losses.categorical_crossentropy,
79-
optimizer='adam',
80-
metrics=['accuracy'])
81-
82-
# Print the model summary.
83-
model.summary()
84-
85-
# Model needs to be clustered after initial training
86-
# and having achieved good accuracy
87-
model.fit(
88-
x_train,
89-
y_train,
90-
batch_size=batch_size,
91-
epochs=epochs,
92-
verbose=1,
93-
validation_data=(x_test, y_test))
94-
score = model.evaluate(x_test, y_test, verbose=0)
95-
print('Test loss:', score[0])
96-
print('Test accuracy:', score[1])
97-
98-
print('Clustering model')
99-
100-
clustering_params = {
101-
'number_of_clusters': 8,
102-
'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED
103-
}
104-
105-
# Cluster model
106-
clustered_model = cluster.cluster_weights(model, **clustering_params)
107-
108-
# Use smaller learning rate for fine-tuning
109-
# clustered model
110-
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
111-
112-
clustered_model.compile(
113-
loss=tf.keras.losses.categorical_crossentropy,
114-
optimizer=opt,
115-
metrics=['accuracy'])
55+
return (train_images, train_labels), (test_images, test_labels)
56+
57+
def build_sequential_model():
58+
"Define the model architecture."
11659

117-
# Add callback for tensorboard summaries
118-
log_dir = os.path.join(
119-
FLAGS.output_dir,
120-
datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering"))
121-
callbacks = [
122-
clustering_callbacks.ClusteringSummaries(
123-
log_dir,
124-
cluster_update_freq='epoch',
125-
update_freq='batch',
126-
histogram_freq=1)
127-
]
128-
129-
# Fine-tune model
130-
clustered_model.fit(
131-
x_train,
132-
y_train,
133-
batch_size=batch_size,
134-
epochs=epochs_fine_tuning,
135-
verbose=1,
136-
callbacks=callbacks,
137-
validation_data=(x_test, y_test))
138-
139-
score = clustered_model.evaluate(x_test, y_test, verbose=0)
140-
print('Clustered Model Test loss:', score[0])
141-
print('Clustered Model Test accuracy:', score[1])
142-
143-
#Ensure accuracy persists after stripping the model
144-
stripped_model = cluster.strip_clustering(clustered_model)
145-
146-
stripped_model.compile(
147-
loss=tf.keras.losses.categorical_crossentropy,
60+
return keras.Sequential([
61+
keras.layers.InputLayer(input_shape=(28, 28)),
62+
keras.layers.Reshape(target_shape=(28, 28, 1)),
63+
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
64+
keras.layers.MaxPooling2D(pool_size=(2, 2)),
65+
keras.layers.Flatten(),
66+
keras.layers.Dense(10)
67+
])
68+
69+
70+
def train_model(model, x_train, y_train, x_test, y_test):
71+
model.compile(
72+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
73+
optimizer='adam',
74+
metrics=['accuracy'])
75+
76+
# Print the model summary.
77+
model.summary()
78+
79+
# Model needs to be clustered after initial training
80+
# and having achieved good accuracy
81+
model.fit(
82+
x_train,
83+
y_train,
84+
batch_size=batch_size,
85+
epochs=epochs,
86+
verbose=1,
87+
validation_split=0.1)
88+
89+
score = model.evaluate(x_test, y_test, verbose=0)
90+
print('Test loss:', score[0])
91+
print('Test accuracy:', score[1])
92+
93+
return model
94+
95+
96+
def cluster_model(model, x_train, y_train, x_test, y_test):
97+
print('Clustering model')
98+
99+
clustering_params = {
100+
'number_of_clusters': 8,
101+
'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED
102+
}
103+
104+
# Cluster model
105+
clustered_model = cluster.cluster_weights(model, **clustering_params)
106+
107+
# Use smaller learning rate for fine-tuning
108+
# clustered model
109+
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
110+
111+
clustered_model.compile(
112+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
113+
optimizer=opt,
114+
metrics=['accuracy'])
115+
116+
# Add callback for tensorboard summaries
117+
log_dir = os.path.join(
118+
FLAGS.output_dir,
119+
datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering"))
120+
callbacks = [
121+
clustering_callbacks.ClusteringSummaries(
122+
log_dir,
123+
cluster_update_freq='epoch',
124+
update_freq='batch',
125+
histogram_freq=1)
126+
]
127+
128+
# Fine-tune clustered model
129+
clustered_model.fit(
130+
x_train,
131+
y_train,
132+
batch_size=batch_size,
133+
epochs=epochs_fine_tuning,
134+
verbose=1,
135+
callbacks=callbacks,
136+
validation_split=0.1)
137+
138+
score = clustered_model.evaluate(x_test, y_test, verbose=0)
139+
print('Clustered model test loss:', score[0])
140+
print('Clustered model test accuracy:', score[1])
141+
142+
return clustered_model
143+
144+
145+
def test_clustered_model(clustered_model, x_test, y_test):
146+
# Ensure accuracy persists after serializing/deserializing the model
147+
clustered_model.save('clustered_model.h5')
148+
# To deserialize the clustered model, use the clustering scope
149+
with cluster.cluster_scope():
150+
loaded_clustered_model = keras.models.load_model('clustered_model.h5')
151+
152+
# Checking that the deserialized model's accuracy matches the clustered model
153+
score = loaded_clustered_model.evaluate(x_test, y_test, verbose=0)
154+
print('Deserialized model test loss:', score[0])
155+
print('Deserialized model test accuracy:', score[1])
156+
157+
# Ensure accuracy persists after stripping the model
158+
stripped_model = cluster.strip_clustering(loaded_clustered_model)
159+
stripped_model.compile(
160+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
148161
optimizer='adam',
149162
metrics=['accuracy'])
150-
stripped_model.save('stripped_model.h5')
151163

152-
# To acquire the stripped model,
153-
# deserialize with clustering scope
154-
with cluster.cluster_scope():
155-
loaded_model = keras.models.load_model('stripped_model.h5')
164+
# Checking that the stripped model's accuracy matches the clustered model
165+
score = stripped_model.evaluate(x_test, y_test, verbose=0)
166+
print('Stripped model test loss:', score[0])
167+
print('Stripped model test accuracy:', score[1])
156168

157-
# Checking that the stripped model's accuracy matches the clustered model
158-
score = loaded_model.evaluate(x_test, y_test, verbose=0)
159-
print('Stripped Model Test loss:', score[0])
160-
print('Stripped Model Test accuracy:', score[1])
161169

162170
def main(unused_argv):
163171
if FLAGS.enable_eager:
164172
print('Running in Eager mode.')
165173
tf.compat.v1.enable_eager_execution()
166174

167-
# input image dimensions
168-
img_rows, img_cols = 28, 28
169-
170175
# the data, shuffled and split between train and test sets
171-
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
172-
173-
if tf.keras.backend.image_data_format() == 'channels_first':
174-
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
175-
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
176-
input_shape = (1, img_rows, img_cols)
177-
else:
178-
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
179-
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
180-
input_shape = (img_rows, img_cols, 1)
181-
182-
x_train = x_train.astype('float32')
183-
x_test = x_test.astype('float32')
184-
x_train /= 255
185-
x_test /= 255
176+
(x_train, y_train), (x_test, y_test) = load_mnist_dataset()
177+
186178
print('x_train shape:', x_train.shape)
187179
print(x_train.shape[0], 'train samples')
188180
print(x_test.shape[0], 'test samples')
189181

190-
# convert class vectors to binary class matrices
191-
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
192-
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
193-
194-
sequential_model = build_sequential_model(input_shape)
195-
functional_model = build_functional_model(input_shape)
196-
models = [sequential_model, functional_model]
197-
train_and_save(models, x_train, y_train, x_test, y_test)
182+
# Build model
183+
model = build_sequential_model()
184+
# Train model
185+
model = train_model(model, x_train, y_train, x_test, y_test)
186+
# Cluster and fine-tune model
187+
clustered_model = cluster_model(model, x_train, y_train, x_test, y_test)
188+
# Test clustered model (serialize/deserialize, strip clustering)
189+
test_clustered_model(clustered_model, x_test, y_test)
198190

199191

200192
if __name__ == '__main__':

0 commit comments

Comments
 (0)