Skip to content

Commit 226f21e

Browse files
Merge pull request #951 from wwwind:clustering_per_channel_bugfix
PiperOrigin-RevId: 448168313
2 parents 057e721 + e490e7e commit 226f21e

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(self,
154154
if hasattr(layer, 'data_format'):
155155
self.data_format = self.layer.data_format
156156
else:
157-
self.data_format = None
157+
self.data_format = 'channels_last'
158158

159159
# Save the input shape specified in the build
160160
self.build_input_shape = None
@@ -215,7 +215,7 @@ def build(self, input_shape):
215215
centroid_init = centroid_init_factory.get_centroid_initializer(
216216
self.cluster_centroids_init)(weight, self.number_of_clusters,
217217
self.cluster_per_channel,
218-
self.num_channels,
218+
self.data_format,
219219
self.preserve_sparsity)
220220

221221
# Init the cluster centroids

tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,15 @@ class AbstractCentroidsInitialisation:
4444
"""
4545

4646
def __init__(self, weights, number_of_clusters,
47-
cluster_per_channel=False, data_format=None,
47+
cluster_per_channel=False, data_format='channels_last',
4848
preserve_sparsity=False):
49+
50+
# Input checks
51+
if (data_format != 'channels_first' and data_format != 'channels_last'):
52+
raise ValueError(
53+
'The given parameter data_format is not correct: {input}'.format(
54+
input = data_format))
55+
4956
self.weights = weights
5057
self.number_of_clusters = number_of_clusters
5158
self.cluster_per_channel = cluster_per_channel

tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def testThrowsValueErrorForNonExistingInit(self):
6666
with self.assertRaises(ValueError):
6767
self.factory.get_centroid_initializer("DEADBEEF")
6868

69+
def testThrowsValueErrorForWrongDataFormat(self):
70+
"""Verifies that the centroid initializer factory method raises an exception
71+
when invoked with an wrong type of data_format."""
72+
f = self.factory.get_centroid_initializer("CentroidInitialization.KMEANS_PLUS_PLUS")
73+
with self.assertRaises(ValueError):
74+
f([1, 2], 2, True, "NCHW")
75+
6976
@parameterized.parameters(
7077
(0, 0, 1, 1, 1, 0),
7178
(0, 0, 5, 5, 1, 0),

0 commit comments

Comments
 (0)