Skip to content

Commit 570444d

Browse files
karimnosseirtensorflower-gardener
authored andcommitted
Remove explicit line "experimental_new_converter = True" the converter launched and is now the default.
PiperOrigin-RevId: 318582066
1 parent 226ca71 commit 570444d

File tree

2 files changed

+99
-105
lines changed

2 files changed

+99
-105
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 99 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -14,123 +14,119 @@
1414
# ==============================================================================
1515
"""End-to-end tests for keras clustering API."""
1616

17-
import numpy as np
18-
import tensorflow as tf
17+
import os
1918
import tempfile
19+
2020
from absl.testing import parameterized
21-
from tensorflow.python.keras import keras_parameterized
21+
import numpy as np
22+
import tensorflow as tf
2223

24+
from tensorflow.python.keras import keras_parameterized
2325
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2426
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2527
from tensorflow_model_optimization.python.core.keras import compat
26-
import os
2728

2829
keras = tf.keras
2930
layers = keras.layers
3031
test = tf.test
3132

3233
CentroidInitialization = cluster_config.CentroidInitialization
3334

35+
3436
class ClusterIntegrationTest(test.TestCase, parameterized.TestCase):
37+
"""Integration tests for clustering."""
38+
39+
def setUp(self):
40+
self.params = {
41+
"number_of_clusters": 8,
42+
"cluster_centroids_init": CentroidInitialization.LINEAR,
43+
}
44+
45+
self.x_train = np.array(
46+
[[0.0, 1.0], [2.0, 0.0], [0.0, 3.0], [4.0, 1.0], [5.0, 1.0]],
47+
dtype="float32",
48+
)
49+
50+
self.y_train = np.array(
51+
[[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],
52+
dtype="float32",
53+
)
54+
55+
def dataset_generator(self):
56+
for x, y in zip(self.x_train, self.y_train):
57+
yield np.array([x]), np.array([y])
58+
59+
@staticmethod
60+
def _verify_tflite(tflite_file, x_test):
61+
interpreter = tf.lite.Interpreter(model_path=tflite_file)
62+
interpreter.allocate_tensors()
63+
input_index = interpreter.get_input_details()[0]["index"]
64+
output_index = interpreter.get_output_details()[0]["index"]
65+
x = x_test[0]
66+
x = x.reshape((1,) + x.shape)
67+
interpreter.set_tensor(input_index, x)
68+
interpreter.invoke()
69+
interpreter.get_tensor(output_index)
70+
71+
@keras_parameterized.run_all_keras_modes
72+
def testValuesRemainClusteredAfterTraining(self):
73+
"""Verifies that training a clustered model does not destroy the clusters."""
74+
original_model = keras.Sequential([
75+
layers.Dense(2, input_shape=(2,)),
76+
layers.Dense(2),
77+
])
78+
79+
clustered_model = cluster.cluster_weights(original_model, **self.params)
80+
81+
clustered_model.compile(
82+
loss=keras.losses.categorical_crossentropy,
83+
optimizer="adam",
84+
metrics=["accuracy"],
85+
)
86+
87+
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
88+
stripped_model = cluster.strip_clustering(clustered_model)
89+
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
90+
unique_weights = set(weights_as_list)
91+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
92+
93+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
94+
def testEndToEnd(self):
95+
"""Test End to End clustering."""
96+
original_model = keras.Sequential([
97+
layers.Dense(2, input_shape=(2,)),
98+
layers.Dense(2),
99+
])
100+
101+
clustered_model = cluster.cluster_weights(original_model, **self.params)
102+
103+
clustered_model.compile(
104+
loss=keras.losses.categorical_crossentropy,
105+
optimizer="adam",
106+
metrics=["accuracy"],
107+
)
108+
109+
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
110+
stripped_model = cluster.strip_clustering(clustered_model)
111+
112+
_, tflite_file = tempfile.mkstemp(".tflite")
113+
_, keras_file = tempfile.mkstemp(".h5")
114+
115+
if not compat.is_v1_apis():
116+
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model)
117+
else:
118+
tf.keras.models.save_model(stripped_model, keras_file)
119+
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
120+
121+
tflite_model = converter.convert()
122+
with open(tflite_file, "wb") as f:
123+
f.write(tflite_model)
124+
125+
self._verify_tflite(tflite_file, self.x_train)
126+
127+
os.remove(keras_file)
128+
os.remove(tflite_file)
35129

