Skip to content

Commit 6346405

Browse files
committed
Fix minimal
1 parent d4b82b3 commit 6346405

File tree

4 files changed

+22
-6
lines changed

4 files changed

+22
-6
lines changed

segmentation_models_pytorch/base/hub_mixin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,3 @@ def wrapper(self, *args, **kwargs):
147147
return func(self, *args, **kwargs)
148148

149149
return wrapper
150-
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading

tests/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
import os
2+
import timm
3+
from packaging.version import Version
4+
5+
6+
has_timm_test_models = Version(timm.__version__) >= Version("1.0.12")
27

38

49
def get_commit_message():

tests/encoders/test_timm_universal.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import timm
21
from tests.encoders import base
3-
from packaging.version import Version
2+
from tests.config import has_timm_test_models
43

54
# check if timm >= 1.0.12
65
timm_encoders = [
@@ -9,7 +8,7 @@
98
"tu-darknet17", # for timm universal vgg-like encoder
109
]
1110

12-
if Version(timm.__version__) >= Version("1.0.12"):
11+
if has_timm_test_models:
1312
timm_encoders.append("tu-test_resnet.r160_in1k")
1413

1514

tests/models/base.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import inspect
23
import tempfile
34
import unittest
@@ -6,9 +7,13 @@
67
import torch
78
import segmentation_models_pytorch as smp
89

10+
from tests.config import has_timm_test_models
11+
912

1013
class BaseModelTester(unittest.TestCase):
11-
test_encoder_name = "tu-test_resnet.r160_in1k"
14+
test_encoder_name = (
15+
"tu-test_resnet.r160_in1k" if has_timm_test_models else "resnet18"
16+
)
1217

1318
# should be overriden
1419
test_model_type = None
@@ -136,8 +141,12 @@ def test_save_load_with_hub_mixin(self):
136141

137142
# save model
138143
with tempfile.TemporaryDirectory() as tmpdir:
139-
model.save_pretrained(tmpdir)
144+
model.save_pretrained(
145+
tmpdir, dataset="test_dataset", metrics={"my_awesome_metric": 0.99}
146+
)
140147
restored_model = smp.from_pretrained(tmpdir)
148+
with open(os.path.join(tmpdir, "README.md"), "r") as f:
149+
readme = f.read()
141150

142151
# check inference is correct
143152
sample = self._get_sample(
@@ -153,3 +162,7 @@ def test_save_load_with_hub_mixin(self):
153162

154163
self.assertEqual(output.shape, restored_output.shape)
155164
self.assertEqual(output.shape[1], 1)
165+
166+
# check dataset and metrics are saved in readme
167+
self.assertIn("test_dataset", readme)
168+
self.assertIn("my_awesome_metric", readme)

0 commit comments

Comments
 (0)