File tree Expand file tree Collapse file tree 2 files changed +22
-0
lines changed
tensorflow_model_optimization/python/core/clustering/keras Expand file tree Collapse file tree 2 files changed +22
-0
lines changed Original file line number Diff line number Diff line change @@ -126,7 +126,15 @@ def cluster_weights(to_cluster,
126
126
format (cluster_centroids_init ))
127
127
128
128
def _add_clustering_wrapper (layer ):
129
+
129
130
if (isinstance (layer , keras .Model )):
131
+ # Check whether the model is a subclass.
132
+ # NB: This check is copied from keras.py file in tensorflow.
133
+ # There is no available public API to do this check.
134
+ if (not layer ._is_graph_network and
135
+ not isinstance (layer , keras .models .Sequential )):
136
+ raise ValueError ("SubClass models are not supported currently." )
137
+
130
138
return keras .models .clone_model (layer ,
131
139
input_tensors = None ,
132
140
clone_function = _add_clustering_wrapper )
Original file line number Diff line number Diff line change @@ -392,6 +392,20 @@ def testClusterSubclassModel(self):
392
392
with self .assertRaises (ValueError ):
393
393
_ = cluster .cluster_weights (model , ** self .params )
394
394
395
+ @keras_parameterized .run_all_keras_modes
396
+ def testClusterSubclassModelAsSubmodel (self ):
397
+ """
398
+ Verifies that attempting to cluster a model with submodel
399
+ that is a subclass throws an exception.
400
+ """
401
+ model_subclass = TestModel ()
402
+ model = keras .Sequential ([
403
+ layers .Dense (10 ),
404
+ model_subclass
405
+ ])
406
+ with self .assertRaisesRegexp (ValueError , "SubClass models.*" ):
407
+ _ = cluster .cluster_weights (model , ** self .params )
408
+
395
409
@keras_parameterized .run_all_keras_modes
396
410
def testStripClusteringSequentialModel (self ):
397
411
"""
You can’t perform that action at this time.
0 commit comments