Skip to content

Commit f70d861

Browse files
committed
Fix scripting for encoders
1 parent 556b3aa commit f70d861

16 files changed

+156
-110
lines changed

segmentation_models_pytorch/encoders/_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ class EncoderMixin:
1010
- patching first convolution for arbitrary input channels
1111
"""
1212

13+
_is_torch_scriptable = True
14+
_is_torch_exportable = True
15+
_is_torch_compilable = True
16+
1317
def __init__(self):
1418
self._depth = 5
1519
self._in_channels = 3

segmentation_models_pytorch/encoders/densenet.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
"""
2525

2626
import re
27-
import torch
28-
import torch.nn as nn
2927

3028
from torchvision.models.densenet import DenseNet
3129

@@ -47,15 +45,6 @@ def make_dilated(self, *args, **kwargs):
4745
"due to pooling operation for downsampling!"
4846
)
4947

50-
def apply_transition(
51-
self, transition: torch.nn.Sequential, x: torch.Tensor
52-
) -> tuple[torch.Tensor, torch.Tensor]:
53-
for module in transition:
54-
x = module(x)
55-
if isinstance(module, nn.ReLU):
56-
intermediate = x
57-
return x, intermediate
58-
5948
def forward(self, x):
6049
features = [x]
6150

@@ -68,20 +57,29 @@ def forward(self, x):
6857
if self._depth >= 2:
6958
x = self.features.pool0(x)
7059
x = self.features.denseblock1(x)
71-
x, intermediate = self.apply_transition(self.features.transition1, x)
72-
features.append(intermediate)
60+
x = self.features.transition1.norm(x)
61+
x = self.features.transition1.relu(x)
62+
features.append(x)
7363

7464
if self._depth >= 3:
65+
x = self.features.transition1.conv(x)
66+
x = self.features.transition1.pool(x)
7567
x = self.features.denseblock2(x)
76-
x, intermediate = self.apply_transition(self.features.transition2, x)
77-
features.append(intermediate)
68+
x = self.features.transition2.norm(x)
69+
x = self.features.transition2.relu(x)
70+
features.append(x)
7871

7972
if self._depth >= 4:
73+
x = self.features.transition2.conv(x)
74+
x = self.features.transition2.pool(x)
8075
x = self.features.denseblock3(x)
81-
x, intermediate = self.apply_transition(self.features.transition3, x)
82-
features.append(intermediate)
76+
x = self.features.transition3.norm(x)
77+
x = self.features.transition3.relu(x)
78+
features.append(x)
8379

8480
if self._depth >= 5:
81+
x = self.features.transition3.conv(x)
82+
x = self.features.transition3.pool(x)
8583
x = self.features.denseblock4(x)
8684
x = self.features.norm5(x)
8785
features.append(x)

segmentation_models_pytorch/encoders/dpn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434

3535

3636
class DPNEncoder(DPN, EncoderMixin):
37+
_is_torch_scriptable = False
38+
3739
def __init__(
3840
self,
3941
stage_idxs: List[int],

segmentation_models_pytorch/encoders/efficientnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434

3535
class EfficientNetEncoder(EfficientNet, EncoderMixin):
36+
_is_torch_scriptable = False
37+
3638
def __init__(
3739
self,
3840
stage_idxs: List[int],

segmentation_models_pytorch/encoders/inceptionresnetv2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def __init__(
5454
if isinstance(m, nn.MaxPool2d):
5555
m.padding = (1, 1)
5656

57+
# for torchscript, block8 does not have relu defined
58+
self.block8.relu = nn.Identity()
59+
5760
# remove linear layers
5861
del self.avgpool_1a
5962
del self.last_linear

segmentation_models_pytorch/encoders/inceptionv4.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,19 @@
3535
class InceptionV4Encoder(InceptionV4, EncoderMixin):
3636
def __init__(
3737
self,
38-
stage_idxs: List[int],
38+
out_indexes: List[int],
3939
out_channels: List[int],
4040
depth: int = 5,
4141
output_stride: int = 32,
4242
**kwargs,
4343
):
4444
super().__init__(**kwargs)
4545

46-
self._stage_idxs = stage_idxs
4746
self._depth = depth
4847
self._in_channels = 3
4948
self._out_channels = out_channels
5049
self._output_stride = output_stride
50+
self._out_indexes = out_indexes
5151

5252
# correct paddings
5353
for m in self.modules():
@@ -67,28 +67,22 @@ def make_dilated(self, *args, **kwargs):
6767
)
6868

6969
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
70+
depth = 0
7071
features = [x]
7172

72-
if self._depth >= 1:
73-
x = self.features[: self._stage_idxs[0]](x)
74-
features.append(x)
73+
for i, module in enumerate(self.features):
74+
x = module(x)
7575

76-
if self._depth >= 2:
77-
x = self.features[self._stage_idxs[0] : self._stage_idxs[1]](x)
78-
features.append(x)
76+
if i in self._out_indexes:
77+
features.append(x)
78+
depth += 1
7979

80-
if self._depth >= 3:
81-
x = self.features[self._stage_idxs[1] : self._stage_idxs[2]](x)
82-
features.append(x)
83-
84-
if self._depth >= 4:
85-
x = self.features[self._stage_idxs[2] : self._stage_idxs[3]](x)
86-
features.append(x)
87-
88-
if self._depth >= 5:
89-
x = self.features[self._stage_idxs[3] :](x)
90-
features.append(x)
80+
# torchscript does not support break in cycle, so we just
81+
# go over all modules and then slice number of features
82+
if not torch.jit.is_scripting() and depth > self._depth:
83+
break
9184

85+
features = features[: self._depth + 1]
9286
return features
9387

9488
def load_state_dict(self, state_dict, **kwargs):
@@ -121,7 +115,7 @@ def load_state_dict(self, state_dict, **kwargs):
121115
},
122116
},
123117
"params": {
124-
"stage_idxs": [3, 5, 9, 15],
118+
"out_indexes": [2, 4, 8, 14],
125119
"out_channels": [3, 64, 192, 384, 1024, 1536],
126120
"num_classes": 1001,
127121
},

