Skip to content

Commit 95f1070

Browse files
tensorflower-gardenerfyangf
authored andcommitted
Internal change
PiperOrigin-RevId: 490350449
1 parent 93cea12 commit 95f1070

File tree

8 files changed

+152
-43
lines changed

8 files changed

+152
-43
lines changed

official/projects/deepmac_maskrcnn/tasks/deep_mask_head_rcnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
120120
pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
121121
nms_iou_threshold=generator_config.nms_iou_threshold,
122122
max_num_detections=generator_config.max_num_detections,
123-
nms_version=generator_config.nms_version)
123+
nms_version=generator_config.nms_version,
124+
use_sigmoid_probability=generator_config.use_sigmoid_probability)
124125

125126
if model_config.include_mask:
126127
mask_head = deep_instance_heads.DeepMaskHead(

official/vision/configs/maskrcnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class DetectionGenerator(hyperparams.Config):
135135
nms_version: str = 'v2' # `v2`, `v1`, `batched`
136136
use_cpu_nms: bool = False
137137
soft_nms_sigma: Optional[float] = None # Only works when nms_version='v1'.
138+
use_sigmoid_probability: bool = False
138139

139140

140141
@dataclasses.dataclass
@@ -189,6 +190,7 @@ class Losses(hyperparams.Config):
189190
loss_weight: float = 1.0
190191
rpn_huber_loss_delta: float = 1. / 9.
191192
frcnn_huber_loss_delta: float = 1.
193+
frcnn_class_use_binary_cross_entropy: bool = False
192194
l2_weight_decay: float = 0.0
193195
rpn_score_weight: float = 1.0
194196
rpn_box_weight: float = 1.0

official/vision/losses/maskrcnn_losses.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -131,43 +131,45 @@ def _rpn_box_loss(self, box_outputs, box_targets, normalizer=1.0):
131131
class FastrcnnClassLoss(object):
132132
"""Fast R-CNN classification loss function."""
133133

134-
def __init__(self):
135-
self._categorical_crossentropy = tf.keras.losses.CategoricalCrossentropy(
136-
reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
134+
def __init__(self, use_binary_cross_entropy=False):
135+
"""Initializes loss computation.
136+
137+
Args:
138+
use_binary_cross_entropy: If true, uses binary cross entropy loss,
139+
otherwise uses categorical cross entropy loss.
140+
"""
141+
self._use_binary_cross_entropy = use_binary_cross_entropy
137142

138143
def __call__(self, class_outputs, class_targets):
139144
"""Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
140145
141146
This function implements the classification loss of the Fast-RCNN.
142147
143-
The classification loss is softmax on all RoIs.
148+
The classification loss is categorical (or binary) cross entropy on all
149+
RoIs.
144150
Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
145151
146152
Args:
147-
class_outputs: a float tensor representing the class prediction for each box
148-
with a shape of [batch_size, num_boxes, num_classes].
153+
class_outputs: a float tensor representing the class prediction for each
154+
box with a shape of [batch_size, num_boxes, num_classes].
149155
class_targets: a float tensor representing the class label for each box
150156
with a shape of [batch_size, num_boxes].
151157
152158
Returns:
153159
a scalar tensor representing total class loss.
154160
"""
155161
with tf.name_scope('fast_rcnn_loss'):
156-
batch_size, num_boxes, num_classes = class_outputs.get_shape().as_list()
157-
class_targets = tf.cast(class_targets, dtype=tf.int32)
158-
class_targets_one_hot = tf.one_hot(class_targets, num_classes)
159-
return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot,
160-
normalizer=batch_size * num_boxes)
161-
162-
def _fast_rcnn_class_loss(self, class_outputs, class_targets_one_hot,
163-
normalizer=1.0):
164-
"""Computes classification loss."""
165-
with tf.name_scope('fast_rcnn_class_loss'):
166-
class_loss = self._categorical_crossentropy(class_targets_one_hot,
167-
class_outputs)
168-
169-
class_loss /= normalizer
170-
return class_loss
162+
num_classes = class_outputs.get_shape().as_list()[-1]
163+
class_targets_one_hot = tf.one_hot(
164+
tf.cast(class_targets, dtype=tf.int32), num_classes)
165+
if self._use_binary_cross_entropy:
166+
cross_entropy_loss = tf.nn.sigmoid_cross_entropy_with_logits(
167+
class_targets_one_hot, class_outputs)
168+
return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, axis=-1))
169+
else:
170+
return tf.reduce_mean(
171+
tf.nn.softmax_cross_entropy_with_logits(class_targets_one_hot,
172+
class_outputs))
171173

172174

