@@ -35,15 +35,20 @@ def get_encoder_names():
35
35
return list (encoders .keys ())
36
36
37
37
38
- def get_preprocessing_fn (encoder_name , pretrained = 'imagenet' ):
38
+ def get_preprocessing_params (encoder_name , pretrained = 'imagenet' ):
39
39
settings = encoders [encoder_name ]['pretrained_settings' ]
40
40
41
41
if pretrained not in settings .keys ():
42
42
raise ValueError ('Avaliable pretrained options {}' .format (settings .keys ()))
43
-
44
- input_space = settings [pretrained ].get ('input_space' )
45
- input_range = settings [pretrained ].get ('input_range' )
46
- mean = settings [pretrained ].get ('mean' )
47
- std = settings [pretrained ].get ('std' )
48
43
49
- return functools .partial (preprocess_input , mean = mean , std = std , input_space = input_space , input_range = input_range )
44
+ formatted_settings = {}
45
+ formatted_settings ['input_space' ] = settings [pretrained ].get ('input_space' )
46
+ formatted_settings ['input_range' ] = settings [pretrained ].get ('input_range' )
47
+ formatted_settings ['mean' ] = settings [pretrained ].get ('mean' )
48
+ formatted_settings ['std' ] = settings [pretrained ].get ('std' )
49
+ return formatted_settings
50
+
51
+
52
+ def get_preprocessing_fn (encoder_name , pretrained = 'imagenet' ):
53
+ params = get_preprocessing_params (encoder_name , pretrained = pretrained )
54
+ return functools .partial (preprocess_input , ** params )
0 commit comments