Skip to content

Commit 888621b

Browse files
Merge pull request #422 from SaoirseARM:toupstream/cluster_saved_model
PiperOrigin-RevId: 318079120
2 parents a396089 + 66e206d commit 888621b

File tree

4 files changed

+128
-56
lines changed

4 files changed

+128
-56
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,5 +137,6 @@ py_test(
137137
deps = [
138138
":cluster",
139139
# tensorflow dep1,
140+
"//tensorflow_model_optimization/python/core/keras:compat",
140141
],
141142
)

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,9 @@ def _strip_clustering_wrapper(layer):
216216
# If the value was not clustered(e.g. bias), we still store a valid
217217
# reference to the tensor. We use this reference to get the value
218218
new_weight_value = k.batch_get_value([weight])[0]
219-
layer.layer.add_weight(
220-
name=name,
221-
shape=new_weight_value.shape,
222-
initializer=initializers.Constant(new_weight_value),
223-
trainable=True
224-
)
219+
setattr(layer.layer,
220+
name,
221+
k.variable(new_weight_value, name=name))
225222
# When all weights are filled with the values, just return the underlying
226223
# layer since it is now fully autonomous from its wrapper
227224
return layer.layer

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 102 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
import numpy as np
1818
import tensorflow as tf
19-
19+
import tempfile
2020
from absl.testing import parameterized
2121
from tensorflow.python.keras import keras_parameterized
2222

2323
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2424
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
25+
from tensorflow_model_optimization.python.core.keras import compat
26+
import os
2527

2628
keras = tf.keras
2729
layers = keras.layers
@@ -30,55 +32,105 @@
3032
CentroidInitialization = cluster_config.CentroidInitialization
3133

3234
class ClusterIntegrationTest(test.TestCase, parameterized.TestCase):
33-
"""Integration tests for clustering."""
3435

