Skip to content

Commit fd499ca

Browse files
Internal change
PiperOrigin-RevId: 397467113
1 parent a6b0c05 commit fd499ca

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

official/vision/beta/modeling/retinanet_model.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def call(self,
7777
images: tf.Tensor,
7878
image_shape: Optional[tf.Tensor] = None,
7979
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
80+
output_intermediate_features: bool = False,
8081
training: bool = None) -> Mapping[str, tf.Tensor]:
8182
"""Forward pass of the RetinaNet model.
8283
@@ -92,6 +93,8 @@ def call(self,
9293
- key: `str`, the level of the multilevel predictions.
9394
- values: `Tensor`, the anchor coordinates of a particular feature
9495
level, whose shape is [height_l, width_l, num_anchors_per_location].
96+
output_intermediate_features: `bool` indicating whether to return the
97+
intermediate feature maps generated by backbone and decoder.
9598
training: `bool`, indicating whether it is in training mode.
9699
97100
Returns:
@@ -112,19 +115,26 @@ def call(self,
112115
feature level, whose shape is
113116
[batch, height_l, width_l, att_size * num_anchors_per_location].
114117
"""
118+
outputs = {}
115119
# Feature extraction.
116120
features = self.backbone(images)
121+
if output_intermediate_features:
122+
outputs.update(
123+
{'backbone_{}'.format(k): v for k, v in features.items()})
117124
if self.decoder:
118125
features = self.decoder(features)
126+
if output_intermediate_features:
127+
outputs.update(
128+
{'decoder_{}'.format(k): v for k, v in features.items()})
119129

120130
# Dense prediction. `raw_attributes` can be empty.
121131
raw_scores, raw_boxes, raw_attributes = self.head(features)
122132

123133
if training:
124-
outputs = {
134+
outputs.update({
125135
'cls_outputs': raw_scores,
126136
'box_outputs': raw_boxes,
127-
}
137+
})
128138
if raw_attributes:
129139
outputs.update({'attribute_outputs': raw_attributes})
130140
return outputs
@@ -145,12 +155,13 @@ def call(self,
145155
[tf.shape(images)[0], 1, 1, 1])
146156

147157
# Post-processing.
148-
final_results = self.detection_generator(
149-
raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes)
150-
outputs = {
158+
final_results = self.detection_generator(raw_boxes, raw_scores,
159+
anchor_boxes, image_shape,
160+
raw_attributes)
161+
outputs.update({
151162
'cls_outputs': raw_scores,
152163
'box_outputs': raw_boxes,
153-
}
164+
})
154165
if self.detection_generator.get_config()['apply_nms']:
155166
outputs.update({
156167
'detection_boxes': final_results['detection_boxes'],

official/vision/beta/modeling/retinanet_model_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ def test_build_model(self, use_separable_conv, build_anchor_boxes,
147147
],
148148
training=[True, False],
149149
has_att_heads=[True, False],
150+
output_intermediate_features=[True, False],
150151
))
151-
def test_forward(self, strategy, image_size, training, has_att_heads):
152+
def test_forward(self, strategy, image_size, training, has_att_heads,
153+
output_intermediate_features):
152154
"""Test for creation of a R50-FPN RetinaNet."""
153155
tf.keras.backend.set_image_data_format('channels_last')
154156
num_classes = 3
@@ -202,6 +204,7 @@ def test_forward(self, strategy, image_size, training, has_att_heads):
202204
images,
203205
image_shape,
204206
anchor_boxes,
207+
output_intermediate_features=output_intermediate_features,
205208
training=training)
206209

207210
if training:
@@ -247,6 +250,19 @@ def test_forward(self, strategy, image_size, training, has_att_heads):
247250
self.assertAllEqual(
248251
[2, 10, 1],
249252
model_outputs['detection_attributes']['depth'].numpy().shape)
253+
if output_intermediate_features:
254+
for l in range(2, 6):
255+
self.assertIn('backbone_{}'.format(l), model_outputs)
256+
self.assertAllEqual([
257+
2, image_size[0] // 2**l, image_size[1] // 2**l,
258+
backbone.output_specs[str(l)].as_list()[-1]
259+
], model_outputs['backbone_{}'.format(l)].numpy().shape)
260+
for l in range(min_level, max_level + 1):
261+
self.assertIn('decoder_{}'.format(l), model_outputs)
262+
self.assertAllEqual([
263+
2, image_size[0] // 2**l, image_size[1] // 2**l,
264+
decoder.output_specs[str(l)].as_list()[-1]
265+
], model_outputs['decoder_{}'.format(l)].numpy().shape)
250266

251267
def test_serialize_deserialize(self):
252268
"""Validate the network can be serialized and deserialized."""

0 commit comments

Comments
 (0)