File tree Expand file tree Collapse file tree 2 files changed +22
-0
lines changed
tensorflow_model_optimization/python/core/clustering/keras Expand file tree Collapse file tree 2 files changed +22
-0
lines changed Original file line number Diff line number Diff line change @@ -126,7 +126,15 @@ def cluster_weights(to_cluster,
126126 format (cluster_centroids_init ))
127127
128128 def _add_clustering_wrapper (layer ):
129+
129130 if (isinstance (layer , keras .Model )):
131+ # Check whether the model is a subclass.
132+ # NB: This check is copied from keras.py file in tensorflow.
133+ # There is no available public API to do this check.
134+ if (not layer ._is_graph_network and
135+ not isinstance (layer , keras .models .Sequential )):
136+ raise ValueError ("SubClass models are not supported currently." )
137+
130138 return keras .models .clone_model (layer ,
131139 input_tensors = None ,
132140 clone_function = _add_clustering_wrapper )
Original file line number Diff line number Diff line change @@ -392,6 +392,20 @@ def testClusterSubclassModel(self):
392392 with self .assertRaises (ValueError ):
393393 _ = cluster .cluster_weights (model , ** self .params )
394394
395+ @keras_parameterized .run_all_keras_modes
396+ def testClusterSubclassModelAsSubmodel (self ):
397+ """
398+ Verifies that attempting to cluster a model with submodel
399+ that is a subclass throws an exception.
400+ """
401+ model_subclass = TestModel ()
402+ model = keras .Sequential ([
403+ layers .Dense (10 ),
404+ model_subclass
405+ ])
406+ with self .assertRaisesRegexp (ValueError , "SubClass models.*" ):
407+ _ = cluster .cluster_weights (model , ** self .params )
408+
395409 @keras_parameterized .run_all_keras_modes
396410 def testStripClusteringSequentialModel (self ):
397411 """
You can’t perform that action at this time.
0 commit comments