Skip to content

Commit f1beeb7

Browse files
Merge pull request #567 from wwwind:clustering_tidy_up_test
PiperOrigin-RevId: 337365407
2 parents 066e3f7 + 75fe2f1 commit f1beeb7

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,7 @@ def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self):
160160
@parameterized.parameters(
161161
*itertools.product(
162162
range(2, 16, 4),
163-
(
164-
CentroidInitialization.LINEAR,
165-
CentroidInitialization.RANDOM,
166-
CentroidInitialization.DENSITY_BASED
167-
)
163+
[type_centroid for type_centroid in CentroidInitialization]
168164
)
169165
)
170166
def testValuesAreClusteredAfterStripping(self,
@@ -178,6 +174,9 @@ def testValuesAreClusteredAfterStripping(self,
178174
original_model = tf.keras.Sequential([
179175
layers.Dense(32, input_shape=(10,)),
180176
])
177+
self.assertGreater(
178+
len(set(original_model.get_weights()[0].reshape(-1,).tolist())),
179+
number_of_clusters)
181180
clustered_model = cluster.cluster_weights(
182181
original_model,
183182
number_of_clusters=number_of_clusters,

0 commit comments

Comments
 (0)