|
| 1 | +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +# pylint: disable=protected-access |
| 16 | +"""Pruning Policy classes to control application of pruning wrapper.""" |
| 17 | + |
| 18 | +import abc |
| 19 | +import tensorflow as tf |
| 20 | + |
| 21 | +layers = tf.keras.layers |
| 22 | +activations = tf.keras.activations |
| 23 | + |
| 24 | + |
| 25 | +class PruningPolicy(abc.ABC): |
| 26 | + """Specifies what layers to prune in the model. |
| 27 | +
|
| 28 | + PruningPolicy controls application of `PruneLowMagnitude` wrapper on per-layer |
| 29 | + basis and checks that the model contains only supported layers. |
| 30 | + PruningPolicy works together with `prune_low_magnitude` through which it |
| 31 | + provides fine-grained control over pruning in the model. |
| 32 | +
|
| 33 | + ```python |
| 34 | + pruning_params = { |
| 35 | + 'pruning_schedule': ConstantSparsity(0.5, 0), |
| 36 | + 'block_size': (1, 1), |
| 37 | + 'block_pooling_type': 'AVG' |
| 38 | + } |
| 39 | +
|
| 40 | + model = prune_low_magnitude( |
| 41 | + keras.Sequential([ |
| 42 | + layers.Dense(10, activation='relu', input_shape=(100,)), |
| 43 | + layers.Dense(2, activation='sigmoid') |
| 44 | + ]), |
| 45 | + pruning_policy=PruneForLatencyOnXNNPack(), |
| 46 | + **pruning_params) |
| 47 | + ``` |
| 48 | +
|
| 49 | + You can inherit this class to write your own custom pruning policy. |
| 50 | + """ |
| 51 | + |
| 52 | + @abc.abstractmethod |
| 53 | + def allow_pruning(self, layer): |
| 54 | + """Checks if pruning wrapper should be applied for the current layer. |
| 55 | +
|
| 56 | + Args: |
| 57 | + layer: Current layer in the model. |
| 58 | +
|
| 59 | + Returns: |
| 60 | + True/False, whether the pruning wrapper should be applied for the layer. |
| 61 | + """ |
| 62 | + raise NotImplementedError |
| 63 | + |
| 64 | + @abc.abstractmethod |
| 65 | + def ensure_model_supports_pruning(self, model): |
| 66 | + """Checks that the model contains only supported layers. |
| 67 | +
|
| 68 | + Args: |
| 69 | + model: A `tf.keras.Model` instance which is going to be pruned. |
| 70 | +
|
| 71 | + Raises: |
| 72 | + ValueError: if the keras model doesn't support pruning policy, i.e. keras |
| 73 | + model contains an unsupported layer. |
| 74 | + """ |
| 75 | + raise NotImplementedError |
| 76 | + |
| 77 | + |
| 78 | +class PruneForLatencyOnXNNPack(PruningPolicy): |
| 79 | + """Specifies to prune only 1x1 Conv2D layers in the model. |
| 80 | +
|
| 81 | + PruneForLatencyOnXNNPack checks that the model contains a subgraph that can |
| 82 | + leverage XNNPACK's sparse inference and applies pruning wrapper only to |
| 83 | + Conv2D with `kernel_size = (1, 1)`. |
| 84 | +
|
| 85 | + Reference: |
| 86 | + - [Fast Sparse ConvNets](https://arxiv.org/abs/1911.09723) |
| 87 | + - [XNNPACK Sparse Inference](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md#sparse-inference) # pylint: disable=line-too-long |
| 88 | + """ |
| 89 | + |
| 90 | + def allow_pruning(self, layer): |
| 91 | + """Allows to prune only 1x1 Conv2D layers.""" |
| 92 | + return isinstance(layer, layers.Conv2D) and layer.kernel_size == (1, 1) |
| 93 | + |
| 94 | + def _get_producers(self, layer): |
| 95 | + producers = [] |
| 96 | + for node in layer._inbound_nodes: |
| 97 | + if isinstance(node.inbound_layers, list): |
| 98 | + producers.extend(node.inbound_layers) |
| 99 | + else: |
| 100 | + producers.append(node.inbound_layers) |
| 101 | + return producers |
| 102 | + |
| 103 | + def _get_consumers(self, layer): |
| 104 | + return [node.outbound_layer for node in layer._outbound_nodes] |
| 105 | + |
| 106 | + def _lookup_layers(self, source_layers, stop_fn, next_fn): |
| 107 | + """Traverses the model and returns layers satisfying `stop_fn` criteria.""" |
| 108 | + to_visit = set(source_layers) |
| 109 | + used_layers = set(source_layers) |
| 110 | + found_layers = set() |
| 111 | + while to_visit: |
| 112 | + layer = to_visit.pop() |
| 113 | + if stop_fn(layer): |
| 114 | + found_layers.add(layer) |
| 115 | + else: |
| 116 | + next_layers = next_fn(layer) |
| 117 | + if not next_layers: |
| 118 | + return set() |
| 119 | + for next_layer in next_layers: |
| 120 | + if next_layer not in used_layers: |
| 121 | + used_layers.add(next_layer) |
| 122 | + to_visit.add(next_layer) |
| 123 | + |
| 124 | + return found_layers |
| 125 | + |
| 126 | + def _start_layer_stop_fn(self, layer): |
| 127 | + """Determines whether the layer starts a subgraph of sparse inference.""" |
| 128 | + return (isinstance(layer, layers.Conv2D) and hasattr(layer, 'kernel') and |
| 129 | + layer.kernel.shape[:3] == (3, 3, 3) and layer.strides == (2, 2) and |
| 130 | + layer.padding.lower() == 'valid') |
| 131 | + |
| 132 | + def _end_layer_stop_fn(self, layer): |
| 133 | + """Determines whether the layer ends a subgraph of sparse inference.""" |
| 134 | + return isinstance(layer, layers.GlobalAveragePooling2D) and layer.keepdims |
| 135 | + |
| 136 | + def _check_layer_support(self, layer): |
| 137 | + """Returns whether the layer is supported or not. |
| 138 | +
|
| 139 | + Mimics XNNPACK's behaviour of compatibility function. |
| 140 | +
|
| 141 | + Args: |
| 142 | + layer: Current layer in the model. |
| 143 | +
|
| 144 | + Returns: |
| 145 | + True if the layer is supported, False otherwise. |
| 146 | +
|
| 147 | + References: |
| 148 | + - https://github.com/google/XNNPACK/blob/master/src/subgraph.c#L130 |
| 149 | + """ |
| 150 | + if isinstance(layer, (layers.Add, layers.Multiply, layers.ZeroPadding2D, |
| 151 | + layers.ReLU, layers.LeakyReLU, layers.ELU)): |
| 152 | + return True |
| 153 | + elif isinstance(layer, layers.DepthwiseConv2D): |
| 154 | + # 3x3 stride-1 convolution (no dilation, padding 1 on each side). |
| 155 | + # 3x3 stride-2 convolution (no dilation, padding 1 on each side). |
| 156 | + # 5x5 stride-1 convolution (no dilation, padding 2 on each side). |
| 157 | + # 5x5 stride-2 convolution (no dilation, padding 2 on each side). |
| 158 | + return (layer.depth_multiplier == 1 and layer.dilation_rate == (1, 1) and |
| 159 | + (layer.kernel_size == (3, 3) or layer.kernel_size == (5, 5)) and |
| 160 | + ((layer.padding.lower() == 'same' and layer.strides == (1, 1)) or |
| 161 | + (layer.padding.lower() == 'valid' and layer.strides == (2, 2)))) |
| 162 | + elif isinstance(layer, layers.Conv2D): |
| 163 | + # 1x1 convolution (no stride, no dilation, no padding, no groups). |
| 164 | + return (layer.groups == 1 and layer.dilation_rate == (1, 1) and |
| 165 | + layer.kernel_size == (1, 1) and layer.strides == (1, 1)) |
| 166 | + elif isinstance(layer, layers.GlobalAveragePooling2D): |
| 167 | + return layer.keepdims |
| 168 | + elif isinstance(layer, layers.BatchNormalization): |
| 169 | + return list(layer.axis) == [3] |
| 170 | + elif isinstance(layer, layers.UpSampling2D): |
| 171 | + return layer.interpolation == 'bilinear' |
| 172 | + elif isinstance(layer, layers.Activation): |
| 173 | + return activations.serialize(layer.activation) in ('relu', 'relu6', |
| 174 | + 'leaky_relu', 'elu', |
| 175 | + 'sigmoid') |
| 176 | + return False |
| 177 | + |
| 178 | + def ensure_model_supports_pruning(self, model): |
| 179 | + """Ensures that the model contains only supported layers.""" |
| 180 | + |
| 181 | + # Check whether the model is a subclass model. |
| 182 | + if (not model._is_graph_network and |
| 183 | + not isinstance(model, tf.keras.models.Sequential)): |
| 184 | + raise ValueError('Subclassed models are not supported currently.') |
| 185 | + |
| 186 | + if not model.built: |
| 187 | + raise ValueError('Unbuilt models are not supported currently.') |
| 188 | + |
| 189 | + # Gather the layers that consume model's input tensors. |
| 190 | + input_layers = set(inp._keras_history.layer for inp in model.inputs) |
| 191 | + |
| 192 | + # Search for the start layer (Conv2D 3x3, `stride = (2, 2)`, |
| 193 | + # `filters = 3`, `padding = `VALID``) in every input branch (forward). |
| 194 | + start_layers = self._lookup_layers( |
| 195 | + input_layers, |
| 196 | + self._start_layer_stop_fn, |
| 197 | + self._get_consumers, |
| 198 | + ) |
| 199 | + if not start_layers: |
| 200 | + raise ValueError(('Could not find `Conv2D 3x3` layer with stride 2x2, ' |
| 201 | + '`input filters == 3` and `VALID` padding in all input ' |
| 202 | + 'branches of the model')) |
| 203 | + |
| 204 | + # Search for the end layer (GlobalAveragePooling with `keepdims = True`) |
| 205 | + # for every output branch (backward). |
| 206 | + output_layers = set(inp._keras_history.layer for inp in model.outputs) |
| 207 | + end_layers = self._lookup_layers( |
| 208 | + output_layers, |
| 209 | + self._end_layer_stop_fn, |
| 210 | + self._get_producers, |
| 211 | + ) |
| 212 | + if not end_layers: |
| 213 | + raise ValueError(('Could not find a `GlobalAveragePooling2D` layer with ' |
| 214 | + '`keepdims = True` in all output branches')) |
| 215 | + |
| 216 | + # Ensure that all layers between the start and the end layers are supported |
| 217 | + # for pruning. |
| 218 | + def visit_fn(layer): |
| 219 | + if layer not in end_layers and not self._check_layer_support(layer): |
| 220 | + raise ValueError(('Layer {layer} is not supported for the {policy} ' |
| 221 | + 'pruning policy'.format( |
| 222 | + layer=layer.__class__.__name__, |
| 223 | + policy=self.__class__.__name__))) |
| 224 | + return layer in end_layers |
| 225 | + |
| 226 | + _ = self._lookup_layers( |
| 227 | + sum([self._get_consumers(layer) for layer in start_layers], []), |
| 228 | + visit_fn, |
| 229 | + self._get_consumers, |
| 230 | + ) |
0 commit comments