|
| 1 | +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Abstract base class for clustering algorithm.""" |
| 16 | + |
| 17 | +import abc |
| 18 | +import six |
| 19 | +import tensorflow as tf |
| 20 | + |
| 21 | +@six.add_metaclass(abc.ABCMeta) |
| 22 | +class AbstractClusteringAlgorithm(object): |
| 23 | + |
| 24 | + """ |
| 25 | + The reason to have an abstract class here is to be able to implement highly |
| 26 | + efficient vectorised look-ups. |
| 27 | +
|
| 28 | + We do not utilise looping for that purpose, instead we `smartly` reshape and |
| 29 | + tile arrays. The trade-off is that we are potentially using way more memory |
| 30 | + than we would have if looping is used. |
| 31 | +
|
| 32 | + Each class that inherits from this class is supposed to implement a particular |
| 33 | + lookup function for a certain shape. |
| 34 | +
|
| 35 | + For example, look-ups for 2D table will be different in the case of 3D. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__(self, clusters_centroids): |
| 39 | + """ |
| 40 | + For generating clustered tensors we will need two things: cluster centroids |
| 41 | + and the final shape tensor must have. |
| 42 | +
|
| 43 | + :param clusters_centroids: An array of shape (N,) that contains initial |
| 44 | + values of clusters centroids. |
| 45 | + """ |
| 46 | + self.cluster_centroids = clusters_centroids |
| 47 | + |
| 48 | + @abc.abstractmethod |
| 49 | + def get_pulling_indices(self, weight): |
| 50 | + """ |
| 51 | + Takes a weight(can be 1D, 2D or ND) and creates tf.int32 array of the same |
| 52 | + shape that will hold indices of cluster centroids clustered arrays elements |
| 53 | + will be pulled from. |
| 54 | +
|
| 55 | + In the current setup pulling indices are meant to be created once and used |
| 56 | + everywhere |
| 57 | +
|
| 58 | + :param weight: ND array of weights. For each weight in this array the |
| 59 | + closest cluster centroids is found. |
| 60 | +
|
| 61 | + :return: ND array of the same shape as `weight` parameter of the type |
| 62 | + tf.int32. The returned array contain weight lookup indices |
| 63 | + """ |
| 64 | + pass |
| 65 | + |
| 66 | + @tf.custom_gradient |
| 67 | + def add_custom_gradients(self, clst_weights, weights): |
| 68 | + |
| 69 | + """ |
| 70 | + This function overrides gradients in the backprop stage: original mul |
| 71 | + becomes add, tf.sign becomes tf.identity. It is to update the original |
| 72 | + weights with the gradients updates directly from the layer wrapped. We |
| 73 | + assume the gradients updates on individual elements inside a cluster |
| 74 | + will be different so that there is no point of mapping the gradient |
| 75 | + updates back to original weight matrix using the LUT. |
| 76 | + """ |
| 77 | + override_weights = tf.sign(tf.reshape(weights, shape=(-1,)) + 1e+6) |
| 78 | + z = clst_weights*override_weights |
| 79 | + |
| 80 | + def grad(dz): |
| 81 | + return dz, dz |
| 82 | + return z, grad |
| 83 | + |
| 84 | + def get_clustered_weight(self, pulling_indices): |
| 85 | + """ |
| 86 | + Takes an array with integer number that represent lookup indices and forms a |
| 87 | + new array according to the given indices. |
| 88 | +
|
| 89 | + :param pulling_indices: an array of indices used for lookup. |
| 90 | + :return: array with the same shape as `pulling_indices`. Each array element |
| 91 | + is a member of self.cluster_centroids |
| 92 | + """ |
| 93 | + |
| 94 | + return tf.reshape( |
| 95 | + tf.gather(self.cluster_centroids, |
| 96 | + tf.reshape(pulling_indices, shape=(-1,))), |
| 97 | + shape=pulling_indices.shape |
| 98 | + ) |
| 99 | + |
| 100 | + def get_clustered_weight_forward(self, pulling_indices, weight): |
| 101 | + """ |
| 102 | + Takes indices (pulling_indices) and original weights (weight) as inputs |
| 103 | + and then forms a new array according to the given indices. The original |
| 104 | + weights (weight) here are added to the graph since we want the backprop |
| 105 | + to update their values via the new implementation using tf.custom_gradient |
| 106 | +
|
| 107 | + :param pulling_indices: an array of indices used for lookup. |
| 108 | + :param weight: the original weights of the wrapped layer. |
| 109 | + :return: array with the same shape as `pulling_indices`. Each array element |
| 110 | + is a member of self.cluster_centroids |
| 111 | + """ |
| 112 | + |
| 113 | + x = tf.reshape(self.get_clustered_weight(pulling_indices), shape=(-1,)) |
| 114 | + |
| 115 | + return tf.reshape(self.add_custom_gradients( |
| 116 | + x, tf.reshape(weight, shape=(-1,))), pulling_indices.shape) |
0 commit comments