173175
class FastrcnnBoxLoss(object):
@@ -227,22 +229,9 @@ def _assign_class_targets(self, box_outputs, class_targets):
227229
num_classes = num_class_specific_boxes // 4
228230
box_outputs = tf.reshape(box_outputs,
229231
[batch_size, num_rois, num_classes, 4])
230-
231-
box_indices = tf.reshape(
232-
class_targets + tf.tile(
233-
tf.expand_dims(tf.range(batch_size) * num_rois * num_classes, 1),
234-
[1, num_rois]) + tf.tile(
235-
tf.expand_dims(tf.range(num_rois) * num_classes, 0),
236-
[batch_size, 1]), [-1])
237-
238-
box_outputs = tf.matmul(
239-
tf.one_hot(
240-
box_indices,
241-
batch_size * num_rois * num_classes,
242-
dtype=box_outputs.dtype), tf.reshape(box_outputs, [-1, 4]))
243-
box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4])
244-
245-
return box_outputs
232+
class_targets_ont_hot = tf.one_hot(
233+
class_targets, num_classes, dtype=box_outputs.dtype)
234+
return tf.einsum('bnij,bni->bnj', box_outputs, class_targets_ont_hot)
246235

247236
def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets,
248237
normalizer=1.0):
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for maskrcnn_losses."""
16+
17+
from absl.testing import parameterized
18+
import tensorflow as tf
19+
20+
from official.vision.losses import maskrcnn_losses
21+
22+
23+
class MaskrcnnLossesTest(parameterized.TestCase, tf.test.TestCase):
24+
25+
def testRpnScoreLoss(self):
26+
batch_size = 2
27+
height = 32
28+
width = 32
29+
num_anchors = 10
30+
score_outputs = {
31+
'1': tf.random.uniform([batch_size, height, width, num_anchors])
32+
}
33+
score_targets = {
34+
'1':
35+
tf.random.uniform([batch_size, height, width, num_anchors],
36+
minval=-1,
37+
maxval=2,
38+
dtype=tf.int32)
39+
}
40+
loss_fn = maskrcnn_losses.RpnScoreLoss(rpn_batch_size_per_im=8)
41+
self.assertEqual(tf.rank(loss_fn(score_outputs, score_targets)), 0)
42+
43+
def testRpnBoxLoss(self):
44+
batch_size = 2
45+
height = 32
46+
width = 32
47+
num_anchors = 10
48+
box_outputs = {
49+
'1': tf.random.uniform([batch_size, height, width, num_anchors * 4])
50+
}
51+
box_targets = {
52+
'1': tf.random.uniform([batch_size, height, width, num_anchors * 4])
53+
}
54+
loss_fn = maskrcnn_losses.RpnBoxLoss(huber_loss_delta=1. / 9.)
55+
self.assertEqual(tf.rank(loss_fn(box_outputs, box_targets)), 0)
56+
57+
@parameterized.parameters((True, False))
58+
def testFastrcnnClassLoss(self, use_binary_cross_entropy):
59+
batch_size = 2
60+
num_boxes = 10
61+
num_classes = 5
62+
class_outputs = tf.random.uniform([batch_size, num_boxes, num_classes])
63+
class_targets = tf.random.uniform([batch_size, num_boxes],
64+
minval=0,
65+
maxval=num_classes + 1,
66+
dtype=tf.int32)
67+
loss_fn = maskrcnn_losses.FastrcnnClassLoss(use_binary_cross_entropy)
68+
self.assertEqual(tf.rank(loss_fn(class_outputs, class_targets)), 0)
69+
70+
def testFastrcnnBoxLoss(self):
71+
batch_size = 2
72+
num_boxes = 10
73+
num_classes = 5
74+
box_outputs = tf.random.uniform([batch_size, num_boxes, num_classes * 4])
75+
box_targets = tf.random.uniform([batch_size, num_boxes, 4])
76+
class_targets = tf.random.uniform([batch_size, num_boxes],
77+
minval=0,
78+
maxval=num_classes + 1,
79+
dtype=tf.int32)
80+
loss_fn = maskrcnn_losses.FastrcnnBoxLoss(huber_loss_delta=1.)
81+
self.assertEqual(
82+
tf.rank(loss_fn(box_outputs, box_targets, class_targets)), 0)
83+
84+
def testMaskrcnnLoss(self):
85+
batch_size = 2
86+
num_masks = 10
87+
mask_height = 16
88+
mask_width = 16
89+
num_classes = 5
90+
mask_outputs = tf.random.uniform(
91+
[batch_size, num_masks, mask_height, mask_width])
92+
mask_targets = tf.cast(
93+
tf.random.uniform([batch_size, num_masks, mask_height, mask_width],
94+
minval=0,
95+
maxval=2,
96+
dtype=tf.int32), tf.float32)
97+
select_class_targets = tf.random.uniform([batch_size, num_masks],
98+
minval=0,
99+
maxval=num_classes + 1,
100+
dtype=tf.int32)
101+
loss_fn = maskrcnn_losses.MaskrcnnLoss()
102+
self.assertEqual(
103+
tf.rank(loss_fn(mask_outputs, mask_targets, select_class_targets)), 0)

official/vision/modeling/factory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
204204
max_num_detections=generator_config.max_num_detections,
205205
nms_version=generator_config.nms_version,
206206
use_cpu_nms=generator_config.use_cpu_nms,
207-
soft_nms_sigma=generator_config.soft_nms_sigma)
207+
soft_nms_sigma=generator_config.soft_nms_sigma,
208+
use_sigmoid_probability=generator_config.use_sigmoid_probability)
208209

209210
if model_config.include_mask:
210211
mask_head = instance_heads.MaskHead(

official/vision/modeling/layers/detection_generator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ def __init__(self,
572572
nms_version: str = 'v2',
573573
use_cpu_nms: bool = False,
574574
soft_nms_sigma: Optional[float] = None,
575+
use_sigmoid_probability: bool = False,
575576
**kwargs):
576577
"""Initializes a detection generator.
577578
@@ -590,6 +591,8 @@ def __init__(self,
590591
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
591592
soft_nms_sigma: A `float` representing the sigma parameter for Soft NMS.
592593
When soft_nms_sigma=0.0, we fall back to standard NMS.
594+
use_sigmoid_probability: A `bool`, if true, use sigmoid to get
595+
probability, otherwise use softmax.
593596
**kwargs: Additional keyword arguments passed to Layer.
594597
"""
595598
self._config_dict = {
@@ -601,6 +604,7 @@ def __init__(self,
601604
'nms_version': nms_version,
602605
'use_cpu_nms': use_cpu_nms,
603606
'soft_nms_sigma': soft_nms_sigma,
607+
'use_sigmoid_probability': use_sigmoid_probability,
604608
}
605609
super(DetectionGenerator, self).__init__(**kwargs)
606610

@@ -644,7 +648,10 @@ def __call__(self,
644648
`decoded_box_scores`: A `float` tf.Tensor of shape
645649
[batch, num_raw_boxes] representing socres of all the decoded boxes.
646650
"""
647-
box_scores = tf.nn.softmax(raw_scores, axis=-1)
651+
if self._config_dict['use_sigmoid_probability']:
652+
box_scores = tf.math.sigmoid(raw_scores)
653+
else:
654+
box_scores = tf.nn.softmax(raw_scores, axis=-1)
648655

649656
# Removes the background class.
650657
box_scores_shape = tf.shape(box_scores)

official/vision/modeling/layers/detection_generator_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ class DetectionGeneratorTest(
4646
@parameterized.product(
4747
nms_version=['batched', 'v1', 'v2'],
4848
use_cpu_nms=[True, False],
49-
soft_nms_sigma=[None, 0.1])
50-
def testDetectionsOutputShape(self, nms_version, use_cpu_nms, soft_nms_sigma):
49+
soft_nms_sigma=[None, 0.1],
50+
use_sigmoid_probability=[True, False])
51+
def testDetectionsOutputShape(self, nms_version, use_cpu_nms, soft_nms_sigma,
52+
use_sigmoid_probability):
5153
max_num_detections = 10
5254
num_classes = 4
5355
pre_nms_top_k = 5000
@@ -62,6 +64,7 @@ def testDetectionsOutputShape(self, nms_version, use_cpu_nms, soft_nms_sigma):
6264
'nms_version': nms_version,
6365
'use_cpu_nms': use_cpu_nms,
6466
'soft_nms_sigma': soft_nms_sigma,
67+
'use_sigmoid_probability': use_sigmoid_probability,
6568
}
6669
generator = detection_generator.DetectionGenerator(**kwargs)
6770

@@ -103,6 +106,7 @@ def test_serialize_deserialize(self):
103106
'nms_version': 'v2',
104107
'use_cpu_nms': False,
105108
'soft_nms_sigma': None,
109+
'use_sigmoid_probability': False,
106110
}
107111
generator = detection_generator.DetectionGenerator(**kwargs)
108112

official/vision/tasks/maskrcnn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def _build_frcnn_losses(
193193
"""Build losses for Fast R-CNN."""
194194
cascade_ious = self.task_config.model.roi_sampler.cascade_iou_thresholds
195195

196-
frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss()
196+
frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss(
197+
use_binary_cross_entropy=self.task_config.losses
198+
.frcnn_class_use_binary_cross_entropy)
197199
frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss(
198200
self.task_config.losses.frcnn_huber_loss_delta,
199201
self.task_config.model.detection_head.class_agnostic_bbox_pred)

0 commit comments

Comments
 (0)