Skip to content

Commit ea47f4e

Browse files
Apply class weights in maskrcnn loss
PiperOrigin-RevId: 571114756
1 parent aaabbd0 commit ea47f4e

File tree

4 files changed

+54
-18
lines changed

4 files changed

+54
-18
lines changed

official/vision/configs/maskrcnn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class Losses(hyperparams.Config):
223223
frcnn_class_weight: float = 1.0
224224
frcnn_box_weight: float = 1.0
225225
mask_weight: float = 1.0
226+
class_weights: Optional[List[float]] = None
226227

227228

228229
@dataclasses.dataclass

official/vision/losses/maskrcnn_losses.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,26 +167,40 @@ def __init__(self,
167167
self._use_binary_cross_entropy = use_binary_cross_entropy
168168
self._top_k_percent = top_k_percent
169169

170-
def __call__(self, class_outputs, class_targets):
170+
def __call__(self, class_outputs, class_targets, class_weights=None):
171171
"""Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
172172
173173
This function implements the classification loss of the Fast-RCNN.
174174
175175
The classification loss is categorical (or binary) cross entropy on all
176176
RoIs.
177-
Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
177+
Reference:
178+
https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py
179+
# pylint: disable=line-too-long
178180
179181
Args:
180182
class_outputs: a float tensor representing the class prediction for each
181183
box with a shape of [batch_size, num_boxes, num_classes].
182184
class_targets: a float tensor representing the class label for each box
183185
with a shape of [batch_size, num_boxes].
186+
class_weights: A float list containing the weight of each class.
184187
185188
Returns:
186189
a scalar tensor representing total class loss.
187190
"""
188191
with tf.name_scope('fast_rcnn_loss'):
192+
output_dtype = class_outputs.dtype
189193
num_classes = class_outputs.get_shape().as_list()[-1]
194+
class_weights = (
195+
class_weights if class_weights is not None else [1.0] * num_classes
196+
)
197+
if num_classes != len(class_weights):
198+
raise ValueError(
199+
'Length of class_weights should be {}'.format(num_classes)
200+
)
201+
202+
class_weights = tf.constant(class_weights, dtype=output_dtype)
203+
190204
class_targets_one_hot = tf.one_hot(
191205
tf.cast(class_targets, dtype=tf.int32),
192206
num_classes,
@@ -195,10 +209,15 @@ def __call__(self, class_outputs, class_targets):
195209
# (batch_size, num_boxes, num_classes)
196210
cross_entropy_loss = tf.nn.sigmoid_cross_entropy_with_logits(
197211
labels=class_targets_one_hot, logits=class_outputs)
212+
cross_entropy_loss *= class_weights
198213
else:
199214
# (batch_size, num_boxes)
200215
cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(
201216
labels=class_targets_one_hot, logits=class_outputs)
217+
class_weight_mask = tf.einsum(
218+
'...y,y->...', class_targets_one_hot, class_weights
219+
)
220+
cross_entropy_loss *= class_weight_mask
202221

203222
if self._top_k_percent < 1.0:
204223
return self.aggregate_loss_top_k(cross_entropy_loss)

official/vision/losses/maskrcnn_losses_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ def testFastrcnnClassLoss(self, use_binary_cross_entropy):
7777
maxval=num_classes + 1,
7878
dtype=tf.int32)
7979
loss_fn = maskrcnn_losses.FastrcnnClassLoss(use_binary_cross_entropy)
80-
self.assertEqual(tf.rank(loss_fn(class_outputs, class_targets)), 0)
80+
class_weights = [1.0] * num_classes
81+
self.assertEqual(
82+
tf.rank(loss_fn(class_outputs, class_targets, class_weights)), 0
83+
)
8184

8285
def testFastrcnnClassLossTopK(self):
8386
class_targets = tf.constant([[0, 0, 0, 2]])
@@ -87,16 +90,17 @@ def testFastrcnnClassLossTopK(self):
8790
[100.0, 0.0, 0.0],
8891
[0.0, 1.0, 0.0],
8992
]])
93+
class_weights = [1.0, 1.0, 1.0]
9094
self.assertAllClose(
9195
maskrcnn_losses.FastrcnnClassLoss(top_k_percent=0.5)(
92-
class_outputs, class_targets
96+
class_outputs, class_targets, class_weights
9397
),
9498
0.775718,
9599
atol=1e-4,
96100
)
97101
self.assertAllClose(
98102
maskrcnn_losses.FastrcnnClassLoss(top_k_percent=1.0)(
99-
class_outputs, class_targets
103+
class_outputs, class_targets, class_weights
100104
),
101105
0.387861,
102106
atol=1e-4,

official/vision/tasks/maskrcnn.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,10 @@ def _build_rpn_losses(
202202
return rpn_score_loss, rpn_box_loss
203203

204204
def _build_frcnn_losses(
205-
self, outputs: Mapping[str, Any],
206-
labels: Mapping[str, Any]) -> Tuple[tf.Tensor, tf.Tensor]:
205+
self,
206+
outputs: Mapping[str, Any],
207+
labels: Mapping[str, Any],
208+
) -> Tuple[tf.Tensor, tf.Tensor]:
207209
"""Builds losses for Fast R-CNN."""
208210
cascade_ious = self.task_config.model.roi_sampler.cascade_iou_thresholds
209211

@@ -222,10 +224,19 @@ def _build_frcnn_losses(
222224
for cas_num in range(num_det_heads):
223225
frcnn_cls_loss_i = tf.reduce_mean(
224226
frcnn_cls_loss_fn(
225-
outputs['class_outputs_{}'
226-
.format(cas_num) if cas_num else 'class_outputs'],
227-
outputs['class_targets_{}'
228-
.format(cas_num) if cas_num else 'class_targets']))
227+
outputs[
228+
'class_outputs_{}'.format(cas_num)
229+
if cas_num
230+
else 'class_outputs'
231+
],
232+
outputs[
233+
'class_targets_{}'.format(cas_num)
234+
if cas_num
235+
else 'class_targets'
236+
],
237+
self.task_config.losses.class_weights,
238+
)
239+
)
229240
frcnn_box_loss_i = tf.reduce_mean(
230241
frcnn_box_loss_fn(
231242
outputs['box_outputs_{}'.format(cas_num
@@ -257,27 +268,28 @@ def build_losses(self,
257268
labels: Mapping[str, Any],
258269
aux_losses: Optional[Any] = None) -> Dict[str, tf.Tensor]:
259270
"""Builds Mask R-CNN losses."""
271+
loss_params = self.task_config.losses
260272
rpn_score_loss, rpn_box_loss = self._build_rpn_losses(outputs, labels)
261273
frcnn_cls_loss, frcnn_box_loss = self._build_frcnn_losses(outputs, labels)
262274
if self.task_config.model.include_mask:
263275
mask_loss = self._build_mask_loss(outputs)
264276
else:
265277
mask_loss = tf.constant(0.0, dtype=tf.float32)
266278

267-
params = self.task_config
268279
model_loss = (
269-
params.losses.rpn_score_weight * rpn_score_loss +
270-
params.losses.rpn_box_weight * rpn_box_loss +
271-
params.losses.frcnn_class_weight * frcnn_cls_loss +
272-
params.losses.frcnn_box_weight * frcnn_box_loss +
273-
params.losses.mask_weight * mask_loss)
280+
loss_params.rpn_score_weight * rpn_score_loss
281+
+ loss_params.rpn_box_weight * rpn_box_loss
282+
+ loss_params.frcnn_class_weight * frcnn_cls_loss
283+
+ loss_params.frcnn_box_weight * frcnn_box_loss
284+
+ loss_params.mask_weight * mask_loss
285+
)
274286

275287
total_loss = model_loss
276288
if aux_losses:
277289
reg_loss = tf.reduce_sum(aux_losses)
278290
total_loss = model_loss + reg_loss
279291

280-
total_loss = params.losses.loss_weight * total_loss
292+
total_loss = loss_params.loss_weight * total_loss
281293
losses = {
282294
'total_loss': total_loss,
283295
'rpn_score_loss': rpn_score_loss,

0 commit comments

Comments
 (0)