@@ -77,6 +77,7 @@ def call(self,
77
77
images : tf .Tensor ,
78
78
image_shape : Optional [tf .Tensor ] = None ,
79
79
anchor_boxes : Optional [Mapping [str , tf .Tensor ]] = None ,
80
+ output_intermediate_features : bool = False ,
80
81
training : bool = None ) -> Mapping [str , tf .Tensor ]:
81
82
"""Forward pass of the RetinaNet model.
82
83
@@ -92,6 +93,8 @@ def call(self,
92
93
- key: `str`, the level of the multilevel predictions.
93
94
- values: `Tensor`, the anchor coordinates of a particular feature
94
95
level, whose shape is [height_l, width_l, num_anchors_per_location].
96
+ output_intermediate_features: `bool` indicating whether to return the
97
+ intermediate feature maps generated by backbone and decoder.
95
98
training: `bool`, indicating whether it is in training mode.
96
99
97
100
Returns:
@@ -112,19 +115,26 @@ def call(self,
112
115
feature level, whose shape is
113
116
[batch, height_l, width_l, att_size * num_anchors_per_location].
114
117
"""
118
+ outputs = {}
115
119
# Feature extraction.
116
120
features = self .backbone (images )
121
+ if output_intermediate_features :
122
+ outputs .update (
123
+ {'backbone_{}' .format (k ): v for k , v in features .items ()})
117
124
if self .decoder :
118
125
features = self .decoder (features )
126
+ if output_intermediate_features :
127
+ outputs .update (
128
+ {'decoder_{}' .format (k ): v for k , v in features .items ()})
119
129
120
130
# Dense prediction. `raw_attributes` can be empty.
121
131
raw_scores , raw_boxes , raw_attributes = self .head (features )
122
132
123
133
if training :
124
- outputs = {
134
+ outputs . update ( {
125
135
'cls_outputs' : raw_scores ,
126
136
'box_outputs' : raw_boxes ,
127
- }
137
+ })
128
138
if raw_attributes :
129
139
outputs .update ({'attribute_outputs' : raw_attributes })
130
140
return outputs
@@ -145,12 +155,13 @@ def call(self,
145
155
[tf .shape (images )[0 ], 1 , 1 , 1 ])
146
156
147
157
# Post-processing.
148
- final_results = self .detection_generator (
149
- raw_boxes , raw_scores , anchor_boxes , image_shape , raw_attributes )
150
- outputs = {
158
+ final_results = self .detection_generator (raw_boxes , raw_scores ,
159
+ anchor_boxes , image_shape ,
160
+ raw_attributes )
161
+ outputs .update ({
151
162
'cls_outputs' : raw_scores ,
152
163
'box_outputs' : raw_boxes ,
153
- }
164
+ })
154
165
if self .detection_generator .get_config ()['apply_nms' ]:
155
166
outputs .update ({
156
167
'detection_boxes' : final_results ['detection_boxes' ],
0 commit comments