@@ -125,15 +125,29 @@ def get_quant_config(
125
125
def prepare_hf_model_weights (
126
126
model_name_or_path : str ,
127
127
cache_dir : Optional [str ] = None ,
128
- use_safetensors : bool = False ,
128
+ load_format : str = "auto" ,
129
129
fall_back_to_pt : bool = True ,
130
130
revision : Optional [str ] = None ,
131
131
) -> Tuple [str , List [str ], bool ]:
132
132
# Download model weights from huggingface.
133
133
is_local = os .path .isdir (model_name_or_path )
134
+ use_safetensors = False
134
135
# Some quantized models use .pt files for storing the weights.
135
- allow_patterns = ["*.safetensors"
136
- ] if use_safetensors else ["*.bin" , "*.pt" ]
136
+ if load_format == "auto" :
137
+ allow_patterns = ["*.safetensors" , "*.bin" ]
138
+ elif load_format == "safetensors" :
139
+ use_safetensors = True
140
+ allow_patterns = ["*.safetensors" ]
141
+ elif load_format == "pt" :
142
+ allow_patterns = ["*.pt" ]
143
+ elif load_format == "npcache" :
144
+ allow_patterns = ["*.bin" ]
145
+ else :
146
+ raise ValueError (f"Unknown load_format: { load_format } " )
147
+
148
+ if fall_back_to_pt :
149
+ allow_patterns += [".pt" ]
150
+
137
151
if not is_local :
138
152
# Use file lock to prevent multiple processes from
139
153
# downloading the same model weights at the same time.
@@ -148,6 +162,10 @@ def prepare_hf_model_weights(
148
162
hf_weights_files : List [str ] = []
149
163
for pattern in allow_patterns :
150
164
hf_weights_files += glob .glob (os .path .join (hf_folder , pattern ))
165
+ if len (hf_weights_files ) > 0 :
166
+ if pattern == "*.safetensors" :
167
+ use_safetensors = True
168
+ break
151
169
if not use_safetensors :
152
170
# Exclude files that are not needed for inference.
153
171
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
@@ -163,13 +181,6 @@ def prepare_hf_model_weights(
163
181
if not any (f .endswith (x ) for x in blacklist )
164
182
]
165
183
166
- if len (hf_weights_files ) == 0 and use_safetensors and fall_back_to_pt :
167
- return prepare_hf_model_weights (model_name_or_path ,
168
- cache_dir = cache_dir ,
169
- use_safetensors = False ,
170
- fall_back_to_pt = False ,
171
- revision = revision )
172
-
173
184
if len (hf_weights_files ) == 0 :
174
185
raise RuntimeError (
175
186
f"Cannot find any model weights with `{ model_name_or_path } `" )
@@ -182,30 +193,16 @@ def hf_model_weights_iterator(
182
193
cache_dir : Optional [str ] = None ,
183
194
load_format : str = "auto" ,
184
195
revision : Optional [str ] = None ,
196
+ fall_back_to_pt : Optional [bool ] = True ,
185
197
) -> Iterator [Tuple [str , torch .Tensor ]]:
186
- use_safetensors = False
187
- use_np_cache = False
188
- fall_back_to_pt = False
189
- if load_format == "auto" :
190
- use_safetensors = True
191
- fall_back_to_pt = True
192
- elif load_format == "safetensors" :
193
- use_safetensors = True
194
- elif load_format == "pt" :
195
- pass
196
- elif load_format == "npcache" :
197
- use_np_cache = True
198
- else :
199
- raise ValueError (f"Unknown load_format: { load_format } " )
200
-
201
198
hf_folder , hf_weights_files , use_safetensors = prepare_hf_model_weights (
202
199
model_name_or_path ,
203
200
cache_dir = cache_dir ,
204
- use_safetensors = use_safetensors ,
201
+ load_format = load_format ,
205
202
fall_back_to_pt = fall_back_to_pt ,
206
203
revision = revision )
207
204
208
- if use_np_cache :
205
+ if load_format == "npcache" :
209
206
# Currently np_cache only support *.bin checkpoints
210
207
assert use_safetensors is False
211
208
0 commit comments