Skip to content

Commit f70502e

Browse files
authored
Add preprocessing_params (#45)
1 parent 8b365d0 commit f70502e

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,20 @@ def get_encoder_names():
3535
return list(encoders.keys())
3636

3737

38-
def get_preprocessing_fn(encoder_name, pretrained='imagenet'):
38+
def get_preprocessing_params(encoder_name, pretrained='imagenet'):
3939
settings = encoders[encoder_name]['pretrained_settings']
4040

4141
if pretrained not in settings.keys():
4242
raise ValueError('Avaliable pretrained options {}'.format(settings.keys()))
43-
44-
input_space = settings[pretrained].get('input_space')
45-
input_range = settings[pretrained].get('input_range')
46-
mean = settings[pretrained].get('mean')
47-
std = settings[pretrained].get('std')
4843

49-
return functools.partial(preprocess_input, mean=mean, std=std, input_space=input_space, input_range=input_range)
44+
formatted_settings = {}
45+
formatted_settings['input_space'] = settings[pretrained].get('input_space')
46+
formatted_settings['input_range'] = settings[pretrained].get('input_range')
47+
formatted_settings['mean'] = settings[pretrained].get('mean')
48+
formatted_settings['std'] = settings[pretrained].get('std')
49+
return formatted_settings
50+
51+
52+
def get_preprocessing_fn(encoder_name, pretrained='imagenet'):
53+
params = get_preprocessing_params(encoder_name, pretrained=pretrained)
54+
return functools.partial(preprocess_input, **params)

0 commit comments

Comments
 (0)