11# ---------------------------------------------------------------
22# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
33#
4- # This work is licensed under the NVIDIA Source Code License
4+ # Licensed under the NVIDIA Source Code License. For full license
5+ # terms, please refer to the LICENSE file provided with this code
6+ # or visit NVIDIA's official repository at
7+ # https://github.com/NVlabs/SegFormer/tree/master.
8+ #
9+ # This code has been modified.
510# ---------------------------------------------------------------
611import math
712import torch
1116from timm .layers import DropPath , to_2tuple , trunc_normal_
1217
1318
19+ class LayerNorm (nn .LayerNorm ):
20+ def forward (self , x ):
21+ if x .ndim == 4 :
22+ B , C , H , W = x .shape
23+ x = x .view (B , C , - 1 ).transpose (1 , 2 )
24+ x = nn .functional .layer_norm (
25+ x , self .normalized_shape , self .weight , self .bias , self .eps
26+ )
27+ x = x .transpose (1 , 2 ).view (B , - 1 , H , W ).contiguous ()
28+ else :
29+ x = nn .functional .layer_norm (
30+ x , self .normalized_shape , self .weight , self .bias , self .eps
31+ )
32+ return x
33+
1434class Mlp (nn .Module ):
1535 def __init__ (
1636 self ,
@@ -36,9 +56,6 @@ def _init_weights(self, m):
3656 trunc_normal_ (m .weight , std = 0.02 )
3757 if isinstance (m , nn .Linear ) and m .bias is not None :
3858 nn .init .constant_ (m .bias , 0 )
39- elif isinstance (m , nn .LayerNorm ):
40- nn .init .constant_ (m .bias , 0 )
41- nn .init .constant_ (m .weight , 1.0 )
4259 elif isinstance (m , nn .Conv2d ):
4360 fan_out = m .kernel_size [0 ] * m .kernel_size [1 ] * m .out_channels
4461 fan_out //= m .groups
@@ -86,7 +103,7 @@ def __init__(
86103 self .sr_ratio = sr_ratio
87104 if sr_ratio > 1 :
88105 self .sr = nn .Conv2d (dim , dim , kernel_size = sr_ratio , stride = sr_ratio )
89- self .norm = nn . LayerNorm (dim )
106+ self .norm = LayerNorm (dim )
90107
91108 self .apply (self ._init_weights )
92109
@@ -95,7 +112,7 @@ def _init_weights(self, m):
95112 trunc_normal_ (m .weight , std = 0.02 )
96113 if isinstance (m , nn .Linear ) and m .bias is not None :
97114 nn .init .constant_ (m .bias , 0 )
98- elif isinstance (m , nn . LayerNorm ):
115+ elif isinstance (m , LayerNorm ):
99116 nn .init .constant_ (m .bias , 0 )
100117 nn .init .constant_ (m .weight , 1.0 )
101118 elif isinstance (m , nn .Conv2d ):
@@ -153,7 +170,7 @@ def __init__(
153170 attn_drop = 0.0 ,
154171 drop_path = 0.0 ,
155172 act_layer = nn .GELU ,
156- norm_layer = nn . LayerNorm ,
173+ norm_layer = LayerNorm ,
157174 sr_ratio = 1 ,
158175 ):
159176 super ().__init__ ()
@@ -185,7 +202,7 @@ def _init_weights(self, m):
185202 trunc_normal_ (m .weight , std = 0.02 )
186203 if isinstance (m , nn .Linear ) and m .bias is not None :
187204 nn .init .constant_ (m .bias , 0 )
188- elif isinstance (m , nn . LayerNorm ):
205+ elif isinstance (m , LayerNorm ):
189206 nn .init .constant_ (m .bias , 0 )
190207 nn .init .constant_ (m .weight , 1.0 )
191208 elif isinstance (m , nn .Conv2d ):
@@ -195,10 +212,12 @@ def _init_weights(self, m):
195212 if m .bias is not None :
196213 m .bias .data .zero_ ()
197214
198- def forward (self , x , H , W ):
215+ def forward (self , x : torch .Tensor ):
216+ B , _ , H , W = x .shape
217+ x = x .flatten (2 ).transpose (1 , 2 )
199218 x = x + self .drop_path (self .attn (self .norm1 (x ), H , W ))
200219 x = x + self .drop_path (self .mlp (self .norm2 (x ), H , W ))
201-
220+ x = x . transpose ( 1 , 2 ). view ( B , - 1 , H , W )
202221 return x
203222
204223
@@ -221,7 +240,7 @@ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=7
221240 stride = stride ,
222241 padding = (patch_size [0 ] // 2 , patch_size [1 ] // 2 ),
223242 )
224- self .norm = nn . LayerNorm (embed_dim )
243+ self .norm = LayerNorm (embed_dim )
225244
226245 self .apply (self ._init_weights )
227246
@@ -230,7 +249,7 @@ def _init_weights(self, m):
230249 trunc_normal_ (m .weight , std = 0.02 )
231250 if isinstance (m , nn .Linear ) and m .bias is not None :
232251 nn .init .constant_ (m .bias , 0 )
233- elif isinstance (m , nn . LayerNorm ):
252+ elif isinstance (m , LayerNorm ):
234253 nn .init .constant_ (m .bias , 0 )
235254 nn .init .constant_ (m .weight , 1.0 )
236255 elif isinstance (m , nn .Conv2d ):
@@ -242,11 +261,8 @@ def _init_weights(self, m):
242261
243262 def forward (self , x ):
244263 x = self .proj (x )
245- _ , _ , H , W = x .shape
246- x = x .flatten (2 ).transpose (1 , 2 )
247264 x = self .norm (x )
248-
249- return x , H , W
265+ return x
250266
251267
252268class MixVisionTransformer (nn .Module ):
@@ -307,8 +323,8 @@ def __init__(
307323 x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))
308324 ] # stochastic depth decay rule
309325 cur = 0
310- self .block1 = nn .ModuleList (
311- [
326+ self .block1 = nn .Sequential (
327+ * [
312328 Block (
313329 dim = embed_dims [0 ],
314330 num_heads = num_heads [0 ],
@@ -327,8 +343,8 @@ def __init__(
327343 self .norm1 = norm_layer (embed_dims [0 ])
328344
329345 cur += depths [0 ]
330- self .block2 = nn .ModuleList (
331- [
346+ self .block2 = nn .Sequential (
347+ * [
332348 Block (
333349 dim = embed_dims [1 ],
334350 num_heads = num_heads [1 ],
@@ -347,8 +363,8 @@ def __init__(
347363 self .norm2 = norm_layer (embed_dims [1 ])
348364
349365 cur += depths [1 ]
350- self .block3 = nn .ModuleList (
351- [
366+ self .block3 = nn .Sequential (
367+ * [
352368 Block (
353369 dim = embed_dims [2 ],
354370 num_heads = num_heads [2 ],
@@ -367,8 +383,8 @@ def __init__(
367383 self .norm3 = norm_layer (embed_dims [2 ])
368384
369385 cur += depths [2 ]
370- self .block4 = nn .ModuleList (
371- [
386+ self .block4 = nn .Sequential (
387+ * [
372388 Block (
373389 dim = embed_dims [3 ],
374390 num_heads = num_heads [3 ],
@@ -396,7 +412,7 @@ def _init_weights(self, m):
396412 trunc_normal_ (m .weight , std = 0.02 )
397413 if isinstance (m , nn .Linear ) and m .bias is not None :
398414 nn .init .constant_ (m .bias , 0 )
399- elif isinstance (m , nn . LayerNorm ):
415+ elif isinstance (m , LayerNorm ):
400416 nn .init .constant_ (m .bias , 0 )
401417 nn .init .constant_ (m .weight , 1.0 )
402418 elif isinstance (m , nn .Conv2d ):
@@ -454,35 +470,27 @@ def forward_features(self, x):
454470 outs = []
455471
456472 # stage 1
457- x , H , W = self .patch_embed1 (x )
458- for i , blk in enumerate (self .block1 ):
459- x = blk (x , H , W )
473+ x = self .patch_embed1 (x )
474+ x = self .block1 (x )
460475 x = self .norm1 (x )
461- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
462476 outs .append (x )
463477
464478 # stage 2
465- x , H , W = self .patch_embed2 (x )
466- for i , blk in enumerate (self .block2 ):
467- x = blk (x , H , W )
479+ x = self .patch_embed2 (x )
480+ x = self .block2 (x )
468481 x = self .norm2 (x )
469- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
470482 outs .append (x )
471483
472484 # stage 3
473- x , H , W = self .patch_embed3 (x )
474- for i , blk in enumerate (self .block3 ):
475- x = blk (x , H , W )
485+ x = self .patch_embed3 (x )
486+ x = self .block3 (x )
476487 x = self .norm3 (x )
477- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
478488 outs .append (x )
479489
480490 # stage 4
481- x , H , W = self .patch_embed4 (x )
482- for i , blk in enumerate (self .block4 ):
483- x = blk (x , H , W )
491+ x = self .patch_embed4 (x )
492+ x = self .block4 (x )
484493 x = self .norm4 (x )
485- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
486494 outs .append (x )
487495
488496 return outs
@@ -500,7 +508,7 @@ def __init__(self, dim=768):
500508 self .dwconv = nn .Conv2d (dim , dim , 3 , 1 , 1 , bias = True , groups = dim )
501509
502510 def forward (self , x , H , W ):
503- B , N , C = x .shape
511+ B , _ , C = x .shape
504512 x = x .transpose (1 , 2 ).view (B , C , H , W )
505513 x = self .dwconv (x )
506514 x = x .flatten (2 ).transpose (1 , 2 )
@@ -522,8 +530,15 @@ def __init__(self, out_channels, depth=5, **kwargs):
522530 self ._depth = depth
523531 self ._in_channels = 3
524532
525- def make_dilated (self , * args , ** kwargs ):
526- raise ValueError ("MixVisionTransformer encoder does not support dilated mode" )
533+ def get_stages (self ):
534+ return [
535+ nn .Identity (),
536+ nn .Identity (),
537+ nn .Sequential (self .patch_embed1 , self .block1 , self .norm1 ),
538+ nn .Sequential (self .patch_embed2 , self .block2 , self .norm2 ),
539+ nn .Sequential (self .patch_embed3 , self .block3 , self .norm3 ),
540+ nn .Sequential (self .patch_embed4 , self .block4 , self .norm4 ),
541+ ]
527542
528543 def set_in_channels (self , in_channels , * args , ** kwargs ):
529544 if in_channels != 3 :
@@ -532,11 +547,20 @@ def set_in_channels(self, in_channels, *args, **kwargs):
532547 )
533548
534549 def forward (self , x ):
550+ stages = self .get_stages ()
551+
535552 # create dummy output for the first block
536553 B , C , H , W = x .shape
537554 dummy = torch .empty ([B , 0 , H // 2 , W // 2 ], dtype = x .dtype , device = x .device )
538555
539- return [x , dummy ] + self .forward_features (x )[: self ._depth - 1 ]
556+ features = []
557+ for i in range (self ._depth + 1 ):
558+ if i == 1 :
559+ features .append (dummy )
560+ else :
561+ x = stages [i ](x )
562+ features .append (x )
563+ return features
540564
541565 def load_state_dict (self , state_dict ):
542566 state_dict .pop ("head.weight" , None )
0 commit comments