1- from typing import Any , Dict , Union , Sequence
1+ from typing import Any , Dict , Union , Sequence , List
22
33import torch
44import torch .nn as nn
@@ -57,7 +57,9 @@ def forward(self, feature: torch.Tensor) -> torch.Tensor:
5757class LayerNorm2d (nn .LayerNorm ):
5858 def forward (self , x : torch .Tensor ) -> torch .Tensor :
5959 x = x .permute (0 , 2 , 3 , 1 ) # to channels_last
60- normed_x = super ().forward (x )
60+ normed_x = nn .functional .layer_norm (
61+ x , self .normalized_shape , self .weight , self .bias , self .eps
62+ )
6163 normed_x = normed_x .permute (0 , 3 , 1 , 2 ) # to channels_first
6264 return normed_x
6365
@@ -158,10 +160,10 @@ def __init__(
158160 use_norm = use_norm ,
159161 )
160162
161- def forward (self , features : Sequence [torch .Tensor ]) -> torch .Tensor :
163+ def forward (self , features : List [torch .Tensor ]) -> torch .Tensor :
162164 """
163165 Args:
164- features (Sequence [torch.Tensor]):
166+ features (List [torch.Tensor]):
165167 features with: [1, 1/2, 1/4, 1/8, 1/16, ...] spatial resolutions,
166168 where the first feature is the highest resolution and the number
167169 of features is equal to encoder_depth + 1.
@@ -171,24 +173,23 @@ def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor:
171173 features = features [2 :]
172174
173175 # normalize feature maps
174- features = [
175- self .feature_norms [i ](feature ) for i , feature in enumerate (features )
176- ]
176+ for i , norm in enumerate (self .feature_norms ):
177+ features [i ] = norm (features [i ])
177178
178179 # pass lowest resolution feature to PSP module
179180 psp_out = self .psp (features [- 1 ])
180181
181182 # skip lowest features for FPN + reverse the order
182183 # [1/4, 1/8, 1/16, 1/32] -> [1/16, 1/8, 1/4]
183- fpn_encoder_features = features [:- 1 ][::- 1 ]
184+ fpn_lateral_features = features [:- 1 ][::- 1 ]
184185 fpn_features = [psp_out ]
185- for fpn_encoder_feature , block in zip (
186- fpn_encoder_features , self .fpn_lateral_blocks
187- ):
186+ for i , block in enumerate (self .fpn_lateral_blocks ):
188187 # 1. for each encoder (skip) feature we apply 1x1 ConvNormRelu,
189188 # 2. upsample latest fpn feature to it's resolution
190189 # 3. sum them together
191- fpn_feature = block (fpn_features [- 1 ], fpn_encoder_feature )
190+ lateral_feature = fpn_lateral_features [i ]
191+ state_feature = fpn_features [- 1 ]
192+ fpn_feature = block (state_feature , lateral_feature )
192193 fpn_features .append (fpn_feature )
193194
194195 # Apply FPN conv blocks, but skip PSP module
0 commit comments