Commit f4e484d
[Performance] Sequential onloading (#1263)
# Sequential Onloading #
<p align="center"><img width="403" alt="Screenshot 2025-06-05 at 22 53
01"
src="https://github.com/user-attachments/assets/ffd610ac-c511-4dc1-b858-b0ed2bf95193"
/></p>
```
(25/33): Calibrating: 0%| | 0/512 [00:00<?, ?it/s]
<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>.weight -> cuda
<class 'torch.nn.modules.linear.Linear'>.weight -> cuda
<class 'torch.nn.modules.linear.Linear'>.weight_scale -> cuda
<class 'torch.nn.modules.linear.Linear'>.weight_zero_point -> cuda
...
(25/33): Calibrating: 100%|█████| 512/512 [00:23<00:00, 21.91it/s]
2025-06-03T17:29:15.536963-0400 | compress_modules | INFO - Quantizing model.layers.24.self_attn.q_proj using 512 samples
2025-06-03T17:29:17.328720-0400 | compress | METRIC - time 1.79s
2025-06-03T17:29:17.329265-0400 | compress | METRIC - error 8948.54
2025-06-03T17:29:17.329781-0400 | compress | METRIC - GPU 0 | usage: 5.41% | total memory: 85 GB
2025-06-03T17:29:17.330248-0400 | compress | METRIC - Compressed module size: 33.947648 MB
...
(25/33): Propagating: 100%|█████| 512/512 [00:03<00:00, 131.16it/s]
<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>.weight -> meta
<class 'torch.nn.modules.linear.Linear'>.weight -> meta
<class 'torch.nn.modules.linear.Linear'>.weight_scale -> meta
<class 'torch.nn.modules.linear.Linear'>.weight_zero_point -> meta
...
```
## Purpose ##
* Reduce hardware requirements for calibrating large models
* Reduce runtime caused by excess device movement when calibrating
offloaded models
## Prerequisites ##
* vllm-project/compressed-tensors#354
* vllm-project/compressed-tensors#355
* vllm-project/compressed-tensors#356
* vllm-project/compressed-tensors#357
## Related Issues ##
* Resolves #1383
* Resolves #1228
* Resolves #1122
* Resolves #1078
* Resolves #1216
* Resolves #1483
## Changes ##
* Keep layer parameters onloaded during the entire sequential
calibration + compression + propagation step
* This is achieved through the `keep_onload_context`, which disables
offloaded until the context is exited
* Dispatch model within each calibration pipeline
* Sequential pipeline offloads the model to CPU, and executes on the
first cuda device
* Deprecate passing sequential_targets via modifiers, instead prefer
passing via oneshot argument
* Use sequential pipeline as default pipeline (basic pipeline is never
used)
* Deprecate passing sequential_targets via modifiers, instead prefer
passing via oneshot argument
* Dispatch model before sample generation
* The model is dispatched exactly as it would be if it was loaded with
`device_map="auto"`
### Examples ###
* Models are loaded onto cpu before oneshot (rather than being
dispatched across GPUs)
* Model is reloaded from disk in order to redispatch onto "auto" device
map
* In my opinion, this is a better flow anyways, since models can raise
errors / take a very long time during generation, which can cause the
entirely compression job to go to waste
* The alternative is to either call `accelerate.remove_hooks(model)` and
`accelerate.dispatch_model(model)` before generating, or get rid of
sample generation entirely. One of these may be required if
compressed_linear isn't reliable enough to add to our examples
<details><summary>New example script</summary>
```python3
from transformers import AutoModelForCausalLM
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.utils.dev import dispatch_for_generation
# Load model (on cpu)
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") # model is loaded on cpu
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Define recipe
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
# Apply oneshot (model execution device is set to cuda, model stays on cpu)
oneshot(
model=model,
dataset="ultrachat_200k",
recipe=recipe,
max_seq_length=2048,
num_calibration_samples=512,
)
# Perform sample generation
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
# Save to disk before generating
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
</details>
## Testing ##
* Calibrated and GPTQ-compressed one layer of Deepseek-V3 with a single
H100 in 50 seconds
* 4.5x Improvement over original 236 seconds
* Peak memory of ~40 GB, which can be further reduced by increasing the
granularity of sequential targets
* Not offloading activations did not result in a performance improvement
* TODO: Test all example models can be reloaded and run
---------
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Co-authored-by: Brian Dellabetta <bdellabe@redhat.com>1 parent ca00edd commit f4e484d
File tree
72 files changed
+346
-890
lines changed- examples
- awq
- big_models_with_accelerate
- big_models_with_sequential_onloading
- assets
- multimodal_audio
- multimodal_vision
- quantization_2of4_sparse_w4a16
- quantization_kv_cache
- quantization_w4a16_fp4
- quantization_w4a16
- quantization_w4a4_fp4
- quantization_w8a8_fp8
- quantization_w8a8_int8
- quantizing_moe
- sparse_2of4_quantization_fp8
- trl_mixin
- src/llmcompressor
- args
- entrypoints
- modifiers/awq
- pipelines
- basic
- layer_sequential
- sequential
- pytorch/model_load
- transformers/compression
- utils
- tests
- e2e
- vLLM
- examples
- llmcompressor/transformers/tracing
- lmeval
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
72 files changed
+346
-890
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
| 21 | + | |
26 | 22 | | |
27 | 23 | | |
28 | 24 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
10 | | - | |
11 | | - | |
12 | | - | |
| 10 | + | |
13 | 11 | | |
14 | 12 | | |
15 | 13 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
6 | 7 | | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
10 | | - | |
11 | | - | |
12 | | - | |
| 11 | + | |
13 | 12 | | |
14 | 13 | | |
15 | 14 | | |
| |||
71 | 70 | | |
72 | 71 | | |
73 | 72 | | |
| 73 | + | |
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
| |||
This file was deleted.
Lines changed: 0 additions & 26 deletions
This file was deleted.
Lines changed: 0 additions & 81 deletions
This file was deleted.
This file was deleted.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
| 24 | + | |
29 | 25 | | |
30 | 26 | | |
31 | 27 | | |
| |||
0 commit comments