Skip to content

Commit 2cb141e

Browse files
authored
Merge branch 'main' into qwen3VLMoE_lineared
2 parents 4995535 + fae9429 commit 2cb141e

File tree

10 files changed

+405
-25
lines changed

10 files changed

+405
-25
lines changed

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@ quality:
2828
ruff format --check $(CHECKDIRS);
2929

3030
# style the code according to accepted standards for the repo
31+
# Note: We run `ruff format` twice. Once to fix long lines before lint check
32+
# and again to fix any formatting issues introduced by ruff check --fix
3133
style:
3234
@echo "Running python styling";
35+
ruff format $(CHECKDIRS);
3336
ruff check --fix $(CHECKDIRS);
34-
ruff format $(CHECKDIRS);
37+
ruff format --silent $(CHECKDIRS);
3538

3639
# run tests for the repo
3740
test:
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
from llmcompressor.utils import dispatch_for_generation
7+
8+
# NOTE: Qwen3-Next-80B-A3B-Instruct support is not in transformers<=4.56.2
9+
# you may need to install transformers from source
10+
11+
MODEL_ID = "Qwen/Qwen3-Next-80B-A3B-Instruct"
12+
13+
# Load model.
14+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
15+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16+
17+
18+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
19+
DATASET_SPLIT = "train_sft"
20+
21+
# Select number of samples
22+
NUM_CALIBRATION_SAMPLES = 20
23+
MAX_SEQUENCE_LENGTH = 2048
24+
25+
# Load dataset and preprocess.
26+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
27+
ds = ds.shuffle(seed=42)
28+
29+
30+
def preprocess(example):
31+
return {
32+
"text": tokenizer.apply_chat_template(
33+
example["messages"],
34+
tokenize=False,
35+
)
36+
}
37+
38+
39+
ds = ds.map(preprocess)
40+
41+
42+
# Tokenize inputs.
43+
def tokenize(sample):
44+
return tokenizer(
45+
sample["text"],
46+
padding=False,
47+
max_length=MAX_SEQUENCE_LENGTH,
48+
truncation=True,
49+
add_special_tokens=False,
50+
)
51+
52+
53+
ds = ds.map(tokenize, remove_columns=ds.column_names)
54+
55+
# Configure the quantization algorithm and scheme.
56+
# In this case, we:
57+
# * quantize the weights to fp4 with per group 16 via ptq
58+
# * calibrate a global_scale for activations, which will be used to
59+
# quantize activations to fp4 on the fly
60+
recipe = QuantizationModifier(
61+
targets="Linear",
62+
scheme="NVFP4",
63+
ignore=[
64+
"lm_head",
65+
"re:.*mlp.gate$",
66+
"re:.*mlp.shared_expert_gate$",
67+
"re:.*linear_attn.*",
68+
],
69+
)
70+
71+
# Apply quantization.
72+
# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock`
73+
# during calibration.
74+
# Feel free to update the definition under
75+
# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with
76+
# this behaviour and evaluate its impact on quantization performance
77+
oneshot(
78+
model=model,
79+
dataset=ds,
80+
recipe=recipe,
81+
max_seq_length=MAX_SEQUENCE_LENGTH,
82+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
83+
calibrate_moe_context=True,
84+
)
85+
86+
87+
print("\n\n")
88+
print("========== SAMPLE GENERATION ==============")
89+
dispatch_for_generation(model)
90+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
91+
model.device
92+
)
93+
output = model.generate(input_ids, max_new_tokens=100)
94+
print(tokenizer.decode(output[0]))
95+
print("==========================================\n\n")
96+
97+
98+
# Save to disk in compressed-tensors format.
99+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4"
100+
model.save_pretrained(SAVE_DIR, save_compressed=True)
101+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from transformers import AutoModelForCausalLM, AutoTokenizer
2+
3+
from llmcompressor import oneshot
4+
from llmcompressor.modifiers.quantization import QuantizationModifier
5+
from llmcompressor.utils import dispatch_for_generation
6+
7+
MODEL_ID = "Qwen/Qwen3-Next-80B-A3B-Instruct"
8+
9+
# Load model.
10+
model = AutoModelForCausalLM.from_pretrained(
11+
MODEL_ID,
12+
torch_dtype="auto",
13+
low_cpu_mem_usage=True,
14+
trust_remote_code=True,
15+
)
16+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
17+
18+
recipe = QuantizationModifier(
19+
targets=["Linear"],
20+
scheme="FP8_DYNAMIC",
21+
ignore=[
22+
"lm_head",
23+
"re:.*mlp.gate$",
24+
"re:.*mlp.shared_expert_gate$",
25+
"re:.*linear_attn.*",
26+
],
27+
)
28+
29+
# Apply quantization.
30+
oneshot(model=model, recipe=recipe)
31+
32+
# Confirm generations of the quantized model look sane.
33+
print("========== SAMPLE GENERATION ==============")
34+
dispatch_for_generation(model)
35+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
36+
model.device
37+
)
38+
output = model.generate(input_ids, max_new_tokens=20)
39+
print(tokenizer.decode(output[0]))
40+
print("==========================================")
41+
42+
# Save to disk in compressed-tensors format.
43+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-Dynamic"
44+
model.save_pretrained(SAVE_DIR, save_compressed=True)
45+
tokenizer.save_pretrained(SAVE_DIR)

