Skip to content

Commit 17bf360

Browse files
No public description
PiperOrigin-RevId: 721434742
1 parent 1748aa9 commit 17bf360

File tree

1 file changed

+65
-97
lines changed

1 file changed

+65
-97
lines changed

official/projects/detr/tasks/detection.py

Lines changed: 65 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
"""DETR detection task definition."""
16-
1716
from typing import Optional
1817

1918
from absl import logging
@@ -48,25 +47,21 @@ class DetectionTask(base_task.Task):
4847
def build_model(self):
4948
"""Build DETR model."""
5049

51-
input_specs = tf_keras.layers.InputSpec(
52-
shape=[None] + self._task_config.model.input_size
53-
)
50+
input_specs = tf_keras.layers.InputSpec(shape=[None] +
51+
self._task_config.model.input_size)
5452

5553
backbone = backbones.factory.build_backbone(
5654
input_specs=input_specs,
5755
backbone_config=self._task_config.model.backbone,
58-
norm_activation_config=self._task_config.model.norm_activation,
59-
)
60-
61-
model = detr.DETR(
62-
backbone,
63-
self._task_config.model.backbone_endpoint_name,
64-
self._task_config.model.num_queries,
65-
self._task_config.model.hidden_size,
66-
self._task_config.model.num_classes,
67-
self._task_config.model.num_encoder_layers,
68-
self._task_config.model.num_decoder_layers,
69-
)
56+
norm_activation_config=self._task_config.model.norm_activation)
57+
58+
model = detr.DETR(backbone,
59+
self._task_config.model.backbone_endpoint_name,
60+
self._task_config.model.num_queries,
61+
self._task_config.model.hidden_size,
62+
self._task_config.model.num_classes,
63+
self._task_config.model.num_encoder_layers,
64+
self._task_config.model.num_decoder_layers)
7065
return model
7166

7267
def initialize(self, model: tf_keras.Model):
@@ -89,13 +84,12 @@ def initialize(self, model: tf_keras.Model):
8984
status = ckpt.restore(ckpt_dir_or_file)
9085
status.expect_partial().assert_existing_objects_matched()
9186

92-
logging.info(
93-
'Finished loading pretrained checkpoint from %s', ckpt_dir_or_file
94-
)
87+
logging.info('Finished loading pretrained checkpoint from %s',
88+
ckpt_dir_or_file)
9589

96-
def build_inputs(
97-
self, params, input_context: Optional[tf.distribute.InputContext] = None
98-
):
90+
def build_inputs(self,
91+
params,
92+
input_context: Optional[tf.distribute.InputContext] = None):
9993
"""Build input dataset."""
10094
if isinstance(params, coco.COCODataConfig):
10195
dataset = coco.COCODataLoader(params).load(input_context)
@@ -106,17 +100,14 @@ def build_inputs(
106100
decoder_cfg = params.decoder.get()
107101
if params.decoder.type == 'simple_decoder':
108102
decoder = tf_example_decoder.TfExampleDecoder(
109-
regenerate_source_id=decoder_cfg.regenerate_source_id
110-
)
103+
regenerate_source_id=decoder_cfg.regenerate_source_id)
111104
elif params.decoder.type == 'label_map_decoder':
112105
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
113106
label_map=decoder_cfg.label_map,
114-
regenerate_source_id=decoder_cfg.regenerate_source_id,
115-
)
107+
regenerate_source_id=decoder_cfg.regenerate_source_id)
116108
else:
117-
raise ValueError(
118-
'Unknown decoder type: {}!'.format(params.decoder.type)
119-
)
109+
raise ValueError('Unknown decoder type: {}!'.format(
110+
params.decoder.type))
120111

121112
parser = detr_input.Parser(
122113
class_offset=self._task_config.losses.class_offset,
@@ -127,8 +118,7 @@ def build_inputs(
127118
params,
128119
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
129120
decoder_fn=decoder.decode,
130-
parser_fn=parser.parse_fn(params.is_training),
131-
)
121+
parser_fn=parser.parse_fn(params.is_training))
132122
dataset = reader.read(input_context=input_context)
133123

134124
return dataset
@@ -187,8 +177,7 @@ def build_losses(self, outputs, labels, aux_losses=None):
187177
box_targets = labels['boxes']
188178

189179
cost = self._compute_cost(
190-
cls_outputs, box_outputs, cls_targets, box_targets
191-
)
180+
cls_outputs, box_outputs, cls_targets, box_targets)
192181

