Skip to content

Commit 761702f

Browse files
authored
[Core] Integrate fastsafetensors loader for loading model weights (#10647)
Signed-off-by: Manish Sethi <[email protected]>
1 parent 9606d57 commit 761702f

File tree

11 files changed

+152
-9
lines changed

11 files changed

+152
-9
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Loading Model weights with fastsafetensors
2+
===================================================================
3+
4+
Using fastsafetensor library enables loading model weights to GPU memory by leveraging GPU direct storage. See https://github.com/foundation-model-stack/fastsafetensors for more details.
5+
For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true``

docs/source/models/extensions/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55

66
runai_model_streamer
77
tensorizer
8+
fastsafetensor
89
:::

requirements/test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@ tritonclient==2.51.0
4141
numpy < 2.0.0
4242
runai-model-streamer==0.11.0
4343
runai-model-streamer-s3==0.11.0
44+
fastsafetensors>=0.1.10

requirements/test.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ click==8.1.7
6767
# jiwer
6868
# nltk
6969
# ray
70+
# typer
7071
colorama==0.4.6
7172
# via
7273
# awscli
@@ -122,6 +123,8 @@ fastparquet==2024.11.0
122123
# via genai-perf
123124
fastrlock==0.8.2
124125
# via cupy-cuda12x
126+
fastsafetensors==0.1.10
127+
# via -r requirements/test.in
125128
filelock==3.16.1
126129
# via
127130
# datasets
@@ -505,7 +508,9 @@ requests==2.32.3
505508
responses==0.25.3
506509
# via genai-perf
507510
rich==13.9.4
508-
# via genai-perf
511+
# via
512+
# genai-perf
513+
# typer
509514
rouge-score==0.1.2
510515
# via lm-eval
511516
rpds-py==0.20.1
@@ -550,6 +555,8 @@ setuptools==75.8.0
550555
# via
551556
# pytablewriter
552557
# torch
558+
shellingham==1.5.4
559+
# via typer
553560
six==1.16.0
554561
# via
555562
# python-dateutil
@@ -600,6 +607,7 @@ torch==2.6.0
600607
# accelerate
601608
# bitsandbytes
602609
# encodec
610+
# fastsafetensors
603611
# lm-eval
604612
# peft
605613
# runai-model-streamer
@@ -654,6 +662,8 @@ typepy==1.3.2
654662
# dataproperty
655663
# pytablewriter
656664
# tabledata
665+
typer==0.15.2
666+
# via fastsafetensors
657667
typing-extensions==4.12.2
658668
# via
659669
# huggingface-hub
@@ -663,6 +673,7 @@ typing-extensions==4.12.2
663673
# pydantic
664674
# pydantic-core
665675
# torch
676+
# typer
666677
tzdata==2024.2
667678
# via pandas
668679
urllib3==2.2.3

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def _read_requirements(filename: str) -> list[str]:
680680
install_requires=get_requirements(),
681681
extras_require={
682682
"tensorizer": ["tensorizer>=2.9.0"],
683+
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
683684
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
684685
"audio": ["librosa", "soundfile"], # Required for audio processing
685686
"video": ["decord"] # Required for video processing

tests/fastsafetensors_loader/__init__.py

Whitespace-only changes.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from vllm import SamplingParams
4+
from vllm.config import LoadFormat
5+
6+
test_model = "openai-community/gpt2"
7+
8+
prompts = [
9+
"Hello, my name is",
10+
"The president of the United States is",
11+
"The capital of France is",
12+
"The future of AI is",
13+
]
14+
# Create a sampling params object.
15+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
16+
17+
18+
def test_model_loader_download_files(vllm_runner):
19+
with vllm_runner(test_model,
20+
load_format=LoadFormat.FASTSAFETENSORS) as llm:
21+
deserialized_outputs = llm.generate(prompts, sampling_params)
22+
assert deserialized_outputs
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import glob
4+
import tempfile
5+
6+
import huggingface_hub.constants
7+
import torch
8+
9+
from vllm.model_executor.model_loader.weight_utils import (
10+
download_weights_from_hf, fastsafetensors_weights_iterator,
11+
safetensors_weights_iterator)
12+
13+
14+
def test_fastsafetensors_model_loader():
15+
with tempfile.TemporaryDirectory() as tmpdir:
16+
huggingface_hub.constants.HF_HUB_OFFLINE = False
17+
download_weights_from_hf("openai-community/gpt2",
18+
allow_patterns=["*.safetensors"],
19+
cache_dir=tmpdir)
20+
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
21+
assert len(safetensors) > 0
22+
23+
fastsafetensors_tensors = {}
24+
hf_safetensors_tensors = {}
25+
26+
for name, tensor in fastsafetensors_weights_iterator(
27+
safetensors, True):
28+
fastsafetensors_tensors[name] = tensor
29+
30+
for name, tensor in safetensors_weights_iterator(safetensors, True):
31+
hf_safetensors_tensors[name] = tensor
32+
33+
assert len(fastsafetensors_tensors) == len(hf_safetensors_tensors)
34+
35+
for name, fastsafetensors_tensor in fastsafetensors_tensors.items():
36+
fastsafetensors_tensor = fastsafetensors_tensor.to('cpu')
37+
assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[
38+
name].dtype
39+
assert fastsafetensors_tensor.shape == hf_safetensors_tensors[
40+
name].shape
41+
assert torch.all(
42+
fastsafetensors_tensor.eq(hf_safetensors_tensors[name]))
43+
44+
45+
if __name__ == "__main__":
46+
test_fastsafetensors_model_loader()

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,7 @@ class LoadFormat(str, enum.Enum):
12771277
BITSANDBYTES = "bitsandbytes"
12781278
MISTRAL = "mistral"
12791279
RUNAI_STREAMER = "runai_streamer"
1280+
FASTSAFETENSORS = "fastsafetensors"
12801281

12811282

12821283
@dataclass

vllm/model_executor/model_loader/loader.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@
4949
set_default_torch_dtype)
5050
from vllm.model_executor.model_loader.weight_utils import (
5151
download_safetensors_index_file_from_hf, download_weights_from_hf,
52-
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
53-
get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator,
54-
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
52+
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
53+
filter_files_not_needed_for_inference, get_gguf_extra_tensor_names,
54+
get_lock, gguf_quant_weights_iterator, initialize_dummy_weights,
55+
np_cache_weights_iterator, pt_weights_iterator,
5556
runai_safetensors_weights_iterator, safetensors_weights_iterator)
5657
from vllm.model_executor.utils import set_weight_attrs
5758
from vllm.platforms import current_platform
@@ -275,7 +276,8 @@ def _prepare_weights(
275276
# Some quantized models use .pt files for storing the weights.
276277
if load_format == LoadFormat.AUTO:
277278
allow_patterns = ["*.safetensors", "*.bin"]
278-
elif load_format == LoadFormat.SAFETENSORS:
279+
elif (load_format == LoadFormat.SAFETENSORS
280+
or load_format == LoadFormat.FASTSAFETENSORS):
279281
use_safetensors = True
280282
allow_patterns = ["*.safetensors"]
281283
elif load_format == LoadFormat.MISTRAL:
@@ -357,10 +359,16 @@ def _get_weights_iterator(
357359
self.load_config.use_tqdm_on_load,
358360
)
359361
elif use_safetensors:
360-
weights_iterator = safetensors_weights_iterator(
361-
hf_weights_files,
362-
self.load_config.use_tqdm_on_load,
363-
)
362+
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
363+
weights_iterator = fastsafetensors_weights_iterator(
364+
hf_weights_files,
365+
self.load_config.use_tqdm_on_load,
366+
)
367+
else:
368+
weights_iterator = safetensors_weights_iterator(
369+
hf_weights_files,
370+
self.load_config.use_tqdm_on_load,
371+
)
364372
else:
365373
weights_iterator = pt_weights_iterator(
366374
hf_weights_files,

0 commit comments

Comments
 (0)