|
3 | 3 | import copy
|
4 | 4 | import warnings
|
5 | 5 | import functools
|
| 6 | +from torch.utils.model_zoo import load_url |
6 | 7 | from huggingface_hub import hf_hub_download
|
7 | 8 | from safetensors.torch import load_file
|
8 | 9 |
|
| 10 | + |
9 | 11 | from .resnet import resnet_encoders
|
10 | 12 | from .dpn import dpn_encoders
|
11 | 13 | from .vgg import vgg_encoders
|
|
24 | 26 | from .timm_universal import TimmUniversalEncoder
|
25 | 27 |
|
26 | 28 | from ._preprocessing import preprocess_input
|
| 29 | +from ._legacy_pretrained_settings import pretrained_settings |
27 | 30 |
|
28 | 31 | __all__ = [
|
29 | 32 | "encoders",
|
@@ -114,14 +117,28 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
|
114 | 117 | repo_id = settings["repo_id"]
|
115 | 118 | revision = settings["revision"]
|
116 | 119 |
|
117 |
| - # Load config and model |
118 |
| - hf_hub_download(repo_id, filename="config.json", revision=revision) |
119 |
| - model_path = hf_hub_download( |
120 |
| - repo_id, filename="model.safetensors", revision=revision |
121 |
| - ) |
| 120 | + # First, try to load from HF-Hub, but as far as I know not all countries have |
| 121 | + # access to the Hub (e.g. China), so we try to load from the original url if |
| 122 | + # the first attempt fails. |
| 123 | + try: |
| 124 | + hf_hub_download(repo_id, filename="config.json", revision=revision) |
| 125 | + model_path = hf_hub_download( |
| 126 | + repo_id, filename="model.safetensors", revision=revision |
| 127 | + ) |
| 128 | + state_dict = load_file(model_path, device="cpu") |
| 129 | + except Exception as e: |
| 130 | + if name in pretrained_settings and weights in pretrained_settings[name]: |
| 131 | + message = ( |
| 132 | + f"Error loading {name} `{weights}` weights from Hugging Face Hub, " |
| 133 | + "trying loading from original url..." |
| 134 | + ) |
| 135 | + warnings.warn(message, UserWarning) |
| 136 | + url = pretrained_settings[name][weights]["url"] |
| 137 | + state_dict = load_url(url, map_location="cpu") |
| 138 | + else: |
| 139 | + raise e |
122 | 140 |
|
123 | 141 | # Load model weights
|
124 |
| - state_dict = load_file(model_path, device="cpu") |
125 | 142 | encoder.load_state_dict(state_dict)
|
126 | 143 |
|
127 | 144 | encoder.set_in_channels(in_channels, pretrained=weights is not None)
|
@@ -154,11 +171,20 @@ def get_preprocessing_params(encoder_name, pretrained="imagenet"):
|
154 | 171 | revision = all_settings[pretrained]["revision"]
|
155 | 172 |
|
156 | 173 | # Load config and model
|
157 |
| - config_path = hf_hub_download( |
158 |
| - repo_id, filename="config.json", revision=revision |
159 |
| - ) |
160 |
| - with open(config_path, "r") as f: |
161 |
| - settings = json.load(f) |
| 174 | + try: |
| 175 | + config_path = hf_hub_download( |
| 176 | + repo_id, filename="config.json", revision=revision |
| 177 | + ) |
| 178 | + with open(config_path, "r") as f: |
| 179 | + settings = json.load(f) |
| 180 | + except Exception as e: |
| 181 | + if ( |
| 182 | + encoder_name in pretrained_settings |
| 183 | + and pretrained in pretrained_settings[encoder_name] |
| 184 | + ): |
| 185 | + settings = pretrained_settings[encoder_name][pretrained] |
| 186 | + else: |
| 187 | + raise e |
162 | 188 |
|
163 | 189 | formatted_settings = {}
|
164 | 190 | formatted_settings["input_space"] = settings.get("input_space", "RGB")
|
|
0 commit comments