|
19 | 19 | class LayerNorm(nn.LayerNorm): |
20 | 20 | def forward(self, x): |
21 | 21 | if x.ndim == 4: |
22 | | - B, C, H, W = x.shape |
23 | | - x = x.view(B, C, -1).transpose(1, 2) |
24 | | - x = nn.functional.layer_norm( |
25 | | - x, self.normalized_shape, self.weight, self.bias, self.eps |
26 | | - ) |
27 | | - x = x.transpose(1, 2).view(B, -1, H, W).contiguous() |
| 22 | + x = x.permute(0, 2, 3, 1) |
| 23 | + x = super().forward(x) |
| 24 | + x = x.permute(0, 3, 1, 2) |
28 | 25 | else: |
29 | | - x = nn.functional.layer_norm( |
30 | | - x, self.normalized_shape, self.weight, self.bias, self.eps |
31 | | - ) |
| 26 | + x = super().forward(x) |
32 | 27 | return x |
33 | 28 |
|
34 | 29 |
|
@@ -472,25 +467,25 @@ def forward_features(self, x): |
472 | 467 | # stage 1 |
473 | 468 | x = self.patch_embed1(x) |
474 | 469 | x = self.block1(x) |
475 | | - x = self.norm1(x) |
| 470 | + x = self.norm1(x).contiguous() |
476 | 471 | outs.append(x) |
477 | 472 |
|
478 | 473 | # stage 2 |
479 | 474 | x = self.patch_embed2(x) |
480 | 475 | x = self.block2(x) |
481 | | - x = self.norm2(x) |
| 476 | + x = self.norm2(x).contiguous() |
482 | 477 | outs.append(x) |
483 | 478 |
|
484 | 479 | # stage 3 |
485 | 480 | x = self.patch_embed3(x) |
486 | 481 | x = self.block3(x) |
487 | | - x = self.norm3(x) |
| 482 | + x = self.norm3(x).contiguous() |
488 | 483 | outs.append(x) |
489 | 484 |
|
490 | 485 | # stage 4 |
491 | 486 | x = self.patch_embed4(x) |
492 | 487 | x = self.block4(x) |
493 | | - x = self.norm4(x) |
| 488 | + x = self.norm4(x).contiguous() |
494 | 489 | outs.append(x) |
495 | 490 |
|
496 | 491 | return outs |
@@ -552,7 +547,7 @@ def forward(self, x): |
552 | 547 | if i == 1: |
553 | 548 | features.append(dummy) |
554 | 549 | else: |
555 | | - x = stages[i](x) |
| 550 | + x = stages[i](x).contiguous() |
556 | 551 | features.append(x) |
557 | 552 | return features |
558 | 553 |
|
|
0 commit comments