Skip to content

Commit 83db397

Browse files
committed
Move everything to HF hub
1 parent 456871a commit 83db397

File tree

15 files changed

+583
-622
lines changed

15 files changed

+583
-622
lines changed

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import json
12
import timm
23
import copy
34
import warnings
45
import functools
5-
import torch.utils.model_zoo as model_zoo
6+
from huggingface_hub import hf_hub_download
7+
from safetensors.torch import load_file
68

79
from .resnet import resnet_encoders
810
from .dpn import dpn_encoders
@@ -101,15 +103,26 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
101103
encoder = EncoderClass(**params)
102104

103105
if weights is not None:
104-
try:
105-
settings = encoders[name]["pretrained_settings"][weights]
106-
except KeyError:
106+
if weights not in encoders[name]["pretrained_settings"]:
107+
available_weights = list(encoders[name]["pretrained_settings"].keys())
107108
raise KeyError(
108-
"Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
109-
weights, name, list(encoders[name]["pretrained_settings"].keys())
110-
)
109+
f"Wrong pretrained weights `{weights}` for encoder `{name}`. "
110+
f"Available options are: {available_weights}"
111111
)
112-
encoder.load_state_dict(model_zoo.load_url(settings["url"]))
112+
113+
settings = encoders[name]["pretrained_settings"][weights]
114+
repo_id = settings["repo_id"]
115+
revision = settings["revision"]
116+
117+
# Load config and model
118+
hf_hub_download(repo_id, filename="config.json", revision=revision)
119+
model_path = hf_hub_download(
120+
repo_id, filename="model.safetensors", revision=revision
121+
)
122+
123+
# Load model weights
124+
state_dict = load_file(model_path, device="cpu")
125+
encoder.load_state_dict(state_dict)
113126

114127
encoder.set_in_channels(in_channels, pretrained=weights is not None)
115128
if output_stride != 32:
@@ -136,7 +149,16 @@ def get_preprocessing_params(encoder_name, pretrained="imagenet"):
136149
raise ValueError(
137150
"Available pretrained options {}".format(all_settings.keys())
138151
)
139-
settings = all_settings[pretrained]
152+
153+
repo_id = all_settings[pretrained]["repo_id"]
154+
revision = all_settings[pretrained]["revision"]
155+
156+
# Load config and model
157+
config_path = hf_hub_download(
158+
repo_id, filename="config.json", revision=revision
159+
)
160+
with open(config_path, "r") as f:
161+
settings = json.load(f)
140162

141163
formatted_settings = {}
142164
formatted_settings["input_space"] = settings.get("input_space", "RGB")

segmentation_models_pytorch/encoders/densenet.py

Lines changed: 24 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -110,92 +110,65 @@ def load_state_dict(self, state_dict):
110110
super().load_state_dict(state_dict)
111111

112112