35-
@keras_parameterized.run_all_keras_modes
36-
def testValuesRemainClusteredAfterTraining(self):
3736
"""
38-
Verifies that training a clustered model does not destroy the clusters.
37+
Integration tests for clustering.
3938
"""
40-
number_of_clusters = 10
41-
original_model = keras.Sequential([
42-
layers.Dense(2, input_shape=(2,)),
43-
layers.Dense(2),
44-
])
45-
46-
clustered_model = cluster.cluster_weights(
47-
original_model,
48-
number_of_clusters=number_of_clusters,
49-
cluster_centroids_init=CentroidInitialization.LINEAR
50-
)
51-
52-
clustered_model.compile(
53-
loss=keras.losses.categorical_crossentropy,
54-
optimizer='adam',
55-
metrics=['accuracy']
56-
)
57-
58-
def dataset_generator():
59-
x_train = np.array([
60-
[0, 1],
61-
[2, 0],
62-
[0, 3],
63-
[4, 1],
64-
[5, 1],
65-
])
66-
y_train = np.array([
67-
[0, 1],
68-
[1, 0],
69-
[1, 0],
70-
[0, 1],
71-
[0, 1],
72-
])
73-
for x, y in zip(x_train, y_train):
74-
yield np.array([x]), np.array([y])
75-
76-
clustered_model.fit_generator(dataset_generator(), steps_per_epoch=1)
77-
stripped_model = cluster.strip_clustering(clustered_model)
78-
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
79-
unique_weights = set(weights_as_list)
80-
self.assertLessEqual(len(unique_weights), number_of_clusters)
81-
82-
83-
if __name__ == '__main__':
84-
test.main()
39+
def setUp(self):
40+
self.params = {
41+
"number_of_clusters": 8,
42+
"cluster_centroids_init": CentroidInitialization.LINEAR,
43+
}
44+
45+
self.x_train = np.array(
46+
[[0.0, 1.0], [2.0, 0.0], [0.0, 3.0], [4.0, 1.0], [5.0, 1.0]],
47+
dtype="float32",
48+
)
49+
50+
self.y_train = np.array(
51+
[[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],
52+
dtype="float32",
53+
)
54+
55+
def dataset_generator(self):
56+
for x, y in zip(self.x_train, self.y_train):
57+
yield np.array([x]), np.array([y])
58+
59+
@staticmethod
60+
def _verify_tflite(tflite_file, x_test):
61+
interpreter = tf.lite.Interpreter(model_path=tflite_file)
62+
interpreter.allocate_tensors()
63+
input_index = interpreter.get_input_details()[0]["index"]
64+
output_index = interpreter.get_output_details()[0]["index"]
65+
x = x_test[0]
66+
x = x.reshape((1,) + x.shape)
67+
interpreter.set_tensor(input_index, x)
68+
interpreter.invoke()
69+
interpreter.get_tensor(output_index)
70+
71+
@keras_parameterized.run_all_keras_modes
72+
def testValuesRemainClusteredAfterTraining(self):
73+
74+
"""
75+
Verifies that training a clustered model does not destroy the clusters.
76+
"""
77+
original_model = keras.Sequential(
78+
[layers.Dense(2, input_shape=(2,)), layers.Dense(2),]
79+
)
80+
81+
clustered_model = cluster.cluster_weights(original_model, **self.params)
82+
83+
clustered_model.compile(
84+
loss=keras.losses.categorical_crossentropy,
85+
optimizer="adam",
86+
metrics=["accuracy"],
87+
)
88+
89+
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
90+
stripped_model = cluster.strip_clustering(clustered_model)
91+
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
92+
unique_weights = set(weights_as_list)
93+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
94+
95+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
96+
def testEndToEnd(self):
97+
98+
"""
99+
Test End to End clustering.
100+
"""
101+
original_model = keras.Sequential(
102+
[layers.Dense(2, input_shape=(2,)), layers.Dense(2),]
103+
)
104+
105+
clustered_model = cluster.cluster_weights(original_model, **self.params)
106+
107+
clustered_model.compile(
108+
loss=keras.losses.categorical_crossentropy,
109+
optimizer="adam",
110+
metrics=["accuracy"],
111+
)
112+
113+
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
114+
stripped_model = cluster.strip_clustering(clustered_model)
115+
116+
_, tflite_file = tempfile.mkstemp(".tflite")
117+
_, keras_file = tempfile.mkstemp(".h5")
118+
119+
if not compat.is_v1_apis():
120+
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model)
121+
else:
122+
tf.keras.models.save_model(stripped_model, keras_file)
123+
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
124+
125+
converter.experimental_new_converter = True
126+
tflite_model = converter.convert()
127+
with open(tflite_file, "wb") as f:
128+
f.write(tflite_model)
129+
130+
self._verify_tflite(tflite_file, self.x_train)
131+
132+
os.remove(keras_file)
133+
os.remove(tflite_file)
134+
135+
if __name__ == "__main__":
136+
test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,28 @@ def testClusterWeightsStrippedWeights(self):
446446
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
447447
self.assertEqual(len(stripped_model.get_weights()), cluster_weight_length)
448448

449+
@keras_parameterized.run_all_keras_modes
450+
def testStrippedKernel(self):
451+
"""
452+
Verifies that stripping the clustering wrappers from a functional model
453+
restores the layers kernel and the layers weight array to the new clustered weight value .
454+
"""
455+
i1 = keras.Input(shape=(1, 1, 1))
456+
x1 = layers.Conv2D(1, 1)(i1)
457+
outputs = x1
458+
model = keras.Model(inputs=[i1], outputs=outputs)
459+
460+
clustered_model = cluster.cluster_weights(model, **self.params)
461+
clustered_conv2d_layer = clustered_model.layers[1]
462+
clustered_kernel = clustered_conv2d_layer.layer.kernel
463+
stripped_model = cluster.strip_clustering(clustered_model)
464+
stripped_conv2d_layer = stripped_model.layers[1]
465+
466+
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
467+
self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel)
468+
self.assertEqual(stripped_conv2d_layer.kernel,
469+
stripped_conv2d_layer.weights[0])
470+
449471
@keras_parameterized.run_all_keras_modes
450472
def testStripSelectivelyClusteredFunctionalModel(self):
451473
"""

0 commit comments

Comments
 (0)