|
| 1 | +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""SVD algorithm, where the training and inference graphs are the same.""" |
| 16 | +from typing import List |
| 17 | + |
| 18 | +import tensorflow as tf |
| 19 | + |
| 20 | +from tensorflow_model_optimization.python.core.common.keras.compression import algorithm |
| 21 | + |
| 22 | + |
| 23 | +class SVD(algorithm.WeightCompressor): |
| 24 | + """Define how to apply SVD algorithm. |
| 25 | +
|
| 26 | + This periodic update and scheduling base SVD algorithm update the original |
| 27 | + weights to make lower rank by SVD for each update_freq steps. During the |
| 28 | + warmup steps, It adjust the rank from the original to target rank gradually. |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__(self, rank, update_freq=100, warmup_step=1000): |
| 32 | + self.rank = rank |
| 33 | + self.update_freq = update_freq |
| 34 | + self.warmup_step = warmup_step |
| 35 | + |
| 36 | + # TODO(tfmot): communicate that `pretrained_weight` will sometimes |
| 37 | + # be a dummy tensor and sometimes be actual pretrained values during |
| 38 | + # its actual usage. |
| 39 | + def init_training_weights( |
| 40 | + self, pretrained_weight: tf.Tensor): |
| 41 | + self.add_training_weight( |
| 42 | + name='w', |
| 43 | + shape=pretrained_weight.shape, |
| 44 | + dtype=pretrained_weight.dtype, |
| 45 | + initializer=tf.keras.initializers.Constant(pretrained_weight)) |
| 46 | + self.add_training_weight( |
| 47 | + name='step', |
| 48 | + shape=(), |
| 49 | + dtype=tf.int32, |
| 50 | + initializer=tf.keras.initializers.Constant(0)) |
| 51 | + |
| 52 | + def decompress_weights(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor: |
| 53 | + return tf.matmul(u, sv) |
| 54 | + |
| 55 | + def project_training_weights( |
| 56 | + self, weight: tf.Tensor, step: tf.Tensor) -> tf.Tensor: |
| 57 | + weight_rank = tf.math.minimum(weight.shape[-1], weight.shape[-2]) |
| 58 | + self.update_training_weight(step, step + 1) |
| 59 | + if step % self.update_freq == 0: |
| 60 | + rank = self.rank |
| 61 | + if step < self.warmup_step: |
| 62 | + rank = tf.cast(tf.math.round( |
| 63 | + weight_rank * (self.warmup_step - step) |
| 64 | + + self.rank * step |
| 65 | + ) / self.warmup_step, tf.int32) |
| 66 | + rank = tf.math.minimum(rank, weight_rank) |
| 67 | + |
| 68 | + s, u, v = tf.linalg.svd(weight) |
| 69 | + |
| 70 | + if len(weight.shape) == 2: |
| 71 | + # FC Layer |
| 72 | + s = s[:rank] |
| 73 | + u = u[:, :rank] |
| 74 | + v = v[:, :rank] |
| 75 | + elif len(weight.shape) == 4: |
| 76 | + # Conv2D Layer |
| 77 | + s = s[:, :, :rank] |
| 78 | + u = u[:, :, :, :rank] |
| 79 | + v = v[:, :, :, :rank] |
| 80 | + else: |
| 81 | + raise NotImplementedError('Only for dimension=2 or 4 is supported.') |
| 82 | + |
| 83 | + sv = tf.matmul(tf.linalg.diag(s), v, adjoint_b=True) |
| 84 | + |
| 85 | + new_weight = tf.matmul(u, sv) |
| 86 | + self.update_training_weight(weight, new_weight) |
| 87 | + |
| 88 | + return weight |
| 89 | + |
| 90 | + def compress_training_weights(self, weight: tf.Tensor, _) -> List[tf.Tensor]: |
| 91 | + rank = self.rank |
| 92 | + s, u, v = tf.linalg.svd(weight) |
| 93 | + |
| 94 | + if len(weight.shape) == 2: |
| 95 | + # FC Layer |
| 96 | + s = s[:rank] |
| 97 | + u = u[:, :rank] |
| 98 | + v = v[:, :rank] |
| 99 | + elif len(weight.shape) == 4: |
| 100 | + # Conv2D Layer |
| 101 | + s = s[:, :, :rank] |
| 102 | + u = u[:, :, :, :rank] |
| 103 | + v = v[:, :, :, :rank] |
| 104 | + else: |
| 105 | + raise NotImplementedError('Only for dimension=2 or 4 is supported.') |
| 106 | + |
| 107 | + sv = tf.matmul(tf.linalg.diag(s), v, adjoint_b=True) |
| 108 | + |
| 109 | + return [u, sv] |
| 110 | + |
| 111 | + def get_compressible_weights( |
| 112 | + self, original_layer: tf.keras.layers.Layer) -> List[str]: |
| 113 | + if isinstance(original_layer, (tf.keras.layers.Conv2D, |
| 114 | + tf.keras.layers.Dense)): |
| 115 | + return [original_layer.kernel] |
| 116 | + return [] |
| 117 | + |
| 118 | + def optimize_model(self, to_optimize: tf.keras.Model) -> tf.keras.Model: |
| 119 | + """Model developer API for optimizing a model for training. |
| 120 | +
|
| 121 | + The returned model should be used for compression aware training. |
| 122 | + Args: |
| 123 | + to_optimize: The model to be optimize. |
| 124 | + Returns: |
| 125 | + A wrapped model that has compression optimizers. |
| 126 | + """ |
| 127 | + # pylint: disable=protected-access |
| 128 | + if not isinstance( |
| 129 | + to_optimize, tf.keras.Sequential) and not to_optimize._is_graph_network: |
| 130 | + raise ValueError( |
| 131 | + '`optimize_model` can only either be a tf.keras Sequential or ' |
| 132 | + 'Functional model.') |
| 133 | + # pylint: enable=protected-access |
| 134 | + |
| 135 | + def _optimize_layer(layer): |
| 136 | + # Require layer to be built so that the SVD-factorized weights |
| 137 | + # can be initialized from the weights. |
| 138 | + if not layer.built: |
| 139 | + raise ValueError( |
| 140 | + 'Applying SVD currently requires passing in a built model') |
| 141 | + |
| 142 | + return algorithm.create_layer_for_training(layer, algorithm=self) |
| 143 | + |
| 144 | + return tf.keras.models.clone_model( |
| 145 | + to_optimize, clone_function=_optimize_layer) |
| 146 | + |
| 147 | + def compress_model(self, to_compress: tf.keras.Model) -> tf.keras.Model: |
| 148 | + """Model developer API for optimizing a model for inference. |
| 149 | +
|
| 150 | + Args: |
| 151 | + to_compress: The model that trained for compression. This model should |
| 152 | + generated from the `optimize_model` method. |
| 153 | + Returns: |
| 154 | + A compressed model for the inference. |
| 155 | + """ |
| 156 | + def _optimize_layer(layer): |
| 157 | + # Require layer to be built so that the SVD-factorized weights |
| 158 | + # can be initialized from the weights. |
| 159 | + if not layer.built: |
| 160 | + raise ValueError( |
| 161 | + 'Applying SVD currently requires passing in a built model') |
| 162 | + |
| 163 | + return algorithm.create_layer_for_inference(layer, algorithm=self) |
| 164 | + |
| 165 | + return tf.keras.models.clone_model( |
| 166 | + to_compress, clone_function=_optimize_layer) |
0 commit comments