@@ -131,43 +131,45 @@ def _rpn_box_loss(self, box_outputs, box_targets, normalizer=1.0):
131
131
class FastrcnnClassLoss (object ):
132
132
"""Fast R-CNN classification loss function."""
133
133
134
- def __init__ (self ):
135
- self ._categorical_crossentropy = tf .keras .losses .CategoricalCrossentropy (
136
- reduction = tf .keras .losses .Reduction .SUM , from_logits = True )
134
+ def __init__ (self , use_binary_cross_entropy = False ):
135
+ """Initializes loss computation.
136
+
137
+ Args:
138
+ use_binary_cross_entropy: If true, uses binary cross entropy loss,
139
+ otherwise uses categorical cross entropy loss.
140
+ """
141
+ self ._use_binary_cross_entropy = use_binary_cross_entropy
137
142
138
143
def __call__ (self , class_outputs , class_targets ):
139
144
"""Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
140
145
141
146
This function implements the classification loss of the Fast-RCNN.
142
147
143
- The classification loss is softmax on all RoIs.
148
+ The classification loss is categorical (or binary) cross entropy on all
149
+ RoIs.
144
150
Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
145
151
146
152
Args:
147
- class_outputs: a float tensor representing the class prediction for each box
148
- with a shape of [batch_size, num_boxes, num_classes].
153
+ class_outputs: a float tensor representing the class prediction for each
154
+ box with a shape of [batch_size, num_boxes, num_classes].
149
155
class_targets: a float tensor representing the class label for each box
150
156
with a shape of [batch_size, num_boxes].
151
157
152
158
Returns:
153
159
a scalar tensor representing total class loss.
154
160
"""
155
161
with tf .name_scope ('fast_rcnn_loss' ):
156
- batch_size , num_boxes , num_classes = class_outputs .get_shape ().as_list ()
157
- class_targets = tf .cast (class_targets , dtype = tf .int32 )
158
- class_targets_one_hot = tf .one_hot (class_targets , num_classes )
159
- return self ._fast_rcnn_class_loss (class_outputs , class_targets_one_hot ,
160
- normalizer = batch_size * num_boxes )
161
-
162
- def _fast_rcnn_class_loss (self , class_outputs , class_targets_one_hot ,
163
- normalizer = 1.0 ):
164
- """Computes classification loss."""
165
- with tf .name_scope ('fast_rcnn_class_loss' ):
166
- class_loss = self ._categorical_crossentropy (class_targets_one_hot ,
167
- class_outputs )
168
-
169
- class_loss /= normalizer
170
- return class_loss
162
+ num_classes = class_outputs .get_shape ().as_list ()[- 1 ]
163
+ class_targets_one_hot = tf .one_hot (
164
+ tf .cast (class_targets , dtype = tf .int32 ), num_classes )
165
+ if self ._use_binary_cross_entropy :
166
+ cross_entropy_loss = tf .nn .sigmoid_cross_entropy_with_logits (
167
+ class_targets_one_hot , class_outputs )
168
+ return tf .reduce_mean (tf .reduce_sum (cross_entropy_loss , axis = - 1 ))
169
+ else :
170
+ return tf .reduce_mean (
171
+ tf .nn .softmax_cross_entropy_with_logits (class_targets_one_hot ,
172
+ class_outputs ))
171
173
172
174
173
175
class FastrcnnBoxLoss (object ):
@@ -227,22 +229,9 @@ def _assign_class_targets(self, box_outputs, class_targets):
227
229
num_classes = num_class_specific_boxes // 4
228
230
box_outputs = tf .reshape (box_outputs ,
229
231
[batch_size , num_rois , num_classes , 4 ])
230
-
231
- box_indices = tf .reshape (
232
- class_targets + tf .tile (
233
- tf .expand_dims (tf .range (batch_size ) * num_rois * num_classes , 1 ),
234
- [1 , num_rois ]) + tf .tile (
235
- tf .expand_dims (tf .range (num_rois ) * num_classes , 0 ),
236
- [batch_size , 1 ]), [- 1 ])
237
-
238
- box_outputs = tf .matmul (
239
- tf .one_hot (
240
- box_indices ,
241
- batch_size * num_rois * num_classes ,
242
- dtype = box_outputs .dtype ), tf .reshape (box_outputs , [- 1 , 4 ]))
243
- box_outputs = tf .reshape (box_outputs , [batch_size , - 1 , 4 ])
244
-
245
- return box_outputs
232
+ class_targets_ont_hot = tf .one_hot (
233
+ class_targets , num_classes , dtype = box_outputs .dtype )
234
+ return tf .einsum ('bnij,bni->bnj' , box_outputs , class_targets_ont_hot )
246
235
247
236
def _fast_rcnn_box_loss (self , box_outputs , box_targets , class_targets ,
248
237
normalizer = 1.0 ):
0 commit comments