Skip to content

Commit a039526

Browse files
committed
Update LayerNorm
1 parent 2989d41 commit a039526

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

segmentation_models_pytorch/encoders/mix_transformer.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,10 @@ def forward(self, x):
2121
if x.ndim == 4:
2222
B, C, H, W = x.shape
2323
x = x.view(B, C, -1).transpose(1, 2)
24-
x = nn.functional.layer_norm(
25-
x, self.normalized_shape, self.weight, self.bias, self.eps
26-
)
27-
x = x.transpose(1, 2).view(B, -1, H, W).contiguous()
24+
x = super().forward(x)
25+
x = x.transpose(1, 2).view(B, C, H, W)
2826
else:
29-
x = nn.functional.layer_norm(
30-
x, self.normalized_shape, self.weight, self.bias, self.eps
31-
)
27+
x = super().forward(x)
3228
return x
3329

3430

@@ -472,25 +468,25 @@ def forward_features(self, x):
472468
# stage 1
473469
x = self.patch_embed1(x)
474470
x = self.block1(x)
475-
x = self.norm1(x)
471+
x = self.norm1(x).contiguous()
476472
outs.append(x)
477473

478474
# stage 2
479475
x = self.patch_embed2(x)
480476
x = self.block2(x)
481-
x = self.norm2(x)
477+
x = self.norm2(x).contiguous()
482478
outs.append(x)
483479

484480
# stage 3
485481
x = self.patch_embed3(x)
486482
x = self.block3(x)
487-
x = self.norm3(x)
483+
x = self.norm3(x).contiguous()
488484
outs.append(x)
489485

490486
# stage 4
491487
x = self.patch_embed4(x)
492488
x = self.block4(x)
493-
x = self.norm4(x)
489+
x = self.norm4(x).contiguous()
494490
outs.append(x)
495491

496492
return outs
@@ -552,7 +548,7 @@ def forward(self, x):
552548
if i == 1:
553549
features.append(dummy)
554550
else:
555-
x = stages[i](x)
551+
x = stages[i](x).contiguous()
556552
features.append(x)
557553
return features
558554

0 commit comments

Comments
 (0)