193182
_, indices = matchers.hungarian_matching(cost)
194183
indices = tf.stop_gradient(indices)
@@ -199,53 +188,45 @@ def build_losses(self, outputs, labels, aux_losses=None):
199188

200189
background = tf.equal(cls_targets, 0)
201190
num_boxes = tf.reduce_sum(
202-
tf.cast(tf.logical_not(background), tf.float32), axis=-1
203-
)
191+
tf.cast(tf.logical_not(background), tf.float32), axis=-1)
204192

205193
# Down-weight background to account for class imbalance.
206194
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
207-
labels=cls_targets, logits=cls_assigned
208-
)
195+
labels=cls_targets, logits=cls_assigned)
209196
cls_loss = self._task_config.losses.lambda_cls * tf.where(
210-
background,
211-
self._task_config.losses.background_cls_weight * xentropy,
212-
xentropy,
213-
)
197+
background, self._task_config.losses.background_cls_weight * xentropy,
198+
xentropy)
214199
cls_weights = tf.where(
215200
background,
216201
self._task_config.losses.background_cls_weight * tf.ones_like(cls_loss),
217-
tf.ones_like(cls_loss),
218-
)
202+
tf.ones_like(cls_loss))
219203

220204
# Box loss is only calculated on non-background class.
221205
l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1)
222206
box_loss = self._task_config.losses.lambda_box * tf.where(
223-
background, tf.zeros_like(l_1), l_1
224-
)
207+
background, tf.zeros_like(l_1), l_1)
225208

226209
# Giou loss is only calculated on non-background class.
227-
giou = tf.linalg.diag_part(
228-
1.0
229-
- box_ops.bbox_generalized_overlap(
230-
box_ops.cycxhw_to_yxyx(box_assigned),
231-
box_ops.cycxhw_to_yxyx(box_targets),
232-
)
233-
)
210+
giou = tf.linalg.diag_part(1.0 - box_ops.bbox_generalized_overlap(
211+
box_ops.cycxhw_to_yxyx(box_assigned),
212+
box_ops.cycxhw_to_yxyx(box_targets)
213+
))
234214
giou_loss = self._task_config.losses.lambda_giou * tf.where(
235-
background, tf.zeros_like(giou), giou
236-
)
215+
background, tf.zeros_like(giou), giou)
237216

238217
# Consider doing all reduce once in train_step to speed up.
239218
num_boxes_per_replica = tf.reduce_sum(num_boxes)
240219
cls_weights_per_replica = tf.reduce_sum(cls_weights)
241220
replica_context = tf.distribute.get_replica_context()
242221
num_boxes_sum, cls_weights_sum = replica_context.all_reduce(
243222
tf.distribute.ReduceOp.SUM,
244-
[num_boxes_per_replica, cls_weights_per_replica],
245-
)
246-
cls_loss = tf.math.divide_no_nan(tf.reduce_sum(cls_loss), cls_weights_sum)
247-
box_loss = tf.math.divide_no_nan(tf.reduce_sum(box_loss), num_boxes_sum)
248-
giou_loss = tf.math.divide_no_nan(tf.reduce_sum(giou_loss), num_boxes_sum)
223+
[num_boxes_per_replica, cls_weights_per_replica])
224+
cls_loss = tf.math.divide_no_nan(
225+
tf.reduce_sum(cls_loss), cls_weights_sum)
226+
box_loss = tf.math.divide_no_nan(
227+
tf.reduce_sum(box_loss), num_boxes_sum)
228+
giou_loss = tf.math.divide_no_nan(
229+
tf.reduce_sum(giou_loss), num_boxes_sum)
249230

250231
aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0
251232

@@ -264,8 +245,7 @@ def build_metrics(self, training=True):
264245
annotation_file=self._task_config.annotation_file,
265246
include_mask=False,
266247
need_rescale_bboxes=True,
267-
per_category_metrics=self._task_config.per_category_metrics,
268-
)
248+
per_category_metrics=self._task_config.per_category_metrics)
269249
return metrics
270250

271251
def train_step(self, inputs, model, optimizer, metrics=None):
@@ -355,8 +335,7 @@ def validation_step(self, inputs, model, metrics=None):
355335

356336
outputs = model(features, training=False)[-1]
357337
loss, cls_loss, box_loss, giou_loss = self.build_losses(
358-
outputs=outputs, labels=labels, aux_losses=model.losses
359-
)
338+
outputs=outputs, labels=labels, aux_losses=model.losses)
360339

