Skip to content

Commit 87c06eb

Browse files
Ruomeitensorflower-gardener
authored andcommitted
Copybara import of the project:
-- a248f89 by Ruomei Yan <[email protected]>: Enable differentiable training and update cluster indices COPYBARA_INTEGRATE_REVIEW=#519 from Ruomei:toupstream/enable_differentiable_training a248f89 PiperOrigin-RevId: 333108062
1 parent 48c08d1 commit 87c06eb

File tree

4 files changed

+199
-29
lines changed

4 files changed

+199
-29
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
k = keras.backend
2828
Layer = keras.layers.Layer
2929
Wrapper = keras.layers.Wrapper
30+
CentroidInitialization = cluster_config.CentroidInitialization
3031

3132

3233
class ClusterWeights(Wrapper):
@@ -105,7 +106,7 @@ def __init__(self,
105106
self.number_of_clusters = number_of_clusters
106107

107108
# Stores the pairs of weight names and references to their tensors
108-
self.clustered_vars = []
109+
self.ori_weights_vars_tf = {}
109110

110111
# Stores references to class instances that implement different clustering
111112
# behaviour for different shapes of objects
@@ -194,7 +195,7 @@ def build(self, input_shape):
194195
# initial value taken from a Tensor object. For each weight there is a
195196
# different set of cluster centroids
196197
self.cluster_centroids_tf[weight_name] = self.add_weight(
197-
'cluster_centroids_tf',
198+
'{}{}'.format('cluster_centroids_tf_', weight_name),
198199
shape=(self.number_of_clusters,),
199200
dtype=weight.dtype,
200201
trainable=True,
@@ -217,7 +218,7 @@ def build(self, input_shape):
217218
pulling_indices = self.clustering_impl[weight_name].\
218219
get_pulling_indices(weight)
219220
self.pulling_indices_tf[weight_name] = self.add_weight(
220-
'pulling_indices_tf',
221+
'{}{}'.format('pulling_indices_tf_', weight_name),
221222
shape=pulling_indices.shape,
222223
dtype=tf.int32,
223224
trainable=False,
@@ -227,21 +228,30 @@ def build(self, input_shape):
227228
)
228229

229230
# We store these pairs to easily update this variables later on
230-
self.clustered_vars.append((weight_name, weight))
231+
self.ori_weights_vars_tf[weight_name] = self.add_weight(
232+
'{}{}'.format('ori_weights_vars_tf_', weight_name),
233+
shape=weight.shape,
234+
dtype=weight.dtype,
235+
trainable=True,
236+
initializer=initializers.Constant(
237+
value=k.batch_get_value([weight])[0]
238+
)
239+
)
231240

232241
# We use currying here to get an updater which can be triggered at any time
233242
# in future and it would return the latest version of clustered weights
234243
def get_updater(for_weight_name):
235244
def fn():
236-
return self.clustering_impl[for_weight_name].get_clustered_weight(
237-
self.pulling_indices_tf[for_weight_name]
238-
)
245+
# Get the clustered weights
246+
pulling_indices = self.pulling_indices_tf[for_weight_name]
247+
clustered_weights = self.clustering_impl[for_weight_name].\
248+
get_clustered_weight(pulling_indices)
249+
return clustered_weights
239250

240251
return fn
241252

242253
# This will allow us to restore the order of weights later
243254
# This loop stores pairs of weight names and how to restore them
244-
245255
for ct, weight in enumerate(self.layer.weights):
246256
name = self._weight_name(weight.name)
247257
full_name = self.layer.name + "/" + name
@@ -253,14 +263,26 @@ def fn():
253263
self.restore.append((name, full_name, weight))
254264

255265
def call(self, inputs):
266+
# In the forward pass, we need to update the cluster associations manually
267+
# since they are integers and not differentiable. Gradients won't flow back
268+
# through tf.argmin
256269
# Go through all tensors and replace them with their clustered copies.
257-
for weight_name, _ in self.clustered_vars:
258-
setattr(
259-
self.layer, weight_name,
260-
self.clustering_impl[weight_name].get_clustered_weight(
261-
self.pulling_indices_tf[weight_name]
262-
)
263-
)
270+
for weight_name in self.ori_weights_vars_tf:
271+
pulling_indices = self.pulling_indices_tf[weight_name]
272+
273+
# Update cluster associations
274+
pulling_indices.assign(tf.dtypes.cast(
275+
self.clustering_impl[weight_name].\
276+
get_pulling_indices(self.ori_weights_vars_tf[weight_name]),
277+
pulling_indices.dtype
278+
))
279+
280+
clustered_weights = self.clustering_impl[weight_name].\
281+
get_clustered_weight_forward(pulling_indices,\
282+
self.ori_weights_vars_tf[weight_name])
283+
284+
# Replace the weights with their clustered counterparts
285+
setattr(self.layer, weight_name, clustered_weights)
264286

265287
return self.layer.call(inputs)
266288

@@ -271,7 +293,7 @@ def get_config(self):
271293
base_config = super(ClusterWeights, self).get_config()
272294
config = {
273295
'number_of_clusters': self.number_of_clusters,
274-
'cluster_centroids_init': self.cluster_centroids_init,
296+
'cluster_centroids_init': self.cluster_centroids_init
275297
}
276298
return dict(list(base_config.items()) + list(config.items()))
277299

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Tests for keras ClusterWeights wrapper API."""
1616

1717
import itertools
18-
import numpy as np
1918
import tensorflow as tf
2019

2120
from absl.testing import parameterized
@@ -155,9 +154,9 @@ def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self):
155154
*itertools.product(
156155
range(2, 16, 4),
157156
(
158-
CentroidInitialization.LINEAR,
159-
CentroidInitialization.RANDOM,
160-
CentroidInitialization.DENSITY_BASED
157+
CentroidInitialization.LINEAR,
158+
CentroidInitialization.RANDOM,
159+
CentroidInitialization.DENSITY_BASED
161160
)
162161
)
163162
)
@@ -170,12 +169,8 @@ def testValuesAreClusteredAfterStripping(self,
170169
or equal to number_of_clusters.
171170
"""
172171
original_model = tf.keras.Sequential([
173-
layers.Dense(32, input_shape=(10,), name='abc'),
172+
layers.Dense(32, input_shape=(10,)),
174173
])
175-
176-
weights_name = original_model.layers[0].weights[0].name
177-
bias_name = original_model.layers[0].weights[1].name
178-
179174
clustered_model = cluster.cluster_weights(
180175
original_model,
181176
number_of_clusters=number_of_clusters,
@@ -190,9 +185,67 @@ def testValuesAreClusteredAfterStripping(self,
190185
# Make sure that the stripped layer is the Dense one
191186
self.assertIsInstance(stripped_model.layers[0], layers.Dense)
192187

193-
# Check that we keep names for weights/bias
194-
self.assertEqual(stripped_model.layers[0].weights[0].name, weights_name)
195-
self.assertEqual(stripped_model.layers[0].weights[1].name, bias_name)
188+
def testClusterReassociation(self):
189+
"""
190+
Verifies that the association of weights to cluster centroids are updated
191+
every iteration.
192+
"""
193+
194+
# Create a dummy layer for this test
195+
input_shape = (1, 2,)
196+
l = cluster_wrapper.ClusterWeights(
197+
keras.layers.Dense(8, input_shape=input_shape),
198+
number_of_clusters=2,
199+
cluster_centroids_init=CentroidInitialization.LINEAR
200+
)
201+
# Build a layer with the given shape
202+
l.build(input_shape)
203+
204+
# Get name of the clusterable weights
205+
clusterable_weights = l.layer.get_clusterable_weights()
206+
self.assertEqual(len(clusterable_weights), 1)
207+
weights_name = clusterable_weights[0][0]
208+
self.assertEqual(weights_name, 'kernel')
209+
# Get cluster centroids
210+
centroids = l.cluster_centroids_tf[weights_name]
211+
212+
# Calculate some statistics of the weights to set the centroids later on
213+
mean_weight = tf.reduce_mean(l.layer.kernel)
214+
min_weight = tf.reduce_min(l.layer.kernel)
215+
max_weight = tf.reduce_max(l.layer.kernel)
216+
max_dist = max_weight - min_weight
217+
218+
def assert_all_weights_associated(weights, centroid_index):
219+
"""Helper function to make sure that all weights are associated with one
220+
centroid."""
221+
all_associated = tf.reduce_all(
222+
tf.equal(
223+
weights,
224+
tf.constant(centroids[centroid_index], shape=weights.shape)
225+
)
226+
)
227+
self.assertTrue(all_associated)
228+
229+
# Set centroids so that all weights should be re-associated with centroid 0
230+
centroids[0].assign(mean_weight)
231+
centroids[1].assign(mean_weight + 2.0 * max_dist)
232+
233+
# Update associations of weights to centroids
234+
l.call(tf.ones(shape=input_shape))
235+
236+
# Weights should now be all clustered with the centroid 0
237+
assert_all_weights_associated(l.layer.kernel, centroid_index=0)
238+
239+
# Set centroids so that all weights should be re-associated with centroid 1
240+
centroids[0].assign(mean_weight - 2.0 * max_dist)
241+
centroids[1].assign(mean_weight)
242+
243+
# Update associations of weights to centroids
244+
l.call(tf.ones(shape=input_shape))
245+
246+
# Weights should now be all clustered with the centroid 1
247+
assert_all_weights_associated(l.layer.kernel, centroid_index=1)
248+
196249

197250
if __name__ == '__main__':
198251
test.main()

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,22 @@ def get_pulling_indices(self, weight):
6464
"""
6565
pass
6666

67+
@tf.custom_gradient
68+
def add_custom_gradients(self, clst_weights, weights):
69+
"""
70+
This function overrides gradients in the backprop stage: original mul
71+
becomes add, tf.sign becomes tf.identity. It is to update the original
72+
weights with the gradients updates directly from the layer wrapped. We
73+
assume the gradients updates on individual elements inside a cluster
74+
will be different so that there is no point of mapping the gradient
75+
updates back to original weight matrix using the LUT.
76+
"""
77+
override_weights = tf.sign(tf.reshape(weights, shape=(-1,)) + 1e+6)
78+
z = clst_weights*override_weights
79+
def grad(dz):
80+
return dz, dz
81+
return z, grad
82+
6783
def get_clustered_weight(self, pulling_indices):
6884
"""
6985
Takes an array with integer number that represent lookup indices and forms a
@@ -75,9 +91,23 @@ def get_clustered_weight(self, pulling_indices):
7591
return tf.reshape(
7692
tf.gather(self.cluster_centroids,
7793
tf.reshape(pulling_indices, shape=(-1,))),
78-
pulling_indices.shape
94+
shape=pulling_indices.shape
7995
)
8096

97+
def get_clustered_weight_forward(self, pulling_indices, weight):
98+
"""
99+
Takes indices (pulling_indices) and original weights (weight) as inputs
100+
and then forms a new array according to the given indices. The original
101+
weights (weight) here are added to the graph since we want the backprop
102+
to update their values via the new implementation using tf.custom_gradient
103+
:param pulling_indices: an array of indices used for lookup.
104+
:param weight: the original weights of the wrapped layer.
105+
:return: array with the same shape as `pulling_indices`. Each array element
106+
is a member of self.cluster_centroids
107+
"""
108+
x = tf.reshape(self.get_clustered_weight(pulling_indices), shape=(-1,))
109+
return tf.reshape(self.add_custom_gradients(
110+
x, tf.reshape(weight, shape=(-1,))), pulling_indices.shape)
81111

82112
class ConvolutionalWeightsCA(AbstractClusteringAlgorithm):
83113
"""

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,48 @@ def _pull_values(self, ca, pulling_indices, expected_output):
4545

4646
self.assertSequenceEqual(res_np_list, expected_output)
4747

48+
def _check_gradients(self, ca, weight, pulling_indices, expected_output):
49+
pulling_indices_tf = tf.convert_to_tensor(pulling_indices)
50+
weight_tf = tf.convert_to_tensor(weight)
51+
with tf.GradientTape(persistent=True) as t:
52+
t.watch(pulling_indices_tf)
53+
t.watch(weight_tf)
54+
cls_weights_tf = tf.reshape(
55+
ca.get_clustered_weight(pulling_indices_tf), shape=(-1,))
56+
t.watch(cls_weights_tf)
57+
out_forward = ca.add_custom_gradients(cls_weights_tf, weight_tf)
58+
grad_cls_weight = t.gradient(out_forward, cls_weights_tf)
59+
grad_weight = t.gradient(out_forward, weight_tf)
60+
61+
chk_output = tf.math.equal(grad_cls_weight, grad_weight)
62+
chk_output_np = k.batch_get_value(chk_output)
63+
64+
self.assertSequenceEqual(chk_output_np, expected_output)
65+
66+
@parameterized.parameters(
67+
([-0.800450444, 0.864694357],
68+
[[0.220442653, 0.854694366, 0.0328432359, 0.506857157],
69+
[0.0527950861, -0.659555554, -0.849919915, -0.54047],
70+
[-0.305815876, 0.0865516588, 0.659202456, -0.355699599],
71+
[-0.348868281, -0.662001, 0.6171574, -0.296582848]],
72+
[[1, 1, 1, 1],
73+
[1, 0, 0, 0],
74+
[0, 1, 1, 0],
75+
[0, 0, 1, 0]],
76+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
77+
)
78+
)
79+
def testDenseWeightsCAGrad(self,
80+
clustering_centroids,
81+
weight,
82+
pulling_indices,
83+
expected_output):
84+
"""
85+
Verifies that the gradients of DenseWeightsCA work as expected.
86+
"""
87+
ca = clustering_registry.DenseWeightsCA(clustering_centroids)
88+
self._check_gradients(ca, weight, pulling_indices, expected_output)
89+
4890
@parameterized.parameters(
4991
([-1, 1], [[0, 0, 1], [1, 1, 1]], [[-1, -1, 1], [1, 1, 1]]),
5092
([-1, 0, 1], [[1, 1, 1], [1, 1, 1]], [[0, 0, 0], [0, 0, 0]]),
@@ -73,6 +115,29 @@ def testBiasWeightsCA(self,
73115
ca = clustering_registry.BiasWeightsCA(clustering_centroids)
74116
self._pull_values(ca, pulling_indices, expected_output)
75117

118+
@parameterized.parameters(
119+
([0.0, 3.0],
120+
[[0.1, 0.1, 0.1],
121+
[3.0, 3.0, 3.0],
122+
[0.2, 0.2, 0.2]],
123+
[[0, 0, 0],
124+
[1, 1, 1],
125+
[0, 0, 0]],
126+
[1, 1, 1, 1, 1, 1, 1, 1, 1]
127+
)
128+
)
129+
def testConvolutionalWeightsCAGrad(self,
130+
clustering_centroids,
131+
weight,
132+
pulling_indices,
133+
expected_output):
134+
"""
135+
Verifies that the gradients of ConvolutionalWeightsCA work as expected.
136+
"""
137+
ca = clustering_registry.DenseWeightsCA(clustering_centroids)
138+
self._check_gradients(ca, weight, pulling_indices, expected_output)
139+
140+
76141
@parameterized.parameters(
77142
([0, 3], [[[[0, 0, 0], [1, 1, 1], [0, 0, 0]]]],
78143
[[[[0, 0, 0], [3, 3, 3], [0, 0, 0]]]]),

0 commit comments

Comments
 (0)