@@ -21,14 +21,10 @@ def forward(self, x):
2121 if x .ndim == 4 :
2222 B , C , H , W = x .shape
2323 x = x .view (B , C , - 1 ).transpose (1 , 2 )
24- x = nn .functional .layer_norm (
25- x , self .normalized_shape , self .weight , self .bias , self .eps
26- )
27- x = x .transpose (1 , 2 ).view (B , - 1 , H , W ).contiguous ()
24+ x = super ().forward (x )
25+ x = x .transpose (1 , 2 ).view (B , C , H , W )
2826 else :
29- x = nn .functional .layer_norm (
30- x , self .normalized_shape , self .weight , self .bias , self .eps
31- )
27+ x = super ().forward (x )
3228 return x
3329
3430
@@ -472,25 +468,25 @@ def forward_features(self, x):
472468 # stage 1
473469 x = self .patch_embed1 (x )
474470 x = self .block1 (x )
475- x = self .norm1 (x )
471+ x = self .norm1 (x ). contiguous ()
476472 outs .append (x )
477473
478474 # stage 2
479475 x = self .patch_embed2 (x )
480476 x = self .block2 (x )
481- x = self .norm2 (x )
477+ x = self .norm2 (x ). contiguous ()
482478 outs .append (x )
483479
484480 # stage 3
485481 x = self .patch_embed3 (x )
486482 x = self .block3 (x )
487- x = self .norm3 (x )
483+ x = self .norm3 (x ). contiguous ()
488484 outs .append (x )
489485
490486 # stage 4
491487 x = self .patch_embed4 (x )
492488 x = self .block4 (x )
493- x = self .norm4 (x )
489+ x = self .norm4 (x ). contiguous ()
494490 outs .append (x )
495491
496492 return outs
@@ -552,7 +548,7 @@ def forward(self, x):
552548 if i == 1 :
553549 features .append (dummy )
554550 else :
555- x = stages [i ](x )
551+ x = stages [i ](x ). contiguous ()
556552 features .append (x )
557553 return features
558554
0 commit comments