Skip to content

Commit 28a6e31

Browse files
kylesayrsyiliu30
authored andcommitted
[Offloading] Support Disk Offloading (vllm-project#2373)
## Purpose ## * Support disk offloading for very large models ## Prerequisites ## * [[Offload] Convert accelerate for loading/saving](vllm-project/compressed-tensors#572) ## Examples ## * Add `examples/disk_offloading/qwen3_example.py` * Add `examples/disk_offloading/kimi_k2_example.py` ## Changes ## ### Required ### * Remove post-processing step where `remove_dispatch` is called * Previously, this was used to avoid conflicts between `dispatch_for_sequential` and `dispatch_for_generation`. * Now, the two functions are directly compatible: you don't need to remove the dispatch of one to use the other * Add `to_accelerate` to `save_pretrained_wrapper` * This ensures that the model is converted to `accelerate` offloading before saving * This ensures the best compatibility with `save_pretrained`, and reduces excess memory usage which would cause gpu/cpu ooms ### Hardening / Future Pre Changes ### * During oneshot preprocessing, convert `from_accelerate` if possible. This guards against users who load their model outside of the `load_offloaded_model` context * Remove `offload_device` arguemnt from `dispatch_for_sequential` to avoid deprecation warning * `dispatch_for_sequential` now always respects the device the model was loaded on ## Testing ## * Ran `Qwen/Qwen3-0.6B` example to completion * [IN PROGRESS] Run `unsloth/Kimi-K2-Instruct-0905-BF16` example to completion --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent 144affb commit 28a6e31

File tree

7 files changed

+143
-23
lines changed

7 files changed

+143
-23
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from compressed_tensors.offload import get_device_map, load_offloaded_model
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
7+
# Select model and load it in the `load_offloaded_model` context
8+
with load_offloaded_model():
9+
model_id = "unsloth/Kimi-K2-Instruct-0905-BF16"
10+
model = AutoModelForCausalLM.from_pretrained(
11+
model_id,
12+
dtype="auto",
13+
device_map="auto_offload", # fit as much as possible on cpu, rest goes on disk
14+
trust_remote_code=True,
15+
offload_folder="./offload_folder",
16+
)
17+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
18+
19+
# Confirm that model is dispatched correctly
20+
devices = {offloaded for _onloaded, offloaded in get_device_map(model).values()}
21+
print(f"Model was offloaded to the following devices: {devices}")
22+
23+
# Select calibration dataset.
24+
DATASET_ID = "ultrachat-200k"
25+
DATASET_SPLIT = "train_sft"
26+
27+
# Select number of samples. 512 samples is a good place to start.
28+
# Increasing the number of samples can improve accuracy.
29+
NUM_CALIBRATION_SAMPLES = 20
30+
MAX_SEQUENCE_LENGTH = 2048
31+
32+
# Configure the quantization algorithm to run.
33+
# * quantize the weights to NVFP4
34+
recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"])
35+
36+
# Apply algorithms.
37+
oneshot(
38+
model=model,
39+
processor=tokenizer,
40+
dataset=DATASET_ID,
41+
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
42+
recipe=recipe,
43+
max_seq_length=MAX_SEQUENCE_LENGTH,
44+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
45+
)
46+
47+
# Save to disk compressed.
48+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-NVFP4"
49+
model.save_pretrained(SAVE_DIR, save_compressed=True)
50+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from compressed_tensors.offload import (
2+
dispatch_model,
3+
get_device_map,
4+
load_offloaded_model,
5+
)
6+
from transformers import AutoModelForCausalLM, AutoTokenizer
7+
8+
from llmcompressor import oneshot
9+
from llmcompressor.modifiers.quantization import QuantizationModifier
10+
11+
# Select model and load it in the `load_offloaded_model` context
12+
# In this example, we emulate large model quantization with disk offloading by
13+
# restricting the theoretical size of CPU RAM to be smaller than the size of the model
14+
with load_offloaded_model():
15+
model_id = "Qwen/Qwen3-0.6B"
16+
model = AutoModelForCausalLM.from_pretrained(
17+
model_id,
18+
dtype="auto",
19+
device_map="auto_offload", # fit as much as possible on cpu, rest goes on disk
20+
max_memory={"cpu": 6e8}, # remove this line to use as much cpu as possible
21+
offload_folder="./offload_folder",
22+
)
23+
tokenizer = AutoTokenizer.from_pretrained(model_id)
24+
25+
# Confirm that model is dispatched correctly
26+
devices = {offloaded for _onloaded, offloaded in get_device_map(model).values()}
27+
print(f"Model was offloaded to the following devices: {devices}")
28+
29+
# Select calibration dataset.
30+
DATASET_ID = "ultrachat-200k"
31+
DATASET_SPLIT = "train_sft"
32+
33+
# Select number of samples. 512 samples is a good place to start.
34+
# Increasing the number of samples can improve accuracy.
35+
NUM_CALIBRATION_SAMPLES = 20
36+
MAX_SEQUENCE_LENGTH = 2048
37+
38+
# Configure the quantization algorithm to run.
39+
# * quantize the weights to NVFP4
40+
recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"])
41+
42+
# Apply algorithms.
43+
oneshot(
44+
model=model,
45+
dataset=DATASET_ID,
46+
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
47+
recipe=recipe,
48+
max_seq_length=MAX_SEQUENCE_LENGTH,
49+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
50+
)
51+
52+
# Confirm generations of the quantized model look sane.
53+
print("\n\n")
54+
print("========== SAMPLE GENERATION ==============")
55+
dispatch_model(model)
56+
sample = tokenizer("Hello my name is", return_tensors="pt")
57+
sample = {key: value.to(model.device) for key, value in sample.items()}
58+
output = model.generate(**sample, max_new_tokens=100)
59+
print(tokenizer.decode(output[0]))
60+
print("==========================================\n\n")
61+
62+
# Save to disk compressed.
63+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-NVFP4"
64+
model.save_pretrained(SAVE_DIR, save_compressed=True)
65+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/entrypoints/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
from pathlib import PosixPath
1212

13-
from compressed_tensors.utils import remove_dispatch
13+
from compressed_tensors.offload import from_accelerate
1414
from loguru import logger
1515
from transformers import (
1616
AutoConfig,
@@ -84,6 +84,10 @@ def pre_process(
8484
if not model_args.tie_word_embeddings:
8585
untie_word_embeddings(model_args.model)
8686

87+
# if the model was loaded with accelerate offloading, convert to CT offloading
88+
if hasattr(model_args.model, "hf_device_map"):
89+
from_accelerate(model_args.model)
90+
8791
# wrap model.save_pretrained
8892
modify_save_pretrained(model_args.model)
8993

@@ -104,10 +108,6 @@ def post_process(
104108
Raises:
105109
ValueError: If saving fails due to an invalid `output_dir` or other issues.
106110
"""
107-
# remove any existing dispatches
108-
if model_args is not None and model_args.model is not None:
109-
remove_dispatch(model_args.model)
110-
111111
if model_args is not None and output_dir is not None:
112112
if recipe_args is not None and getattr(recipe_args, "stage", None) is not None:
113113
output_dir = os.path.join(output_dir, recipe_args.stage)

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def is_ancestor(module: Module) -> bool:
532532
def dispatch_for_sequential(
533533
model: PreTrainedModel,
534534
onload_device: Optional[torch.device | str] = None,
535-
offload_device: torch.device | str = torch.device("cpu"),
535+
offload_device: Optional[torch.device | str] = None,
536536
) -> PreTrainedModel:
537537
"""
538538
Dispatch a model for sequential calibration using a sequential pipeline.

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __call__(
9797
# prepare model for sequential onloading
9898
onload_device = get_main_device()
9999
offload_device = torch.device(dataset_args.sequential_offload_device)
100-
dispatch_for_sequential(model, onload_device, offload_device)
100+
dispatch_for_sequential(model, onload_device)
101101

102102
# prepare to trace subgraphs
103103
modifiers = session.lifecycle.recipe.modifiers

src/llmcompressor/transformers/compression/compressed_tensors_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
SparsityCompressionConfig,
1010
)
1111
from compressed_tensors.config import CompressionFormat
12-
from compressed_tensors.offload import is_rank0
12+
from compressed_tensors.offload import is_rank0, to_accelerate
1313
from loguru import logger
1414
from transformers import PreTrainedModel
1515

@@ -90,6 +90,9 @@ def save_pretrained_wrapper(
9090
compressor.compress_model(model)
9191

9292
if is_rank0():
93+
# convert to accelerate offloaded for optimal saving with transformers
94+
to_accelerate(model)
95+
9396
# save (compressed) model structure
9497
original_save_pretrained.__get__(model, model_class)(
9598
save_directory,

tests/llmcompressor/transformers/compression/test_quantization.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import torch
3+
from accelerate.utils import align_module_device
34
from compressed_tensors.offload import dispatch_model
45
from compressed_tensors.quantization.utils import is_module_quantized
56
from torch.utils.data import DataLoader
@@ -36,22 +37,23 @@ def _get_quant_info(model):
3637
quant_info_weights = {}
3738
quant_info_inputs = {}
3839
for name, module in model.named_modules():
39-
if is_module_quantized(module):
40-
if module.quantization_scheme.weights is not None:
41-
quant_info_weights[name] = (
42-
module.weight_scale,
43-
module.weight_zero_point,
44-
module.weight,
45-
)
46-
47-
if module.quantization_scheme.input_activations is not None:
48-
is_dynamic = module.quantization_scheme.input_activations.dynamic
49-
if not is_dynamic:
50-
quant_info_inputs[name] = (
51-
module.input_scale,
52-
module.input_zero_point,
40+
with align_module_device(module):
41+
if is_module_quantized(module):
42+
if module.quantization_scheme.weights is not None:
43+
quant_info_weights[name] = (
44+
module.weight_scale,
45+
module.weight_zero_point,
46+
module.weight,
5347
)
5448

49+
if module.quantization_scheme.input_activations is not None:
50+
is_dynamic = module.quantization_scheme.input_activations.dynamic
51+
if not is_dynamic:
52+
quant_info_inputs[name] = (
53+
module.input_scale,
54+
module.input_zero_point,
55+
)
56+
5557
return quant_info_weights, quant_info_inputs
5658

5759

@@ -85,7 +87,7 @@ def setup_model_and_config(request, tmpdir_factory):
8587
num_calibration_samples=num_calibration_samples,
8688
recipe=config["new_recipe"],
8789
pad_to_max_length=pad_to_max_length,
88-
splits={"calibration": "train_gen[:1%]"},
90+
splits={"calibration": f"train_gen[:{num_calibration_samples}]"},
8991
save_compressed=False,
9092
)
9193

0 commit comments

Comments
 (0)