Skip to content

Commit c0399c5

Browse files
committed
Update LayerNorm
1 parent 3b51edd commit c0399c5

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

segmentation_models_pytorch/encoders/mix_transformer.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,11 @@
1919
class LayerNorm(nn.LayerNorm):
2020
def forward(self, x):
2121
if x.ndim == 4:
22-
B, C, H, W = x.shape
23-
x = x.view(B, C, -1).transpose(1, 2)
24-
x = nn.functional.layer_norm(
25-
x, self.normalized_shape, self.weight, self.bias, self.eps
26-
)
27-
x = x.transpose(1, 2).view(B, -1, H, W).contiguous()
22+
x = x.permute(0, 2, 3, 1)
23+
x = super().forward(x)
24+
x = x.permute(0, 3, 1, 2)
2825
else:
29-
x = nn.functional.layer_norm(
30-
x, self.normalized_shape, self.weight, self.bias, self.eps
31-
)
26+
x = super().forward(x)
3227
return x
3328

3429

@@ -472,25 +467,25 @@ def forward_features(self, x):
472467
# stage 1
473468
x = self.patch_embed1(x)
474469
x = self.block1(x)
475-
x = self.norm1(x)
470+
x = self.norm1(x).contiguous()
476471
outs.append(x)
477472

478473
# stage 2
479474
x = self.patch_embed2(x)
480475
x = self.block2(x)
481-
x = self.norm2(x)
476+
x = self.norm2(x).contiguous()
482477
outs.append(x)
483478

484479
# stage 3
485480
x = self.patch_embed3(x)
486481
x = self.block3(x)
487-
x = self.norm3(x)
482+
x = self.norm3(x).contiguous()
488483
outs.append(x)
489484

490485
# stage 4
491486
x = self.patch_embed4(x)
492487
x = self.block4(x)
493-
x = self.norm4(x)
488+
x = self.norm4(x).contiguous()
494489
outs.append(x)
495490

496491
return outs
@@ -552,7 +547,7 @@ def forward(self, x):
552547
if i == 1:
553548
features.append(dummy)
554549
else:
555-
x = stages[i](x)
550+
x = stages[i](x).contiguous()
556551
features.append(x)
557552
return features
558553

0 commit comments

Comments
 (0)