Skip to content

Commit 318eaa4

Browse files
committed
Fix torch_scriptable
1 parent 4f6999a commit 318eaa4

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

segmentation_models_pytorch/decoders/upernet/decoder.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Union, Sequence
1+
from typing import Any, Dict, Union, Sequence, List
22

33
import torch
44
import torch.nn as nn
@@ -57,7 +57,9 @@ def forward(self, feature: torch.Tensor) -> torch.Tensor:
5757
class LayerNorm2d(nn.LayerNorm):
5858
def forward(self, x: torch.Tensor) -> torch.Tensor:
5959
x = x.permute(0, 2, 3, 1) # to channels_last
60-
normed_x = super().forward(x)
60+
normed_x = nn.functional.layer_norm(
61+
x, self.normalized_shape, self.weight, self.bias, self.eps
62+
)
6163
normed_x = normed_x.permute(0, 3, 1, 2) # to channels_first
6264
return normed_x
6365

@@ -158,10 +160,10 @@ def __init__(
158160
use_norm=use_norm,
159161
)
160162

161-
def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor:
163+
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
162164
"""
163165
Args:
164-
features (Sequence[torch.Tensor]):
166+
features (List[torch.Tensor]):
165167
features with: [1, 1/2, 1/4, 1/8, 1/16, ...] spatial resolutions,
166168
where the first feature is the highest resolution and the number
167169
of features is equal to encoder_depth + 1.
@@ -171,24 +173,23 @@ def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor:
171173
features = features[2:]
172174

173175
# normalize feature maps
174-
features = [
175-
self.feature_norms[i](feature) for i, feature in enumerate(features)
176-
]
176+
for i, norm in enumerate(self.feature_norms):
177+
features[i] = norm(features[i])
177178

178179
# pass lowest resolution feature to PSP module
179180
psp_out = self.psp(features[-1])
180181

181182
# skip lowest features for FPN + reverse the order
182183
# [1/4, 1/8, 1/16, 1/32] -> [1/16, 1/8, 1/4]
183-
fpn_encoder_features = features[:-1][::-1]
184+
fpn_lateral_features = features[:-1][::-1]
184185
fpn_features = [psp_out]
185-
for fpn_encoder_feature, block in zip(
186-
fpn_encoder_features, self.fpn_lateral_blocks
187-
):
186+
for i, block in enumerate(self.fpn_lateral_blocks):
188187
# 1. for each encoder (skip) feature we apply 1x1 ConvNormRelu,
189188
# 2. upsample latest fpn feature to it's resolution
190189
# 3. sum them together
191-
fpn_feature = block(fpn_features[-1], fpn_encoder_feature)
190+
lateral_feature = fpn_lateral_features[i]
191+
state_feature = fpn_features[-1]
192+
fpn_feature = block(state_feature, lateral_feature)
192193
fpn_features.append(fpn_feature)
193194

194195
# Apply FPN conv blocks, but skip PSP module

segmentation_models_pytorch/decoders/upernet/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ class UPerNet(SegmentationModel):
6363
6464
"""
6565

66-
_is_torch_scriptable = False
67-
6866
@supports_config_loading
6967
def __init__(
7068
self,

0 commit comments

Comments
 (0)