@@ -13,8 +13,11 @@ def get_encoders():
1313 ]
1414 encoders = smp .encoders .get_encoder_names ()
1515 encoders = [e for e in encoders if e not in exclude_encoders ]
16- encoders .append ("tu-resnet34" ) # for timm universal encoder
17- return encoders
16+ encoders .append ("tu-resnet34" ) # for timm universal traditional-like encoder
17+ encoders .append ("tu-convnext_atto" ) # for timm universal transformer-like encoder
18+ encoders .append ("tu-darknet17" ) # for timm universal vgg-like encoder
19+ encoders .append ("mit_b0" )
20+ return encoders [- 3 :]
1821
1922
2023ENCODERS = get_encoders ()
@@ -80,16 +83,12 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
8083 or model_class is smp .MAnet
8184 ):
8285 kwargs ["decoder_channels" ] = (16 , 16 , 16 , 16 , 16 )[- encoder_depth :]
83- if model_class in [smp .UnetPlusPlus , smp .Linknet ] and encoder_name .startswith (
84- "mit_b"
85- ):
86- return # skip mit_b*
87- if (
88- model_class is smp .FPN
89- and encoder_name .startswith ("mit_b" )
90- and encoder_depth != 5
91- ):
92- return # skip mit_b*
86+ if model_class in [smp .UnetPlusPlus , smp .Linknet ]:
87+ if encoder_name .startswith ("mit_b" ) or encoder_name .startswith ("tu-convnext" ):
88+ return # skip transformer-like model*
89+ if model_class is smp .FPN and encoder_depth != 5 :
90+ if encoder_name .startswith ("mit_b" ) or encoder_name .startswith ("tu-convnext" ):
91+ return # skip transformer-like model*
9392 model = model_class (
9493 encoder_name , encoder_depth = encoder_depth , encoder_weights = None , ** kwargs
9594 )
@@ -180,7 +179,6 @@ def test_dilation(encoder_name):
180179 or encoder_name .startswith ("vgg" )
181180 or encoder_name .startswith ("densenet" )
182181 or encoder_name .startswith ("timm-res" )
183- or encoder_name .startswith ("mit_b" )
184182 ):
185183 return
186184
0 commit comments