Skip to content

Commit 357794b

Browse files
No public description
PiperOrigin-RevId: 569683088
1 parent 3339258 commit 357794b

File tree

3 files changed

+186
-8
lines changed

3 files changed

+186
-8
lines changed

official/projects/centernet/modeling/layers/detection_generator.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class CenterNetDetectionGenerator(tf.keras.layers.Layer):
3434
"""CenterNet Detection Generator."""
3535

3636
def __init__(self,
37-
input_image_dims: int = 512,
37+
input_image_dims: tuple[int, int] | int = 512,
3838
net_down_scale: int = 4,
3939
max_detections: int = 100,
4040
peak_error: float = 1e-6,
@@ -47,7 +47,10 @@ def __init__(self,
4747
"""Initialize CenterNet Detection Generator.
4848
4949
Args:
50-
input_image_dims: An `int` that specifies the input image size.
50+
input_image_dims: The input image size. If it is a tuple of two `int`s, it
51+
is the size (height, width) of the input images. If it is an `int`, the
52+
input images are supposed to be squared images whose height and width
53+
are equal.
5154
net_down_scale: An `int` that specifies stride of the output.
5255
max_detections: An `int` specifying the maximum number of bounding
5356
boxes generated. This is an upper bound, so the number of generated
@@ -67,6 +70,9 @@ def __init__(self,
6770
"""
6871
super(CenterNetDetectionGenerator, self).__init__(**kwargs)
6972

73+
if isinstance(input_image_dims, int):
74+
input_image_dims = (input_image_dims, input_image_dims)
75+
7076
# Object center selection parameters
7177
self._max_detections = max_detections
7278
self._peak_error = peak_error
@@ -246,10 +252,28 @@ def get_boxes(self,
246252
return boxes, detection_classes
247253

248254
def convert_strided_predictions_to_normalized_boxes(self, boxes: tf.Tensor):
255+
"""Converts strided predictions to normalized boxes.
256+
257+
Args:
258+
boxes: A tf.Tensor of shape [batch_size, num_predictions, 4], representing
259+
the strided predictions of the detected objects.
260+
261+
Returns:
262+
A tf.Tensor of shape [batch_size, num_predictions, 4], representing
263+
the normalized boxes of the detected objects.
264+
"""
249265
boxes = boxes * tf.cast(self._net_down_scale, boxes.dtype)
250-
boxes = boxes / tf.cast(self._input_image_dims, boxes.dtype)
251-
boxes = tf.clip_by_value(boxes, 0.0, 1.0)
252-
return boxes
266+
267+
height = tf.cast(self._input_image_dims[0], boxes.dtype)
268+
width = tf.cast(self._input_image_dims[1], boxes.dtype)
269+
ymin = boxes[..., 0:1] / height
270+
xmin = boxes[..., 1:2] / width
271+
ymax = boxes[..., 2:3] / height
272+
xmax = boxes[..., 3:4] / width
273+
274+
normalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1)
275+
normalized_boxes = tf.clip_by_value(normalized_boxes, 0.0, 1.0)
276+
return normalized_boxes
253277

254278
def __call__(self, inputs):
255279
# Get heatmaps from decoded outputs via final hourglass stack output
@@ -308,8 +332,7 @@ def __call__(self, inputs):
308332
nms_thresh=0.4)
309333

310334
num_det = tf.reduce_sum(tf.cast(scores > 0, dtype=tf.int32), axis=1)
311-
boxes = box_ops.denormalize_boxes(
312-
boxes, [self._input_image_dims, self._input_image_dims])
335+
boxes = box_ops.denormalize_boxes(boxes, self._input_image_dims)
313336

314337
return {
315338
'boxes': boxes,
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for Centernet detection_generator."""
16+
17+
from collections.abc import Mapping, Sequence
18+
19+
from absl.testing import parameterized
20+
import tensorflow as tf
21+
22+
from official.projects.centernet.modeling.layers import detection_generator
23+
24+
25+
def _build_input_example(
26+
batch_size: int, height: int, width: int, num_classes: int, num_outputs: int
27+
) -> Mapping[str, Sequence[tf.Tensor]]:
28+
"""Builds a random input example for CenterNetDetectionGenerator.
29+
30+
Args:
31+
batch_size: The batch size.
32+
height: The height of the feature_map.
33+
width: The width of the feature_map.
34+
num_classes: The number of classes to detect.
35+
num_outputs: The number of output heatmaps, which corresponds to the length
36+
of CenterNetHead's input_levels.
37+
38+
Returns:
39+
A dictionary, mapping from feature names to sequences of tensors.
40+
"""
41+
return {
42+
'ct_heatmaps': [
43+
tf.random.normal([batch_size, height, width, num_classes])
44+
for _ in range(num_outputs)
45+
],
46+
'ct_size': [
47+
tf.random.normal([batch_size, height, width, 2])
48+
for _ in range(num_outputs)
49+
],
50+
'ct_offset': [
51+
tf.random.normal([batch_size, height, width, 2])
52+
for _ in range(num_outputs)
53+
],
54+
}
55+
56+
57+
class CenterNetDetectionGeneratorTest(parameterized.TestCase, tf.test.TestCase):
58+
59+
@parameterized.parameters(
60+
(1, 256),
61+
(1, 512),
62+
(2, 256),
63+
(2, 512),
64+
)
65+
def test_squered_image_forward(self, batch_size, input_image_dims):
66+
max_detections = 128
67+
num_classes = 80
68+
generator = detection_generator.CenterNetDetectionGenerator(
69+
input_image_dims=input_image_dims, max_detections=max_detections
70+
)
71+
test_input = _build_input_example(
72+
batch_size=batch_size,
73+
height=input_image_dims,
74+
width=input_image_dims,
75+
num_classes=num_classes,
76+
num_outputs=2,
77+
)
78+
79+
output = generator(test_input)
80+
81+
self.assert_detection_generator_output_shapes(
82+
output, batch_size, max_detections
83+
)
84+
85+
@parameterized.parameters(
86+
(1, (256, 512)),
87+
(1, (512, 256)),
88+
(2, (256, 512)),
89+
(2, (512, 256)),
90+
)
91+
def test_rectangular_image_forward(self, batch_size, input_image_dims):
92+
max_detections = 128
93+
num_classes = 80
94+
generator = detection_generator.CenterNetDetectionGenerator(
95+
input_image_dims=input_image_dims, max_detections=max_detections
96+
)
97+
test_input = _build_input_example(
98+
batch_size=batch_size,
99+
height=input_image_dims[0],
100+
width=input_image_dims[1],
101+
num_classes=num_classes,
102+
num_outputs=2,
103+
)
104+
105+
output = generator(test_input)
106+
107+
self.assert_detection_generator_output_shapes(
108+
output, batch_size, max_detections
109+
)
110+
111+
def assert_detection_generator_output_shapes(
112+
self,
113+
output: Mapping[str, tf.Tensor],
114+
batch_size: int,
115+
max_detections: int,
116+
):
117+
self.assertAllEqual(output['boxes'].shape, (batch_size, max_detections, 4))
118+
self.assertAllEqual(output['classes'].shape, (batch_size, max_detections))
119+
self.assertAllEqual(
120+
output['confidence'].shape, (batch_size, max_detections)
121+
)
122+
self.assertAllEqual(output['num_detections'].shape, (batch_size,))
123+
124+
@parameterized.parameters(
125+
(256,),
126+
(512,),
127+
((256, 512),),
128+
((512, 256),),
129+
)
130+
def test_serialize_deserialize(self, input_image_dims):
131+
kwargs = {
132+
'input_image_dims': input_image_dims,
133+
'net_down_scale': 4,
134+
'max_detections': 128,
135+
'peak_error': 1e-6,
136+
'peak_extract_kernel_size': 3,
137+
'class_offset': 1,
138+
'use_nms': False,
139+
'nms_pre_thresh': 0.1,
140+
'nms_thresh': 0.5,
141+
}
142+
143+
generator = detection_generator.CenterNetDetectionGenerator(**kwargs)
144+
new_generator = detection_generator.CenterNetDetectionGenerator.from_config(
145+
generator.get_config()
146+
)
147+
148+
self.assertAllEqual(generator.get_config(), new_generator.get_config())
149+
150+
151+
if __name__ == '__main__':
152+
tf.test.main()

official/projects/centernet/tasks/centernet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,10 @@ def build_model(self):
130130
peak_extract_kernel_size=dg_config.peak_extract_kernel_size,
131131
class_offset=dg_config.class_offset,
132132
net_down_scale=self._net_down_scale,
133-
input_image_dims=model_config.input_size[0],
133+
input_image_dims=(
134+
model_config.input_size[0],
135+
model_config.input_size[1],
136+
),
134137
use_nms=dg_config.use_nms,
135138
nms_pre_thresh=dg_config.nms_pre_thresh,
136139
nms_thresh=dg_config.nms_thresh)

0 commit comments

Comments
 (0)