1+ import json
12import timm
23import copy
34import warnings
45import functools
5- import torch .utils .model_zoo as model_zoo
6+ from torch .utils .model_zoo import load_url
7+ from huggingface_hub import hf_hub_download
8+ from safetensors .torch import load_file
9+
610
711from .resnet import resnet_encoders
812from .dpn import dpn_encoders
2226from .timm_universal import TimmUniversalEncoder
2327
2428from ._preprocessing import preprocess_input
29+ from ._legacy_pretrained_settings import pretrained_settings
2530
2631__all__ = [
2732 "encoders" ,
@@ -101,15 +106,43 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
101106 encoder = EncoderClass (** params )
102107
103108 if weights is not None :
104- try :
105- settings = encoders [name ]["pretrained_settings" ][weights ]
106- except KeyError :
109+ if weights not in encoders [name ]["pretrained_settings" ]:
110+ available_weights = list (encoders [name ]["pretrained_settings" ].keys ())
107111 raise KeyError (
108- "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}" .format (
109- weights , name , list (encoders [name ]["pretrained_settings" ].keys ())
110- )
112+ f"Wrong pretrained weights `{ weights } ` for encoder `{ name } `. "
113+ f"Available options are: { available_weights } "
114+ )
115+
116+ settings = encoders [name ]["pretrained_settings" ][weights ]
117+ repo_id = settings ["repo_id" ]
118+ revision = settings ["revision" ]
119+
120+ # First, try to load from HF-Hub, but as far as I know not all countries have
121+ # access to the Hub (e.g. China), so we try to load from the original url if
122+ # the first attempt fails.
123+ weights_path = None
124+ try :
125+ hf_hub_download (repo_id , filename = "config.json" , revision = revision )
126+ weights_path = hf_hub_download (
127+ repo_id , filename = "model.safetensors" , revision = revision
111128 )
112- encoder .load_state_dict (model_zoo .load_url (settings ["url" ]))
129+ except Exception as e :
130+ if name in pretrained_settings and weights in pretrained_settings [name ]:
131+ message = (
132+ f"Error loading { name } `{ weights } ` weights from Hugging Face Hub, "
133+ "trying loading from original url..."
134+ )
135+ warnings .warn (message , UserWarning )
136+ url = pretrained_settings [name ][weights ]["url" ]
137+ state_dict = load_url (url , map_location = "cpu" )
138+ else :
139+ raise e
140+
141+ if weights_path is not None :
142+ state_dict = load_file (weights_path , device = "cpu" )
143+
144+ # Load model weights
145+ encoder .load_state_dict (state_dict )
113146
114147 encoder .set_in_channels (in_channels , pretrained = weights is not None )
115148 if output_stride != 32 :
@@ -136,7 +169,25 @@ def get_preprocessing_params(encoder_name, pretrained="imagenet"):
136169 raise ValueError (
137170 "Available pretrained options {}" .format (all_settings .keys ())
138171 )
139- settings = all_settings [pretrained ]
172+
173+ repo_id = all_settings [pretrained ]["repo_id" ]
174+ revision = all_settings [pretrained ]["revision" ]
175+
176+ # Load config and model
177+ try :
178+ config_path = hf_hub_download (
179+ repo_id , filename = "config.json" , revision = revision
180+ )
181+ with open (config_path , "r" ) as f :
182+ settings = json .load (f )
183+ except Exception as e :
184+ if (
185+ encoder_name in pretrained_settings
186+ and pretrained in pretrained_settings [encoder_name ]
187+ ):
188+ settings = pretrained_settings [encoder_name ][pretrained ]
189+ else :
190+ raise e
140191
141192 formatted_settings = {}
142193 formatted_settings ["input_space" ] = settings .get ("input_space" , "RGB" )
0 commit comments