113-
pretrained_settings = {
114-
"densenet121": {
115-
"imagenet": {
116-
"url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet121-fbdb23505.pth",
117-
"input_space": "RGB",
118-
"input_size": [3, 224, 224],
119-
"input_range": [0, 1],
120-
"mean": [0.485, 0.456, 0.406],
121-
"std": [0.229, 0.224, 0.225],
122-
"num_classes": 1000,
123-
}
124-
},
125-
"densenet169": {
126-
"imagenet": {
127-
"url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet169-f470b90a4.pth",
128-
"input_space": "RGB",
129-
"input_size": [3, 224, 224],
130-
"input_range": [0, 1],
131-
"mean": [0.485, 0.456, 0.406],
132-
"std": [0.229, 0.224, 0.225],
133-
"num_classes": 1000,
134-
}
135-
},
136-
"densenet201": {
137-
"imagenet": {
138-
"url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet201-5750cbb1e.pth",
139-
"input_space": "RGB",
140-
"input_size": [3, 224, 224],
141-
"input_range": [0, 1],
142-
"mean": [0.485, 0.456, 0.406],
143-
"std": [0.229, 0.224, 0.225],
144-
"num_classes": 1000,
145-
}
146-
},
147-
"densenet161": {
148-
"imagenet": {
149-
"url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet161-347e6b360.pth",
150-
"input_space": "RGB",
151-
"input_size": [3, 224, 224],
152-
"input_range": [0, 1],
153-
"mean": [0.485, 0.456, 0.406],
154-
"std": [0.229, 0.224, 0.225],
155-
"num_classes": 1000,
156-
}
157-
},
158-
}
159-
160113
densenet_encoders = {
161114
"densenet121": {
162115
"encoder": DenseNetEncoder,
163-
"pretrained_settings": pretrained_settings["densenet121"],
164116
"params": {
165117
"out_channels": [3, 64, 256, 512, 1024, 1024],
166118
"num_init_features": 64,
167119
"growth_rate": 32,
168120
"block_config": (6, 12, 24, 16),
169121
},
122+
"pretrained_settings": {
123+
"imagenet": {
124+
"repo_id": "smp-hub/densenet121-imagenet",
125+
"revision": "main",
126+
}
127+
},
170128
},
171129
"densenet169": {
172130
"encoder": DenseNetEncoder,
173-
"pretrained_settings": pretrained_settings["densenet169"],
174131
"params": {
175132
"out_channels": [3, 64, 256, 512, 1280, 1664],
176133
"num_init_features": 64,
177134
"growth_rate": 32,
178135
"block_config": (6, 12, 32, 32),
179136
},
137+
"pretrained_settings": {
138+
"imagenet": {
139+
"repo_id": "smp-hub/densenet169-imagenet",
140+
"revision": "main",
141+
}
142+
},
180143
},
181144
"densenet201": {
182145
"encoder": DenseNetEncoder,
183-
"pretrained_settings": pretrained_settings["densenet201"],
184146
"params": {
185147
"out_channels": [3, 64, 256, 512, 1792, 1920],
186148
"num_init_features": 64,
187149
"growth_rate": 32,
188150
"block_config": (6, 12, 48, 32),
189151
},
152+
"pretrained_settings": {
153+
"imagenet": {
154+
"repo_id": "smp-hub/densenet201-imagenet",
155+
"revision": "main",
156+
}
157+
},
190158
},
191159
"densenet161": {
192160
"encoder": DenseNetEncoder,
193-
"pretrained_settings": pretrained_settings["densenet161"],
194161
"params": {
195162
"out_channels": [3, 96, 384, 768, 2112, 2208],
196163
"num_init_features": 96,
197164
"growth_rate": 48,
198165
"block_config": (6, 12, 36, 24),
199166
},
167+
"pretrained_settings": {
168+
"imagenet": {
169+
"repo_id": "smp-hub/densenet161-imagenet",
170+
"revision": "main",
171+
}
172+
},
200173
},
201174
}

segmentation_models_pytorch/encoders/dpn.py

Lines changed: 36 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -101,79 +101,15 @@ def load_state_dict(self, state_dict, **kwargs):
101101
super().load_state_dict(state_dict, **kwargs)
102102

103103