361340
# Multiply for logging.
362341
# Since we expect the gradient replica sum to happen in the optimizer,
@@ -374,46 +353,35 @@ def validation_step(self, inputs, model, metrics=None):
374353
# This is for backward compatibility.
375354
if 'detection_boxes' not in outputs:
376355
detection_boxes = box_ops.cycxhw_to_yxyx(
377-
outputs['box_outputs']
378-
) * tf.expand_dims(
379-
tf.concat(
380-
[
381-
labels['image_info'][:, 1:2, 0],
382-
labels['image_info'][:, 1:2, 1],
383-
labels['image_info'][:, 1:2, 0],
384-
labels['image_info'][:, 1:2, 1],
356+
outputs['box_outputs']) * tf.expand_dims(
357+
tf.concat([
358+
labels['image_info'][:, 1:2, 0], labels['image_info'][:, 1:2,
359+
1],
360+
labels['image_info'][:, 1:2, 0], labels['image_info'][:, 1:2,
361+
1]
385362
],
386-
axis=1,
387-
),
388-
axis=1,
389-
)
363+
axis=1),
364+
axis=1)
390365
else:
391366
detection_boxes = outputs['detection_boxes']
392367

393-
if 'detection_scores' not in outputs:
394-
detection_scores = tf.math.reduce_max(
395-
tf.nn.softmax(outputs['cls_outputs'])[:, :, 1:], axis=-1
396-
)
397-
else:
398-
detection_scores = outputs['detection_scores']
368+
detection_scores = tf.math.reduce_max(
369+
tf.nn.softmax(outputs['cls_outputs'])[:, :, 1:], axis=-1
370+
) if 'detection_scores' not in outputs else outputs['detection_scores']
399371

400372
if 'detection_classes' not in outputs:
401-
detection_classes = (
402-
tf.math.argmax(outputs['cls_outputs'][:, :, 1:], axis=-1) + 1
403-
)
373+
detection_classes = tf.math.argmax(
374+
outputs['cls_outputs'][:, :, 1:], axis=-1) + 1
404375
else:
405376
detection_classes = outputs['detection_classes']
406377

407378
if 'num_detections' not in outputs:
408379
num_detections = tf.reduce_sum(
409380
tf.cast(
410381
tf.math.greater(
411-
tf.math.reduce_max(outputs['cls_outputs'], axis=-1), 0
412-
),
413-
tf.int32,
414-
),
415-
axis=-1,
416-
)
382+
tf.math.reduce_max(outputs['cls_outputs'], axis=-1), 0),
383+
tf.int32),
384+
axis=-1)
417385
else:
418386
num_detections = outputs['num_detections']
419387

@@ -423,21 +391,21 @@ def validation_step(self, inputs, model, metrics=None):
423391
'detection_classes': detection_classes,
424392
'num_detections': num_detections,
425393
'source_id': labels['id'],
426-
'image_info': labels['image_info'],
394+
'image_info': labels['image_info']
427395
}
428396

429397
ground_truths = {
430398
'source_id': labels['id'],
431399
'height': labels['image_info'][:, 0:1, 0],
432400
'width': labels['image_info'][:, 0:1, 1],
433401
'num_detections': tf.reduce_sum(
434-
tf.cast(tf.math.greater(labels['classes'], 0), tf.int32), axis=-1
435-
),
402+
tf.cast(tf.math.greater(labels['classes'], 0), tf.int32), axis=-1),
436403
'boxes': labels['gt_boxes'],
437404
'classes': labels['classes'],
438-
'is_crowds': labels['is_crowd'],
405+
'is_crowds': labels['is_crowd']
439406
}
440-
logs.update({'predictions': predictions, 'ground_truths': ground_truths})
407+
logs.update({'predictions': predictions,
408+
'ground_truths': ground_truths})
441409

442410
all_losses = {
443411
'cls_loss': cls_loss,
@@ -457,8 +425,8 @@ def aggregate_logs(self, state=None, step_outputs=None):
457425
state = self.coco_metric
458426

459427
state.update_state(
460-
step_outputs['ground_truths'], step_outputs['predictions']
461-
)
428+
step_outputs['ground_truths'],
429+
step_outputs['predictions'])
462430
return state
463431

464432
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):

0 commit comments

Comments
 (0)