1
- from typing import Any , Dict , Union , Sequence
1
+ from typing import Any , Dict , Union , Sequence , List
2
2
3
3
import torch
4
4
import torch .nn as nn
@@ -57,7 +57,9 @@ def forward(self, feature: torch.Tensor) -> torch.Tensor:
57
57
class LayerNorm2d (nn .LayerNorm ):
58
58
def forward (self , x : torch .Tensor ) -> torch .Tensor :
59
59
x = x .permute (0 , 2 , 3 , 1 ) # to channels_last
60
- normed_x = super ().forward (x )
60
+ normed_x = nn .functional .layer_norm (
61
+ x , self .normalized_shape , self .weight , self .bias , self .eps
62
+ )
61
63
normed_x = normed_x .permute (0 , 3 , 1 , 2 ) # to channels_first
62
64
return normed_x
63
65
@@ -158,10 +160,10 @@ def __init__(
158
160
use_norm = use_norm ,
159
161
)
160
162
161
- def forward (self , features : Sequence [torch .Tensor ]) -> torch .Tensor :
163
+ def forward (self , features : List [torch .Tensor ]) -> torch .Tensor :
162
164
"""
163
165
Args:
164
- features (Sequence [torch.Tensor]):
166
+ features (List [torch.Tensor]):
165
167
features with: [1, 1/2, 1/4, 1/8, 1/16, ...] spatial resolutions,
166
168
where the first feature is the highest resolution and the number
167
169
of features is equal to encoder_depth + 1.
@@ -171,24 +173,23 @@ def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor:
171
173
features = features [2 :]
172
174
173
175
# normalize feature maps
174
- features = [
175
- self .feature_norms [i ](feature ) for i , feature in enumerate (features )
176
- ]
176
+ for i , norm in enumerate (self .feature_norms ):
177
+ features [i ] = norm (features [i ])
177
178
178
179
# pass lowest resolution feature to PSP module
179
180
psp_out = self .psp (features [- 1 ])
180
181
181
182
# skip lowest features for FPN + reverse the order
182
183
# [1/4, 1/8, 1/16, 1/32] -> [1/16, 1/8, 1/4]
183
- fpn_encoder_features = features [:- 1 ][::- 1 ]
184
+ fpn_lateral_features = features [:- 1 ][::- 1 ]
184
185
fpn_features = [psp_out ]
185
- for fpn_encoder_feature , block in zip (
186
- fpn_encoder_features , self .fpn_lateral_blocks
187
- ):
186
+ for i , block in enumerate (self .fpn_lateral_blocks ):
188
187
# 1. for each encoder (skip) feature we apply 1x1 ConvNormRelu,
189
188
# 2. upsample latest fpn feature to it's resolution
190
189
# 3. sum them together
191
- fpn_feature = block (fpn_features [- 1 ], fpn_encoder_feature )
190
+ lateral_feature = fpn_lateral_features [i ]
191
+ state_feature = fpn_features [- 1 ]
192
+ fpn_feature = block (state_feature , lateral_feature )
192
193
fpn_features .append (fpn_feature )
193
194
194
195
# Apply FPN conv blocks, but skip PSP module
0 commit comments