Skip to content

Commit f5cf4b8

Browse files
khornlundqubvel
authored andcommitted
Add inceptionv4 backbone (#88)
1 parent b83e000 commit f5cf4b8

File tree

4 files changed

+64
-10
lines changed

4 files changed

+64
-10
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ The main features of this library are:
77

88
- High level API (just two lines to create neural network)
99
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
10-
- 30 available encoders for each architecture
10+
- 31 available encoders for each architecture
1111
- All encoders have pre-trained weights for faster and better convergence
1212

1313
### Table of content
1414
1. [Quick start](#start)
1515
2. [Examples](#examples)
16-
3. [Models](#models)
16+
3. [Models](#models)
1717
1. [Architectures](#architectires)
1818
2. [Encoders](#encoders)
1919
3. [Pretrained weights](#weights)
@@ -57,7 +57,7 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
5757
- [Linknet](https://arxiv.org/abs/1707.03718)
5858
- [FPN](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)
5959
- [PSPNet](https://arxiv.org/abs/1612.01105)
60-
60+
6161
#### Encoders <a name="encoders"></a>
6262

6363
| Type | Encoder names |
@@ -82,10 +82,10 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
8282
| [instagram](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) | resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
8383

8484
### Models API <a name="api"></a>
85-
- `model.encoder` - pretrained backbone to extract features of different spatial resolution
86-
- `model.decoder` - segmentation head, depends on models architecture (`Unet`/`Linknet`/`PSPNet`/`FPN`)
85+
- `model.encoder` - pretrained backbone to extract features of different spatial resolution
86+
- `model.decoder` - segmentation head, depends on models architecture (`Unet`/`Linknet`/`PSPNet`/`FPN`)
8787
- `model.activation` - output activation function, one of `sigmoid`, `softmax`
88-
- `model.forward(x)` - sequentially pass `x` through model\`s encoder and decoder (return logits!)
88+
- `model.forward(x)` - sequentially pass `x` through model\`s encoder and decoder (return logits!)
8989
- `model.predict(x)` - inference method, switch model to `.eval()` mode, call `.forward(x)` and apply activation function with `torch.no_grad()`
9090

9191
### Installation <a name="installation"></a>

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from .vgg import vgg_encoders
77
from .senet import senet_encoders
88
from .densenet import densenet_encoders
9-
from .inceptionresnetv2 import inception_encoders
9+
from .inceptionresnetv2 import inceptionresnetv2_encoders
10+
from .inceptionv4 import inceptionv4_encoders
1011
from .efficientnet import efficient_net_encoders
1112

1213

@@ -18,7 +19,8 @@
1819
encoders.update(vgg_encoders)
1920
encoders.update(senet_encoders)
2021
encoders.update(densenet_encoders)
21-
encoders.update(inception_encoders)
22+
encoders.update(inceptionresnetv2_encoders)
23+
encoders.update(inceptionv4_encoders)
2224
encoders.update(efficient_net_encoders)
2325

2426

@@ -43,7 +45,7 @@ def get_preprocessing_params(encoder_name, pretrained='imagenet'):
4345

4446
if pretrained not in settings.keys():
4547
raise ValueError('Avaliable pretrained options {}'.format(settings.keys()))
46-
48+
4749
formatted_settings = {}
4850
formatted_settings['input_space'] = settings[pretrained].get('input_space')
4951
formatted_settings['input_range'] = settings[pretrained].get('input_range')

segmentation_models_pytorch/encoders/inceptionresnetv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def load_state_dict(self, state_dict, **kwargs):
5757
super().load_state_dict(state_dict, **kwargs)
5858

5959

60-
inception_encoders = {
60+
inceptionresnetv2_encoders = {
6161
'inceptionresnetv2': {
6262
'encoder': InceptionResNetV2Encoder,
6363
'pretrained_settings': pretrained_settings['inceptionresnetv2'],
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch.nn as nn
2+
from pretrainedmodels.models.inceptionv4 import InceptionV4, BasicConv2d
3+
from pretrainedmodels.models.inceptionv4 import pretrained_settings
4+
5+
6+
class InceptionV4Encoder(InceptionV4):
7+
8+
def __init__(self, *args, **kwargs):
9+
super().__init__(*args, **kwargs)
10+
self.in_channels = 3
11+
self.features[0] = BasicConv2d(self.in_channels, 32, kernel_size=3, stride=2, padding=1)
12+
self.features[1] = BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
13+
14+
self.chunks = [3, 5, 9, 15]
15+
16+
# correct paddings
17+
for m in self.modules():
18+
if isinstance(m, nn.Conv2d):
19+
if m.kernel_size == (3, 3):
20+
m.padding = (1, 1)
21+
if isinstance(m, nn.MaxPool2d):
22+
m.padding = (1, 1)
23+
24+
# remove linear layers
25+
del self.last_linear
26+
27+
def forward(self, x):
28+
x0 = self.features[:self.chunks[0]](x)
29+
x1 = self.features[self.chunks[0]:self.chunks[1]](x0)
30+
x2 = self.features[self.chunks[1]:self.chunks[2]](x1)
31+
x3 = self.features[self.chunks[2]:self.chunks[3]](x2)
32+
x4 = self.features[self.chunks[3]:](x3)
33+
34+
features = [x4, x3, x2, x1, x0]
35+
return features
36+
37+
def load_state_dict(self, state_dict, **kwargs):
38+
state_dict.pop('last_linear.bias')
39+
state_dict.pop('last_linear.weight')
40+
super().load_state_dict(state_dict, **kwargs)
41+
42+
43+
inceptionv4_encoders = {
44+
'inceptionv4': {
45+
'encoder': InceptionV4Encoder,
46+
'pretrained_settings': pretrained_settings['inceptionv4'],
47+
'out_shapes': (1536, 1024, 384, 192, 64),
48+
'params': {
49+
'num_classes': 1001,
50+
}
51+
}
52+
}

0 commit comments

Comments
 (0)