|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | """Contains definitions of segmentation heads."""
|
16 |
| -from typing import List, Union, Optional, Mapping, Tuple |
| 16 | +from typing import List, Union, Optional, Mapping, Tuple, Any |
17 | 17 | import tensorflow as tf
|
18 | 18 |
|
19 | 19 | from official.modeling import tf_utils
|
20 | 20 | from official.vision.beta.modeling.layers import nn_layers
|
21 | 21 | from official.vision.beta.ops import spatial_transform_ops
|
22 | 22 |
|
23 | 23 |
|
| 24 | +class MaskScoring(tf.keras.Model): |
| 25 | + """Creates a mask scoring layer. |
| 26 | +
|
| 27 | + This implements mask scoring layer from the paper: |
| 28 | +
|
| 29 | + Zhaojin Huang, Lichao Huang, Yongchao Gong, Chang Huang, Xinggang Wang. |
| 30 | + Mask Scoring R-CNN. |
| 31 | + (https://arxiv.org/pdf/1903.00241.pdf) |
| 32 | + """ |
| 33 | + |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + num_classes: int, |
| 37 | + fc_input_size: List[int], |
| 38 | + num_convs: int = 3, |
| 39 | + num_filters: int = 256, |
| 40 | + fc_dims: int = 1024, |
| 41 | + num_fcs: int = 2, |
| 42 | + activation: str = 'relu', |
| 43 | + use_sync_bn: bool = False, |
| 44 | + norm_momentum: float = 0.99, |
| 45 | + norm_epsilon: float = 0.001, |
| 46 | + kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, |
| 47 | + bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, |
| 48 | + **kwargs): |
| 49 | + |
| 50 | + """Initializes mask scoring layer. |
| 51 | +
|
| 52 | + Args: |
| 53 | + num_classes: An `int` for number of classes. |
| 54 | + fc_input_size: A List of `int` for the input size of the |
| 55 | + fully connected layers. |
| 56 | + num_convs: An`int` for number of conv layers. |
| 57 | + num_filters: An `int` for the number of filters for conv layers. |
| 58 | + fc_dims: An `int` number of filters for each fully connected layers. |
| 59 | + num_fcs: An `int` for number of fully connected layers. |
| 60 | + activation: A `str` name of the activation function. |
| 61 | + use_sync_bn: A bool, whether or not to use sync batch normalization. |
| 62 | + norm_momentum: A float for the momentum in BatchNorm. Defaults to 0.99. |
| 63 | + norm_epsilon: A float for the epsilon value in BatchNorm. Defaults to |
| 64 | + 0.001. |
| 65 | + kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for |
| 66 | + Conv2D. Default is None. |
| 67 | + bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D. |
| 68 | + **kwargs: Additional keyword arguments to be passed. |
| 69 | + """ |
| 70 | + super(MaskScoring, self).__init__(**kwargs) |
| 71 | + |
| 72 | + self._config_dict = { |
| 73 | + 'num_classes': num_classes, |
| 74 | + 'num_convs': num_convs, |
| 75 | + 'num_filters': num_filters, |
| 76 | + 'fc_input_size': fc_input_size, |
| 77 | + 'fc_dims': fc_dims, |
| 78 | + 'num_fcs': num_fcs, |
| 79 | + 'use_sync_bn': use_sync_bn, |
| 80 | + 'norm_momentum': norm_momentum, |
| 81 | + 'norm_epsilon': norm_epsilon, |
| 82 | + 'activation': activation, |
| 83 | + 'kernel_regularizer': kernel_regularizer, |
| 84 | + 'bias_regularizer': bias_regularizer, |
| 85 | + } |
| 86 | + |
| 87 | + if tf.keras.backend.image_data_format() == 'channels_last': |
| 88 | + self._bn_axis = -1 |
| 89 | + else: |
| 90 | + self._bn_axis = 1 |
| 91 | + self._activation = tf_utils.get_activation(activation) |
| 92 | + |
| 93 | + def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]): |
| 94 | + """Creates the variables of the mask scoring head.""" |
| 95 | + conv_op = tf.keras.layers.Conv2D |
| 96 | + conv_kwargs = { |
| 97 | + 'filters': self._config_dict['num_filters'], |
| 98 | + 'kernel_size': 3, |
| 99 | + 'padding': 'same', |
| 100 | + } |
| 101 | + conv_kwargs.update({ |
| 102 | + 'kernel_initializer': tf.keras.initializers.VarianceScaling( |
| 103 | + scale=2, mode='fan_out', distribution='untruncated_normal'), |
| 104 | + 'bias_initializer': tf.zeros_initializer(), |
| 105 | + 'kernel_regularizer': self._config_dict['kernel_regularizer'], |
| 106 | + 'bias_regularizer': self._config_dict['bias_regularizer'], |
| 107 | + }) |
| 108 | + bn_op = (tf.keras.layers.experimental.SyncBatchNormalization |
| 109 | + if self._config_dict['use_sync_bn'] |
| 110 | + else tf.keras.layers.BatchNormalization) |
| 111 | + bn_kwargs = { |
| 112 | + 'axis': self._bn_axis, |
| 113 | + 'momentum': self._config_dict['norm_momentum'], |
| 114 | + 'epsilon': self._config_dict['norm_epsilon'], |
| 115 | + } |
| 116 | + |
| 117 | + self._convs = [] |
| 118 | + self._conv_norms = [] |
| 119 | + for i in range(self._config_dict['num_convs']): |
| 120 | + conv_name = 'mask-scoring_{}'.format(i) |
| 121 | + self._convs.append(conv_op(name=conv_name, **conv_kwargs)) |
| 122 | + bn_name = 'mask-scoring-bn_{}'.format(i) |
| 123 | + self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs)) |
| 124 | + |
| 125 | + self._fcs = [] |
| 126 | + self._fc_norms = [] |
| 127 | + for i in range(self._config_dict['num_fcs']): |
| 128 | + fc_name = 'mask-scoring-fc_{}'.format(i) |
| 129 | + self._fcs.append( |
| 130 | + tf.keras.layers.Dense( |
| 131 | + units=self._config_dict['fc_dims'], |
| 132 | + kernel_initializer=tf.keras.initializers.VarianceScaling( |
| 133 | + scale=1 / 3.0, mode='fan_out', distribution='uniform'), |
| 134 | + kernel_regularizer=self._config_dict['kernel_regularizer'], |
| 135 | + bias_regularizer=self._config_dict['bias_regularizer'], |
| 136 | + name=fc_name)) |
| 137 | + bn_name = 'mask-scoring-fc-bn_{}'.format(i) |
| 138 | + self._fc_norms.append(bn_op(name=bn_name, **bn_kwargs)) |
| 139 | + |
| 140 | + self._classifier = tf.keras.layers.Dense( |
| 141 | + units=self._config_dict['num_classes'], |
| 142 | + kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), |
| 143 | + bias_initializer=tf.zeros_initializer(), |
| 144 | + kernel_regularizer=self._config_dict['kernel_regularizer'], |
| 145 | + bias_regularizer=self._config_dict['bias_regularizer'], |
| 146 | + name='iou-scores') |
| 147 | + |
| 148 | + super(MaskScoring, self).build(input_shape) |
| 149 | + |
| 150 | + def call(self, inputs: tf.Tensor, training: bool = None): |
| 151 | + """Forward pass mask scoring head. |
| 152 | +
|
| 153 | + Args: |
| 154 | + inputs: A `tf.Tensor` of the shape [batch_size, width, size, num_classes], |
| 155 | + representing the segmentation logits. |
| 156 | + training: a `bool` indicating whether it is in `training` mode. |
| 157 | +
|
| 158 | + Returns: |
| 159 | + mask_scores: A `tf.Tensor` of predicted mask scores |
| 160 | + [batch_size, num_classes]. |
| 161 | + """ |
| 162 | + x = tf.stop_gradient(inputs) |
| 163 | + for conv, bn in zip(self._convs, self._conv_norms): |
| 164 | + x = conv(x) |
| 165 | + x = bn(x) |
| 166 | + x = self._activation(x) |
| 167 | + |
| 168 | + # Casts feat to float32 so the resize op can be run on TPU. |
| 169 | + x = tf.cast(x, tf.float32) |
| 170 | + x = tf.image.resize(x, size=self._config_dict['fc_input_size'], |
| 171 | + method=tf.image.ResizeMethod.BILINEAR) |
| 172 | + # Casts it back to be compatible with the rest opetations. |
| 173 | + x = tf.cast(x, inputs.dtype) |
| 174 | + |
| 175 | + _, h, w, filters = x.get_shape().as_list() |
| 176 | + x = tf.reshape(x, [-1, h * w * filters]) |
| 177 | + |
| 178 | + for fc, bn in zip(self._fcs, self._fc_norms): |
| 179 | + x = fc(x) |
| 180 | + x = bn(x) |
| 181 | + x = self._activation(x) |
| 182 | + |
| 183 | + ious = self._classifier(x) |
| 184 | + return ious |
| 185 | + |
| 186 | + def get_config(self) -> Mapping[str, Any]: |
| 187 | + return self._config_dict |
| 188 | + |
| 189 | + @classmethod |
| 190 | + def from_config(cls, config, custom_objects=None): |
| 191 | + return cls(**config) |
| 192 | + |
| 193 | + |
24 | 194 | @tf.keras.utils.register_keras_serializable(package='Vision')
|
25 | 195 | class SegmentationHead(tf.keras.layers.Layer):
|
26 | 196 | """Creates a segmentation head."""
|
@@ -225,6 +395,7 @@ def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
|
225 | 395 | segmentation prediction mask: A `tf.Tensor` of the segmentation mask
|
226 | 396 | scores predicted from input features.
|
227 | 397 | """
|
| 398 | + |
228 | 399 | backbone_output = inputs[0]
|
229 | 400 | decoder_output = inputs[1]
|
230 | 401 | if self._config_dict['feature_fusion'] == 'deeplabv3plus':
|
|
0 commit comments