13
13
# limitations under the License.
14
14
15
15
"""DETR detection task definition."""
16
-
17
16
from typing import Optional
18
17
19
18
from absl import logging
@@ -48,25 +47,21 @@ class DetectionTask(base_task.Task):
48
47
def build_model (self ):
49
48
"""Build DETR model."""
50
49
51
- input_specs = tf_keras .layers .InputSpec (
52
- shape = [None ] + self ._task_config .model .input_size
53
- )
50
+ input_specs = tf_keras .layers .InputSpec (shape = [None ] +
51
+ self ._task_config .model .input_size )
54
52
55
53
backbone = backbones .factory .build_backbone (
56
54
input_specs = input_specs ,
57
55
backbone_config = self ._task_config .model .backbone ,
58
- norm_activation_config = self ._task_config .model .norm_activation ,
59
- )
60
-
61
- model = detr .DETR (
62
- backbone ,
63
- self ._task_config .model .backbone_endpoint_name ,
64
- self ._task_config .model .num_queries ,
65
- self ._task_config .model .hidden_size ,
66
- self ._task_config .model .num_classes ,
67
- self ._task_config .model .num_encoder_layers ,
68
- self ._task_config .model .num_decoder_layers ,
69
- )
56
+ norm_activation_config = self ._task_config .model .norm_activation )
57
+
58
+ model = detr .DETR (backbone ,
59
+ self ._task_config .model .backbone_endpoint_name ,
60
+ self ._task_config .model .num_queries ,
61
+ self ._task_config .model .hidden_size ,
62
+ self ._task_config .model .num_classes ,
63
+ self ._task_config .model .num_encoder_layers ,
64
+ self ._task_config .model .num_decoder_layers )
70
65
return model
71
66
72
67
def initialize (self , model : tf_keras .Model ):
@@ -89,13 +84,12 @@ def initialize(self, model: tf_keras.Model):
89
84
status = ckpt .restore (ckpt_dir_or_file )
90
85
status .expect_partial ().assert_existing_objects_matched ()
91
86
92
- logging .info (
93
- 'Finished loading pretrained checkpoint from %s' , ckpt_dir_or_file
94
- )
87
+ logging .info ('Finished loading pretrained checkpoint from %s' ,
88
+ ckpt_dir_or_file )
95
89
96
- def build_inputs (
97
- self , params , input_context : Optional [ tf . distribute . InputContext ] = None
98
- ):
90
+ def build_inputs (self ,
91
+ params ,
92
+ input_context : Optional [ tf . distribute . InputContext ] = None ):
99
93
"""Build input dataset."""
100
94
if isinstance (params , coco .COCODataConfig ):
101
95
dataset = coco .COCODataLoader (params ).load (input_context )
@@ -106,17 +100,14 @@ def build_inputs(
106
100
decoder_cfg = params .decoder .get ()
107
101
if params .decoder .type == 'simple_decoder' :
108
102
decoder = tf_example_decoder .TfExampleDecoder (
109
- regenerate_source_id = decoder_cfg .regenerate_source_id
110
- )
103
+ regenerate_source_id = decoder_cfg .regenerate_source_id )
111
104
elif params .decoder .type == 'label_map_decoder' :
112
105
decoder = tf_example_label_map_decoder .TfExampleDecoderLabelMap (
113
106
label_map = decoder_cfg .label_map ,
114
- regenerate_source_id = decoder_cfg .regenerate_source_id ,
115
- )
107
+ regenerate_source_id = decoder_cfg .regenerate_source_id )
116
108
else :
117
- raise ValueError (
118
- 'Unknown decoder type: {}!' .format (params .decoder .type )
119
- )
109
+ raise ValueError ('Unknown decoder type: {}!' .format (
110
+ params .decoder .type ))
120
111
121
112
parser = detr_input .Parser (
122
113
class_offset = self ._task_config .losses .class_offset ,
@@ -127,8 +118,7 @@ def build_inputs(
127
118
params ,
128
119
dataset_fn = dataset_fn .pick_dataset_fn (params .file_type ),
129
120
decoder_fn = decoder .decode ,
130
- parser_fn = parser .parse_fn (params .is_training ),
131
- )
121
+ parser_fn = parser .parse_fn (params .is_training ))
132
122
dataset = reader .read (input_context = input_context )
133
123
134
124
return dataset
@@ -187,8 +177,7 @@ def build_losses(self, outputs, labels, aux_losses=None):
187
177
box_targets = labels ['boxes' ]
188
178
189
179
cost = self ._compute_cost (
190
- cls_outputs , box_outputs , cls_targets , box_targets
191
- )
180
+ cls_outputs , box_outputs , cls_targets , box_targets )
192
181
193
182
_ , indices = matchers .hungarian_matching (cost )
194
183
indices = tf .stop_gradient (indices )
@@ -199,53 +188,45 @@ def build_losses(self, outputs, labels, aux_losses=None):
199
188
200
189
background = tf .equal (cls_targets , 0 )
201
190
num_boxes = tf .reduce_sum (
202
- tf .cast (tf .logical_not (background ), tf .float32 ), axis = - 1
203
- )
191
+ tf .cast (tf .logical_not (background ), tf .float32 ), axis = - 1 )
204
192
205
193
# Down-weight background to account for class imbalance.
206
194
xentropy = tf .nn .sparse_softmax_cross_entropy_with_logits (
207
- labels = cls_targets , logits = cls_assigned
208
- )
195
+ labels = cls_targets , logits = cls_assigned )
209
196
cls_loss = self ._task_config .losses .lambda_cls * tf .where (
210
- background ,
211
- self ._task_config .losses .background_cls_weight * xentropy ,
212
- xentropy ,
213
- )
197
+ background , self ._task_config .losses .background_cls_weight * xentropy ,
198
+ xentropy )
214
199
cls_weights = tf .where (
215
200
background ,
216
201
self ._task_config .losses .background_cls_weight * tf .ones_like (cls_loss ),
217
- tf .ones_like (cls_loss ),
218
- )
202
+ tf .ones_like (cls_loss ))
219
203
220
204
# Box loss is only calculated on non-background class.
221
205
l_1 = tf .reduce_sum (tf .abs (box_assigned - box_targets ), axis = - 1 )
222
206
box_loss = self ._task_config .losses .lambda_box * tf .where (
223
- background , tf .zeros_like (l_1 ), l_1
224
- )
207
+ background , tf .zeros_like (l_1 ), l_1 )
225
208
226
209
# Giou loss is only calculated on non-background class.
227
- giou = tf .linalg .diag_part (
228
- 1.0
229
- - box_ops .bbox_generalized_overlap (
230
- box_ops .cycxhw_to_yxyx (box_assigned ),
231
- box_ops .cycxhw_to_yxyx (box_targets ),
232
- )
233
- )
210
+ giou = tf .linalg .diag_part (1.0 - box_ops .bbox_generalized_overlap (
211
+ box_ops .cycxhw_to_yxyx (box_assigned ),
212
+ box_ops .cycxhw_to_yxyx (box_targets )
213
+ ))
234
214
giou_loss = self ._task_config .losses .lambda_giou * tf .where (
235
- background , tf .zeros_like (giou ), giou
236
- )
215
+ background , tf .zeros_like (giou ), giou )
237
216
238
217
# Consider doing all reduce once in train_step to speed up.
239
218
num_boxes_per_replica = tf .reduce_sum (num_boxes )
240
219
cls_weights_per_replica = tf .reduce_sum (cls_weights )
241
220
replica_context = tf .distribute .get_replica_context ()
242
221
num_boxes_sum , cls_weights_sum = replica_context .all_reduce (
243
222
tf .distribute .ReduceOp .SUM ,
244
- [num_boxes_per_replica , cls_weights_per_replica ],
245
- )
246
- cls_loss = tf .math .divide_no_nan (tf .reduce_sum (cls_loss ), cls_weights_sum )
247
- box_loss = tf .math .divide_no_nan (tf .reduce_sum (box_loss ), num_boxes_sum )
248
- giou_loss = tf .math .divide_no_nan (tf .reduce_sum (giou_loss ), num_boxes_sum )
223
+ [num_boxes_per_replica , cls_weights_per_replica ])
224
+ cls_loss = tf .math .divide_no_nan (
225
+ tf .reduce_sum (cls_loss ), cls_weights_sum )
226
+ box_loss = tf .math .divide_no_nan (
227
+ tf .reduce_sum (box_loss ), num_boxes_sum )
228
+ giou_loss = tf .math .divide_no_nan (
229
+ tf .reduce_sum (giou_loss ), num_boxes_sum )
249
230
250
231
aux_losses = tf .add_n (aux_losses ) if aux_losses else 0.0
251
232
@@ -264,8 +245,7 @@ def build_metrics(self, training=True):
264
245
annotation_file = self ._task_config .annotation_file ,
265
246
include_mask = False ,
266
247
need_rescale_bboxes = True ,
267
- per_category_metrics = self ._task_config .per_category_metrics ,
268
- )
248
+ per_category_metrics = self ._task_config .per_category_metrics )
269
249
return metrics
270
250
271
251
def train_step (self , inputs , model , optimizer , metrics = None ):
@@ -355,8 +335,7 @@ def validation_step(self, inputs, model, metrics=None):
355
335
356
336
outputs = model (features , training = False )[- 1 ]
357
337
loss , cls_loss , box_loss , giou_loss = self .build_losses (
358
- outputs = outputs , labels = labels , aux_losses = model .losses
359
- )
338
+ outputs = outputs , labels = labels , aux_losses = model .losses )
360
339
361
340
# Multiply for logging.
362
341
# Since we expect the gradient replica sum to happen in the optimizer,
@@ -374,46 +353,35 @@ def validation_step(self, inputs, model, metrics=None):
374
353
# This is for backward compatibility.
375
354
if 'detection_boxes' not in outputs :
376
355
detection_boxes = box_ops .cycxhw_to_yxyx (
377
- outputs ['box_outputs' ]
378
- ) * tf .expand_dims (
379
- tf .concat (
380
- [
381
- labels ['image_info' ][:, 1 :2 , 0 ],
382
- labels ['image_info' ][:, 1 :2 , 1 ],
383
- labels ['image_info' ][:, 1 :2 , 0 ],
384
- labels ['image_info' ][:, 1 :2 , 1 ],
356
+ outputs ['box_outputs' ]) * tf .expand_dims (
357
+ tf .concat ([
358
+ labels ['image_info' ][:, 1 :2 , 0 ], labels ['image_info' ][:, 1 :2 ,
359
+ 1 ],
360
+ labels ['image_info' ][:, 1 :2 , 0 ], labels ['image_info' ][:, 1 :2 ,
361
+ 1 ]
385
362
],
386
- axis = 1 ,
387
- ),
388
- axis = 1 ,
389
- )
363
+ axis = 1 ),
364
+ axis = 1 )
390
365
else :
391
366
detection_boxes = outputs ['detection_boxes' ]
392
367
393
- if 'detection_scores' not in outputs :
394
- detection_scores = tf .math .reduce_max (
395
- tf .nn .softmax (outputs ['cls_outputs' ])[:, :, 1 :], axis = - 1
396
- )
397
- else :
398
- detection_scores = outputs ['detection_scores' ]
368
+ detection_scores = tf .math .reduce_max (
369
+ tf .nn .softmax (outputs ['cls_outputs' ])[:, :, 1 :], axis = - 1
370
+ ) if 'detection_scores' not in outputs else outputs ['detection_scores' ]
399
371
400
372
if 'detection_classes' not in outputs :
401
- detection_classes = (
402
- tf .math .argmax (outputs ['cls_outputs' ][:, :, 1 :], axis = - 1 ) + 1
403
- )
373
+ detection_classes = tf .math .argmax (
374
+ outputs ['cls_outputs' ][:, :, 1 :], axis = - 1 ) + 1
404
375
else :
405
376
detection_classes = outputs ['detection_classes' ]
406
377
407
378
if 'num_detections' not in outputs :
408
379
num_detections = tf .reduce_sum (
409
380
tf .cast (
410
381
tf .math .greater (
411
- tf .math .reduce_max (outputs ['cls_outputs' ], axis = - 1 ), 0
412
- ),
413
- tf .int32 ,
414
- ),
415
- axis = - 1 ,
416
- )
382
+ tf .math .reduce_max (outputs ['cls_outputs' ], axis = - 1 ), 0 ),
383
+ tf .int32 ),
384
+ axis = - 1 )
417
385
else :
418
386
num_detections = outputs ['num_detections' ]
419
387
@@ -423,21 +391,21 @@ def validation_step(self, inputs, model, metrics=None):
423
391
'detection_classes' : detection_classes ,
424
392
'num_detections' : num_detections ,
425
393
'source_id' : labels ['id' ],
426
- 'image_info' : labels ['image_info' ],
394
+ 'image_info' : labels ['image_info' ]
427
395
}
428
396
429
397
ground_truths = {
430
398
'source_id' : labels ['id' ],
431
399
'height' : labels ['image_info' ][:, 0 :1 , 0 ],
432
400
'width' : labels ['image_info' ][:, 0 :1 , 1 ],
433
401
'num_detections' : tf .reduce_sum (
434
- tf .cast (tf .math .greater (labels ['classes' ], 0 ), tf .int32 ), axis = - 1
435
- ),
402
+ tf .cast (tf .math .greater (labels ['classes' ], 0 ), tf .int32 ), axis = - 1 ),
436
403
'boxes' : labels ['gt_boxes' ],
437
404
'classes' : labels ['classes' ],
438
- 'is_crowds' : labels ['is_crowd' ],
405
+ 'is_crowds' : labels ['is_crowd' ]
439
406
}
440
- logs .update ({'predictions' : predictions , 'ground_truths' : ground_truths })
407
+ logs .update ({'predictions' : predictions ,
408
+ 'ground_truths' : ground_truths })
441
409
442
410
all_losses = {
443
411
'cls_loss' : cls_loss ,
@@ -457,8 +425,8 @@ def aggregate_logs(self, state=None, step_outputs=None):
457
425
state = self .coco_metric
458
426
459
427
state .update_state (
460
- step_outputs ['ground_truths' ], step_outputs [ 'predictions' ]
461
- )
428
+ step_outputs ['ground_truths' ],
429
+ step_outputs [ 'predictions' ] )
462
430
return state
463
431
464
432
def reduce_aggregated_logs (self , aggregated_logs , global_step = None ):
0 commit comments