104-
pretrained_settings = {
105-
"dpn68": {
106-
"imagenet": {
107-
"url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn68-4af7d88d2.pth",
108-
"input_space": "RGB",
109-
"input_size": [3, 224, 224],
110-
"input_range": [0, 1],
111-
"mean": [124 / 255, 117 / 255, 104 / 255],
112-
"std": [1 / (0.0167 * 255)] * 3,
113-
"num_classes": 1000,
114-
}
115-
},
116-
"dpn68b": {
117-
"imagenet+5k": {
118-
"url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-363ab9c19.pth",
119-
"input_space": "RGB",
120-
"input_size": [3, 224, 224],
121-
"input_range": [0, 1],
122-
"mean": [124 / 255, 117 / 255, 104 / 255],
123-
"std": [1 / (0.0167 * 255)] * 3,
124-
"num_classes": 1000,
125-
}
126-
},
127-
"dpn92": {
128-
"imagenet+5k": {
129-
"url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-fda993c95.pth",
130-
"input_space": "RGB",
131-
"input_size": [3, 224, 224],
132-
"input_range": [0, 1],
133-
"mean": [124 / 255, 117 / 255, 104 / 255],
134-
"std": [1 / (0.0167 * 255)] * 3,
135-
"num_classes": 1000,
136-
}
137-
},
138-
"dpn98": {
139-
"imagenet": {
140-
"url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn98-722954780.pth",
141-
"input_space": "RGB",
142-
"input_size": [3, 224, 224],
143-
"input_range": [0, 1],
144-
"mean": [124 / 255, 117 / 255, 104 / 255],
145-
"std": [1 / (0.0167 * 255)] * 3,
146-
"num_classes": 1000,
147-
}
148-
},
149-
"dpn131": {
150-
"imagenet": {
151-
"url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn131-7af84be88.pth",
152-
"input_space": "RGB",
153-
"input_size": [3, 224, 224],
154-
"input_range": [0, 1],
155-
"mean": [124 / 255, 117 / 255, 104 / 255],
156-
"std": [1 / (0.0167 * 255)] * 3,
157-
"num_classes": 1000,
158-
}
159-
},
160-
"dpn107": {
161-
"imagenet+5k": {
162-
"url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-b7f9f4cc9.pth",
163-
"input_space": "RGB",
164-
"input_size": [3, 224, 224],
165-
"input_range": [0, 1],
166-
"mean": [124 / 255, 117 / 255, 104 / 255],
167-
"std": [1 / (0.0167 * 255)] * 3,
168-
"num_classes": 1000,
169-
}
170-
},
171-
}
172-
173104
dpn_encoders = {
174105
"dpn68": {
175106
"encoder": DPNEncoder,
176-
"pretrained_settings": pretrained_settings["dpn68"],
107+
"pretrained_settings": {
108+
"imagenet": {
109+
"repo_id": "smp-hub/dpn68-imagenet",
110+
"revision": "main",
111+
}
112+
},
177113
"params": {
178114
"stage_idxs": [4, 8, 20, 24],
179115
"out_channels": [3, 10, 144, 320, 704, 832],
@@ -189,7 +125,12 @@ def load_state_dict(self, state_dict, **kwargs):
189125
},
190126
"dpn68b": {
191127
"encoder": DPNEncoder,
192-
"pretrained_settings": pretrained_settings["dpn68b"],
128+
"pretrained_settings": {
129+
"imagenet+5k": {
130+
"repo_id": "smp-hub/dpn68b-imagenet-5k",
131+
"revision": "main",
132+
}
133+
},
193134
"params": {
194135
"stage_idxs": [4, 8, 20, 24],
195136
"out_channels": [3, 10, 144, 320, 704, 832],
@@ -206,7 +147,12 @@ def load_state_dict(self, state_dict, **kwargs):
206147
},
207148
"dpn92": {
208149
"encoder": DPNEncoder,
209-
"pretrained_settings": pretrained_settings["dpn92"],
150+
"pretrained_settings": {
151+
"imagenet+5k": {
152+
"repo_id": "smp-hub/dpn92-imagenet-5k",
153+
"revision": "main",
154+
}
155+
},
210156
"params": {
211157
"stage_idxs": [4, 8, 28, 32],
212158
"out_channels": [3, 64, 336, 704, 1552, 2688],
@@ -221,7 +167,12 @@ def load_state_dict(self, state_dict, **kwargs):
221167
},
222168
"dpn98": {
223169
"encoder": DPNEncoder,
224-
"pretrained_settings": pretrained_settings["dpn98"],
170+
"pretrained_settings": {
171+
"imagenet": {
172+
"repo_id": "smp-hub/dpn98-imagenet",
173+
"revision": "main",
174+
}
175+
},
225176
"params": {
226177
"stage_idxs": [4, 10, 30, 34],
227178
"out_channels": [3, 96, 336, 768, 1728, 2688],
@@ -236,7 +187,12 @@ def load_state_dict(self, state_dict, **kwargs):
236187
},
237188
"dpn107": {
238189
"encoder": DPNEncoder,
239-
"pretrained_settings": pretrained_settings["dpn107"],
190+
"pretrained_settings": {
191+
"imagenet+5k": {
192+
"repo_id": "smp-hub/dpn107-imagenet-5k",
193+
"revision": "main",
194+
}
195+
},
240196
"params": {
241197
"stage_idxs": [5, 13, 33, 37],
242198
"out_channels": [3, 128, 376, 1152, 2432, 2688],
@@ -251,7 +207,12 @@ def load_state_dict(self, state_dict, **kwargs):
251207
},
252208
"dpn131": {
253209
"encoder": DPNEncoder,
254-
"pretrained_settings": pretrained_settings["dpn131"],
210+
"pretrained_settings": {
211+
"imagenet": {
212+
"repo_id": "smp-hub/dpn131-imagenet",
213+
"revision": "main",
214+
}
215+
},
255216
"params": {
256217
"stage_idxs": [5, 13, 41, 45],
257218
"out_channels": [3, 128, 352, 832, 1984, 2688],

0 commit comments

Comments
 (0)