Skip to content

Commit 4ccce0d

Browse files
Vighnesh BirodkarTF Object Detection Team
authored andcommitted
Add IOU heatmap loss for CenterNet.
PiperOrigin-RevId: 421140005
1 parent 671615c commit 4ccce0d

File tree

4 files changed

+193
-11
lines changed

4 files changed

+193
-11
lines changed

research/object_detection/core/target_assigner.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,9 @@ def __init__(self,
925925
compute_heatmap_sparse=False,
926926
keypoint_class_id=None,
927927
keypoint_indices=None,
928-
keypoint_weights_for_center=None):
928+
keypoint_weights_for_center=None,
929+
box_heatmap_type='adaptive_gaussian',
930+
heatmap_exponent=1.0):
929931
"""Initializes the target assigner.
930932
931933
Args:
@@ -947,6 +949,17 @@ def __init__(self,
947949
the number of keypoints. The object center is calculated by the weighted
948950
mean of the keypoint locations. If not provided, the object center is
949951
determined by the center of the bounding box (default behavior).
952+
box_heatmap_type: str, the algorithm used to compute the box heatmap,
953+
used when calling the assign_center_targets_from_boxes method.
954+
Options are:
955+
'adaptaive_gaussian': A box-size adaptive Gaussian from the original
956+
paper[1].
957+
'iou': IOU based heatmap target where each point is assigned an IOU
958+
based on its location, assuming that it produced a box centered at
959+
that point with the correct size.
960+
heatmap_exponent: float, The generated heatmap is exponentiated with
961+
this number. A number > 1 will result in the heatmap being more peaky
962+
and a number < 1 will cause the heatmap to be more spreadout.
950963
"""
951964

