Skip to content

Commit a05751a

Browse files
committed
Move loading file outside of try/except
1 parent 5b34296 commit a05751a

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,12 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
120120
# First, try to load from HF-Hub, but as far as I know not all countries have
121121
# access to the Hub (e.g. China), so we try to load from the original url if
122122
# the first attempt fails.
123+
weights_path = None
123124
try:
124125
hf_hub_download(repo_id, filename="config.json", revision=revision)
125-
model_path = hf_hub_download(
126+
weights_path = hf_hub_download(
126127
repo_id, filename="model.safetensors", revision=revision
127128
)
128-
state_dict = load_file(model_path, device="cpu")
129129
except Exception as e:
130130
if name in pretrained_settings and weights in pretrained_settings[name]:
131131
message = (
@@ -138,6 +138,9 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
138138
else:
139139
raise e
140140

141+
if weights_path is not None:
142+
state_dict = load_file(weights_path, device="cpu")
143+
141144
# Load model weights
142145
encoder.load_state_dict(state_dict)
143146

0 commit comments

Comments
 (0)