36-
"""
37-
Integration tests for clustering.
38-
"""
39-
def setUp(self):
40-
self.params = {
41-
"number_of_clusters": 8,
42-
"cluster_centroids_init": CentroidInitialization.LINEAR,
43-
}
44-
45-
self.x_train = np.array(
46-
[[0.0, 1.0], [2.0, 0.0], [0.0, 3.0], [4.0, 1.0], [5.0, 1.0]],
47-
dtype="float32",
48-
)
49-
50-
self.y_train = np.array(
51-
[[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],
52-
dtype="float32",
53-
)
54-
55-
def dataset_generator(self):
56-
for x, y in zip(self.x_train, self.y_train):
57-
yield np.array([x]), np.array([y])
58-
59-
@staticmethod
60-
def _verify_tflite(tflite_file, x_test):
61-
interpreter = tf.lite.Interpreter(model_path=tflite_file)
62-
interpreter.allocate_tensors()
63-
input_index = interpreter.get_input_details()[0]["index"]
64-
output_index = interpreter.get_output_details()[0]["index"]
65-
x = x_test[0]
66-
x = x.reshape((1,) + x.shape)
67-
interpreter.set_tensor(input_index, x)
68-
interpreter.invoke()
69-
interpreter.get_tensor(output_index)
70-
71-
@keras_parameterized.run_all_keras_modes
72-
def testValuesRemainClusteredAfterTraining(self):
73-
74-
"""
75-
Verifies that training a clustered model does not destroy the clusters.
76-
"""
77-
original_model = keras.Sequential(
78-
[layers.Dense(2, input_shape=(2,)), layers.Dense(2),]
79-
)
80-
81-
clustered_model = cluster.cluster_weights(original_model, **self.params)
82-
83-
clustered_model.compile(
84-
loss=keras.losses.categorical_crossentropy,
85-
optimizer="adam",
86-
metrics=["accuracy"],
87-
)
88-
89-
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
90-
stripped_model = cluster.strip_clustering(clustered_model)
91-
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
92-
unique_weights = set(weights_as_list)
93-
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
94-
95-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
96-
def testEndToEnd(self):
97-
98-
"""
99-
Test End to End clustering.
100-
"""
101-
original_model = keras.Sequential(
102-
[layers.Dense(2, input_shape=(2,)), layers.Dense(2),]
103-
)
104-
105-
clustered_model = cluster.cluster_weights(original_model, **self.params)
106-
107-
clustered_model.compile(
108-
loss=keras.losses.categorical_crossentropy,
109-
optimizer="adam",
110-
metrics=["accuracy"],
111-
)
112-
113-
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
114-
stripped_model = cluster.strip_clustering(clustered_model)
115-
116-
_, tflite_file = tempfile.mkstemp(".tflite")
117-
_, keras_file = tempfile.mkstemp(".h5")
118-
119-
if not compat.is_v1_apis():
120-
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model)
121-
else:
122-
tf.keras.models.save_model(stripped_model, keras_file)
123-
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
124-
125-
converter.experimental_new_converter = True
126-
tflite_model = converter.convert()
127-
with open(tflite_file, "wb") as f:
128-
f.write(tflite_model)
129-
130-
self._verify_tflite(tflite_file, self.x_train)
131-
132-
os.remove(keras_file)
133-
os.remove(tflite_file)
134130

135131
if __name__ == "__main__":
136-
test.main()
132+
test.main()

tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def main(unused_argv):
112112
model = train(model, x_train, y_train, x_test, y_test)
113113

114114
converter = tf.lite.TFLiteConverter.from_keras_model(model)
115-
converter.experimental_new_converter = True
116115

117116
# Get a dense model as baseline
118117
tflite_model_dense = converter.convert()
@@ -153,7 +152,6 @@ def main(unused_argv):
153152
model = train(model, x_train, y_train, x_test, y_test)
154153

155154
converter = tf.lite.TFLiteConverter.from_keras_model(model)
156-
converter.experimental_new_converter = True
157155
converter._experimental_sparsify_model = True
158156

159157
tflite_model = converter.convert()

0 commit comments

Comments
 (0)