Skip to content

Commit 8d9a16c

Browse files
yeqinglitensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 285844156
1 parent 913640d commit 8d9a16c

File tree

6 files changed

+26
-19
lines changed

6 files changed

+26
-19
lines changed

official/vision/detection/configs/base_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@
3030
BASE_CFG = {
3131
'model_dir': '',
3232
'use_tpu': True,
33+
'strategy_type': 'tpu',
3334
'isolate_session_state': False,
3435
'train': {
3536
'iterations_per_loop': 100,
36-
'train_batch_size': 64,
37+
'batch_size': 64,
3738
'total_steps': 22500,
3839
'num_cores_per_replica': None,
3940
'input_partition_dims': None,
@@ -57,13 +58,13 @@
5758
'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX,
5859
'train_file_pattern': '',
5960
'train_dataset_type': 'tfrecord',
60-
'transpose_input': True,
61+
'transpose_input': False,
6162
'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
6263
'l2_weight_decay': 0.0001,
6364
'gradient_clip_norm': 0.0,
6465
},
6566
'eval': {
66-
'eval_batch_size': 8,
67+
'batch_size': 8,
6768
'eval_samples': 5000,
6869
'min_eval_interval': 180,
6970
'eval_timeout': None,

official/vision/detection/configs/maskrcnn_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
'maskrcnn_parser': {
3535
'use_bfloat16': True,
3636
'output_size': [1024, 1024],
37+
'num_channels': 3,
3738
'rpn_match_threshold': 0.7,
3839
'rpn_unmatched_threshold': 0.3,
3940
'rpn_batch_size_per_im': 256,

official/vision/detection/dataloader/maskrcnn_parser.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -275,22 +275,27 @@ def _parse_train_data(self, data):
275275
if self._use_bfloat16:
276276
image = tf.cast(image, dtype=tf.bfloat16)
277277

278+
inputs = {
279+
'image': image,
280+
'image_info': image_info,
281+
}
278282
# Packs labels for model_fn outputs.
279283
labels = {
280284
'anchor_boxes': input_anchor.multilevel_boxes,
281285
'image_info': image_info,
282286
'rpn_score_targets': rpn_score_targets,
283287
'rpn_box_targets': rpn_box_targets,
284288
}
285-
labels['gt_boxes'] = input_utils.pad_to_fixed_size(
286-
boxes, self._max_num_instances, -1)
287-
labels['gt_classes'] = input_utils.pad_to_fixed_size(
289+
inputs['gt_boxes'] = input_utils.pad_to_fixed_size(boxes,
290+
self._max_num_instances,
291+
-1)
292+
inputs['gt_classes'] = input_utils.pad_to_fixed_size(
288293
classes, self._max_num_instances, -1)
289294
if self._include_mask:
290-
labels['gt_masks'] = input_utils.pad_to_fixed_size(
295+
inputs['gt_masks'] = input_utils.pad_to_fixed_size(
291296
masks, self._max_num_instances, -1)
292297

293-
return image, labels
298+
return inputs, labels
294299

295300
def _parse_eval_data(self, data):
296301
"""Parses data for evaluation."""
@@ -348,11 +353,7 @@ def _parse_predict_data(self, data):
348353
self._anchor_size,
349354
(image_height, image_width))
350355

351-
labels = {
352-
'source_id': dataloader_utils.process_source_id(data['source_id']),
353-
'anchor_boxes': input_anchor.multilevel_boxes,
354-
'image_info': image_info,
355-
}
356+
labels = {}
356357

357358
if self._mode == ModeKeys.PREDICT_WITH_GT:
358359
# Converts boxes from normalized coordinates to pixel coordinates.
@@ -372,6 +373,11 @@ def _parse_predict_data(self, data):
372373
groundtruths['source_id'])
373374
groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size(
374375
groundtruths, self._max_num_instances)
376+
# TODO(yeqing): Remove the `groundtrtuh` layer key (no longer needed).
375377
labels['groundtruths'] = groundtruths
378+
inputs = {
379+
'image': image,
380+
'image_info': image_info,
381+
}
376382

377-
return image, labels
383+
return inputs, labels

official/vision/detection/modeling/base_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(self, params):
9999
params.train.learning_rate)
100100

101101
self._frozen_variable_prefix = params.train.frozen_variable_prefix
102+
self._l2_weight_decay = params.train.l2_weight_decay
102103

103104
# Checkpoint restoration.
104105
self._checkpoint = params.train.checkpoint.as_dict()

official/vision/detection/modeling/losses.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class RpnBoxLoss(object):
147147
"""Region Proposal Network box regression loss function."""
148148

149149
def __init__(self, params):
150+
self._delta = params.huber_loss_delta
150151
self._huber_loss = tf.keras.losses.Huber(
151152
delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
152153

@@ -212,7 +213,7 @@ def __call__(self, class_outputs, class_targets):
212213
a scalar tensor representing total class loss.
213214
"""
214215
with tf.name_scope('fast_rcnn_loss'):
215-
_, _, _, num_classes = class_outputs.get_shape().as_list()
216+
_, _, num_classes = class_outputs.get_shape().as_list()
216217
class_targets = tf.cast(class_targets, dtype=tf.int32)
217218
class_targets_one_hot = tf.one_hot(class_targets, num_classes)
218219
return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot)
@@ -320,9 +321,6 @@ def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets,
320321
class MaskrcnnLoss(object):
321322
"""Mask R-CNN instance segmentation mask loss function."""
322323

323-
def __init__(self):
324-
raise ValueError('Not TF 2.0 ready.')
325-
326324
def __call__(self, mask_outputs, mask_targets, select_class_targets):
327325
"""Computes the mask loss of Mask-RCNN.
328326

official/vision/detection/modeling/retinanet_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def __init__(self, params):
5656
self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
5757
params.postprocess)
5858

59-
self._l2_weight_decay = params.train.l2_weight_decay
6059
self._transpose_input = params.train.transpose_input
6160
assert not self._transpose_input, 'Transpose input is not supportted.'
6261
# Input layer.
@@ -134,6 +133,7 @@ def build_model(self, params, mode=None):
134133
return self._keras_model
135134

136135
def post_processing(self, labels, outputs):
136+
# TODO(yeqing): Moves the output related part into build_outputs.
137137
required_output_fields = ['cls_outputs', 'box_outputs']
138138
for field in required_output_fields:
139139
if field not in outputs:

0 commit comments

Comments
 (0)