1313# limitations under the License.
1414
1515"""DETR detection task definition."""
16-
1716from typing import Optional
1817
1918from absl import logging
@@ -48,25 +47,21 @@ class DetectionTask(base_task.Task):
4847 def build_model (self ):
4948 """Build DETR model."""
5049
51- input_specs = tf_keras .layers .InputSpec (
52- shape = [None ] + self ._task_config .model .input_size
53- )
50+ input_specs = tf_keras .layers .InputSpec (shape = [None ] +
51+ self ._task_config .model .input_size )
5452
5553 backbone = backbones .factory .build_backbone (
5654 input_specs = input_specs ,
5755 backbone_config = self ._task_config .model .backbone ,
58- norm_activation_config = self ._task_config .model .norm_activation ,
59- )
60-
61- model = detr .DETR (
62- backbone ,
63- self ._task_config .model .backbone_endpoint_name ,
64- self ._task_config .model .num_queries ,
65- self ._task_config .model .hidden_size ,
66- self ._task_config .model .num_classes ,
67- self ._task_config .model .num_encoder_layers ,
68- self ._task_config .model .num_decoder_layers ,
69- )
56+ norm_activation_config = self ._task_config .model .norm_activation )
57+
58+ model = detr .DETR (backbone ,
59+ self ._task_config .model .backbone_endpoint_name ,
60+ self ._task_config .model .num_queries ,
61+ self ._task_config .model .hidden_size ,
62+ self ._task_config .model .num_classes ,
63+ self ._task_config .model .num_encoder_layers ,
64+ self ._task_config .model .num_decoder_layers )
7065 return model
7166
7267 def initialize (self , model : tf_keras .Model ):
@@ -89,13 +84,12 @@ def initialize(self, model: tf_keras.Model):
8984 status = ckpt .restore (ckpt_dir_or_file )
9085 status .expect_partial ().assert_existing_objects_matched ()
9186
92- logging .info (
93- 'Finished loading pretrained checkpoint from %s' , ckpt_dir_or_file
94- )
87+ logging .info ('Finished loading pretrained checkpoint from %s' ,
88+ ckpt_dir_or_file )
9589
96- def build_inputs (
97- self , params , input_context : Optional [ tf . distribute . InputContext ] = None
98- ):
90+ def build_inputs (self ,
91+ params ,
92+ input_context : Optional [ tf . distribute . InputContext ] = None ):
9993 """Build input dataset."""
10094 if isinstance (params , coco .COCODataConfig ):
10195 dataset = coco .COCODataLoader (params ).load (input_context )
@@ -106,17 +100,14 @@ def build_inputs(
106100 decoder_cfg = params .decoder .get ()
107101 if params .decoder .type == 'simple_decoder' :
108102 decoder = tf_example_decoder .TfExampleDecoder (
109- regenerate_source_id = decoder_cfg .regenerate_source_id
110- )
103+ regenerate_source_id = decoder_cfg .regenerate_source_id )
111104 elif params .decoder .type == 'label_map_decoder' :
112105 decoder = tf_example_label_map_decoder .TfExampleDecoderLabelMap (
113106 label_map = decoder_cfg .label_map ,
114- regenerate_source_id = decoder_cfg .regenerate_source_id ,
115- )
107+ regenerate_source_id = decoder_cfg .regenerate_source_id )
116108 else :
117- raise ValueError (
118- 'Unknown decoder type: {}!' .format (params .decoder .type )
119- )
109+ raise ValueError ('Unknown decoder type: {}!' .format (
110+ params .decoder .type ))
120111
121112 parser = detr_input .Parser (
122113 class_offset = self ._task_config .losses .class_offset ,
@@ -127,8 +118,7 @@ def build_inputs(
127118 params ,
128119 dataset_fn = dataset_fn .pick_dataset_fn (params .file_type ),
129120 decoder_fn = decoder .decode ,
130- parser_fn = parser .parse_fn (params .is_training ),
131- )
121+ parser_fn = parser .parse_fn (params .is_training ))
132122 dataset = reader .read (input_context = input_context )
133123
134124 return dataset
@@ -187,8 +177,7 @@ def build_losses(self, outputs, labels, aux_losses=None):
187177 box_targets = labels ['boxes' ]
188178
189179 cost = self ._compute_cost (
190- cls_outputs , box_outputs , cls_targets , box_targets
191- )
180+ cls_outputs , box_outputs , cls_targets , box_targets )
192181
193182 _ , indices = matchers .hungarian_matching (cost )
194183 indices = tf .stop_gradient (indices )
@@ -199,53 +188,45 @@ def build_losses(self, outputs, labels, aux_losses=None):
199188
200189 background = tf .equal (cls_targets , 0 )
201190 num_boxes = tf .reduce_sum (
202- tf .cast (tf .logical_not (background ), tf .float32 ), axis = - 1
203- )
191+ tf .cast (tf .logical_not (background ), tf .float32 ), axis = - 1 )
204192
205193 # Down-weight background to account for class imbalance.
206194 xentropy = tf .nn .sparse_softmax_cross_entropy_with_logits (
207- labels = cls_targets , logits = cls_assigned
208- )
195+ labels = cls_targets , logits = cls_assigned )
209196 cls_loss = self ._task_config .losses .lambda_cls * tf .where (
210- background ,
211- self ._task_config .losses .background_cls_weight * xentropy ,
212- xentropy ,
213- )
197+ background , self ._task_config .losses .background_cls_weight * xentropy ,
198+ xentropy )
214199 cls_weights = tf .where (
215200 background ,
216201 self ._task_config .losses .background_cls_weight * tf .ones_like (cls_loss ),
217- tf .ones_like (cls_loss ),
218- )
202+ tf .ones_like (cls_loss ))
219203
220204 # Box loss is only calculated on non-background class.
221205 l_1 = tf .reduce_sum (tf .abs (box_assigned - box_targets ), axis = - 1 )
222206 box_loss = self ._task_config .losses .lambda_box * tf .where (
223- background , tf .zeros_like (l_1 ), l_1
224- )
207+ background , tf .zeros_like (l_1 ), l_1 )
225208
226209 # Giou loss is only calculated on non-background class.
227- giou = tf .linalg .diag_part (
228- 1.0
229- - box_ops .bbox_generalized_overlap (
230- box_ops .cycxhw_to_yxyx (box_assigned ),
231- box_ops .cycxhw_to_yxyx (box_targets ),
232- )
233- )
210+ giou = tf .linalg .diag_part (1.0 - box_ops .bbox_generalized_overlap (
211+ box_ops .cycxhw_to_yxyx (box_assigned ),
212+ box_ops .cycxhw_to_yxyx (box_targets )
213+ ))
234214 giou_loss = self ._task_config .losses .lambda_giou * tf .where (
235- background , tf .zeros_like (giou ), giou
236- )
215+ background , tf .zeros_like (giou ), giou )
237216
238217 # Consider doing all reduce once in train_step to speed up.
239218 num_boxes_per_replica = tf .reduce_sum (num_boxes )
240219 cls_weights_per_replica = tf .reduce_sum (cls_weights )
241220 replica_context = tf .distribute .get_replica_context ()
242221 num_boxes_sum , cls_weights_sum = replica_context .all_reduce (
243222 tf .distribute .ReduceOp .SUM ,
244- [num_boxes_per_replica , cls_weights_per_replica ],
245- )
246- cls_loss = tf .math .divide_no_nan (tf .reduce_sum (cls_loss ), cls_weights_sum )
247- box_loss = tf .math .divide_no_nan (tf .reduce_sum (box_loss ), num_boxes_sum )
248- giou_loss = tf .math .divide_no_nan (tf .reduce_sum (giou_loss ), num_boxes_sum )
223+ [num_boxes_per_replica , cls_weights_per_replica ])
224+ cls_loss = tf .math .divide_no_nan (
225+ tf .reduce_sum (cls_loss ), cls_weights_sum )
226+ box_loss = tf .math .divide_no_nan (
227+ tf .reduce_sum (box_loss ), num_boxes_sum )
228+ giou_loss = tf .math .divide_no_nan (
229+ tf .reduce_sum (giou_loss ), num_boxes_sum )
249230
250231 aux_losses = tf .add_n (aux_losses ) if aux_losses else 0.0
251232
@@ -264,8 +245,7 @@ def build_metrics(self, training=True):
264245 annotation_file = self ._task_config .annotation_file ,
265246 include_mask = False ,
266247 need_rescale_bboxes = True ,
267- per_category_metrics = self ._task_config .per_category_metrics ,
268- )
248+ per_category_metrics = self ._task_config .per_category_metrics )
269249 return metrics
270250
271251 def train_step (self , inputs , model , optimizer , metrics = None ):
@@ -355,8 +335,7 @@ def validation_step(self, inputs, model, metrics=None):
355335
356336 outputs = model (features , training = False )[- 1 ]
357337 loss , cls_loss , box_loss , giou_loss = self .build_losses (
358- outputs = outputs , labels = labels , aux_losses = model .losses
359- )
338+ outputs = outputs , labels = labels , aux_losses = model .losses )
360339
361340 # Multiply for logging.
362341 # Since we expect the gradient replica sum to happen in the optimizer,
@@ -374,46 +353,35 @@ def validation_step(self, inputs, model, metrics=None):
374353 # This is for backward compatibility.
375354 if 'detection_boxes' not in outputs :
376355 detection_boxes = box_ops .cycxhw_to_yxyx (
377- outputs ['box_outputs' ]
378- ) * tf .expand_dims (
379- tf .concat (
380- [
381- labels ['image_info' ][:, 1 :2 , 0 ],
382- labels ['image_info' ][:, 1 :2 , 1 ],
383- labels ['image_info' ][:, 1 :2 , 0 ],
384- labels ['image_info' ][:, 1 :2 , 1 ],
356+ outputs ['box_outputs' ]) * tf .expand_dims (
357+ tf .concat ([
358+ labels ['image_info' ][:, 1 :2 , 0 ], labels ['image_info' ][:, 1 :2 ,
359+ 1 ],
360+ labels ['image_info' ][:, 1 :2 , 0 ], labels ['image_info' ][:, 1 :2 ,
361+ 1 ]
385362 ],
386- axis = 1 ,
387- ),
388- axis = 1 ,
389- )
363+ axis = 1 ),
364+ axis = 1 )
390365 else :
391366 detection_boxes = outputs ['detection_boxes' ]
392367
393- if 'detection_scores' not in outputs :
394- detection_scores = tf .math .reduce_max (
395- tf .nn .softmax (outputs ['cls_outputs' ])[:, :, 1 :], axis = - 1
396- )
397- else :
398- detection_scores = outputs ['detection_scores' ]
368+ detection_scores = tf .math .reduce_max (
369+ tf .nn .softmax (outputs ['cls_outputs' ])[:, :, 1 :], axis = - 1
370+ ) if 'detection_scores' not in outputs else outputs ['detection_scores' ]
399371
400372 if 'detection_classes' not in outputs :
401- detection_classes = (
402- tf .math .argmax (outputs ['cls_outputs' ][:, :, 1 :], axis = - 1 ) + 1
403- )
373+ detection_classes = tf .math .argmax (
374+ outputs ['cls_outputs' ][:, :, 1 :], axis = - 1 ) + 1
404375 else :
405376 detection_classes = outputs ['detection_classes' ]
406377
407378 if 'num_detections' not in outputs :
408379 num_detections = tf .reduce_sum (
409380 tf .cast (
410381 tf .math .greater (
411- tf .math .reduce_max (outputs ['cls_outputs' ], axis = - 1 ), 0
412- ),
413- tf .int32 ,
414- ),
415- axis = - 1 ,
416- )
382+ tf .math .reduce_max (outputs ['cls_outputs' ], axis = - 1 ), 0 ),
383+ tf .int32 ),
384+ axis = - 1 )
417385 else :
418386 num_detections = outputs ['num_detections' ]
419387
@@ -423,21 +391,21 @@ def validation_step(self, inputs, model, metrics=None):
423391 'detection_classes' : detection_classes ,
424392 'num_detections' : num_detections ,
425393 'source_id' : labels ['id' ],
426- 'image_info' : labels ['image_info' ],
394+ 'image_info' : labels ['image_info' ]
427395 }
428396
429397 ground_truths = {
430398 'source_id' : labels ['id' ],
431399 'height' : labels ['image_info' ][:, 0 :1 , 0 ],
432400 'width' : labels ['image_info' ][:, 0 :1 , 1 ],
433401 'num_detections' : tf .reduce_sum (
434- tf .cast (tf .math .greater (labels ['classes' ], 0 ), tf .int32 ), axis = - 1
435- ),
402+ tf .cast (tf .math .greater (labels ['classes' ], 0 ), tf .int32 ), axis = - 1 ),
436403 'boxes' : labels ['gt_boxes' ],
437404 'classes' : labels ['classes' ],
438- 'is_crowds' : labels ['is_crowd' ],
405+ 'is_crowds' : labels ['is_crowd' ]
439406 }
440- logs .update ({'predictions' : predictions , 'ground_truths' : ground_truths })
407+ logs .update ({'predictions' : predictions ,
408+ 'ground_truths' : ground_truths })
441409
442410 all_losses = {
443411 'cls_loss' : cls_loss ,
@@ -457,8 +425,8 @@ def aggregate_logs(self, state=None, step_outputs=None):
457425 state = self .coco_metric
458426
459427 state .update_state (
460- step_outputs ['ground_truths' ], step_outputs [ 'predictions' ]
461- )
428+ step_outputs ['ground_truths' ],
429+ step_outputs [ 'predictions' ] )
462430 return state
463431
464432 def reduce_aggregated_logs (self , aggregated_logs , global_step = None ):
0 commit comments