Skip to content

Commit fb85d02

Browse files
committed
Add backup plan for downloading weights
1 parent 83db397 commit fb85d02

File tree

2 files changed

+1099
-11
lines changed

2 files changed

+1099
-11
lines changed

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import copy
44
import warnings
55
import functools
6+
from torch.utils.model_zoo import load_url
67
from huggingface_hub import hf_hub_download
78
from safetensors.torch import load_file
89

10+
911
from .resnet import resnet_encoders
1012
from .dpn import dpn_encoders
1113
from .vgg import vgg_encoders
@@ -24,6 +26,7 @@
2426
from .timm_universal import TimmUniversalEncoder
2527

2628
from ._preprocessing import preprocess_input
29+
from ._legacy_pretrained_settings import pretrained_settings
2730

2831
__all__ = [
2932
"encoders",
@@ -114,14 +117,28 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
114117
repo_id = settings["repo_id"]
115118
revision = settings["revision"]
116119

117-
# Load config and model
118-
hf_hub_download(repo_id, filename="config.json", revision=revision)
119-
model_path = hf_hub_download(
120-
repo_id, filename="model.safetensors", revision=revision
121-
)
120+
# First, try to load from HF-Hub, but as far as I know not all countries have
121+
# access to the Hub (e.g. China), so we try to load from the original url if
122+
# the first attempt fails.
123+
try:
124+
hf_hub_download(repo_id, filename="config.json", revision=revision)
125+
model_path = hf_hub_download(
126+
repo_id, filename="model.safetensors", revision=revision
127+
)
128+
state_dict = load_file(model_path, device="cpu")
129+
except Exception as e:
130+
if name in pretrained_settings and weights in pretrained_settings[name]:
131+
message = (
132+
f"Error loading {name} `{weights}` weights from Hugging Face Hub, "
133+
"trying loading from original url..."
134+
)
135+
warnings.warn(message, UserWarning)
136+
url = pretrained_settings[name][weights]["url"]
137+
state_dict = load_url(url, map_location="cpu")
138+
else:
139+
raise e
122140

123141
# Load model weights
124-
state_dict = load_file(model_path, device="cpu")
125142
encoder.load_state_dict(state_dict)
126143

127144
encoder.set_in_channels(in_channels, pretrained=weights is not None)
@@ -154,11 +171,20 @@ def get_preprocessing_params(encoder_name, pretrained="imagenet"):
154171
revision = all_settings[pretrained]["revision"]
155172

156173
# Load config and model
157-
config_path = hf_hub_download(
158-
repo_id, filename="config.json", revision=revision
159-
)
160-
with open(config_path, "r") as f:
161-
settings = json.load(f)
174+
try:
175+
config_path = hf_hub_download(
176+
repo_id, filename="config.json", revision=revision
177+
)
178+
with open(config_path, "r") as f:
179+
settings = json.load(f)
180+
except Exception as e:
181+
if (
182+
encoder_name in pretrained_settings
183+
and pretrained in pretrained_settings[encoder_name]
184+
):
185+
settings = pretrained_settings[encoder_name][pretrained]
186+
else:
187+
raise e
162188

163189
formatted_settings = {}
164190
formatted_settings["input_space"] = settings.get("input_space", "RGB")

0 commit comments

Comments
 (0)