examples/quantizing_moe/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Quantizing Mixtral-8x7B-Instruct-v0.1 Model with FP8
22

3-
This directory contains an example script for quantizing the `Mixtral-8x7B-Instruct-v0.1` model using the static per-tensor FP8 quantization scheme.
3+
This directory contains example scripts for quantizing LLMs using the static per-tensor FP8 quantization scheme.
44

55
## Installation
66

@@ -32,7 +32,7 @@ python mixtral_example.py
3232

3333
### Step 1: Select a Model, Dataset, and Recipe
3434

35-
In this step, you'll choose a baseline model for quantization, a dataset for calibration, and a quantization recipe.
35+
In this step, you'll choose a base model for quantization, a dataset for calibration, and a quantization recipe.
3636

3737
- **Models**: Can be referenced from a local directory or retrieved from the Hugging Face Hub.
3838
- **Datasets**: Can also be from a local directory or the Hugging Face Hub.

examples/transform/README.md

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Applying Transforms to Improve Quantization Accuracy
2+
3+
This directory contains example scripts for applying transforms to models for the purpose of improving quantization accuracy. For more information on transforms, see [QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs](https://arxiv.org/abs/2404.00456). The two transform styles currently supported are SpinQuant/QuaRot-style (`SpinQuantModifier`), and QuIP-style (`QuIPModifier`).
4+
5+
See also [[vLLM Office Hours #31] vLLM and LLM Compressor Update - August 28, 2025](https://www.youtube.com/watch?v=WVenRmF4dPY&list=PLbMP1JcGBmSHxp4-lubU5WYmJ9YgAQcf3&index=3).
6+
7+
## Installation
8+
9+
To get started, install the necessary dependencies by executing the following commands:
10+
11+
```bash
12+
git clone https://github.com/vllm-project/llm-compressor.git
13+
cd llm-compressor
14+
pip install -e .
15+
```
16+
17+
## Quickstart
18+
19+
The provided example script demonstrates the process for applying quip-style transforms before quantization.
20+
21+
```bash
22+
python3 quip_example.py
23+
```
24+
25+
### Step 1: Select a Model, Dataset, and Recipe
26+
27+
In this step, you'll choose a base model for quantization and a transformation + quantization recipe.
28+
29+
- **Models**: Can be referenced from a local directory or retrieved from the Hugging Face Hub.
30+
- **Recipes**: These are YAML files or Python modifier objects that describe how a model should be optimized during or after training. In this example, we use the `QuIPModifier` applied before the `QuantizationModifier` with the scheme set to `FP8`.
31+
32+
```python
33+
from llmcompressor.modifiers.transform import QuIPModifier
34+
from llmcompressor.modifiers.quantization import QuantizationModifier
35+
36+
recipe = [
37+
QuIPModifier(
38+
rotations=["v", "u"], transform_block_size=128, transform_type="hadamard"
39+
),
40+
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
41+
]
42+
```
43+
44+
Note that `QuIPModifier` can be customized. For a full list of the available arguments, see the [docstring](/src/llmcompressor/modifiers/transform/spinquant/base.py) or documentation.
45+
46+
* `rotations` determines which of the input rotation (v) or output rotations (u) should be used.
47+
* `transform_block_size` determines the size of the hadamard. Smaller hadamards require less cost at runtime.
48+
* `transform_type` determines how the transform is constrcted. hadamard uses the sylvester construction.
49+
50+
### Step 2: Run Quantization Using Oneshot
51+
52+
The `oneshot` method applies the selected recipe to your model and dataset without requiring any fine-tuning. The model will be quantized and saved to `Llama-3.1-8B-Instruct-quip-w4a16`. We use the "datafree" pipeline, since our recipe does not require calibration data.
53+
54+
```python
55+
from llmcompressor import oneshot
56+
57+
# Apply algorithms.
58+
oneshot(model=model, recipe=recipe, pipeline="datafree")
59+
60+
# Save to disk compressed.
61+
SAVE_DIR = MODEL_ID.split("/")[1] + "-quip-w4a16"
62+
model.save_pretrained(SAVE_DIR, save_compressed=True)
63+
tokenizer.save_pretrained(SAVE_DIR)
64+
```
65+
66+
### Step 3: Run optimized model in vLLM
67+
Models optimized with the `hadamard` transform type will be able to leverage the hadacore kernels for accelerated inference. Use the [benchmarks/latency.py](https://github.com/vllm-project/vllm/blob/main/vllm/benchmarks/latency.py) script to benchmark latency
68+
69+
```bash
70+
python3 benchmarks/benchmark_latency.py --model path/to/Llama-3.2-1B-Instruct-quip-w4a16
71+
```
72+
73+
74+
#### Dense Model Latency (sec) ####
75+
| [Base](https://huggingface.co/meta-llama/Llama-3.2-1B-instruct) | Hadacore | GEMM |
76+
| - | - | - |
77+
| 0.4710 | 0.4948 | 1.3946 |
78+
79+
#### Quantized Model Latency (sec) ####
80+
| Base W4A16 | Hadacore | GEMM |
81+
| - | - | - |
82+
| 0.4402 | 0.4489 | 1.2917 |

examples/transform/quip_example.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
# Configure the quantization algorithm to run.
2121
# * apply quip transforms to model in order to make quantization easier
2222
# * quantize the weights to 4 bit with a group size 128
23+
# * NOTE: if a model has activation shapes not divisble by 2^N, consider using
24+
# `random-hadamard` (random hadamard kernels will be added in the future)
2325
recipe = [
2426
QuIPModifier(
25-
rotations=["v", "u"], transform_block_size=128, transform_type="random-hadamard"
27+
rotations=["v", "u"], transform_block_size=128, transform_type="hadamard"
2628
),
2729
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
2830
]

examples/transform/spinquant_example.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
# * quantize the weights to 4 bit with group size 128
1919
recipe = [
2020
SpinQuantModifier(
21-
rotations=["R1", "R2", "R4"], transform_block_size=64, transform_type="hadamard"
21+
rotations=["R1", "R2", "R4"],
22+
transform_block_size=128,
23+
transform_type="hadamard",
2224
),
2325
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
2426
]

src/llmcompressor/args/dataset_arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ class DatasetArguments(CustomDatasetArguments):
201201
"_prepare_4d_causal_attention_mask",
202202
"_prepare_fsmt_decoder_inputs",
203203
"_prepare_4d_causal_attention_mask_with_cache_position",
204+
"_update_linear_attn_mask",
204205
],
205206
metadata={
206207
"help": "List of functions to ignore during tracing, either "

src/llmcompressor/modeling/prepare.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99

1010
try:
1111
from llmcompressor.modeling.qwen3_vl_moe import replace as replace_Qwen3VLMoE
12+
from llmcompressor.modeling.qwen3_next_moe import replace as replace_Qwen3NextMoE
1213
except ImportError:
1314
logger.warning(
14-
"Qwen3-VL-MoE support is not available. "
15+
"Qwen3-VL-MoE and Qwen3-Next support is not available. "
1516
"Please ensure that you have the correct version of transformers installed."
1617
)
1718
replace_Qwen3VLMoE = None
18-
19+
replace_Qwen3NextMoE = None
20+
1921
from llmcompressor.utils.helpers import patch_attr
2022

2123
__all__ = ["replace_modules_for_calibration"]
@@ -50,28 +52,51 @@ def replace_modules_for_calibration(
5052
# ------------------- module replacements; during calibration --------------------
5153

5254

53-
def update_qwen3_moe(model, stack, calibrate_all_experts):
54-
for module in model.modules():
55-
cls_name = module.__class__.__name__
56-
if cls_name == "Qwen3MoeDecoderLayer":
57-
# Optionally update the model.config to pass in other arguments
58-
stack.enter_context(
59-
patch_attr(
60-
module,
61-
"mlp",
62-
replace_Qwen3MoE(
63-
config=model.config,
64-
module=module.mlp,
65-
calibrate_all_experts=calibrate_all_experts,
66-
),
67-
)
55+
def update_qwen3_moe(model, module, stack, calibrate_all_experts):
56+
cls_name = module.__class__.__name__
57+
if (
58+
cls_name == "Qwen3MoeDecoderLayer"
59+
and module.mlp.__class__.__name__ == "Qwen3MoeSparseMoeBlock"
60+
):
61+
stack.enter_context(
62+
patch_attr(
63+
module,
64+
"mlp",
65+
replace_Qwen3MoE(
66+
config=model.config,
67+
module=module.mlp,
68+
calibrate_all_experts=calibrate_all_experts,
69+
),
6870
)
71+
)
72+
73+
74+
def update_qwen3_next_moe(model, module, stack, calibrate_all_experts):
75+
cls_name = module.__class__.__name__
76+
if (
77+
cls_name == "Qwen3NextDecoderLayer"
78+
and module.mlp.__class__.__name__ == "Qwen3NextSparseMoeBlock"
79+
):
80+
stack.enter_context(
81+
patch_attr(
82+
module,
83+
"mlp",
84+
replace_Qwen3NextMoE(
85+
config=model.config,
86+
module=module.mlp,
87+
calibrate_all_experts=calibrate_all_experts,
88+
),
89+
)
90+
)
6991

7092

7193
moe_context = {
7294
"Qwen3MoeForCausalLM": update_qwen3_moe,
7395
}
7496

97+
if replace_Qwen3NextMoE is not None:
98+
moe_context["Qwen3NextForCausalLM"] = update_qwen3_next_moe
99+
75100

76101
def moe_calibration_context(
77102
model: PreTrainedModel,
@@ -80,6 +105,7 @@ def moe_calibration_context(
80105
):
81106
# Temporarily updates the MoE modules within the context
82107
# Once the context exists, parameter updates persist
83-
cls_name = model.__class__.__name__
84-
if cls_name in moe_context:
85-
moe_context.get(cls_name)(model, stack, calibrate_all_experts)
108+
model_name = model.__class__.__name__
109+
if model_name in moe_context:
110+
for module in model.modules():
111+
moe_context[model_name](model, module, stack, calibrate_all_experts)

0 commit comments

Comments
 (0)