Skip to content

Commit 417aa54

Browse files
committed
Moved out AbstractClusteringAlgorithm from the clustering_registry.
Change-Id: I3acf88b3b6eed2447fefb790e5e0dac8e32468f9
1 parent 8e790c6 commit 417aa54

File tree

5 files changed

+133
-88
lines changed

5 files changed

+133
-88
lines changed

tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_scope
2020
from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_weights
2121
from tensorflow_model_optimization.python.core.clustering.keras.cluster import strip_clustering
22+
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
2223

2324
from tensorflow_model_optimization.python.core.clustering.keras.cluster_config import CentroidInitialization
25+
from tensorflow_model_optimization.python.core.clustering.keras.clustering_algorithm import AbstractClusteringAlgorithm
2426
# pylint: enable=g-bad-import-order

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ py_library(
4343
visibility = ["//visibility:public"],
4444
deps = [
4545
":clusterable_layer",
46+
":clustering_algorithm"
4647
],
4748
)
4849

@@ -73,6 +74,16 @@ py_library(
7374
],
7475
)
7576

77+
py_library(
78+
name = "clustering_algorithm",
79+
srcs = ["clustering_algorithm.py"],
80+
srcs_version = "PY3",
81+
visibility = ["//visibility:public"],
82+
deps = [
83+
# tensorflow dep1,
84+
],
85+
)
86+
7687
py_library(
7788
name = "clustering_callbacks",
7889
srcs = ["clustering_callbacks.py"],
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Abstract base class for clustering algorithm."""
16+
17+
import abc
18+
import six
19+
import tensorflow as tf
20+
21+
@six.add_metaclass(abc.ABCMeta)
22+
class AbstractClusteringAlgorithm(object):
23+
24+
"""
25+
The reason to have an abstract class here is to be able to implement highly
26+
efficient vectorised look-ups.
27+
28+
We do not utilise looping for that purpose, instead we `smartly` reshape and
29+
tile arrays. The trade-off is that we are potentially using way more memory
30+
than we would have if looping is used.
31+
32+
Each class that inherits from this class is supposed to implement a particular
33+
lookup function for a certain shape.
34+
35+
For example, look-ups for 2D table will be different in the case of 3D.
36+
"""
37+
38+
def __init__(self, clusters_centroids):
39+
"""
40+
For generating clustered tensors we will need two things: cluster centroids
41+
and the final shape tensor must have.
42+
43+
:param clusters_centroids: An array of shape (N,) that contains initial
44+
values of clusters centroids.
45+
"""
46+
self.cluster_centroids = clusters_centroids
47+
48+
@abc.abstractmethod
49+
def get_pulling_indices(self, weight):
50+
"""
51+
Takes a weight(can be 1D, 2D or ND) and creates tf.int32 array of the same
52+
shape that will hold indices of cluster centroids clustered arrays elements
53+
will be pulled from.
54+
55+
In the current setup pulling indices are meant to be created once and used
56+
everywhere
57+
58+
:param weight: ND array of weights. For each weight in this array the
59+
closest cluster centroids is found.
60+
61+
:return: ND array of the same shape as `weight` parameter of the type
62+
tf.int32. The returned array contain weight lookup indices
63+
"""
64+
pass
65+
66+
@tf.custom_gradient
67+
def add_custom_gradients(self, clst_weights, weights):
68+
69+
"""
70+
This function overrides gradients in the backprop stage: original mul
71+
becomes add, tf.sign becomes tf.identity. It is to update the original
72+
weights with the gradients updates directly from the layer wrapped. We
73+
assume the gradients updates on individual elements inside a cluster
74+
will be different so that there is no point of mapping the gradient
75+
updates back to original weight matrix using the LUT.
76+
"""
77+
override_weights = tf.sign(tf.reshape(weights, shape=(-1,)) + 1e+6)
78+
z = clst_weights*override_weights
79+
80+
def grad(dz):
81+
return dz, dz
82+
return z, grad
83+
84+
def get_clustered_weight(self, pulling_indices):
85+
"""
86+
Takes an array with integer number that represent lookup indices and forms a
87+
new array according to the given indices.
88+
89+
:param pulling_indices: an array of indices used for lookup.
90+
:return: array with the same shape as `pulling_indices`. Each array element
91+
is a member of self.cluster_centroids
92+
"""
93+
94+
return tf.reshape(
95+
tf.gather(self.cluster_centroids,
96+
tf.reshape(pulling_indices, shape=(-1,))),
97+
shape=pulling_indices.shape
98+
)
99+
100+
def get_clustered_weight_forward(self, pulling_indices, weight):
101+
"""
102+
Takes indices (pulling_indices) and original weights (weight) as inputs
103+
and then forms a new array according to the given indices. The original
104+
weights (weight) here are added to the graph since we want the backprop
105+
to update their values via the new implementation using tf.custom_gradient
106+
107+
:param pulling_indices: an array of indices used for lookup.
108+
:param weight: the original weights of the wrapped layer.
109+
:return: array with the same shape as `pulling_indices`. Each array element
110+
is a member of self.cluster_centroids
111+
"""
112+
113+
x = tf.reshape(self.get_clustered_weight(pulling_indices), shape=(-1,))
114+
115+
return tf.reshape(self.add_custom_gradients(
116+
x, tf.reshape(weight, shape=(-1,))), pulling_indices.shape)

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 2 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -20,93 +20,9 @@
2020
from tensorflow.keras import layers
2121

2222
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
23+
from tensorflow_model_optimization.python.core.clustering.keras import clustering_algorithm
2324

24-
25-
@six.add_metaclass(abc.ABCMeta)
26-
class AbstractClusteringAlgorithm(object):
27-
"""
28-
The reason to have an abstract class here is to be able to implement highly
29-
efficient vectorised look-ups.
30-
31-
We do not utilise looping for that purpose, instead we `smartly` reshape and
32-
tile arrays. The trade-off is that we are potentially using way more memory
33-
than we would have if looping is used.
34-
35-
Each class that inherits from this class is supposed to implement a particular
36-
lookup function for a certain shape.
37-
38-
For example, look-ups for 2D table will be different in the case of 3D.
39-
"""
40-
41-
def __init__(self, clusters_centroids):
42-
"""
43-
For generating clustered tensors we will need two things: cluster centroids
44-
and the final shape tensor must have.
45-
:param clusters_centroids: An array of shape (N,) that contains initial
46-
values of clusters centroids.
47-
"""
48-
self.cluster_centroids = clusters_centroids
49-
50-
@abc.abstractmethod
51-
def get_pulling_indices(self, weight):
52-
"""
53-
Takes a weight(can be 1D, 2D or ND) and creates tf.int32 array of the same
54-
shape that will hold indices of cluster centroids clustered arrays elements
55-
will be pulled from.
56-
57-
In the current setup pulling indices are meant to be created once and used
58-
everywhere
59-
:param weight: ND array of weights. For each weight in this array the
60-
closest cluster centroids is found.
61-
:return: ND array of the same shape as `weight` parameter of the type
62-
tf.int32. The returned array contain weight lookup indices
63-
"""
64-
pass
65-
66-
@tf.custom_gradient
67-
def add_custom_gradients(self, clst_weights, weights):
68-
"""
69-
This function overrides gradients in the backprop stage: original mul
70-
becomes add, tf.sign becomes tf.identity. It is to update the original
71-
weights with the gradients updates directly from the layer wrapped. We
72-
assume the gradients updates on individual elements inside a cluster
73-
will be different so that there is no point of mapping the gradient
74-
updates back to original weight matrix using the LUT.
75-
"""
76-
override_weights = tf.sign(tf.reshape(weights, shape=(-1,)) + 1e+6)
77-
z = clst_weights*override_weights
78-
def grad(dz):
79-
return dz, dz
80-
return z, grad
81-
82-
def get_clustered_weight(self, pulling_indices):
83-
"""
84-
Takes an array with integer number that represent lookup indices and forms a
85-
new array according to the given indices.
86-
:param pulling_indices: an array of indices used for lookup.
87-
:return: array with the same shape as `pulling_indices`. Each array element
88-
is a member of self.cluster_centroids
89-
"""
90-
return tf.reshape(
91-
tf.gather(self.cluster_centroids,
92-
tf.reshape(pulling_indices, shape=(-1,))),
93-
shape=pulling_indices.shape
94-
)
95-
96-
def get_clustered_weight_forward(self, pulling_indices, weight):
97-
"""
98-
Takes indices (pulling_indices) and original weights (weight) as inputs
99-
and then forms a new array according to the given indices. The original
100-
weights (weight) here are added to the graph since we want the backprop
101-
to update their values via the new implementation using tf.custom_gradient
102-
:param pulling_indices: an array of indices used for lookup.
103-
:param weight: the original weights of the wrapped layer.
104-
:return: array with the same shape as `pulling_indices`. Each array element
105-
is a member of self.cluster_centroids
106-
"""
107-
x = tf.reshape(self.get_clustered_weight(pulling_indices), shape=(-1,))
108-
return tf.reshape(self.add_custom_gradients(
109-
x, tf.reshape(weight, shape=(-1,))), pulling_indices.shape)
25+
AbstractClusteringAlgorithm = clustering_algorithm.AbstractClusteringAlgorithm
11026

11127
class ConvolutionalWeightsCA(AbstractClusteringAlgorithm):
11228
"""

tensorflow_model_optimization/python/core/clustering/keras/mnist_clusterable_layer_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2020
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2121
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
22-
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
22+
from tensorflow_model_optimization.python.core.clustering.keras import clustering_algorithm
2323

2424
tf.random.set_seed(42)
2525

@@ -38,7 +38,7 @@ def get_clusterable_weights(self):
3838
# Cluster kernel and bias.
3939
return [('kernel', self.kernel), ('bias', self.bias)]
4040

41-
class ClusterableWeightsCA(clustering_registry.AbstractClusteringAlgorithm):
41+
class ClusterableWeightsCA(clustering_algorithm.AbstractClusteringAlgorithm):
4242
"""
4343
This class provided a special lookup function for the the weights 'w'.
4444
It reshapes and tile centroids the same way as the weights. This allows us

0 commit comments

Comments
 (0)