Skip to content

Commit c8bb9aa

Browse files
arashwantensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 416657609
1 parent 28213b2 commit c8bb9aa

14 files changed

+350
-23
lines changed

official/projects/volumetric_models/modeling/segmentation_model_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def test_segmentation_network_unet3d_creation(self, input_size, depth):
4747
model = segmentation_model.SegmentationModel(
4848
backbone=backbone, decoder=decoder, head=head)
4949

50-
logits = model(inputs)
50+
outputs = model(inputs)
5151
self.assertAllEqual(
5252
[2, input_size[0], input_size[0], input_size[1], num_classes],
53-
logits.numpy().shape)
53+
outputs['logits'].numpy().shape)
5454

5555
def test_serialize_deserialize(self):
5656
"""Validate the network can be serialized and deserialized."""

official/projects/volumetric_models/serving/semantic_segmentation_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,4 @@ def serve(
5656
outputs = self.inference_step(images)
5757
output_key = 'logits' if self.params.task.model.head.output_logits else 'probs'
5858

59-
return {output_key: outputs}
59+
return {output_key: outputs['logits']}

official/projects/volumetric_models/serving/semantic_segmentation_3d_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def test_export(self, input_type: str = 'image_tensor'):
104104
# outputs equal.
105105
expected_output = module.model(image_tensor, training=False)
106106
out = segmentation_fn(tf.constant(images))
107-
self.assertAllClose(out['logits'].numpy(), expected_output.numpy())
107+
self.assertAllClose(out['logits'].numpy(),
108+
expected_output['logits'].numpy())
108109

109110

110111
if __name__ == '__main__':

official/projects/volumetric_models/tasks/semantic_segmentation_3d.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def train_step(
198198
# Casting output layer as float32 is necessary when mixed_precision is
199199
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
200200
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
201+
202+
outputs = outputs['logits']
201203
if self.task_config.model.head.output_logits:
202204
outputs = tf.nn.softmax(outputs)
203205

@@ -258,6 +260,7 @@ def validation_step(
258260

259261
outputs = self.inference_step(features, model)
260262
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
263+
outputs = outputs['logits']
261264
if self.task_config.model.head.output_logits:
262265
outputs = tf.nn.softmax(outputs)
263266

@@ -268,8 +271,8 @@ def validation_step(
268271
# Compute dice score metrics on CPU.
269272
for metric in self.metrics:
270273
labels = tf.cast(labels, tf.float32)
271-
outputs = tf.cast(outputs, tf.float32)
272-
logs.update({metric.name: (labels, outputs)})
274+
logits = tf.cast(outputs, tf.float32)
275+
logs.update({metric.name: (labels, logits)})
273276

274277
return logs
275278

official/vision/beta/configs/semantic_segmentation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ class SegmentationHead(hyperparams.Config):
7575
decoder_max_level: Optional[Union[int, str]] = None
7676

7777

78+
@dataclasses.dataclass
79+
class MaskScoringHead(hyperparams.Config):
80+
"""Mask Scoring head config."""
81+
num_convs: int = 4
82+
num_filters: int = 128
83+
fc_input_size: List[int] = dataclasses.field(default_factory=list)
84+
num_fcs: int = 2
85+
fc_dims: int = 1024
86+
87+
7888
@dataclasses.dataclass
7989
class SemanticSegmentationModel(hyperparams.Config):
8090
"""Semantic segmentation model config."""
@@ -86,6 +96,7 @@ class SemanticSegmentationModel(hyperparams.Config):
8696
backbone: backbones.Backbone = backbones.Backbone(
8797
type='resnet', resnet=backbones.ResNet())
8898
decoder: decoders.Decoder = decoders.Decoder(type='identity')
99+
mask_scoring_head: Optional[MaskScoringHead] = None
89100
norm_activation: common.NormActivation = common.NormActivation()
90101

91102

official/vision/beta/losses/segmentation_losses.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# Import libraries
1818
import tensorflow as tf
1919

20+
from official.modeling import tf_utils
21+
2022
EPSILON = 1e-5
2123

2224

@@ -87,3 +89,46 @@ def __call__(self, logits, labels):
8789
loss = tf.reduce_sum(top_k_losses) / normalizer
8890

8991
return loss
92+
93+
94+
def get_actual_mask_scores(logits, labels, ignore_label):
95+
"""Gets actual mask scores."""
96+
_, height, width, num_classes = logits.get_shape().as_list()
97+
batch_size = tf.shape(logits)[0]
98+
logits = tf.stop_gradient(logits)
99+
labels = tf.image.resize(
100+
labels, (height, width),
101+
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
102+
predicted_labels = tf.argmax(logits, -1, output_type=tf.int32)
103+
flat_predictions = tf.reshape(predicted_labels, [batch_size, -1])
104+
flat_labels = tf.cast(tf.reshape(labels, [batch_size, -1]), tf.int32)
105+
106+
one_hot_predictions = tf.one_hot(
107+
flat_predictions, num_classes, on_value=True, off_value=False)
108+
one_hot_labels = tf.one_hot(
109+
flat_labels, num_classes, on_value=True, off_value=False)
110+
keep_mask = tf.not_equal(flat_labels, ignore_label)
111+
keep_mask = tf.expand_dims(keep_mask, 2)
112+
113+
overlap = tf.logical_and(one_hot_predictions, one_hot_labels)
114+
overlap = tf.logical_and(overlap, keep_mask)
115+
overlap = tf.reduce_sum(tf.cast(overlap, tf.float32), axis=1)
116+
union = tf.logical_or(one_hot_predictions, one_hot_labels)
117+
union = tf.logical_and(union, keep_mask)
118+
union = tf.reduce_sum(tf.cast(union, tf.float32), axis=1)
119+
actual_scores = tf.divide(overlap, tf.maximum(union, EPSILON))
120+
return actual_scores
121+
122+
123+
class MaskScoringLoss:
124+
"""Mask Scoring loss."""
125+
126+
def __init__(self, ignore_label):
127+
self._ignore_label = ignore_label
128+
self._mse_loss = tf.keras.losses.MeanSquaredError(
129+
reduction=tf.keras.losses.Reduction.NONE)
130+
131+
def __call__(self, predicted_scores, logits, labels):
132+
actual_scores = get_actual_mask_scores(logits, labels, self._ignore_label)
133+
loss = tf_utils.safe_mean(self._mse_loss(actual_scores, predicted_scores))
134+
return loss

official/vision/beta/modeling/factory.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,5 +369,17 @@ def build_segmentation_model(
369369
norm_epsilon=norm_activation_config.norm_epsilon,
370370
kernel_regularizer=l2_regularizer)
371371

372-
model = segmentation_model.SegmentationModel(backbone, decoder, head)
372+
mask_scoring_head = None
373+
if model_config.mask_scoring_head:
374+
mask_scoring_head = segmentation_heads.MaskScoring(
375+
num_classes=model_config.num_classes,
376+
**model_config.mask_scoring_head.as_dict(),
377+
activation=norm_activation_config.activation,
378+
use_sync_bn=norm_activation_config.use_sync_bn,
379+
norm_momentum=norm_activation_config.norm_momentum,
380+
norm_epsilon=norm_activation_config.norm_epsilon,
381+
kernel_regularizer=l2_regularizer)
382+
383+
model = segmentation_model.SegmentationModel(
384+
backbone, decoder, head, mask_scoring_head=mask_scoring_head)
373385
return model

official/vision/beta/modeling/heads/segmentation_heads.py

Lines changed: 172 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,184 @@
1313
# limitations under the License.
1414

1515
"""Contains definitions of segmentation heads."""
16-
from typing import List, Union, Optional, Mapping, Tuple
16+
from typing import List, Union, Optional, Mapping, Tuple, Any
1717
import tensorflow as tf
1818

1919
from official.modeling import tf_utils
2020
from official.vision.beta.modeling.layers import nn_layers
2121
from official.vision.beta.ops import spatial_transform_ops
2222

2323

24+
class MaskScoring(tf.keras.Model):
25+
"""Creates a mask scoring layer.
26+
27+
This implements mask scoring layer from the paper:
28+
29+
Zhaojin Huang, Lichao Huang, Yongchao Gong, Chang Huang, Xinggang Wang.
30+
Mask Scoring R-CNN.
31+
(https://arxiv.org/pdf/1903.00241.pdf)
32+
"""
33+
34+
def __init__(
35+
self,
36+
num_classes: int,
37+
fc_input_size: List[int],
38+
num_convs: int = 3,
39+
num_filters: int = 256,
40+
fc_dims: int = 1024,
41+
num_fcs: int = 2,
42+
activation: str = 'relu',
43+
use_sync_bn: bool = False,
44+
norm_momentum: float = 0.99,
45+
norm_epsilon: float = 0.001,
46+
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
47+
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
48+
**kwargs):
49+
50+
"""Initializes mask scoring layer.
51+
52+
Args:
53+
num_classes: An `int` for number of classes.
54+
fc_input_size: A List of `int` for the input size of the
55+
fully connected layers.
56+
num_convs: An`int` for number of conv layers.
57+
num_filters: An `int` for the number of filters for conv layers.
58+
fc_dims: An `int` number of filters for each fully connected layers.
59+
num_fcs: An `int` for number of fully connected layers.
60+
activation: A `str` name of the activation function.
61+
use_sync_bn: A bool, whether or not to use sync batch normalization.
62+
norm_momentum: A float for the momentum in BatchNorm. Defaults to 0.99.
63+
norm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
64+
0.001.
65+
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
66+
Conv2D. Default is None.
67+
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
68+
**kwargs: Additional keyword arguments to be passed.
69+
"""
70+
super(MaskScoring, self).__init__(**kwargs)
71+
72+
self._config_dict = {
73+
'num_classes': num_classes,
74+
'num_convs': num_convs,
75+
'num_filters': num_filters,
76+
'fc_input_size': fc_input_size,
77+
'fc_dims': fc_dims,
78+
'num_fcs': num_fcs,
79+
'use_sync_bn': use_sync_bn,
80+
'norm_momentum': norm_momentum,
81+
'norm_epsilon': norm_epsilon,
82+
'activation': activation,
83+
'kernel_regularizer': kernel_regularizer,
84+
'bias_regularizer': bias_regularizer,
85+
}
86+
87+
if tf.keras.backend.image_data_format() == 'channels_last':
88+
self._bn_axis = -1
89+
else:
90+
self._bn_axis = 1
91+
self._activation = tf_utils.get_activation(activation)
92+
93+
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
94+
"""Creates the variables of the mask scoring head."""
95+
conv_op = tf.keras.layers.Conv2D
96+
conv_kwargs = {
97+
'filters': self._config_dict['num_filters'],
98+
'kernel_size': 3,
99+
'padding': 'same',
100+
}
101+
conv_kwargs.update({
102+
'kernel_initializer': tf.keras.initializers.VarianceScaling(
103+
scale=2, mode='fan_out', distribution='untruncated_normal'),
104+
'bias_initializer': tf.zeros_initializer(),
105+
'kernel_regularizer': self._config_dict['kernel_regularizer'],
106+
'bias_regularizer': self._config_dict['bias_regularizer'],
107+
})
108+
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
109+
if self._config_dict['use_sync_bn']
110+
else tf.keras.layers.BatchNormalization)
111+
bn_kwargs = {
112+
'axis': self._bn_axis,
113+
'momentum': self._config_dict['norm_momentum'],
114+
'epsilon': self._config_dict['norm_epsilon'],
115+
}
116+
117+
self._convs = []
118+
self._conv_norms = []
119+
for i in range(self._config_dict['num_convs']):
120+
conv_name = 'mask-scoring_{}'.format(i)
121+
self._convs.append(conv_op(name=conv_name, **conv_kwargs))
122+
bn_name = 'mask-scoring-bn_{}'.format(i)
123+
self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs))
124+
125+
self._fcs = []
126+
self._fc_norms = []
127+
for i in range(self._config_dict['num_fcs']):
128+
fc_name = 'mask-scoring-fc_{}'.format(i)
129+
self._fcs.append(
130+
tf.keras.layers.Dense(
131+
units=self._config_dict['fc_dims'],
132+
kernel_initializer=tf.keras.initializers.VarianceScaling(
133+
scale=1 / 3.0, mode='fan_out', distribution='uniform'),
134+
kernel_regularizer=self._config_dict['kernel_regularizer'],
135+
bias_regularizer=self._config_dict['bias_regularizer'],
136+
name=fc_name))
137+
bn_name = 'mask-scoring-fc-bn_{}'.format(i)
138+
self._fc_norms.append(bn_op(name=bn_name, **bn_kwargs))
139+
140+
self._classifier = tf.keras.layers.Dense(
141+
units=self._config_dict['num_classes'],
142+
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
143+
bias_initializer=tf.zeros_initializer(),
144+
kernel_regularizer=self._config_dict['kernel_regularizer'],
145+
bias_regularizer=self._config_dict['bias_regularizer'],
146+
name='iou-scores')
147+
148+
super(MaskScoring, self).build(input_shape)
149+
150+
def call(self, inputs: tf.Tensor, training: bool = None):
151+
"""Forward pass mask scoring head.
152+
153+
Args:
154+
inputs: A `tf.Tensor` of the shape [batch_size, width, size, num_classes],
155+
representing the segmentation logits.
156+
training: a `bool` indicating whether it is in `training` mode.
157+
158+
Returns:
159+
mask_scores: A `tf.Tensor` of predicted mask scores
160+
[batch_size, num_classes].
161+
"""
162+
x = tf.stop_gradient(inputs)
163+
for conv, bn in zip(self._convs, self._conv_norms):
164+
x = conv(x)
165+
x = bn(x)
166+
x = self._activation(x)
167+
168+
# Casts feat to float32 so the resize op can be run on TPU.
169+
x = tf.cast(x, tf.float32)
170+
x = tf.image.resize(x, size=self._config_dict['fc_input_size'],
171+
method=tf.image.ResizeMethod.BILINEAR)
172+
# Casts it back to be compatible with the rest opetations.
173+
x = tf.cast(x, inputs.dtype)
174+
175+
_, h, w, filters = x.get_shape().as_list()
176+
x = tf.reshape(x, [-1, h * w * filters])
177+
178+
for fc, bn in zip(self._fcs, self._fc_norms):
179+
x = fc(x)
180+
x = bn(x)
181+
x = self._activation(x)
182+
183+
ious = self._classifier(x)
184+
return ious
185+
186+
def get_config(self) -> Mapping[str, Any]:
187+
return self._config_dict
188+
189+
@classmethod
190+
def from_config(cls, config, custom_objects=None):
191+
return cls(**config)
192+
193+
24194
@tf.keras.utils.register_keras_serializable(package='Vision')
25195
class SegmentationHead(tf.keras.layers.Layer):
26196
"""Creates a segmentation head."""
@@ -225,6 +395,7 @@ def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
225395
segmentation prediction mask: A `tf.Tensor` of the segmentation mask
226396
scores predicted from input features.
227397
"""
398+
228399
backbone_output = inputs[0]
229400
decoder_output = inputs[1]
230401
if self._config_dict['feature_fusion'] == 'deeplabv3plus':

official/vision/beta/modeling/heads/segmentation_heads_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,36 @@ def test_serialize_deserialize(self):
7272
new_head = segmentation_heads.SegmentationHead.from_config(config)
7373
self.assertAllEqual(head.get_config(), new_head.get_config())
7474

75+
76+
class MaskScoringHeadTest(parameterized.TestCase, tf.test.TestCase):
77+
78+
@parameterized.parameters(
79+
(1, 1, 64, [4, 4]),
80+
(2, 1, 64, [4, 4]),
81+
(3, 1, 64, [4, 4]),
82+
(1, 2, 32, [8, 8]),
83+
(2, 2, 32, [8, 8]),
84+
(3, 2, 32, [8, 8]),)
85+
def test_forward(self, num_convs, num_fcs,
86+
num_filters, fc_input_size):
87+
features = np.random.rand(2, 64, 64, 16)
88+
89+
head = segmentation_heads.MaskScoring(
90+
num_classes=2,
91+
num_convs=num_convs,
92+
num_filters=num_filters,
93+
fc_dims=128,
94+
fc_input_size=fc_input_size)
95+
96+
scores = head(features)
97+
self.assertAllEqual(scores.numpy().shape, [2, 2])
98+
99+
def test_serialize_deserialize(self):
100+
head = segmentation_heads.MaskScoring(
101+
num_classes=2, fc_input_size=[4, 4], fc_dims=128)
102+
config = head.get_config()
103+
new_head = segmentation_heads.MaskScoring.from_config(config)
104+
self.assertAllEqual(head.get_config(), new_head.get_config())
105+
75106
if __name__ == '__main__':
76107
tf.test.main()

0 commit comments

Comments
 (0)