File tree Expand file tree Collapse file tree 1 file changed +17
-16
lines changed
segmentation_models_pytorch Expand file tree Collapse file tree 1 file changed +17
-16
lines changed Original file line number Diff line number Diff line change 3030 "ignore" , message = r'"is" with \'str\' literal.*' , category = SyntaxWarning
3131) # for python >= 3.12
3232
33+ _MODEL_ARCHITECTURES = [
34+ Unet ,
35+ UnetPlusPlus ,
36+ MAnet ,
37+ Linknet ,
38+ FPN ,
39+ PSPNet ,
40+ DeepLabV3 ,
41+ DeepLabV3Plus ,
42+ PAN ,
43+ UPerNet ,
44+ Segformer ,
45+ ]
46+ MODEL_ARCHITECTURES_MAPPING = {a .__name__ .lower (): a for a in _MODEL_ARCHITECTURES }
47+
3348
3449def create_model (
3550 arch : str ,
@@ -43,26 +58,12 @@ def create_model(
4358 parameters, without using its class
4459 """
4560
46- archs = [
47- Unet ,
48- UnetPlusPlus ,
49- MAnet ,
50- Linknet ,
51- FPN ,
52- PSPNet ,
53- DeepLabV3 ,
54- DeepLabV3Plus ,
55- PAN ,
56- UPerNet ,
57- Segformer ,
58- ]
59- archs_dict = {a .__name__ .lower (): a for a in archs }
6061 try :
61- model_class = archs_dict [arch .lower ()]
62+ model_class = MODEL_ARCHITECTURES_MAPPING [arch .lower ()]
6263 except KeyError :
6364 raise KeyError (
6465 "Wrong architecture type `{}`. Available options are: {}" .format (
65- arch , list (archs_dict .keys ())
66+ arch , list (MODEL_ARCHITECTURES_MAPPING .keys ())
6667 )
6768 )
6869 return model_class (
You can’t perform that action at this time.
0 commit comments