Skip to content

Commit 6c9941e

Browse files
committed
Added check before clone_model when we copy layers: if layer is SubClass model, we throw an exception.
This PR addresses reviewer's comment. Change-Id: I0bd72324fe60da7eda3d3c440c68d1797beecd6c
1 parent e29f872 commit 6c9941e

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,15 @@ def cluster_weights(to_cluster,
126126
format(cluster_centroids_init))
127127

128128
def _add_clustering_wrapper(layer):
129+
129130
if (isinstance(layer, keras.Model)):
131+
# Check whether the model is a subclass.
132+
# NB: This check is copied from keras.py file in tensorflow.
133+
# There is no available public API to do this check.
134+
if (not layer._is_graph_network and
135+
not isinstance(layer, keras.models.Sequential)):
136+
raise ValueError("SubClass models are not supported currently.")
137+
130138
return keras.models.clone_model(layer,
131139
input_tensors=None,
132140
clone_function=_add_clustering_wrapper)

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,20 @@ def testClusterSubclassModel(self):
392392
with self.assertRaises(ValueError):
393393
_ = cluster.cluster_weights(model, **self.params)
394394

395+
@keras_parameterized.run_all_keras_modes
396+
def testClusterSubclassModelAsSubmodel(self):
397+
"""
398+
Verifies that attempting to cluster a model with submodel
399+
that is a subclass throws an exception.
400+
"""
401+
model_subclass = TestModel()
402+
model = keras.Sequential([
403+
layers.Dense(10),
404+
model_subclass
405+
])
406+
with self.assertRaisesRegexp(ValueError, "SubClass models.*"):
407+
_ = cluster.cluster_weights(model, **self.params)
408+
395409
@keras_parameterized.run_all_keras_modes
396410
def testStripClusteringSequentialModel(self):
397411
"""

0 commit comments

Comments
 (0)