Skip to content

Commit 3082412

Browse files
liyunlu0618tensorflower-gardener
authored andcommitted
Support pruning nested model recursively.
PiperOrigin-RevId: 367249686
1 parent f08d37a commit 3082412

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/prune.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@ def _prune_list(layers, **params):
163163
return wrapped_layers
164164

165165
def _add_pruning_wrapper(layer):
166+
if isinstance(layer, keras.Model):
167+
# Check whether the model is a subclass model.
168+
if (not layer._is_graph_network and
169+
not isinstance(layer, keras.models.Sequential)):
170+
raise ValueError('Subclassed models are not supported currently.')
171+
172+
return keras.models.clone_model(
173+
layer, input_tensors=None, clone_function=_add_pruning_wrapper)
166174
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
167175
return layer
168176
return pruning_wrapper.PruneLowMagnitude(layer, **params)
@@ -172,6 +180,7 @@ def _add_pruning_wrapper(layer):
172180
'block_size': block_size,
173181
'block_pooling_type': block_pooling_type
174182
}
183+
175184
is_sequential_or_functional = isinstance(
176185
to_prune, keras.Model) and (isinstance(to_prune, keras.Sequential) or
177186
to_prune._is_graph_network)
@@ -183,8 +192,7 @@ def _add_pruning_wrapper(layer):
183192
if isinstance(to_prune, list):
184193
return _prune_list(to_prune, **params)
185194
elif is_sequential_or_functional:
186-
return keras.models.clone_model(
187-
to_prune, input_tensors=None, clone_function=_add_pruning_wrapper)
195+
return _add_pruning_wrapper(to_prune)
188196
elif is_keras_layer:
189197
params.update(kwargs)
190198
return pruning_wrapper.PruneLowMagnitude(to_prune, **params)

tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,27 @@ def testPrunesEmbedding_ReachesTargetSparsity(self):
397397
input_data = np.random.randint(10, size=(32, 5))
398398
self._check_strip_pruning_matches_original(model, 0.5, input_data)
399399

400+
def testPruneRecursivelyReachesTargetSparsity(self):
401+
internal_model = keras.Sequential(
402+
[keras.layers.Dense(10, input_shape=(10,))])
403+
model = keras.Sequential([
404+
internal_model,
405+
layers.Flatten(),
406+
layers.Dense(1),
407+
])
408+
model.compile(
409+
loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
410+
test_utils.assert_model_sparsity(self, 0.0, model)
411+
model.fit(
412+
np.random.randint(10, size=(32, 10)),
413+
np.random.randint(2, size=(32, 1)),
414+
callbacks=[pruning_callbacks.UpdatePruningStep()])
415+
416+
test_utils.assert_model_sparsity(self, 0.5, model)
417+
418+
input_data = np.random.randint(10, size=(32, 10))
419+
self._check_strip_pruning_matches_original(model, 0.5, input_data)
420+
400421
@parameterized.parameters(test_utils.model_type_keys())
401422
def testPrunesMnist_ReachesTargetSparsity(self, model_type):
402423
model = test_utils.build_mnist_model(model_type, self.params)

tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ def testPruneFunctionalModelPreservesBuiltState(self):
304304
json.loads(pruned_model.to_json()))
305305
self.assertEqual(loaded_model.built, True)
306306

307+
def testPruneModelRecursively(self):
308+
internal_model = keras.Sequential(
309+
[keras.layers.Dense(10, input_shape=(10,))])
310+
original_model = keras.Sequential([
311+
internal_model,
312+
layers.Dense(10),
313+
])
314+
pruned_model = prune.prune_low_magnitude(original_model, **self.params)
315+
self.assertEqual(self._count_pruned_layers(pruned_model), 2)
316+
307317
def testPruneSubclassModel(self):
308318
model = TestSubclassedModel()
309319
with self.assertRaises(ValueError) as e:

0 commit comments

Comments
 (0)