Skip to content

Commit 1dc525a

Browse files
committed
Internal change
PiperOrigin-RevId: 515247220
1 parent fc0a507 commit 1dc525a

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

official/vision/modeling/layers/detection_generator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -661,9 +661,11 @@ def __call__(self,
661661
decoded_boxes = box_ops.decode_boxes(
662662
raw_boxes, anchor_boxes, weights=regression_weights)
663663

664-
# Box clipping
665-
decoded_boxes = box_ops.clip_boxes(
666-
decoded_boxes, tf.expand_dims(image_shape, axis=1))
664+
# Box clipping.
665+
if image_shape is not None:
666+
decoded_boxes = box_ops.clip_boxes(
667+
decoded_boxes, tf.expand_dims(image_shape, axis=1)
668+
)
667669

668670
if bbox_per_class:
669671
decoded_boxes = tf.reshape(
@@ -835,8 +837,10 @@ def _decode_multilevel_outputs(
835837
boxes_i = box_ops.decode_boxes(raw_boxes_i, anchor_boxes_i)
836838

837839
# Box clipping.
838-
boxes_i = box_ops.clip_boxes(
839-
boxes_i, tf.expand_dims(image_shape, axis=1))
840+
if image_shape is not None:
841+
boxes_i = box_ops.clip_boxes(
842+
boxes_i, tf.expand_dims(image_shape, axis=1)
843+
)
840844

841845
boxes.append(boxes_i)
842846
scores.append(scores_i)

0 commit comments

Comments
 (0)