|
49 | 49 | set_default_torch_dtype)
|
50 | 50 | from vllm.model_executor.model_loader.weight_utils import (
|
51 | 51 | download_safetensors_index_file_from_hf, download_weights_from_hf,
|
52 |
| - filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, |
53 |
| - get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator, |
54 |
| - initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, |
| 52 | + fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, |
| 53 | + filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, |
| 54 | + get_lock, gguf_quant_weights_iterator, initialize_dummy_weights, |
| 55 | + np_cache_weights_iterator, pt_weights_iterator, |
55 | 56 | runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
56 | 57 | from vllm.model_executor.utils import set_weight_attrs
|
57 | 58 | from vllm.platforms import current_platform
|
@@ -275,7 +276,8 @@ def _prepare_weights(
|
275 | 276 | # Some quantized models use .pt files for storing the weights.
|
276 | 277 | if load_format == LoadFormat.AUTO:
|
277 | 278 | allow_patterns = ["*.safetensors", "*.bin"]
|
278 |
| - elif load_format == LoadFormat.SAFETENSORS: |
| 279 | + elif (load_format == LoadFormat.SAFETENSORS |
| 280 | + or load_format == LoadFormat.FASTSAFETENSORS): |
279 | 281 | use_safetensors = True
|
280 | 282 | allow_patterns = ["*.safetensors"]
|
281 | 283 | elif load_format == LoadFormat.MISTRAL:
|
@@ -357,10 +359,16 @@ def _get_weights_iterator(
|
357 | 359 | self.load_config.use_tqdm_on_load,
|
358 | 360 | )
|
359 | 361 | elif use_safetensors:
|
360 |
| - weights_iterator = safetensors_weights_iterator( |
361 |
| - hf_weights_files, |
362 |
| - self.load_config.use_tqdm_on_load, |
363 |
| - ) |
| 362 | + if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: |
| 363 | + weights_iterator = fastsafetensors_weights_iterator( |
| 364 | + hf_weights_files, |
| 365 | + self.load_config.use_tqdm_on_load, |
| 366 | + ) |
| 367 | + else: |
| 368 | + weights_iterator = safetensors_weights_iterator( |
| 369 | + hf_weights_files, |
| 370 | + self.load_config.use_tqdm_on_load, |
| 371 | + ) |
364 | 372 | else:
|
365 | 373 | weights_iterator = pt_weights_iterator(
|
366 | 374 | hf_weights_files,
|
|
0 commit comments