@@ -234,18 +234,24 @@ def label_anchors(
234
234
box_weights = self .target_gather (weights , match_indices , mask )
235
235
ignore_mask = tf .equal (match_indicators , - 2 )
236
236
cls_weights = self .target_gather (weights , match_indices , ignore_mask )
237
- box_targets_list = box_list .BoxList (box_targets )
238
- anchor_box_list = box_list .BoxList (flattened_anchor_boxes )
239
- box_targets = self .box_coder .encode (box_targets_list , anchor_box_list )
237
+ box_targets = box_list .BoxList (box_targets )
238
+ anchor_box = box_list .BoxList (flattened_anchor_boxes )
239
+ box_targets = self .box_coder .encode (box_targets , anchor_box )
240
240
241
241
# Unpacks labels into multi-level representations.
242
- cls_targets_dict = unpack_targets (cls_targets , anchor_boxes )
243
- box_targets_dict = unpack_targets (box_targets , anchor_boxes )
244
- attribute_targets_dict = {}
245
- for k , v in att_targets .items ():
246
- attribute_targets_dict [k ] = unpack_targets (v , anchor_boxes )
247
-
248
- return cls_targets_dict , box_targets_dict , attribute_targets_dict , cls_weights , box_weights
242
+ cls_targets = unpack_targets (cls_targets , anchor_boxes )
243
+ box_targets = unpack_targets (box_targets , anchor_boxes )
244
+ attribute_targets = {
245
+ k : unpack_targets (v , anchor_boxes ) for k , v in att_targets .items ()
246
+ }
247
+
248
+ return (
249
+ cls_targets ,
250
+ box_targets ,
251
+ attribute_targets ,
252
+ cls_weights ,
253
+ box_weights ,
254
+ )
249
255
250
256
251
257
class RpnAnchorLabeler (AnchorLabeler ):
0 commit comments