Skip to content

Commit caee359

Browse files
evcutensorflower-gardener
authored andcommitted
Ceiling the remaining weights to ensure there is at least 1 connection.
Fixed #215 PiperOrigin-RevId: 338257075
1 parent 8c875fc commit caee359

File tree

4 files changed

+49
-14
lines changed

4 files changed

+49
-14
lines changed

RELEASE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Tested against TensorFlow nightly, and Python 3.
6262
Keras pruning API:
6363

6464
Tested against TensorFlow 1.14.0, 2.0.0, and nightly, and Python 3.
65+
Pruning now doesn't remove the last remaining connection. So extreme sparsities like 0.999.. would remove all connections but one.
6566

6667

6768
# TensorFlow Model Optimization 0.4.0

tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ def _get_params_for_layer(layer_type):
9494
# TODO(tf-mot): fix for Conv2DTranspose on some form of eager,
9595
# with or without functions. The weights become nan (though the
9696
# mask seems fine still).
97-
#layers.Conv2DTranspose: ([2, (3, 3)], (7, 6, 3)),
97+
# layers.Conv2DTranspose: ([2, (3, 3)], (7, 6, 3)),
9898
layers.Conv3D: ([2, (3, 3, 3)], (5, 7, 6, 3)),
9999
# TODO(tf-mot): fix for Conv3DTranspose on some form of eager,
100100
# with or without functions. The weights become nan (though the
101101
# mask seems fine still).
102-
#layers.Conv3DTranspose: ([2, (3, 3, 3)], (5, 7, 6, 3)),
102+
# layers.Conv3DTranspose: ([2, (3, 3, 3)], (5, 7, 6, 3)),
103103
layers.SeparableConv1D: ([4, 3], (3, 6)),
104104
layers.SeparableConv2D: ([4, (2, 2)], (4, 6, 1)),
105105
layers.Dense: ([4], (6,)),
@@ -193,17 +193,19 @@ def testPrunesZeroSparsity_IsNoOp(self):
193193
self._assert_weights_different_objects(model, pruned_model)
194194
self._assert_weights_same_values(model, pruned_model)
195195

196-
# TODO(tfmot): https://github.com/tensorflow/model-optimization/issues/215
197-
def testPruneWithHighSparsity_Fails(self):
196+
def testPruneWithHighSparsity(self):
198197
params = self.params
199198
params['pruning_schedule'] = pruning_schedule.ConstantSparsity(
200199
target_sparsity=0.99, begin_step=0, frequency=1)
201200

202201
model = prune.prune_low_magnitude(
203202
keras_test_utils.build_simple_dense_model(), **params)
204-
205-
with self.assertRaises(tf.errors.InvalidArgumentError):
206-
self._train_model(model, epochs=1)
203+
self._train_model(model, epochs=1)
204+
for layer in model.layers:
205+
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
206+
for weight in layer.layer.get_prunable_weights():
207+
self.assertEqual(
208+
1, np.count_nonzero(tf.keras.backend.get_value(weight)))
207209

208210
###################################################################
209211
# Tests for training with pruning with pretrained models or weights.

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,12 @@ def _update_mask(self, weights):
8585
with tf.name_scope('pruning_ops'):
8686
abs_weights = tf.math.abs(weights)
8787
k = tf.dtypes.cast(
88-
tf.math.round(
89-
tf.dtypes.cast(tf.size(abs_weights), tf.float32) *
90-
(1 - sparsity)), tf.int32)
88+
tf.math.maximum(
89+
tf.math.round(
90+
tf.dtypes.cast(tf.size(abs_weights), tf.float32) *
91+
(1 - sparsity)),
92+
1),
93+
tf.int32)
9194
# Sort the entire array
9295
values, _ = tf.math.top_k(
9396
tf.reshape(abs_weights, [-1]), k=tf.size(abs_weights))

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,37 @@ def training_step_fn():
7171

7272
compat.initialize_variables(self)
7373

74+
def testExtremelySparseMask(self):
75+
weight = tf.Variable(np.linspace(1.0, 100.0, 100), name="weights")
76+
weight_dtype = weight.dtype.base_dtype
77+
mask = tf.Variable(
78+
tf.ones(weight.get_shape(), dtype=weight_dtype),
79+
name="mask",
80+
dtype=weight_dtype)
81+
threshold = tf.Variable(
82+
tf.zeros([], dtype=weight_dtype), name="threshold", dtype=weight_dtype)
83+
self.initialize()
84+
85+
extreme_sparsity = pruning_schedule.ConstantSparsity(0.9999, 0, 100, 1)
86+
p = pruning_impl.Pruning(
87+
pruning_vars=[(weight, mask, threshold)],
88+
training_step_fn=self.training_step_fn,
89+
pruning_schedule=extreme_sparsity,
90+
block_size=self.block_size,
91+
block_pooling_type=self.block_pooling_type)
92+
93+
mask_before_pruning = K.get_value(mask)
94+
self.assertAllEqual(np.count_nonzero(mask_before_pruning), 100)
95+
96+
if tf.executing_eagerly():
97+
p.conditional_mask_update()
98+
else:
99+
K.get_session().run(p.conditional_mask_update())
100+
101+
# We should always have a single connection remaining.
102+
mask_after_pruning = K.get_value(mask)
103+
self.assertAllEqual(np.count_nonzero(mask_after_pruning), 1)
104+
74105
def testUpdateSingleMask(self):
75106
weight = tf.Variable(np.linspace(1.0, 100.0, 100), name="weights")
76107
weight_dtype = weight.dtype.base_dtype
@@ -154,8 +185,7 @@ def testBlockMaskingMax(self):
154185
block_size = (2, 2)
155186
block_pooling_type = "MAX"
156187
weight = tf.constant([[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2],
157-
[0.3, 0.0, 0.4, 0.0], [0.0, -0.3, 0.0,
158-
-0.4]])
188+
[0.3, 0.0, 0.4, 0.0], [0.0, -0.3, 0.0, -0.4]])
159189
expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
160190
[1., 1., 1., 1.], [1., 1., 1., 1.]]
161191

@@ -167,8 +197,7 @@ def testBlockMaskingWithHigherDimensionsRaisesError(self):
167197
block_pooling_type = "AVG"
168198
# Weights as in testBlockMasking, but with one extra dimension.
169199
weight = tf.constant([[[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2],
170-
[0.3, 0.3, 0.4, 0.4], [0.3, 0.3, 0.4,
171-
0.4]]])
200+
[0.3, 0.3, 0.4, 0.4], [0.3, 0.3, 0.4, 0.4]]])
172201
expected_mask = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
173202
[1., 1., 1., 1.], [1., 1., 1., 1.]]]
174203

0 commit comments

Comments
 (0)