Skip to content

Commit e7bc6e0

Browse files
committed
Add tests/test_models & fix type
1 parent d8ea35f commit e7bc6e0

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

segmentation_models_pytorch/encoders/timm_universal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191

9292
# Determine the model's downsampling pattern and set hierarchy flags
9393
encoder_stage = len(tmp_model.feature_info.reduction())
94-
reduction_scales = tmp_model.feature_info.reduction()
94+
reduction_scales = list(tmp_model.feature_info.reduction())
9595

9696
if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]:
9797
# Transformer-style downsampling: scales (4, 8, 16, 32)

tests/test_models.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@ def get_encoders():
1313
]
1414
encoders = smp.encoders.get_encoder_names()
1515
encoders = [e for e in encoders if e not in exclude_encoders]
16-
encoders.append("tu-resnet34") # for timm universal encoder
17-
return encoders
16+
encoders.append("tu-resnet34") # for timm universal traditional-like encoder
17+
encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder
18+
encoders.append("tu-darknet17") # for timm universal vgg-like encoder
19+
encoders.append("mit_b0")
20+
return encoders[-3:]
1821

1922

2023
ENCODERS = get_encoders()
@@ -80,16 +83,12 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
8083
or model_class is smp.MAnet
8184
):
8285
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
83-
if model_class in [smp.UnetPlusPlus, smp.Linknet] and encoder_name.startswith(
84-
"mit_b"
85-
):
86-
return # skip mit_b*
87-
if (
88-
model_class is smp.FPN
89-
and encoder_name.startswith("mit_b")
90-
and encoder_depth != 5
91-
):
92-
return # skip mit_b*
86+
if model_class in [smp.UnetPlusPlus, smp.Linknet]:
87+
if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"):
88+
return # skip transformer-like model*
89+
if model_class is smp.FPN and encoder_depth != 5:
90+
if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"):
91+
return # skip transformer-like model*
9392
model = model_class(
9493
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
9594
)
@@ -180,7 +179,6 @@ def test_dilation(encoder_name):
180179
or encoder_name.startswith("vgg")
181180
or encoder_name.startswith("densenet")
182181
or encoder_name.startswith("timm-res")
183-
or encoder_name.startswith("mit_b")
184182
):
185183
return
186184

0 commit comments

Comments
 (0)