Skip to content

Commit 849a969

Browse files
committed
Simplify the clustering example
The trainable parameters were reduced (from 3,274,698 to 20,410) to - reduce over-paramterization, - make the example faster to run and - make it easier to visualize the clustering results
1 parent 168c113 commit 849a969

File tree

1 file changed

+115
-123
lines changed
  • tensorflow_model_optimization/python/examples/clustering/keras/mnist

1 file changed

+115
-123
lines changed

tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py

Lines changed: 115 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
# pylint: disable=missing-docstring
16-
"""Train a simple convnet on the MNIST dataset."""
16+
"""Train a simple convnet on the MNIST dataset and cluster it.
1717
18-
from __future__ import print_function
18+
This example is based on the sample that can be found here:
19+
https://www.tensorflow.org/model_optimization/guide/quantization/training_example
20+
"""
1921

20-
import tempfile
22+
from __future__ import print_function
2123

2224
from absl import app as absl_app
2325
from absl import flags
@@ -27,12 +29,10 @@
2729
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2830

2931
keras = tf.keras
30-
l = keras.layers
3132

3233
FLAGS = flags.FLAGS
3334

3435
batch_size = 128
35-
num_classes = 10
3636
epochs = 12
3737
epochs_fine_tuning = 4
3838

@@ -41,145 +41,137 @@
4141
'Output directory to hold tensorboard events')
4242

4343

44-
def build_sequential_model(input_shape):
45-
return tf.keras.Sequential([
46-
l.Conv2D(
47-
32, 5, padding='same', activation='relu', input_shape=input_shape),
48-
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
49-
l.BatchNormalization(),
50-
l.Conv2D(64, 5, padding='same', activation='relu'),
51-
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
52-
l.Flatten(),
53-
l.Dense(1024, activation='relu'),
54-
l.Dropout(0.4),
55-
l.Dense(num_classes, activation='softmax')
56-
])
44+
def load_mnist_dataset():
45+
mnist = keras.datasets.mnist
46+
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
5747

48+
# Normalize the input image so that each pixel value is between 0 to 1.
49+
train_images = train_images / 255.0
50+
test_images = test_images / 255.0
5851

59-
def build_functional_model(input_shape):
60-
inp = tf.keras.Input(shape=input_shape)
61-
x = l.Conv2D(32, 5, padding='same', activation='relu')(inp)
62-
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
63-
x = l.BatchNormalization()(x)
64-
x = l.Conv2D(64, 5, padding='same', activation='relu')(x)
65-
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
66-
x = l.Flatten()(x)
67-
x = l.Dense(1024, activation='relu')(x)
68-
x = l.Dropout(0.4)(x)
69-
out = l.Dense(num_classes, activation='softmax')(x)
70-
71-
return tf.keras.models.Model([inp], [out])
72-
73-
def train_and_save(models, x_train, y_train, x_test, y_test):
74-
for model in models:
75-
model.compile(
76-
loss=tf.keras.losses.categorical_crossentropy,
77-
optimizer='adam',
78-
metrics=['accuracy'])
79-
80-
# Print the model summary.
81-
model.summary()
82-
83-
# Model needs to be clustered after initial training
84-
# and having achieved good accuracy
85-
model.fit(
86-
x_train,
87-
y_train,
88-
batch_size=batch_size,
89-
epochs=epochs,
90-
verbose=1,
91-
validation_data=(x_test, y_test))
92-
score = model.evaluate(x_test, y_test, verbose=0)
93-
print('Test loss:', score[0])
94-
print('Test accuracy:', score[1])
95-
96-
print('Clustering model')
97-
98-
clustering_params = {
99-
'number_of_clusters': 8,
100-
'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED
101-
}
102-
103-
# Cluster model
104-
clustered_model = cluster.cluster_weights(model, **clustering_params)
105-
106-
# Use smaller learning rate for fine-tuning
107-
# clustered model
108-
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
109-
110-
clustered_model.compile(
111-
loss=tf.keras.losses.categorical_crossentropy,
112-
optimizer=opt,
113-
metrics=['accuracy'])
52+
return (train_images, train_labels), (test_images, test_labels)
11453

115-
# Fine-tune model
116-
clustered_model.fit(
117-
x_train,
118-
y_train,
119-
batch_size=batch_size,
120-
epochs=epochs_fine_tuning,
121-
verbose=1,
122-
validation_data=(x_test, y_test))
12354

124-
score = clustered_model.evaluate(x_test, y_test, verbose=0)
125-
print('Clustered Model Test loss:', score[0])
126-
print('Clustered Model Test accuracy:', score[1])
55+
def build_sequential_model():
56+
"Define the model architecture."
12757

128-
#Ensure accuracy persists after stripping the model
129-
stripped_model = cluster.strip_clustering(clustered_model)
58+
return keras.Sequential([
59+
keras.layers.InputLayer(input_shape=(28, 28)),
60+
keras.layers.Reshape(target_shape=(28, 28, 1)),
61+
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
62+
keras.layers.MaxPooling2D(pool_size=(2, 2)),
63+
keras.layers.Flatten(),
64+
keras.layers.Dense(10)
65+
])
13066

