|
16 | 16 |
|
17 | 17 | import numpy as np
|
18 | 18 | import tensorflow as tf
|
19 |
| - |
| 19 | +import tempfile |
20 | 20 | from absl.testing import parameterized
|
21 | 21 | from tensorflow.python.keras import keras_parameterized
|
22 | 22 |
|
23 | 23 | from tensorflow_model_optimization.python.core.clustering.keras import cluster
|
24 | 24 | from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
|
| 25 | +from tensorflow_model_optimization.python.core.keras import compat |
| 26 | +import os |
25 | 27 |
|
26 | 28 | keras = tf.keras
|
27 | 29 | layers = keras.layers
|
|
30 | 32 | CentroidInitialization = cluster_config.CentroidInitialization
|
31 | 33 |
|
32 | 34 | class ClusterIntegrationTest(test.TestCase, parameterized.TestCase):
|
33 |
| - """Integration tests for clustering.""" |
34 | 35 |
|
35 |
| - @keras_parameterized.run_all_keras_modes |
36 |
| - def testValuesRemainClusteredAfterTraining(self): |
37 | 36 | """
|
38 |
| - Verifies that training a clustered model does not destroy the clusters. |
| 37 | + Integration tests for clustering. |
39 | 38 | """
|
40 |
| - number_of_clusters = 10 |
41 |
| - original_model = keras.Sequential([ |
42 |
| - layers.Dense(2, input_shape=(2,)), |
43 |
| - layers.Dense(2), |
44 |
| - ]) |
45 |
| - |
46 |
| - clustered_model = cluster.cluster_weights( |
47 |
| - original_model, |
48 |
| - number_of_clusters=number_of_clusters, |
49 |
| - cluster_centroids_init=CentroidInitialization.LINEAR |
50 |
| - ) |
51 |
| - |
52 |
| - clustered_model.compile( |
53 |
| - loss=keras.losses.categorical_crossentropy, |
54 |
| - optimizer='adam', |
55 |
| - metrics=['accuracy'] |
56 |
| - ) |
57 |
| - |
58 |
| - def dataset_generator(): |
59 |
| - x_train = np.array([ |
60 |
| - [0, 1], |
61 |
| - [2, 0], |
62 |
| - [0, 3], |
63 |
| - [4, 1], |
64 |
| - [5, 1], |
65 |
| - ]) |
66 |
| - y_train = np.array([ |
67 |
| - [0, 1], |
68 |
| - [1, 0], |
69 |
| - [1, 0], |
70 |
| - [0, 1], |
71 |
| - [0, 1], |
72 |
| - ]) |
73 |
| - for x, y in zip(x_train, y_train): |
74 |
| - yield np.array([x]), np.array([y]) |
75 |
| - |
76 |
| - clustered_model.fit_generator(dataset_generator(), steps_per_epoch=1) |
77 |
| - stripped_model = cluster.strip_clustering(clustered_model) |
78 |
| - weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist() |
79 |
| - unique_weights = set(weights_as_list) |
80 |
| - self.assertLessEqual(len(unique_weights), number_of_clusters) |
81 |
| - |
82 |
| - |
83 |
| -if __name__ == '__main__': |
84 |
| - test.main() |
| 39 | + def setUp(self): |
| 40 | + self.params = { |
| 41 | + "number_of_clusters": 8, |
| 42 | + "cluster_centroids_init": CentroidInitialization.LINEAR, |
| 43 | + } |
| 44 | + |
| 45 | + self.x_train = np.array( |
| 46 | + [[0.0, 1.0], [2.0, 0.0], [0.0, 3.0], [4.0, 1.0], [5.0, 1.0]], |
| 47 | + dtype="float32", |
| 48 | + ) |
| 49 | + |
| 50 | + self.y_train = np.array( |
| 51 | + [[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], |
| 52 | + dtype="float32", |
| 53 | + ) |
| 54 | + |
| 55 | + def dataset_generator(self): |
| 56 | + for x, y in zip(self.x_train, self.y_train): |
| 57 | + yield np.array([x]), np.array([y]) |
| 58 | + |
| 59 | + @staticmethod |
| 60 | + def _verify_tflite(tflite_file, x_test): |
| 61 | + interpreter = tf.lite.Interpreter(model_path=tflite_file) |
| 62 | + interpreter.allocate_tensors() |
| 63 | + input_index = interpreter.get_input_details()[0]["index"] |
| 64 | + output_index = interpreter.get_output_details()[0]["index"] |
| 65 | + x = x_test[0] |
| 66 | + x = x.reshape((1,) + x.shape) |
| 67 | + interpreter.set_tensor(input_index, x) |
| 68 | + interpreter.invoke() |
| 69 | + interpreter.get_tensor(output_index) |
| 70 | + |
| 71 | + @keras_parameterized.run_all_keras_modes |
| 72 | + def testValuesRemainClusteredAfterTraining(self): |
| 73 | + |
| 74 | + """ |
| 75 | + Verifies that training a clustered model does not destroy the clusters. |
| 76 | + """ |
| 77 | + original_model = keras.Sequential( |
| 78 | + [layers.Dense(2, input_shape=(2,)), layers.Dense(2),] |
| 79 | + ) |
| 80 | + |
| 81 | + clustered_model = cluster.cluster_weights(original_model, **self.params) |
| 82 | + |
| 83 | + clustered_model.compile( |
| 84 | + loss=keras.losses.categorical_crossentropy, |
| 85 | + optimizer="adam", |
| 86 | + metrics=["accuracy"], |
| 87 | + ) |
| 88 | + |
| 89 | + clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) |
| 90 | + stripped_model = cluster.strip_clustering(clustered_model) |
| 91 | + weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist() |
| 92 | + unique_weights = set(weights_as_list) |
| 93 | + self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"]) |
| 94 | + |
| 95 | + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) |
| 96 | + def testEndToEnd(self): |
| 97 | + |
| 98 | + """ |
| 99 | + Test End to End clustering. |
| 100 | + """ |
| 101 | + original_model = keras.Sequential( |
| 102 | + [layers.Dense(2, input_shape=(2,)), layers.Dense(2),] |
| 103 | + ) |
| 104 | + |
| 105 | + clustered_model = cluster.cluster_weights(original_model, **self.params) |
| 106 | + |
| 107 | + clustered_model.compile( |
| 108 | + loss=keras.losses.categorical_crossentropy, |
| 109 | + optimizer="adam", |
| 110 | + metrics=["accuracy"], |
| 111 | + ) |
| 112 | + |
| 113 | + clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) |
| 114 | + stripped_model = cluster.strip_clustering(clustered_model) |
| 115 | + |
| 116 | + _, tflite_file = tempfile.mkstemp(".tflite") |
| 117 | + _, keras_file = tempfile.mkstemp(".h5") |
| 118 | + |
| 119 | + if not compat.is_v1_apis(): |
| 120 | + converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model) |
| 121 | + else: |
| 122 | + tf.keras.models.save_model(stripped_model, keras_file) |
| 123 | + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) |
| 124 | + |
| 125 | + converter.experimental_new_converter = True |
| 126 | + tflite_model = converter.convert() |
| 127 | + with open(tflite_file, "wb") as f: |
| 128 | + f.write(tflite_model) |
| 129 | + |
| 130 | + self._verify_tflite(tflite_file, self.x_train) |
| 131 | + |
| 132 | + os.remove(keras_file) |
| 133 | + os.remove(tflite_file) |
| 134 | + |
| 135 | +if __name__ == "__main__": |
| 136 | + test.main() |
0 commit comments