@@ -275,22 +275,27 @@ def _parse_train_data(self, data):
275
275
if self ._use_bfloat16 :
276
276
image = tf .cast (image , dtype = tf .bfloat16 )
277
277
278
+ inputs = {
279
+ 'image' : image ,
280
+ 'image_info' : image_info ,
281
+ }
278
282
# Packs labels for model_fn outputs.
279
283
labels = {
280
284
'anchor_boxes' : input_anchor .multilevel_boxes ,
281
285
'image_info' : image_info ,
282
286
'rpn_score_targets' : rpn_score_targets ,
283
287
'rpn_box_targets' : rpn_box_targets ,
284
288
}
285
- labels ['gt_boxes' ] = input_utils .pad_to_fixed_size (
286
- boxes , self ._max_num_instances , - 1 )
287
- labels ['gt_classes' ] = input_utils .pad_to_fixed_size (
289
+ inputs ['gt_boxes' ] = input_utils .pad_to_fixed_size (boxes ,
290
+ self ._max_num_instances ,
291
+ - 1 )
292
+ inputs ['gt_classes' ] = input_utils .pad_to_fixed_size (
288
293
classes , self ._max_num_instances , - 1 )
289
294
if self ._include_mask :
290
- labels ['gt_masks' ] = input_utils .pad_to_fixed_size (
295
+ inputs ['gt_masks' ] = input_utils .pad_to_fixed_size (
291
296
masks , self ._max_num_instances , - 1 )
292
297
293
- return image , labels
298
+ return inputs , labels
294
299
295
300
def _parse_eval_data (self , data ):
296
301
"""Parses data for evaluation."""
@@ -348,11 +353,7 @@ def _parse_predict_data(self, data):
348
353
self ._anchor_size ,
349
354
(image_height , image_width ))
350
355
351
- labels = {
352
- 'source_id' : dataloader_utils .process_source_id (data ['source_id' ]),
353
- 'anchor_boxes' : input_anchor .multilevel_boxes ,
354
- 'image_info' : image_info ,
355
- }
356
+ labels = {}
356
357
357
358
if self ._mode == ModeKeys .PREDICT_WITH_GT :
358
359
# Converts boxes from normalized coordinates to pixel coordinates.
@@ -372,6 +373,11 @@ def _parse_predict_data(self, data):
372
373
groundtruths ['source_id' ])
373
374
groundtruths = dataloader_utils .pad_groundtruths_to_fixed_size (
374
375
groundtruths , self ._max_num_instances )
376
+ # TODO(yeqing): Remove the `groundtrtuh` layer key (no longer needed).
375
377
labels ['groundtruths' ] = groundtruths
378
+ inputs = {
379
+ 'image' : image ,
380
+ 'image_info' : image_info ,
381
+ }
376
382
377
- return image , labels
383
+ return inputs , labels
0 commit comments