Skip to content

Commit 4392454

Browse files
committed
Update mix_transformer.py
1 parent b796a19 commit 4392454

File tree

1 file changed

+69
-45
lines changed

1 file changed

+69
-45
lines changed

segmentation_models_pytorch/encoders/mix_transformer.py

Lines changed: 69 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# ---------------------------------------------------------------
22
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
33
#
4-
# This work is licensed under the NVIDIA Source Code License
4+
# Licensed under the NVIDIA Source Code License. For full license
5+
# terms, please refer to the LICENSE file provided with this code
6+
# or visit NVIDIA's official repository at
7+
# https://github.com/NVlabs/SegFormer/tree/master.
8+
#
9+
# This code has been modified.
510
# ---------------------------------------------------------------
611
import math
712
import torch
@@ -11,6 +16,21 @@
1116
from timm.layers import DropPath, to_2tuple, trunc_normal_
1217

1318

19+
class LayerNorm(nn.LayerNorm):
20+
def forward(self, x):
21+
if x.ndim == 4:
22+
B, C, H, W = x.shape
23+
x = x.view(B, C, -1).transpose(1, 2)
24+
x = nn.functional.layer_norm(
25+
x, self.normalized_shape, self.weight, self.bias, self.eps
26+
)
27+
x = x.transpose(1, 2).view(B, -1, H, W).contiguous()
28+
else:
29+
x = nn.functional.layer_norm(
30+
x, self.normalized_shape, self.weight, self.bias, self.eps
31+
)
32+
return x
33+
1434
class Mlp(nn.Module):
1535
def __init__(
1636
self,
@@ -36,9 +56,6 @@ def _init_weights(self, m):
3656
trunc_normal_(m.weight, std=0.02)
3757
if isinstance(m, nn.Linear) and m.bias is not None:
3858
nn.init.constant_(m.bias, 0)
39-
elif isinstance(m, nn.LayerNorm):
40-
nn.init.constant_(m.bias, 0)
41-
nn.init.constant_(m.weight, 1.0)
4259
elif isinstance(m, nn.Conv2d):
4360
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
4461
fan_out //= m.groups
@@ -86,7 +103,7 @@ def __init__(
86103
self.sr_ratio = sr_ratio
87104
if sr_ratio > 1:
88105
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
89-
self.norm = nn.LayerNorm(dim)
106+
self.norm = LayerNorm(dim)
90107

91108
self.apply(self._init_weights)
92109

@@ -95,7 +112,7 @@ def _init_weights(self, m):
95112
trunc_normal_(m.weight, std=0.02)
96113
if isinstance(m, nn.Linear) and m.bias is not None:
97114
nn.init.constant_(m.bias, 0)
98-
elif isinstance(m, nn.LayerNorm):
115+
elif isinstance(m, LayerNorm):
99116
nn.init.constant_(m.bias, 0)
100117
nn.init.constant_(m.weight, 1.0)
101118
elif isinstance(m, nn.Conv2d):
@@ -153,7 +170,7 @@ def __init__(
153170
attn_drop=0.0,
154171
drop_path=0.0,
155172
act_layer=nn.GELU,
156-
norm_layer=nn.LayerNorm,
173+
norm_layer=LayerNorm,
157174
sr_ratio=1,
158175
):
159176
super().__init__()
@@ -185,7 +202,7 @@ def _init_weights(self, m):
185202
trunc_normal_(m.weight, std=0.02)
186203
if isinstance(m, nn.Linear) and m.bias is not None:
187204
nn.init.constant_(m.bias, 0)
188-
elif isinstance(m, nn.LayerNorm):
205+
elif isinstance(m, LayerNorm):
189206
nn.init.constant_(m.bias, 0)
190207
nn.init.constant_(m.weight, 1.0)
191208
elif isinstance(m, nn.Conv2d):
@@ -195,10 +212,12 @@ def _init_weights(self, m):
195212
if m.bias is not None:
196213
m.bias.data.zero_()
197214

198-
def forward(self, x, H, W):
215+
def forward(self, x: torch.Tensor):
216+
B, _, H, W = x.shape
217+
x = x.flatten(2).transpose(1, 2)
199218
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
200219
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
201-
220+
x = x.transpose(1, 2).view(B, -1, H, W)
202221
return x
203222

204223

@@ -221,7 +240,7 @@ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=7
221240
stride=stride,
222241
padding=(patch_size[0] // 2, patch_size[1] // 2),
223242
)
224-
self.norm = nn.LayerNorm(embed_dim)
243+
self.norm = LayerNorm(embed_dim)
225244

226245
self.apply(self._init_weights)
227246

@@ -230,7 +249,7 @@ def _init_weights(self, m):
230249
trunc_normal_(m.weight, std=0.02)
231250
if isinstance(m, nn.Linear) and m.bias is not None:
232251
nn.init.constant_(m.bias, 0)
233-
elif isinstance(m, nn.LayerNorm):
252+
elif isinstance(m, LayerNorm):
234253
nn.init.constant_(m.bias, 0)
235254
nn.init.constant_(m.weight, 1.0)
236255
elif isinstance(m, nn.Conv2d):
@@ -242,11 +261,8 @@ def _init_weights(self, m):
242261

243262
def forward(self, x):
244263
x = self.proj(x)
245-
_, _, H, W = x.shape
246-
x = x.flatten(2).transpose(1, 2)
247264
x = self.norm(x)
248-
249-
return x, H, W
265+
return x
250266

251267

252268
class MixVisionTransformer(nn.Module):
@@ -307,8 +323,8 @@ def __init__(
307323
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
308324
] # stochastic depth decay rule
309325
cur = 0
310-
self.block1 = nn.ModuleList(
311-
[
326+
self.block1 = nn.Sequential(
327+
*[
312328
Block(
313329
dim=embed_dims[0],
314330
num_heads=num_heads[0],
@@ -327,8 +343,8 @@ def __init__(
327343
self.norm1 = norm_layer(embed_dims[0])
328344

329345
cur += depths[0]
330-
self.block2 = nn.ModuleList(
331-
[
346+
self.block2 = nn.Sequential(
347+
*[
332348
Block(
333349
dim=embed_dims[1],
334350
num_heads=num_heads[1],
@@ -347,8 +363,8 @@ def __init__(
347363
self.norm2 = norm_layer(embed_dims[1])
348364

349365
cur += depths[1]
350-
self.block3 = nn.ModuleList(
351-
[
366+
self.block3 = nn.Sequential(
367+
*[
352368
Block(
353369
dim=embed_dims[2],
354370
num_heads=num_heads[2],
@@ -367,8 +383,8 @@ def __init__(
367383
self.norm3 = norm_layer(embed_dims[2])
368384

369385
cur += depths[2]
370-
self.block4 = nn.ModuleList(
371-
[
386+
self.block4 = nn.Sequential(
387+
*[
372388
Block(
373389
dim=embed_dims[3],
374390
num_heads=num_heads[3],
@@ -396,7 +412,7 @@ def _init_weights(self, m):
396412
trunc_normal_(m.weight, std=0.02)
397413
if isinstance(m, nn.Linear) and m.bias is not None:
398414
nn.init.constant_(m.bias, 0)
399-
elif isinstance(m, nn.LayerNorm):
415+
elif isinstance(m, LayerNorm):
400416
nn.init.constant_(m.bias, 0)
401417
nn.init.constant_(m.weight, 1.0)
402418
elif isinstance(m, nn.Conv2d):
@@ -454,35 +470,27 @@ def forward_features(self, x):
454470
outs = []
455471

456472
# stage 1
457-
x, H, W = self.patch_embed1(x)
458-
for i, blk in enumerate(self.block1):
459-
x = blk(x, H, W)
473+
x = self.patch_embed1(x)
474+
x = self.block1(x)
460475
x = self.norm1(x)
461-
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
462476
outs.append(x)
463477

464478
# stage 2
465-
x, H, W = self.patch_embed2(x)
466-
for i, blk in enumerate(self.block2):
467-
x = blk(x, H, W)
479+
x = self.patch_embed2(x)
480+
x = self.block2(x)
468481
x = self.norm2(x)
469-
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
470482
outs.append(x)
471483

472484
# stage 3
473-
x, H, W = self.patch_embed3(x)
474-
for i, blk in enumerate(self.block3):
475-
x = blk(x, H, W)
485+
x = self.patch_embed3(x)
486+
x = self.block3(x)
476487
x = self.norm3(x)
477-
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
478488
outs.append(x)
479489

480490
# stage 4
481-
x, H, W = self.patch_embed4(x)
482-
for i, blk in enumerate(self.block4):
483-
x = blk(x, H, W)
491+
x = self.patch_embed4(x)
492+
x = self.block4(x)
484493
x = self.norm4(x)
485-
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
486494
outs.append(x)
487495

488496
return outs
@@ -500,7 +508,7 @@ def __init__(self, dim=768):
500508
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
501509

502510
def forward(self, x, H, W):
503-
B, N, C = x.shape
511+
B, _, C = x.shape
504512
x = x.transpose(1, 2).view(B, C, H, W)
505513
x = self.dwconv(x)
506514
x = x.flatten(2).transpose(1, 2)
@@ -522,8 +530,15 @@ def __init__(self, out_channels, depth=5, **kwargs):
522530
self._depth = depth
523531
self._in_channels = 3
524532

525-
def make_dilated(self, *args, **kwargs):
526-
raise ValueError("MixVisionTransformer encoder does not support dilated mode")
533+
def get_stages(self):
534+
return [
535+
nn.Identity(),
536+
nn.Identity(),
537+
nn.Sequential(self.patch_embed1, self.block1, self.norm1),
538+
nn.Sequential(self.patch_embed2, self.block2, self.norm2),
539+
nn.Sequential(self.patch_embed3, self.block3, self.norm3),
540+
nn.Sequential(self.patch_embed4, self.block4, self.norm4),
541+
]
527542

528543
def set_in_channels(self, in_channels, *args, **kwargs):
529544
if in_channels != 3:
@@ -532,11 +547,20 @@ def set_in_channels(self, in_channels, *args, **kwargs):
532547
)
533548

534549
def forward(self, x):
550+
stages = self.get_stages()
551+
535552
# create dummy output for the first block
536553
B, C, H, W = x.shape
537554
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
538555

539-
return [x, dummy] + self.forward_features(x)[: self._depth - 1]
556+
features = []
557+
for i in range(self._depth + 1):
558+
if i == 1:
559+
features.append(dummy)
560+
else:
561+
x = stages[i](x)
562+
features.append(x)
563+
return features
540564

541565
def load_state_dict(self, state_dict):
542566
state_dict.pop("head.weight", None)

0 commit comments

Comments
 (0)