Skip to content

Commit 7c947f8

Browse files
committed
Move model archs
1 parent e97ce92 commit 7c947f8

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

segmentation_models_pytorch/__init__.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,21 @@
3030
"ignore", message=r'"is" with \'str\' literal.*', category=SyntaxWarning
3131
) # for python >= 3.12
3232

33+
_MODEL_ARCHITECTURES = [
34+
Unet,
35+
UnetPlusPlus,
36+
MAnet,
37+
Linknet,
38+
FPN,
39+
PSPNet,
40+
DeepLabV3,
41+
DeepLabV3Plus,
42+
PAN,
43+
UPerNet,
44+
Segformer,
45+
]
46+
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}
47+
3348

3449
def create_model(
3550
arch: str,
@@ -43,26 +58,12 @@ def create_model(
4358
parameters, without using its class
4459
"""
4560

46-
archs = [
47-
Unet,
48-
UnetPlusPlus,
49-
MAnet,
50-
Linknet,
51-
FPN,
52-
PSPNet,
53-
DeepLabV3,
54-
DeepLabV3Plus,
55-
PAN,
56-
UPerNet,
57-
Segformer,
58-
]
59-
archs_dict = {a.__name__.lower(): a for a in archs}
6061
try:
61-
model_class = archs_dict[arch.lower()]
62+
model_class = MODEL_ARCHITECTURES_MAPPING[arch.lower()]
6263
except KeyError:
6364
raise KeyError(
6465
"Wrong architecture type `{}`. Available options are: {}".format(
65-
arch, list(archs_dict.keys())
66+
arch, list(MODEL_ARCHITECTURES_MAPPING.keys())
6667
)
6768
)
6869
return model_class(

0 commit comments

Comments
 (0)