Skip to content

Commit f4e484d

Browse files
[Performance] Sequential onloading (#1263)
# Sequential Onloading # <p align="center"><img width="403" alt="Screenshot 2025-06-05 at 22 53 01" src="https://github.com/user-attachments/assets/ffd610ac-c511-4dc1-b858-b0ed2bf95193" /></p> ``` (25/33): Calibrating: 0%| | 0/512 [00:00<?, ?it/s] <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>.weight -> cuda <class 'torch.nn.modules.linear.Linear'>.weight -> cuda <class 'torch.nn.modules.linear.Linear'>.weight_scale -> cuda <class 'torch.nn.modules.linear.Linear'>.weight_zero_point -> cuda ... (25/33): Calibrating: 100%|█████| 512/512 [00:23<00:00, 21.91it/s] 2025-06-03T17:29:15.536963-0400 | compress_modules | INFO - Quantizing model.layers.24.self_attn.q_proj using 512 samples 2025-06-03T17:29:17.328720-0400 | compress | METRIC - time 1.79s 2025-06-03T17:29:17.329265-0400 | compress | METRIC - error 8948.54 2025-06-03T17:29:17.329781-0400 | compress | METRIC - GPU 0 | usage: 5.41% | total memory: 85 GB 2025-06-03T17:29:17.330248-0400 | compress | METRIC - Compressed module size: 33.947648 MB ... (25/33): Propagating: 100%|█████| 512/512 [00:03<00:00, 131.16it/s] <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>.weight -> meta <class 'torch.nn.modules.linear.Linear'>.weight -> meta <class 'torch.nn.modules.linear.Linear'>.weight_scale -> meta <class 'torch.nn.modules.linear.Linear'>.weight_zero_point -> meta ... ``` ## Purpose ## * Reduce hardware requirements for calibrating large models * Reduce runtime caused by excess device movement when calibrating offloaded models ## Prerequisites ## * neuralmagic/compressed-tensors#354 * neuralmagic/compressed-tensors#355 * neuralmagic/compressed-tensors#356 * neuralmagic/compressed-tensors#357 ## Related Issues ## * Resolves #1383 * Resolves #1228 * Resolves #1122 * Resolves #1078 * Resolves #1216 * Resolves #1483 ## Changes ## * Keep layer parameters onloaded during the entire sequential calibration + compression + propagation step * This is achieved through the `keep_onload_context`, which disables offloaded until the context is exited * Dispatch model within each calibration pipeline * Sequential pipeline offloads the model to CPU, and executes on the first cuda device * Deprecate passing sequential_targets via modifiers, instead prefer passing via oneshot argument * Use sequential pipeline as default pipeline (basic pipeline is never used) * Deprecate passing sequential_targets via modifiers, instead prefer passing via oneshot argument * Dispatch model before sample generation * The model is dispatched exactly as it would be if it was loaded with `device_map="auto"` ### Examples ### * Models are loaded onto cpu before oneshot (rather than being dispatched across GPUs) * Model is reloaded from disk in order to redispatch onto "auto" device map * In my opinion, this is a better flow anyways, since models can raise errors / take a very long time during generation, which can cause the entirely compression job to go to waste * The alternative is to either call `accelerate.remove_hooks(model)` and `accelerate.dispatch_model(model)` before generating, or get rid of sample generation entirely. One of these may be required if compressed_linear isn't reliable enough to add to our examples <details><summary>New example script</summary> ```python3 from transformers import AutoModelForCausalLM from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot from llmcompressor.utils.dev import dispatch_for_generation # Load model (on cpu) model_id = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") # model is loaded on cpu tokenizer = AutoTokenizer.from_pretrained(model_id) # Define recipe recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) # Apply oneshot (model execution device is set to cuda, model stays on cpu) oneshot( model=model, dataset="ultrachat_200k", recipe=recipe, max_seq_length=2048, num_calibration_samples=512, ) # Perform sample generation print("\n\n") print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") output = model.generate(input_ids, max_new_tokens=100) print(tokenizer.decode(output[0])) print("==========================================\n\n") # Save to disk before generating SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) ``` </details> ## Testing ## * Calibrated and GPTQ-compressed one layer of Deepseek-V3 with a single H100 in 50 seconds * 4.5x Improvement over original 236 seconds * Peak memory of ~40 GB, which can be further reduced by increasing the granularity of sequential targets * Not offloading activations did not result in a performance improvement * TODO: Test all example models can be reloaded and run --------- Signed-off-by: Kyle Sayers <[email protected]> Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]>
1 parent ca00edd commit f4e484d

File tree

72 files changed

+346
-890
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+346
-890
lines changed

examples/awq/README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@ recipe = [
1818
To use your own model, start with an existing example change the `model_id` to match your own model stub.
1919
```python
2020
model_id = "path/to/your/model"
21-
model = AutoModelForCausalLM.from_pretrained(
22-
model_id,
23-
device_map="auto",
24-
torch_dtype="auto",
25-
)
21+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
2622
```
2723

2824
## Adding Mappings ##

examples/awq/llama_example.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
# Select model and load it.
88
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
99

10-
model = AutoModelForCausalLM.from_pretrained(
11-
MODEL_ID, device_map="auto", torch_dtype="auto"
12-
)
10+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
1311
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
1412

1513
# Select calibration dataset.

examples/awq/qwen3_moe_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33

44
from llmcompressor import oneshot
55
from llmcompressor.modifiers.awq import AWQModifier
6+
from llmcompressor.utils import dispatch_for_generation
67

78
# Select model and load it.
89
MODEL_ID = "Qwen/Qwen3-30B-A3B"
910

10-
model = AutoModelForCausalLM.from_pretrained(
11-
MODEL_ID, device_map="auto", torch_dtype="auto"
12-
)
11+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
1312
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
1413

1514
# Select calibration dataset.
@@ -71,6 +70,7 @@ def tokenize(sample):
7170
# Confirm generations of the quantized model look sane.
7271
print("\n\n")
7372
print("========== SAMPLE GENERATION ==============")
73+
dispatch_for_generation(model)
7474
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
7575
output = model.generate(input_ids, max_new_tokens=100)
7676
print(tokenizer.decode(output[0]))

examples/big_models_with_accelerate/README.md

Lines changed: 0 additions & 95 deletions
This file was deleted.

examples/big_models_with_accelerate/cpu_offloading_fp8.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

examples/big_models_with_accelerate/mult_gpus_int8_device_map.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

examples/big_models_with_accelerate/multi_gpu_int8.py

Lines changed: 0 additions & 78 deletions
This file was deleted.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
## Big Modeling with Sequential Onloading ##
2+
### What is Sequential Onloading? ###
3+
Sequential onloading is a memory-efficient approach for compressing large language models (LLMs) using only a single GPU. Instead of loading the entire model into memory—which can easily require hundreds of gigabytes—this method loads and compresses one layer at a time. The outputs are offloaded before the next layer is processed, dramatically reducing peak memory usage while maintaining high compression fidelity.
4+
5+
<p align="center">
6+
<img src="assets/sequential_onloading.png"/>
7+
</p>
8+
9+
For more information, see the [RedHat AI blog post](https://developers.redhat.com/articles/2025/05/09/llm-compressor-optimize-llms-low-latency-deployments#generalizing_to_multimodal_and_moe_architectures) or the [LLM Compressor Office Hours Recording](https://www.youtube.com/watch?v=GrhuqQDmBk8).
10+
11+
### Using Sequential Onloading ###
12+
Sequential onloading is enabled by default within LLM Compressor. To disable sequential onloading, add the `pipeline="basic"` argument to the LLM Compressor `oneshot` function call.
69.5 KB
Loading

examples/multimodal_audio/README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,7 @@ This directory contains example scripts for quantizing a variety of audio langua
2121
To use your own multimodal modal, start with an existing example change the `model_id` to match your own model stub.
2222
```python3
2323
model_id = "path/to/your/model"
24-
model = AutoModelForCausalLM.from_pretrained(
25-
model_id,
26-
device_map="auto",
27-
torch_dtype="auto",
28-
)
24+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
2925
```
3026

3127
## Customizing GPTQModifier Parameters ##

0 commit comments

Comments
 (0)