@@ -82,15 +82,12 @@ def forward(
82
82
) -> torch .Tensor :
83
83
# 1. Apply block to encoder feature
84
84
lateral_feature = self .conv_norm_relu (lateral_feature )
85
- # print("lateral_feature", lateral_feature.shape, lateral_feature.min(), lateral_feature.max(), lateral_feature.mean(), lateral_feature.std())
86
85
# 2. Upsample encoder feature to the "state" feature resolution
87
86
_ , _ , height , width = lateral_feature .shape
88
- # print("current feature", state_feature.shape, state_feature.min(), state_feature.max(), state_feature.mean(), state_feature.std())
89
87
state_feature = F .interpolate (
90
88
state_feature , size = (height , width ), mode = "bilinear" , align_corners = False
91
89
)
92
90
# 3. Sum state and encoder features
93
- # print("Fusion::", state_feature.shape, state_feature.max(), lateral_feature.max())
94
91
fused_feature = state_feature + lateral_feature
95
92
return fused_feature
96
93
@@ -178,26 +175,8 @@ def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor:
178
175
self .feature_norms [i ](feature ) for i , feature in enumerate (features )
179
176
]
180
177
181
- for i , feature in enumerate (features ):
182
- print (
183
- f"Encoder feature { i } " ,
184
- feature .shape ,
185
- feature .min (),
186
- feature .max (),
187
- feature .mean (),
188
- feature .std (),
189
- )
190
-
191
178
# pass lowest resolution feature to PSP module
192
179
psp_out = self .psp (features [- 1 ])
193
- print (
194
- "psp_out" ,
195
- psp_out .shape ,
196
- psp_out .min (),
197
- psp_out .max (),
198
- psp_out .mean (),
199
- psp_out .std (),
200
- )
201
180
202
181
# skip lowest features for FPN + reverse the order
203
182
# [1/4, 1/8, 1/16, 1/32] -> [1/16, 1/8, 1/4]
@@ -212,30 +191,10 @@ def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor:
212
191
fpn_feature = block (fpn_features [- 1 ], fpn_encoder_feature )
213
192
fpn_features .append (fpn_feature )
214
193
215
- for i , fpn_feature in enumerate (fpn_features [::- 1 ]):
216
- print (
217
- f"fpn_feature (before conv) { i } " ,
218
- fpn_feature .shape ,
219
- fpn_feature .min (),
220
- fpn_feature .max (),
221
- fpn_feature .mean (),
222
- fpn_feature .std (),
223
- )
224
-
225
194
# Apply FPN conv blocks, but skip PSP module
226
195
for i , conv_block in enumerate (self .fpn_conv_blocks , start = 1 ):
227
196
fpn_features [i ] = conv_block (fpn_features [i ])
228
197
229
- for i , fpn_feature in enumerate (fpn_features [::- 1 ]):
230
- print (
231
- f"fpn_feature (after conv) { i } " ,
232
- fpn_feature .shape ,
233
- fpn_feature .min (),
234
- fpn_feature .max (),
235
- fpn_feature .mean (),
236
- fpn_feature .std (),
237
- )
238
-
239
198
# Resize all FPN features to 1/4 of the original resolution.
240
199
resized_fpn_features = []
241
200
target_size = fpn_features [- 1 ].shape [2 :] # 1/4 of the original resolution
@@ -247,21 +206,5 @@ def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor:
247
206
248
207
# reverse and concatenate
249
208
stacked_features = torch .cat (resized_fpn_features [::- 1 ], dim = 1 )
250
- print (
251
- "stacked_features" ,
252
- stacked_features .shape ,
253
- stacked_features .min (),
254
- stacked_features .max (),
255
- stacked_features .mean (),
256
- stacked_features .std (),
257
- )
258
209
output = self .fusion_block (stacked_features )
259
- print (
260
- "fusion_block" ,
261
- output .shape ,
262
- output .min (),
263
- output .max (),
264
- output .mean (),
265
- output .std (),
266
- )
267
210
return output
0 commit comments