@@ -13,8 +13,11 @@ def get_encoders():
13
13
]
14
14
encoders = smp .encoders .get_encoder_names ()
15
15
encoders = [e for e in encoders if e not in exclude_encoders ]
16
- encoders .append ("tu-resnet34" ) # for timm universal encoder
17
- return encoders
16
+ encoders .append ("tu-resnet34" ) # for timm universal traditional-like encoder
17
+ encoders .append ("tu-convnext_atto" ) # for timm universal transformer-like encoder
18
+ encoders .append ("tu-darknet17" ) # for timm universal vgg-like encoder
19
+ encoders .append ("mit_b0" )
20
+ return encoders [- 3 :]
18
21
19
22
20
23
ENCODERS = get_encoders ()
@@ -80,16 +83,12 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
80
83
or model_class is smp .MAnet
81
84
):
82
85
kwargs ["decoder_channels" ] = (16 , 16 , 16 , 16 , 16 )[- encoder_depth :]
83
- if model_class in [smp .UnetPlusPlus , smp .Linknet ] and encoder_name .startswith (
84
- "mit_b"
85
- ):
86
- return # skip mit_b*
87
- if (
88
- model_class is smp .FPN
89
- and encoder_name .startswith ("mit_b" )
90
- and encoder_depth != 5
91
- ):
92
- return # skip mit_b*
86
+ if model_class in [smp .UnetPlusPlus , smp .Linknet ]:
87
+ if encoder_name .startswith ("mit_b" ) or encoder_name .startswith ("tu-convnext" ):
88
+ return # skip transformer-like model*
89
+ if model_class is smp .FPN and encoder_depth != 5 :
90
+ if encoder_name .startswith ("mit_b" ) or encoder_name .startswith ("tu-convnext" ):
91
+ return # skip transformer-like model*
93
92
model = model_class (
94
93
encoder_name , encoder_depth = encoder_depth , encoder_weights = None , ** kwargs
95
94
)
@@ -180,7 +179,6 @@ def test_dilation(encoder_name):
180
179
or encoder_name .startswith ("vgg" )
181
180
or encoder_name .startswith ("densenet" )
182
181
or encoder_name .startswith ("timm-res" )
183
- or encoder_name .startswith ("mit_b" )
184
182
):
185
183
return
186
184
0 commit comments