Skip to content

Commit eb9d363

Browse files
Another small fix to RetinaNet training example (keras-team#1306)
* fix to retinanet example * linting fix
1 parent 55b52e1 commit eb9d363

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

examples/training/object_detection/pascal_voc/retina_net.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def apply(inputs):
227227
bounding_boxes = keras_cv.bounding_box.convert_format(
228228
boxes, images=image, source="yxyx", target=bounding_box_format
229229
)
230-
bounding_boxes = {"boxes": boxes, "classes": classes}
230+
bounding_boxes = {"boxes": bounding_boxes, "classes": classes}
231231
return image, bounding_boxes
232232

233233
return apply
@@ -278,10 +278,13 @@ def apply(inputs):
278278
return apply
279279

280280

281-
def pad_fn(images, boxes):
282-
boxes = boxes.to_tensor(default_value=-1.0, shape=[GLOBAL_BATCH_SIZE, 32, 5])
283-
boxes = boxes[:, :, :4]
284-
classes = boxes[:, :, 4]
281+
def pad_fn(images, bounding_boxes):
282+
boxes = bounding_boxes["boxes"].to_tensor(
283+
default_value=-1.0, shape=[GLOBAL_BATCH_SIZE, 32, 4]
284+
)
285+
classes = bounding_boxes["classes"].to_tensor(
286+
default_value=-1.0, shape=[GLOBAL_BATCH_SIZE, 32]
287+
)
285288
return images, {"boxes": boxes, "classes": classes}
286289

287290

0 commit comments

Comments
 (0)