segmentation_models_pytorch/encoders/mix_transformer.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,22 @@
1111
import math
1212
import torch
1313
import torch.nn as nn
14+
import torch.nn.functional as F
1415
from functools import partial
16+
from typing import Dict, Sequence, List
1517

1618
from timm.layers import DropPath, to_2tuple, trunc_normal_
1719

1820

1921
class LayerNorm(nn.LayerNorm):
20-
def forward(self, x):
22+
def forward(self, x: torch.Tensor) -> torch.Tensor:
2123
if x.ndim == 4:
22-
B, C, H, W = x.shape
23-
x = x.view(B, C, -1).transpose(1, 2)
24-
x = super().forward(x)
25-
x = x.transpose(1, 2).view(B, C, H, W)
24+
batch_size, channels, height, width = x.shape
25+
x = x.view(batch_size, channels, -1).transpose(1, 2)
26+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
27+
x = x.transpose(1, 2).view(batch_size, channels, height, width)
2628
else:
27-
x = super().forward(x)
29+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
2830
return x
2931

3032

@@ -60,9 +62,9 @@ def _init_weights(self, m):
6062
if m.bias is not None:
6163
m.bias.data.zero_()
6264

63-
def forward(self, x, H, W):
65+
def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
6466
x = self.fc1(x)
65-
x = self.dwconv(x, H, W)
67+
x = self.dwconv(x, height, width)
6668
x = self.act(x)
6769
x = self.drop(x)
6870
x = self.fc2(x)
@@ -101,6 +103,10 @@ def __init__(
101103
if sr_ratio > 1:
102104
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
103105
self.norm = LayerNorm(dim)
106+
else:
107+
# for torchscript compatibility
108+
self.sr = nn.Identity()
109+
self.norm = nn.Identity()
104110

105111
self.apply(self._init_weights)
106112

@@ -119,27 +125,27 @@ def _init_weights(self, m):
119125
if m.bias is not None:
120126
m.bias.data.zero_()
121127

122-
def forward(self, x, H, W):
123-
B, N, C = x.shape
128+
def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
129+
batch_size, N, C = x.shape
124130
q = (
125131
self.q(x)
126-
.reshape(B, N, self.num_heads, C // self.num_heads)
132+
.reshape(batch_size, N, self.num_heads, C // self.num_heads)
127133
.permute(0, 2, 1, 3)
128134
)
129135

130136
if self.sr_ratio > 1:
131-
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
132-
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
137+
x_ = x.permute(0, 2, 1).reshape(batch_size, C, height, width)
138+
x_ = self.sr(x_).reshape(batch_size, C, -1).permute(0, 2, 1)
133139
x_ = self.norm(x_)
134140
kv = (
135141
self.kv(x_)
136-
.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
142+
.reshape(batch_size, -1, 2, self.num_heads, C // self.num_heads)
137143
.permute(2, 0, 3, 1, 4)
138144
)
139145
else:
140146
kv = (
141147
self.kv(x)
142-
.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
148+
.reshape(batch_size, -1, 2, self.num_heads, C // self.num_heads)
143149
.permute(2, 0, 3, 1, 4)
144150
)
145151
k, v = kv[0], kv[1]
@@ -148,7 +154,7 @@ def forward(self, x, H, W):
148154
attn = attn.softmax(dim=-1)
149155
attn = self.attn_drop(attn)
150156

151-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
157+
x = (attn @ v).transpose(1, 2).reshape(batch_size, N, C)
152158
x = self.proj(x)
153159
x = self.proj_drop(x)
154160

@@ -209,12 +215,12 @@ def _init_weights(self, m):
209215
if m.bias is not None:
210216
m.bias.data.zero_()
211217

212-
def forward(self, x):
213-
B, _, H, W = x.shape
218+
def forward(self, x: torch.Tensor) -> torch.Tensor:
219+
batch_size, _, height, width = x.shape
214220
x = x.flatten(2).transpose(1, 2)
215-
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
216-
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
217-
x = x.transpose(1, 2).view(B, -1, H, W)
221+
x = x + self.drop_path(self.attn(self.norm1(x), height, width))
222+
x = x + self.drop_path(self.mlp(self.norm2(x), height, width))
223+
x = x.transpose(1, 2).view(batch_size, -1, height, width)
218224
return x
219225

220226

@@ -256,7 +262,7 @@ def _init_weights(self, m):
256262
if m.bias is not None:
257263
m.bias.data.zero_()
258264

259-
def forward(self, x):
265+
def forward(self, x: torch.Tensor) -> torch.Tensor:
260266
x = self.proj(x)
261267
x = self.norm(x)
262268
return x
@@ -462,7 +468,7 @@ def reset_classifier(self, num_classes, global_pool=""):
462468
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
463469
)
464470

465-
def forward_features(self, x):
471+
def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]:
466472
outs = []
467473

468474
# stage 1
@@ -491,21 +497,21 @@ def forward_features(self, x):
491497

492498
return outs
493499

494-
def forward(self, x):
495-
x = self.forward_features(x)
500+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
501+
features = self.forward_features(x)
496502
# x = self.head(x)
497503

498-
return x
504+
return features
499505

500506

501507
class DWConv(nn.Module):
502508
def __init__(self, dim=768):
503509
super(DWConv, self).__init__()
504510
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
505511

506-
def forward(self, x, H, W):
507-
B, _, C = x.shape
508-
x = x.transpose(1, 2).view(B, C, H, W)
512+
def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
513+
batch_size, _, channels = x.shape
514+
x = x.transpose(1, 2).view(batch_size, channels, height, width)
509515
x = self.dwconv(x)
510516
x = x.flatten(2).transpose(1, 2)
511517

@@ -516,7 +522,6 @@ def forward(self, x, H, W):
516522
# End of NVIDIA code
517523
# ---------------------------------------------------------------
518524

519-
from typing import Dict, Sequence, List # noqa E402
520525
from ._base import EncoderMixin # noqa E402
521526

522527

segmentation_models_pytorch/encoders/mobilenet.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
self._in_channels = 3
4141
self._out_channels = out_channels
4242
self._output_stride = output_stride
43+
self._out_indexes = [2, 4, 7, 14]
4344

4445
del self.classifier
4546

@@ -52,25 +53,20 @@ def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]:
5253
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
5354
features = [x]
5455

55-
if self._depth >= 1:
56-
x = self.features[:2](x)
57-
features.append(x)
56+
depth = 0
57+
for i, module in enumerate(self.features):
58+
x = module(x)
5859

59-
if self._depth >= 2:
60-
x = self.features[2:4](x)
61-
features.append(x)
60+
if i in self._out_indexes:
61+
features.append(x)
62+
depth += 1
6263

63-
if self._depth >= 3:
64-
x = self.features[4:7](x)
65-
features.append(x)
64+
# torchscript does not support break in cycle, so we just
65+
# go over all modules and then slice number of features
66+
if not torch.jit.is_scripting() and depth > self._depth:
67+
break
6668

67-
if self._depth >= 4:
68-
x = self.features[7:14](x)
69-
features.append(x)
70-
71-
if self._depth >= 5:
72-
x = self.features[14:](x)
73-
features.append(x)
69+
features = features[: self._depth + 1]
7470

7571
return features
7672

0 commit comments

Comments
 (0)