1313# limitations under the License.
1414
1515"""DETR detection task definition."""
16+
1617from typing import Optional
1718
1819from absl import logging
@@ -47,21 +48,25 @@ class DetectionTask(base_task.Task):
4748 def build_model (self ):
4849 """Build DETR model."""
4950
50- input_specs = tf_keras .layers .InputSpec (shape = [None ] +
51- self ._task_config .model .input_size )
51+ input_specs = tf_keras .layers .InputSpec (
52+ shape = [None ] + self ._task_config .model .input_size
53+ )
5254
5355 backbone = backbones .factory .build_backbone (
5456 input_specs = input_specs ,
5557 backbone_config = self ._task_config .model .backbone ,
56- norm_activation_config = self ._task_config .model .norm_activation )
57-
58- model = detr .DETR (backbone ,
59- self ._task_config .model .backbone_endpoint_name ,
60- self ._task_config .model .num_queries ,
61- self ._task_config .model .hidden_size ,
62- self ._task_config .model .num_classes ,
63- self ._task_config .model .num_encoder_layers ,
64- self ._task_config .model .num_decoder_layers )
58+ norm_activation_config = self ._task_config .model .norm_activation ,
59+ )
60+
61+ model = detr .DETR (
62+ backbone ,
63+ self ._task_config .model .backbone_endpoint_name ,
64+ self ._task_config .model .num_queries ,
65+ self ._task_config .model .hidden_size ,
66+ self ._task_config .model .num_classes ,
67+ self ._task_config .model .num_encoder_layers ,
68+ self ._task_config .model .num_decoder_layers ,
69+ )
6570 return model
6671
6772 def initialize (self , model : tf_keras .Model ):
@@ -84,12 +89,13 @@ def initialize(self, model: tf_keras.Model):
8489 status = ckpt .restore (ckpt_dir_or_file )
8590 status .expect_partial ().assert_existing_objects_matched ()
8691
87- logging .info ('Finished loading pretrained checkpoint from %s' ,
88- ckpt_dir_or_file )
92+ logging .info (
93+ 'Finished loading pretrained checkpoint from %s' , ckpt_dir_or_file
94+ )
8995
90- def build_inputs (self ,
91- params ,
92- input_context : Optional [ tf . distribute . InputContext ] = None ):
96+ def build_inputs (
97+ self , params , input_context : Optional [ tf . distribute . InputContext ] = None
98+ ):
9399 """Build input dataset."""
94100 if isinstance (params , coco .COCODataConfig ):
95101 dataset = coco .COCODataLoader (params ).load (input_context )
@@ -100,14 +106,17 @@ def build_inputs(self,
100106 decoder_cfg = params .decoder .get ()
101107 if params .decoder .type == 'simple_decoder' :
102108 decoder = tf_example_decoder .TfExampleDecoder (
103- regenerate_source_id = decoder_cfg .regenerate_source_id )
109+ regenerate_source_id = decoder_cfg .regenerate_source_id
110+ )
104111 elif params .decoder .type == 'label_map_decoder' :
105112 decoder = tf_example_label_map_decoder .TfExampleDecoderLabelMap (
106113 label_map = decoder_cfg .label_map ,
107- regenerate_source_id = decoder_cfg .regenerate_source_id )
114+ regenerate_source_id = decoder_cfg .regenerate_source_id ,
115+ )
108116 else :
109- raise ValueError ('Unknown decoder type: {}!' .format (
110- params .decoder .type ))
117+ raise ValueError (
118+ 'Unknown decoder type: {}!' .format (params .decoder .type )
119+ )
111120
112121 parser = detr_input .Parser (
113122 class_offset = self ._task_config .losses .class_offset ,
@@ -118,7 +127,8 @@ def build_inputs(self,
118127 params ,
119128 dataset_fn = dataset_fn .pick_dataset_fn (params .file_type ),
120129 decoder_fn = decoder .decode ,
121- parser_fn = parser .parse_fn (params .is_training ))
130+ parser_fn = parser .parse_fn (params .is_training ),
131+ )
122132 dataset = reader .read (input_context = input_context )
123133
124134 return dataset
@@ -128,35 +138,44 @@ def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets):
128138 # The 1 is a constant that doesn't change the matching, it can be ommitted.
129139 # background: 0
130140 cls_cost = self ._task_config .losses .lambda_cls * tf .gather (
131- - tf .nn .softmax (cls_outputs ), cls_targets , batch_dims = 1 , axis = - 1 )
141+ - tf .nn .softmax (cls_outputs ), cls_targets , batch_dims = 1 , axis = - 1
142+ )
132143
133144 # Compute the L1 cost between boxes,
134145 paired_differences = self ._task_config .losses .lambda_box * tf .abs (
135- tf .expand_dims (box_outputs , 2 ) - tf .expand_dims (box_targets , 1 ))
146+ tf .expand_dims (box_outputs , 2 ) - tf .expand_dims (box_targets , 1 )
147+ )
136148 box_cost = tf .reduce_sum (paired_differences , axis = - 1 )
137149
138150 # Compute the giou cost betwen boxes
139- giou_cost = self ._task_config .losses .lambda_giou * - box_ops .bbox_generalized_overlap (
140- box_ops .cycxhw_to_yxyx (box_outputs ),
141- box_ops .cycxhw_to_yxyx (box_targets ))
151+ giou_cost = (
152+ self ._task_config .losses .lambda_giou
153+ * - box_ops .bbox_generalized_overlap (
154+ box_ops .cycxhw_to_yxyx (box_outputs ),
155+ box_ops .cycxhw_to_yxyx (box_targets ),
156+ )
157+ )
142158
143159 total_cost = cls_cost + box_cost + giou_cost
144160
145161 max_cost = (
146- self ._task_config .losses .lambda_cls * 0.0 +
147- self ._task_config .losses .lambda_box * 4. +
148- self ._task_config .losses .lambda_giou * 0.0 )
162+ self ._task_config .losses .lambda_cls * 0.0
163+ + self ._task_config .losses .lambda_box * 4.0
164+ + self ._task_config .losses .lambda_giou * 0.0
165+ )
149166
150167 # Set pads to large constant
151168 valid = tf .expand_dims (
152- tf .cast (tf .not_equal (cls_targets , 0 ), dtype = total_cost .dtype ), axis = 1 )
169+ tf .cast (tf .not_equal (cls_targets , 0 ), dtype = total_cost .dtype ), axis = 1
170+ )
153171 total_cost = (1 - valid ) * max_cost + valid * total_cost
154172
155173 # Set inf of nan to large constant
156174 total_cost = tf .where (
157175 tf .logical_or (tf .math .is_nan (total_cost ), tf .math .is_inf (total_cost )),
158176 max_cost * tf .ones_like (total_cost , dtype = total_cost .dtype ),
159- total_cost )
177+ total_cost ,
178+ )
160179
161180 return total_cost
162181
@@ -168,7 +187,8 @@ def build_losses(self, outputs, labels, aux_losses=None):
168187 box_targets = labels ['boxes' ]
169188
170189 cost = self ._compute_cost (
171- cls_outputs , box_outputs , cls_targets , box_targets )
190+ cls_outputs , box_outputs , cls_targets , box_targets
191+ )
172192
173193 _ , indices = matchers .hungarian_matching (cost )
174194 indices = tf .stop_gradient (indices )
@@ -179,45 +199,53 @@ def build_losses(self, outputs, labels, aux_losses=None):
179199
180200 background = tf .equal (cls_targets , 0 )
181201 num_boxes = tf .reduce_sum (
182- tf .cast (tf .logical_not (background ), tf .float32 ), axis = - 1 )
202+ tf .cast (tf .logical_not (background ), tf .float32 ), axis = - 1
203+ )
183204
184205 # Down-weight background to account for class imbalance.
185206 xentropy = tf .nn .sparse_softmax_cross_entropy_with_logits (
186- labels = cls_targets , logits = cls_assigned )
207+ labels = cls_targets , logits = cls_assigned
208+ )
187209 cls_loss = self ._task_config .losses .lambda_cls * tf .where (
188- background , self ._task_config .losses .background_cls_weight * xentropy ,
189- xentropy )
210+ background ,
211+ self ._task_config .losses .background_cls_weight * xentropy ,
212+ xentropy ,
213+ )
190214 cls_weights = tf .where (
191215 background ,
192216 self ._task_config .losses .background_cls_weight * tf .ones_like (cls_loss ),
193- tf .ones_like (cls_loss ))
217+ tf .ones_like (cls_loss ),
218+ )
194219
195220 # Box loss is only calculated on non-background class.
196221 l_1 = tf .reduce_sum (tf .abs (box_assigned - box_targets ), axis = - 1 )
197222 box_loss = self ._task_config .losses .lambda_box * tf .where (
198- background , tf .zeros_like (l_1 ), l_1 )
223+ background , tf .zeros_like (l_1 ), l_1
224+ )
199225
200226 # Giou loss is only calculated on non-background class.
201- giou = tf .linalg .diag_part (1.0 - box_ops .bbox_generalized_overlap (
202- box_ops .cycxhw_to_yxyx (box_assigned ),
203- box_ops .cycxhw_to_yxyx (box_targets )
204- ))
227+ giou = tf .linalg .diag_part (
228+ 1.0
229+ - box_ops .bbox_generalized_overlap (
230+ box_ops .cycxhw_to_yxyx (box_assigned ),
231+ box_ops .cycxhw_to_yxyx (box_targets ),
232+ )
233+ )
205234 giou_loss = self ._task_config .losses .lambda_giou * tf .where (
206- background , tf .zeros_like (giou ), giou )
235+ background , tf .zeros_like (giou ), giou
236+ )
207237
208238 # Consider doing all reduce once in train_step to speed up.
209239 num_boxes_per_replica = tf .reduce_sum (num_boxes )
210240 cls_weights_per_replica = tf .reduce_sum (cls_weights )
211241 replica_context = tf .distribute .get_replica_context ()
212242 num_boxes_sum , cls_weights_sum = replica_context .all_reduce (
213243 tf .distribute .ReduceOp .SUM ,
214- [num_boxes_per_replica , cls_weights_per_replica ])
215- cls_loss = tf .math .divide_no_nan (
216- tf .reduce_sum (cls_loss ), cls_weights_sum )
217- box_loss = tf .math .divide_no_nan (
218- tf .reduce_sum (box_loss ), num_boxes_sum )
219- giou_loss = tf .math .divide_no_nan (
220- tf .reduce_sum (giou_loss ), num_boxes_sum )
244+ [num_boxes_per_replica , cls_weights_per_replica ],
245+ )
246+ cls_loss = tf .math .divide_no_nan (tf .reduce_sum (cls_loss ), cls_weights_sum )
247+ box_loss = tf .math .divide_no_nan (tf .reduce_sum (box_loss ), num_boxes_sum )
248+ giou_loss = tf .math .divide_no_nan (tf .reduce_sum (giou_loss ), num_boxes_sum )
221249
222250 aux_losses = tf .add_n (aux_losses ) if aux_losses else 0.0
223251
@@ -236,7 +264,8 @@ def build_metrics(self, training=True):
236264 annotation_file = self ._task_config .annotation_file ,
237265 include_mask = False ,
238266 need_rescale_bboxes = True ,
239- per_category_metrics = self ._task_config .per_category_metrics )
267+ per_category_metrics = self ._task_config .per_category_metrics ,
268+ )
240269 return metrics
241270
242271 def train_step (self , inputs , model , optimizer , metrics = None ):
@@ -262,8 +291,11 @@ def train_step(self, inputs, model, optimizer, metrics=None):
262291
263292 for output in outputs :
264293 # Computes per-replica loss.
265- layer_loss , layer_cls_loss , layer_box_loss , layer_giou_loss = self .build_losses (
266- outputs = output , labels = labels , aux_losses = model .losses )
294+ layer_loss , layer_cls_loss , layer_box_loss , layer_giou_loss = (
295+ self .build_losses (
296+ outputs = output , labels = labels , aux_losses = model .losses
297+ )
298+ )
267299 loss += layer_loss
268300 cls_loss += layer_cls_loss
269301 box_loss += layer_box_loss
@@ -323,7 +355,8 @@ def validation_step(self, inputs, model, metrics=None):
323355
324356 outputs = model (features , training = False )[- 1 ]
325357 loss , cls_loss , box_loss , giou_loss = self .build_losses (
326- outputs = outputs , labels = labels , aux_losses = model .losses )
358+ outputs = outputs , labels = labels , aux_losses = model .losses
359+ )
327360
328361 # Multiply for logging.
329362 # Since we expect the gradient replica sum to happen in the optimizer,
@@ -341,35 +374,46 @@ def validation_step(self, inputs, model, metrics=None):
341374 # This is for backward compatibility.
342375 if 'detection_boxes' not in outputs :
343376 detection_boxes = box_ops .cycxhw_to_yxyx (
344- outputs ['box_outputs' ]) * tf .expand_dims (
345- tf .concat ([
346- labels ['image_info' ][:, 1 :2 , 0 ], labels ['image_info' ][:, 1 :2 ,
347- 1 ],
348- labels ['image_info' ][:, 1 :2 , 0 ], labels ['image_info' ][:, 1 :2 ,
349- 1 ]
377+ outputs ['box_outputs' ]
378+ ) * tf .expand_dims (
379+ tf .concat (
380+ [
381+ labels ['image_info' ][:, 1 :2 , 0 ],
382+ labels ['image_info' ][:, 1 :2 , 1 ],
383+ labels ['image_info' ][:, 1 :2 , 0 ],
384+ labels ['image_info' ][:, 1 :2 , 1 ],
350385 ],
351- axis = 1 ),
352- axis = 1 )
386+ axis = 1 ,
387+ ),
388+ axis = 1 ,
389+ )
353390 else :
354391 detection_boxes = outputs ['detection_boxes' ]
355392
356- detection_scores = tf .math .reduce_max (
357- tf .nn .softmax (outputs ['cls_outputs' ])[:, :, 1 :], axis = - 1
358- ) if 'detection_scores' not in outputs else outputs ['detection_scores' ]
393+ if 'detection_scores' not in outputs :
394+ detection_scores = tf .math .reduce_max (
395+ tf .nn .softmax (outputs ['cls_outputs' ])[:, :, 1 :], axis = - 1
396+ )
397+ else :
398+ detection_scores = outputs ['detection_scores' ]
359399
360400 if 'detection_classes' not in outputs :
361- detection_classes = tf .math .argmax (
362- outputs ['cls_outputs' ][:, :, 1 :], axis = - 1 ) + 1
401+ detection_classes = (
402+ tf .math .argmax (outputs ['cls_outputs' ][:, :, 1 :], axis = - 1 ) + 1
403+ )
363404 else :
364405 detection_classes = outputs ['detection_classes' ]
365406
366407 if 'num_detections' not in outputs :
367408 num_detections = tf .reduce_sum (
368409 tf .cast (
369410 tf .math .greater (
370- tf .math .reduce_max (outputs ['cls_outputs' ], axis = - 1 ), 0 ),
371- tf .int32 ),
372- axis = - 1 )
411+ tf .math .reduce_max (outputs ['cls_outputs' ], axis = - 1 ), 0
412+ ),
413+ tf .int32 ,
414+ ),
415+ axis = - 1 ,
416+ )
373417 else :
374418 num_detections = outputs ['num_detections' ]
375419
@@ -379,21 +423,21 @@ def validation_step(self, inputs, model, metrics=None):
379423 'detection_classes' : detection_classes ,
380424 'num_detections' : num_detections ,
381425 'source_id' : labels ['id' ],
382- 'image_info' : labels ['image_info' ]
426+ 'image_info' : labels ['image_info' ],
383427 }
384428
385429 ground_truths = {
386430 'source_id' : labels ['id' ],
387431 'height' : labels ['image_info' ][:, 0 :1 , 0 ],
388432 'width' : labels ['image_info' ][:, 0 :1 , 1 ],
389433 'num_detections' : tf .reduce_sum (
390- tf .cast (tf .math .greater (labels ['classes' ], 0 ), tf .int32 ), axis = - 1 ),
434+ tf .cast (tf .math .greater (labels ['classes' ], 0 ), tf .int32 ), axis = - 1
435+ ),
391436 'boxes' : labels ['gt_boxes' ],
392437 'classes' : labels ['classes' ],
393- 'is_crowds' : labels ['is_crowd' ]
438+ 'is_crowds' : labels ['is_crowd' ],
394439 }
395- logs .update ({'predictions' : predictions ,
396- 'ground_truths' : ground_truths })
440+ logs .update ({'predictions' : predictions , 'ground_truths' : ground_truths })
397441
398442 all_losses = {
399443 'cls_loss' : cls_loss ,
@@ -413,8 +457,8 @@ def aggregate_logs(self, state=None, step_outputs=None):
413457 state = self .coco_metric
414458
415459 state .update_state (
416- step_outputs ['ground_truths' ],
417- step_outputs [ 'predictions' ] )
460+ step_outputs ['ground_truths' ], step_outputs [ 'predictions' ]
461+ )
418462 return state
419463
420464 def reduce_aggregated_logs (self , aggregated_logs , global_step = None ):
0 commit comments