952965
self._stride = stride
@@ -955,6 +968,8 @@ def __init__(self,
955968
self._keypoint_class_id = keypoint_class_id
956969
self._keypoint_indices = keypoint_indices
957970
self._keypoint_weights_for_center = keypoint_weights_for_center
971+
self._box_heatmap_type = box_heatmap_type
972+
self._heatmap_exponent = heatmap_exponent
958973

959974
def assign_center_targets_from_boxes(self,
960975
height,
@@ -1018,19 +1033,29 @@ def assign_center_targets_from_boxes(self,
10181033
self._min_overlap)
10191034
# Apply the Gaussian kernel to the center coordinates. Returned heatmap
10201035
# has shape of [out_height, out_width, num_classes]
1021-
heatmap = ta_utils.coordinates_to_heatmap(
1022-
y_grid=y_grid,
1023-
x_grid=x_grid,
1024-
y_coordinates=y_center,
1025-
x_coordinates=x_center,
1026-
sigma=sigma,
1027-
channel_onehot=class_targets,
1028-
channel_weights=weights,
1029-
sparse=self._compute_heatmap_sparse)
1036+
1037+
if self._box_heatmap_type == 'adaptive_gaussian':
1038+
heatmap = ta_utils.coordinates_to_heatmap(
1039+
y_grid=y_grid,
1040+
x_grid=x_grid,
1041+
y_coordinates=y_center,
1042+
x_coordinates=x_center,
1043+
sigma=sigma,
1044+
channel_onehot=class_targets,
1045+
channel_weights=weights,
1046+
sparse=self._compute_heatmap_sparse)
1047+
elif self._box_heatmap_type == 'iou':
1048+
heatmap = ta_utils.coordinates_to_iou(y_grid, x_grid, boxes,
1049+
class_targets, weights)
1050+
else:
1051+
raise ValueError(f'Unknown heatmap type - {self._box_heatmap_type}')
1052+
10301053
heatmaps.append(heatmap)
10311054

10321055
# Return the stacked heatmaps over the batch.
1033-
return tf.stack(heatmaps, axis=0)
1056+
stacked_heatmaps = tf.stack(heatmaps, axis=0)
1057+
return (tf.pow(stacked_heatmaps, self._heatmap_exponent) if
1058+
self._heatmap_exponent != 1.0 else stacked_heatmaps)
10341059

10351060
def assign_center_targets_from_keypoints(self,
10361061
height,

research/object_detection/core/target_assigner_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,6 +1678,66 @@ def graph_fn():
16781678
np.testing.assert_array_equal(preds, [[1, 2], [3, 4], [5, 6], [7, 8]])
16791679

16801680

1681+
class CenterNetIOUTargetAssignerTest(test_case.TestCase):
1682+
1683+
def setUp(self):
1684+
super(CenterNetIOUTargetAssignerTest, self).setUp()
1685+
1686+
self._box_center = [0.0, 0.0, 1.0, 1.0]
1687+
self._box_center_small = [0.25, 0.25, 0.75, 0.75]
1688+
self._box_lower_left = [0.5, 0.0, 1.0, 0.5]
1689+
self._box_center_offset = [0.1, 0.05, 1.0, 1.0]
1690+
self._box_odd_coordinates = [0.1625, 0.2125, 0.5625, 0.9625]
1691+
1692+
def test_center_location(self):
1693+
"""Test that the centers are at the correct location."""
1694+
def graph_fn():
1695+
box_batch = [tf.constant([self._box_center, self._box_lower_left]),
1696+
tf.constant([self._box_lower_left, self._box_center])]
1697+
classes = [
1698+
tf.one_hot([0, 1], depth=4),
1699+
tf.one_hot([2, 2], depth=4)
1700+
]
1701+
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(
1702+
4, box_heatmap_type='iou')
1703+
targets = assigner.assign_center_targets_from_boxes(
1704+
80, 80, box_batch, classes)
1705+
return targets
1706+
targets = self.execute(graph_fn, [])
1707+
self.assertEqual((10, 10), _array_argmax(targets[0, :, :, 0]))
1708+
self.assertAlmostEqual(1.0, targets[0, 10, 10, 0])
1709+
self.assertEqual((15, 5), _array_argmax(targets[0, :, :, 1]))
1710+
self.assertAlmostEqual(1.0, targets[0, 15, 5, 1])
1711+
1712+
self.assertAlmostEqual(1.0, targets[1, 15, 5, 2])
1713+
self.assertAlmostEqual(1.0, targets[1, 10, 10, 2])
1714+
self.assertAlmostEqual(0.0, targets[1, 0, 19, 1])
1715+
1716+
def test_exponent(self):
1717+
"""Test that the centers are at the correct location."""
1718+
def graph_fn():
1719+
box_batch = [tf.constant([self._box_center, self._box_lower_left]),
1720+
tf.constant([self._box_lower_left, self._box_center])]
1721+
classes = [
1722+
tf.one_hot([0], depth=2),
1723+
]
1724+
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(
1725+
1, box_heatmap_type='iou')
1726+
targets = assigner.assign_center_targets_from_boxes(
1727+
4, 4, box_batch, classes)
1728+
1729+
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(
1730+
1, box_heatmap_type='iou', heatmap_exponent=0.5)
1731+
targets_pow = assigner.assign_center_targets_from_boxes(
1732+
4, 4, box_batch, classes)
1733+
return targets, targets_pow
1734+
1735+
targets, targets_pow = self.execute(graph_fn, [])
1736+
self.assertLess(targets[0, 2, 3, 0], 1.0)
1737+
self.assertLess(targets_pow[0, 2, 3, 0], 1.0)
1738+
self.assertAlmostEqual(targets[0, 2, 3, 0], targets_pow[0, 2, 3, 0] ** 2)
1739+
1740+
16811741
class CenterNetKeypointTargetAssignerTest(test_case.TestCase):
16821742

16831743
def test_keypoint_heatmap_targets(self):

research/object_detection/utils/target_assigner_utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,77 @@ def compute_floor_offsets_with_indices(y_source,
236236
return offsets, indices
237237

238238

239+
def coordinates_to_iou(y_grid, x_grid, blist,
240+
channels_onehot, weights=None):
241+
"""Computes a per-pixel IoU with groundtruth boxes.
242+
243+
At each pixel, we return the IoU assuming that we predicted the
244+
ideal height and width for the box at that location.
245+
246+
Args:
247+
y_grid: A 2D tensor with shape [height, width] which contains the grid
248+
y-coordinates given in the (output) image dimensions.
249+
x_grid: A 2D tensor with shape [height, width] which contains the grid
250+
x-coordinates given in the (output) image dimensions.
251+
blist: A BoxList object with `num_instances` number of boxes.
252+
channels_onehot: A 2D tensor with shape [num_instances, num_channels]
253+
representing the one-hot encoded channel labels for each point.
254+
weights: A 1D tensor with shape [num_instances] corresponding to the
255+
weight of each instance.
256+
257+
Returns:
258+
iou_heatmap: A [height, width, num_channels] shapes float tensor denoting
259+
the IoU based heatmap.
260+
"""
261+
262+
image_height, image_width = tf.shape(y_grid)[0], tf.shape(y_grid)[1]
263+
num_pixels = image_height * image_width
264+
_, _, height, width = blist.get_center_coordinates_and_sizes()
265+
num_boxes = tf.shape(height)[0]
266+
267+
per_pixel_ymin = (y_grid[tf.newaxis, :, :] -
268+
(height[:, tf.newaxis, tf.newaxis] / 2.0))
269+
per_pixel_xmin = (x_grid[tf.newaxis, :, :] -
270+
(width[:, tf.newaxis, tf.newaxis] / 2.0))
271+
per_pixel_ymax = (y_grid[tf.newaxis, :, :] +
272+
(height[:, tf.newaxis, tf.newaxis] / 2.0))
273+
per_pixel_xmax = (x_grid[tf.newaxis, :, :] +
274+
(width[:, tf.newaxis, tf.newaxis] / 2.0))
275+
276+
# [num_boxes, height, width] -> [num_boxes * height * width]
277+
per_pixel_ymin = tf.reshape(
278+
per_pixel_ymin, [num_pixels * num_boxes])
279+
per_pixel_xmin = tf.reshape(
280+
per_pixel_xmin, [num_pixels * num_boxes])
281+
per_pixel_ymax = tf.reshape(
282+
per_pixel_ymax, [num_pixels * num_boxes])
283+
per_pixel_xmax = tf.reshape(
284+
per_pixel_xmax, [num_pixels * num_boxes])
285+
per_pixel_blist = box_list.BoxList(
286+
tf.stack([per_pixel_ymin, per_pixel_xmin,
287+
per_pixel_ymax, per_pixel_xmax], axis=1))
288+
289+
target_boxes = tf.tile(
290+
blist.get()[:, tf.newaxis, :], [1, num_pixels, 1])
291+
# [num_boxes, height * width, 4] -> [num_boxes * height * wdith, 4]
292+
target_boxes = tf.reshape(target_boxes,
293+
[num_pixels * num_boxes, 4])
294+
target_blist = box_list.BoxList(target_boxes)
295+
296+
ious = box_list_ops.matched_iou(target_blist, per_pixel_blist)
297+
ious = tf.reshape(ious, [num_boxes, image_height, image_width])
298+
per_class_iou = (
299+
ious[:, :, :, tf.newaxis] *
300+
channels_onehot[:, tf.newaxis, tf.newaxis, :])
301+
302+
if weights is not None:
303+
per_class_iou = (
304+
per_class_iou * weights[:, tf.newaxis, tf.newaxis, tf.newaxis])
305+
306+
per_class_iou = tf.maximum(per_class_iou, 0.0)
307+
return tf.reduce_max(per_class_iou, axis=0)
308+
309+
239310
def get_valid_keypoint_mask_for_class(keypoint_coordinates,
240311
class_id,
241312
class_onehot,

research/object_detection/utils/target_assigner_utils_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import tensorflow.compat.v1 as tf
2020

21+
from object_detection.core import box_list
2122
from object_detection.utils import target_assigner_utils as ta_utils
2223
from object_detection.utils import test_case
2324

@@ -265,6 +266,31 @@ def graph_fn():
265266
np.array([[0.0, 3.0, 4.0, 0.0, 4.0]]))
266267
self.assertAllEqual(valid, [[False, True, True, False, True]])
267268

269+
def test_coordinates_to_iou(self):
270+
271+
def graph_fn():
272+
y, x = tf.meshgrid(tf.range(32, dtype=tf.float32),
273+
tf.range(32, dtype=tf.float32))
274+
blist = box_list.BoxList(
275+
tf.constant([[0., 0., 32., 32.],
276+
[0., 0., 16., 16.],
277+
[0.0, 0.0, 4.0, 4.0]]))
278+
classes = tf.constant([[0., 1., 0.],
279+
[1., 0., 0.],
280+
[0., 0., 1.]])
281+
282+
result = ta_utils.coordinates_to_iou(
283+
y, x, blist, classes)
284+
return result
285+
286+
result = self.execute(graph_fn, [])
287+
self.assertEqual(result.shape, (32, 32, 3))
288+
self.assertAlmostEqual(result[0, 0, 0], 1.0 / 7.0)
289+
self.assertAlmostEqual(result[0, 0, 1], 1.0 / 7.0)
290+
self.assertAlmostEqual(result[0, 16, 0], 1.0 / 7.0)
291+
self.assertAlmostEqual(result[2, 2, 2], 1.0)
292+
self.assertAlmostEqual(result[8, 8, 2], 0.0)
293+
268294

269295
if __name__ == '__main__':
270296
tf.test.main()

0 commit comments

Comments
 (0)