Skip to content

Commit 99c900d

Browse files
Internal change
PiperOrigin-RevId: 531571643
1 parent 326930f commit 99c900d

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

official/vision/ops/anchor.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,18 +234,24 @@ def label_anchors(
234234
box_weights = self.target_gather(weights, match_indices, mask)
235235
ignore_mask = tf.equal(match_indicators, -2)
236236
cls_weights = self.target_gather(weights, match_indices, ignore_mask)
237-
box_targets_list = box_list.BoxList(box_targets)
238-
anchor_box_list = box_list.BoxList(flattened_anchor_boxes)
239-
box_targets = self.box_coder.encode(box_targets_list, anchor_box_list)
237+
box_targets = box_list.BoxList(box_targets)
238+
anchor_box = box_list.BoxList(flattened_anchor_boxes)
239+
box_targets = self.box_coder.encode(box_targets, anchor_box)
240240

241241
# Unpacks labels into multi-level representations.
242-
cls_targets_dict = unpack_targets(cls_targets, anchor_boxes)
243-
box_targets_dict = unpack_targets(box_targets, anchor_boxes)
244-
attribute_targets_dict = {}
245-
for k, v in att_targets.items():
246-
attribute_targets_dict[k] = unpack_targets(v, anchor_boxes)
247-
248-
return cls_targets_dict, box_targets_dict, attribute_targets_dict, cls_weights, box_weights
242+
cls_targets = unpack_targets(cls_targets, anchor_boxes)
243+
box_targets = unpack_targets(box_targets, anchor_boxes)
244+
attribute_targets = {
245+
k: unpack_targets(v, anchor_boxes) for k, v in att_targets.items()
246+
}
247+
248+
return (
249+
cls_targets,
250+
box_targets,
251+
attribute_targets,
252+
cls_weights,
253+
box_weights,
254+
)
249255

250256

251257
class RpnAnchorLabeler(AnchorLabeler):

0 commit comments

Comments
 (0)