Skip to content

Commit f29cb95

Browse files
authored
Add preprocessing for timm (#533)
* Add preprocessing for timm
1 parent bc597e9 commit f29cb95

File tree

3 files changed

+34
-13
lines changed

3 files changed

+34
-13
lines changed

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import timm
12
import functools
23
import torch.utils.model_zoo as model_zoo
34

@@ -91,16 +92,24 @@ def get_encoder_names():
9192

9293

9394
def get_preprocessing_params(encoder_name, pretrained="imagenet"):
94-
settings = encoders[encoder_name]["pretrained_settings"]
9595

96-
if pretrained not in settings.keys():
97-
raise ValueError("Available pretrained options {}".format(settings.keys()))
96+
if encoder_name.startswith("tu-"):
97+
encoder_name = encoder_name[3:]
98+
if encoder_name not in timm.models.registry._model_has_pretrained:
99+
raise ValueError(f"{encoder_name} does not have pretrained weights and preprocessing parameters")
100+
settings = timm.models.registry._model_default_cfgs[encoder_name]
101+
else:
102+
all_settings = encoders[encoder_name]["pretrained_settings"]
103+
if pretrained not in all_settings.keys():
104+
raise ValueError("Available pretrained options {}".format(all_settings.keys()))
105+
settings = all_settings[pretrained]
98106

99107
formatted_settings = {}
100-
formatted_settings["input_space"] = settings[pretrained].get("input_space")
101-
formatted_settings["input_range"] = settings[pretrained].get("input_range")
102-
formatted_settings["mean"] = settings[pretrained].get("mean")
103-
formatted_settings["std"] = settings[pretrained].get("std")
108+
formatted_settings["input_space"] = settings.get("input_space", "RGB")
109+
formatted_settings["input_range"] = list(settings.get("input_range", [0, 1]))
110+
formatted_settings["mean"] = list(settings.get("mean"))
111+
formatted_settings["std"] = list(settings.get("std"))
112+
104113
return formatted_settings
105114

106115

tests/test_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import os
21
import sys
32
import mock
43
import pytest
54
import torch
65

76
# mock detection module
87
sys.modules["torchvision._C"] = mock.Mock()
9-
import segmentation_models_pytorch as smp
8+
import segmentation_models_pytorch as smp # noqa
109

1110

1211
def get_encoders():

tests/test_preprocessing.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
import os
21
import sys
32
import mock
4-
import pytest
53
import numpy as np
64

75
# mock detection module
86
sys.modules["torchvision._C"] = mock.Mock()
9-
10-
import segmentation_models_pytorch as smp
7+
import segmentation_models_pytorch as smp # noqa
118

129

1310
def _test_preprocessing(inp, out, **params):
@@ -41,3 +38,19 @@ def test_input_space():
4138
inp = np.stack([np.ones((32, 32)), np.zeros((32, 32))], axis=-1)
4239
out = np.stack([np.zeros((32, 32)), np.ones((32, 32))], axis=-1)
4340
_test_preprocessing(inp, out, input_space="BGR")
41+
42+
43+
def test_preprocessing_params():
44+
# check default encoder params
45+
params = smp.encoders.get_preprocessing_params("resnet18")
46+
assert params["mean"] == [0.485, 0.456, 0.406]
47+
assert params["std"] == [0.229, 0.224, 0.225]
48+
assert params["input_range"] == [0, 1]
49+
assert params["input_space"] == "RGB"
50+
51+
# check timm params
52+
params = smp.encoders.get_preprocessing_params("tu-resnet18")
53+
assert params["mean"] == [0.485, 0.456, 0.406]
54+
assert params["std"] == [0.229, 0.224, 0.225]
55+
assert params["input_range"] == [0, 1]
56+
assert params["input_space"] == "RGB"

0 commit comments

Comments
 (0)