|
14 | 14 | # ==============================================================================
|
15 | 15 | """End-to-end tests for keras clustering API."""
|
16 | 16 |
|
17 |
| -import numpy as np |
18 |
| -import tensorflow as tf |
| 17 | +import os |
19 | 18 | import tempfile
|
| 19 | + |
20 | 20 | from absl.testing import parameterized
|
21 |
| -from tensorflow.python.keras import keras_parameterized |
| 21 | +import numpy as np |
| 22 | +import tensorflow as tf |
22 | 23 |
|
| 24 | +from tensorflow.python.keras import keras_parameterized |
23 | 25 | from tensorflow_model_optimization.python.core.clustering.keras import cluster
|
24 | 26 | from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
|
25 | 27 | from tensorflow_model_optimization.python.core.keras import compat
|
26 |
| -import os |
27 | 28 |
|
28 | 29 | keras = tf.keras
|
29 | 30 | layers = keras.layers
|
30 | 31 | test = tf.test
|
31 | 32 |
|
32 | 33 | CentroidInitialization = cluster_config.CentroidInitialization
|
33 | 34 |
|
| 35 | + |
34 | 36 | class ClusterIntegrationTest(test.TestCase, parameterized.TestCase):
|
| 37 | + """Integration tests for clustering.""" |
| 38 | + |
| 39 | + def setUp(self): |
| 40 | + self.params = { |
| 41 | + "number_of_clusters": 8, |
| 42 | + "cluster_centroids_init": CentroidInitialization.LINEAR, |
| 43 | + } |
| 44 | + |
| 45 | + self.x_train = np.array( |
| 46 | + [[0.0, 1.0], [2.0, 0.0], [0.0, 3.0], [4.0, 1.0], [5.0, 1.0]], |
| 47 | + dtype="float32", |
| 48 | + ) |
| 49 | + |
| 50 | + self.y_train = np.array( |
| 51 | + [[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], |
| 52 | + dtype="float32", |
| 53 | + ) |
| 54 | + |
| 55 | + def dataset_generator(self): |
| 56 | + for x, y in zip(self.x_train, self.y_train): |
| 57 | + yield np.array([x]), np.array([y]) |
| 58 | + |
| 59 | + @staticmethod |
| 60 | + def _verify_tflite(tflite_file, x_test): |
| 61 | + interpreter = tf.lite.Interpreter(model_path=tflite_file) |
| 62 | + interpreter.allocate_tensors() |
| 63 | + input_index = interpreter.get_input_details()[0]["index"] |
| 64 | + output_index = interpreter.get_output_details()[0]["index"] |
| 65 | + x = x_test[0] |
| 66 | + x = x.reshape((1,) + x.shape) |
| 67 | + interpreter.set_tensor(input_index, x) |
| 68 | + interpreter.invoke() |
| 69 | + interpreter.get_tensor(output_index) |
| 70 | + |
| 71 | + @keras_parameterized.run_all_keras_modes |
| 72 | + def testValuesRemainClusteredAfterTraining(self): |
| 73 | + """Verifies that training a clustered model does not destroy the clusters.""" |
| 74 | + original_model = keras.Sequential([ |
| 75 | + layers.Dense(2, input_shape=(2,)), |
| 76 | + layers.Dense(2), |
| 77 | + ]) |
| 78 | + |
| 79 | + clustered_model = cluster.cluster_weights(original_model, **self.params) |
| 80 | + |
| 81 | + clustered_model.compile( |
| 82 | + loss=keras.losses.categorical_crossentropy, |
| 83 | + optimizer="adam", |
| 84 | + metrics=["accuracy"], |
| 85 | + ) |
| 86 | + |
| 87 | + clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) |
| 88 | + stripped_model = cluster.strip_clustering(clustered_model) |
| 89 | + weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist() |
| 90 | + unique_weights = set(weights_as_list) |
| 91 | + self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"]) |
| 92 | + |
| 93 | + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) |
| 94 | + def testEndToEnd(self): |
| 95 | + """Test End to End clustering.""" |
| 96 | + original_model = keras.Sequential([ |
| 97 | + layers.Dense(2, input_shape=(2,)), |
| 98 | + layers.Dense(2), |
| 99 | + ]) |
| 100 | + |
| 101 | + clustered_model = cluster.cluster_weights(original_model, **self.params) |
| 102 | + |
| 103 | + clustered_model.compile( |
| 104 | + loss=keras.losses.categorical_crossentropy, |
| 105 | + optimizer="adam", |
| 106 | + metrics=["accuracy"], |
| 107 | + ) |
| 108 | + |
| 109 | + clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) |
| 110 | + stripped_model = cluster.strip_clustering(clustered_model) |
| 111 | + |
| 112 | + _, tflite_file = tempfile.mkstemp(".tflite") |
| 113 | + _, keras_file = tempfile.mkstemp(".h5") |
| 114 | + |
| 115 | + if not compat.is_v1_apis(): |
| 116 | + converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model) |
| 117 | + else: |
| 118 | + tf.keras.models.save_model(stripped_model, keras_file) |
| 119 | + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) |
| 120 | + |
| 121 | + tflite_model = converter.convert() |
| 122 | + with open(tflite_file, "wb") as f: |
| 123 | + f.write(tflite_model) |
| 124 | + |
| 125 | + self._verify_tflite(tflite_file, self.x_train) |
| 126 | + |
| 127 | + os.remove(keras_file) |
| 128 | + os.remove(tflite_file) |
35 | 129 |
|
36 |
| - """ |
37 |
| - Integration tests for clustering. |
38 |
| - """ |
39 |
| - def setUp(self): |
40 |
| - self.params = { |
41 |
| - "number_of_clusters": 8, |
42 |
| - "cluster_centroids_init": CentroidInitialization.LINEAR, |
43 |
| - } |
44 |
| - |
45 |
| - self.x_train = np.array( |
46 |
| - [[0.0, 1.0], [2.0, 0.0], [0.0, 3.0], [4.0, 1.0], [5.0, 1.0]], |
47 |
| - dtype="float32", |
48 |
| - ) |
49 |
| - |
50 |
| - self.y_train = np.array( |
51 |
| - [[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], |
52 |
| - dtype="float32", |
53 |
| - ) |
54 |
| - |
55 |
| - def dataset_generator(self): |
56 |
| - for x, y in zip(self.x_train, self.y_train): |
57 |
| - yield np.array([x]), np.array([y]) |
58 |
| - |
59 |
| - @staticmethod |
60 |
| - def _verify_tflite(tflite_file, x_test): |
61 |
| - interpreter = tf.lite.Interpreter(model_path=tflite_file) |
62 |
| - interpreter.allocate_tensors() |
63 |
| - input_index = interpreter.get_input_details()[0]["index"] |
64 |
| - output_index = interpreter.get_output_details()[0]["index"] |
65 |
| - x = x_test[0] |
66 |
| - x = x.reshape((1,) + x.shape) |
67 |
| - interpreter.set_tensor(input_index, x) |
68 |
| - interpreter.invoke() |
69 |
| - interpreter.get_tensor(output_index) |
70 |
| - |
71 |
| - @keras_parameterized.run_all_keras_modes |
72 |
| - def testValuesRemainClusteredAfterTraining(self): |
73 |
| - |
74 |
| - """ |
75 |
| - Verifies that training a clustered model does not destroy the clusters. |
76 |
| - """ |
77 |
| - original_model = keras.Sequential( |
78 |
| - [layers.Dense(2, input_shape=(2,)), layers.Dense(2),] |
79 |
| - ) |
80 |
| - |
81 |
| - clustered_model = cluster.cluster_weights(original_model, **self.params) |
82 |
| - |
83 |
| - clustered_model.compile( |
84 |
| - loss=keras.losses.categorical_crossentropy, |
85 |
| - optimizer="adam", |
86 |
| - metrics=["accuracy"], |
87 |
| - ) |
88 |
| - |
89 |
| - clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) |
90 |
| - stripped_model = cluster.strip_clustering(clustered_model) |
91 |
| - weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist() |
92 |
| - unique_weights = set(weights_as_list) |
93 |
| - self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"]) |
94 |
| - |
95 |
| - @keras_parameterized.run_all_keras_modes(always_skip_v1=True) |
96 |
| - def testEndToEnd(self): |
97 |
| - |
98 |
| - """ |
99 |
| - Test End to End clustering. |
100 |
| - """ |
101 |
| - original_model = keras.Sequential( |
102 |
| - [layers.Dense(2, input_shape=(2,)), layers.Dense(2),] |
103 |
| - ) |
104 |
| - |
105 |
| - clustered_model = cluster.cluster_weights(original_model, **self.params) |
106 |
| - |
107 |
| - clustered_model.compile( |
108 |
| - loss=keras.losses.categorical_crossentropy, |
109 |
| - optimizer="adam", |
110 |
| - metrics=["accuracy"], |
111 |
| - ) |
112 |
| - |
113 |
| - clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) |
114 |
| - stripped_model = cluster.strip_clustering(clustered_model) |
115 |
| - |
116 |
| - _, tflite_file = tempfile.mkstemp(".tflite") |
117 |
| - _, keras_file = tempfile.mkstemp(".h5") |
118 |
| - |
119 |
| - if not compat.is_v1_apis(): |
120 |
| - converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model) |
121 |
| - else: |
122 |
| - tf.keras.models.save_model(stripped_model, keras_file) |
123 |
| - converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) |
124 |
| - |
125 |
| - converter.experimental_new_converter = True |
126 |
| - tflite_model = converter.convert() |
127 |
| - with open(tflite_file, "wb") as f: |
128 |
| - f.write(tflite_model) |
129 |
| - |
130 |
| - self._verify_tflite(tflite_file, self.x_train) |
131 |
| - |
132 |
| - os.remove(keras_file) |
133 |
| - os.remove(tflite_file) |
134 | 130 |
|
135 | 131 | if __name__ == "__main__":
|
136 |
| - test.main() |
| 132 | + test.main() |
0 commit comments