131-
stripped_model.compile(
132-
loss=tf.keras.losses.categorical_crossentropy,
67+
68+
def train_model(model, x_train, y_train, x_test, y_test):
69+
model.compile(
70+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
71+
optimizer='adam',
72+
metrics=['accuracy'])
73+
74+
# Print the model summary.
75+
model.summary()
76+
77+
# Model needs to be clustered after initial training
78+
# and having achieved good accuracy
79+
model.fit(
80+
x_train,
81+
y_train,
82+
batch_size=batch_size,
83+
epochs=epochs,
84+
verbose=1,
85+
validation_split=0.1)
86+
87+
score = model.evaluate(x_test, y_test, verbose=0)
88+
print('Test loss:', score[0])
89+
print('Test accuracy:', score[1])
90+
91+
return model
92+
93+
94+
def cluster_model(model, x_train, y_train, x_test, y_test):
95+
print('Clustering model')
96+
97+
clustering_params = {
98+
'number_of_clusters': 8,
99+
'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED
100+
}
101+
102+
# Cluster model
103+
clustered_model = cluster.cluster_weights(model, **clustering_params)
104+
105+
# Use smaller learning rate for fine-tuning
106+
# clustered model
107+
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
108+
109+
clustered_model.compile(
110+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
111+
optimizer=opt,
112+
metrics=['accuracy'])
113+
114+
# Fine-tune clustered model
115+
clustered_model.fit(
116+
x_train,
117+
y_train,
118+
batch_size=batch_size,
119+
epochs=epochs_fine_tuning,
120+
verbose=1,
121+
validation_split=0.1)
122+
123+
score = clustered_model.evaluate(x_test, y_test, verbose=0)
124+
print('Clustered model test loss:', score[0])
125+
print('Clustered model test accuracy:', score[1])
126+
127+
return clustered_model
128+
129+
130+
def test_clustered_model(clustered_model, x_test, y_test):
131+
# Ensure accuracy persists after serializing/deserializing the model
132+
clustered_model.save('clustered_model.h5')
133+
# To deserialize the clustered model, use the clustering scope
134+
with cluster.cluster_scope():
135+
loaded_clustered_model = keras.models.load_model('clustered_model.h5')
136+
137+
# Checking that the deserialized model's accuracy matches the clustered model
138+
score = loaded_clustered_model.evaluate(x_test, y_test, verbose=0)
139+
print('Deserialized model test loss:', score[0])
140+
print('Deserialized model test accuracy:', score[1])
141+
142+
# Ensure accuracy persists after stripping the model
143+
stripped_model = cluster.strip_clustering(loaded_clustered_model)
144+
stripped_model.compile(
145+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
133146
optimizer='adam',
134147
metrics=['accuracy'])
135-
stripped_model.save('stripped_model.h5')
136148

137-
# To acquire the stripped model,
138-
# deserialize with clustering scope
139-
with cluster.cluster_scope():
140-
loaded_model = keras.models.load_model('stripped_model.h5')
149+
# Checking that the stripped model's accuracy matches the clustered model
150+
score = stripped_model.evaluate(x_test, y_test, verbose=0)
151+
print('Stripped model test loss:', score[0])
152+
print('Stripped model test accuracy:', score[1])
141153

142-
# Checking that the stripped model's accuracy matches the clustered model
143-
score = loaded_model.evaluate(x_test, y_test, verbose=0)
144-
print('Stripped Model Test loss:', score[0])
145-
print('Stripped Model Test accuracy:', score[1])
146154

147155
def main(unused_argv):
148156
if FLAGS.enable_eager:
149157
print('Running in Eager mode.')
150158
tf.compat.v1.enable_eager_execution()
151159

152-
# input image dimensions
153-
img_rows, img_cols = 28, 28
154-
155160
# the data, shuffled and split between train and test sets
156-
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
157-
158-
if tf.keras.backend.image_data_format() == 'channels_first':
159-
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
160-
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
161-
input_shape = (1, img_rows, img_cols)
162-
else:
163-
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
164-
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
165-
input_shape = (img_rows, img_cols, 1)
166-
167-
x_train = x_train.astype('float32')
168-
x_test = x_test.astype('float32')
169-
x_train /= 255
170-
x_test /= 255
161+
(x_train, y_train), (x_test, y_test) = load_mnist_dataset()
162+
171163
print('x_train shape:', x_train.shape)
172164
print(x_train.shape[0], 'train samples')
173165
print(x_test.shape[0], 'test samples')
174166

175-
# convert class vectors to binary class matrices
176-
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
177-
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
178-
179-
sequential_model = build_sequential_model(input_shape)
180-
functional_model = build_functional_model(input_shape)
181-
models = [sequential_model, functional_model]
182-
train_and_save(models, x_train, y_train, x_test, y_test)
167+
# Build model
168+
model = build_sequential_model()
169+
# Train model
170+
model = train_model(model, x_train, y_train, x_test, y_test)
171+
# Cluster and fine-tune model
172+
clustered_model = cluster_model(model, x_train, y_train, x_test, y_test)
173+
# Test clustered model (serialize/deserialize, strip clustering)
174+
test_clustered_model(clustered_model, x_test, y_test)
183175

184176

185177
if __name__ == '__main__':

0 commit comments

Comments
 (0)