Skip to content

Commit fdd376c

Browse files
authored
Merge branch 'main' into add_qwen35_docs
2 parents 6a02f09 + 026c917 commit fdd376c

File tree

3 files changed

+79
-2
lines changed

3 files changed

+79
-2
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@ Some of the exciting new features include:
4141
* **Distributed GPTQ Support**: GPTQ now supports Distributed Data Parallel (DDP) functionality to significantly improve calibration runtime. An example using DDP with GPTQ can be found [here](examples/quantization_w4a16/llama3_ddp_example.py).
4242
* **Updated FP4 Microscale Support**: GPTQ now supports FP4 quantization schemes, including both [MXFP4](examples/quantization_w4a16_fp4/mxfp4/llama3_example.py) and [NVFP4](examples/quantization_w4a4_fp4/llama3_gptq_example.py). MXFP4 support has also been improved with updated weight scale generation. Models with weight-only quantization in the MXFP4 format can now run in vLLM as of vLLM v0.14.0. MXFP4 models with activation quantization are not yet supported in vLLM for compressed-tensors models
4343
* **New Model-Free PTQ Pathway**: A new model-free PTQ pathway has been added to LLM Compressor, called [`model_free_ptq`](src/llmcompressor/entrypoints/model_free/__init__.py#L36). This pathway allows you to quantize your model without the requirement of Hugging Face model definition and is especially useful in cases where `oneshot` may fail. This pathway is currently supported for data-free pathways only i.e FP8 quantization and was leveraged to quantize the [Mistral Large 3 model](https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512). Additional [examples](examples/model_free_ptq) have been added illustrating how LLM Compressor can be used for Kimi K2
44+
* **MXFP8 Microscale Support (Experimental)**: LLM Compressor now supports MXFP8 quantization via PTQ. Both W8A8 ([MXFP8](experimental/mxfp8/qwen3_example_w8a8_mxfp8.py)) and W8A16 weight-only ([MXFP8A16](experimental/mxfp8/qwen3_example_w8a16_mxfp8.py)) modes are available.
4445
* **Extended KV Cache and Attention Quantization Support**: LLM Compressor now supports attention quantization. KV Cache quantization, which previously only supported per-tensor scales, has been extended to support any quantization scheme including a new `per-head` quantization scheme. Support for these checkpoints is on-going in vLLM and scripts to get started have been added to the [experimental folder](experimental/attention)
4546

4647

4748
### Supported Formats
48-
* Activation Quantization: W8A8 (int8 and fp8)
49-
* Mixed Precision: W4A16, W8A16, NVFP4 (W4A4 and W4A16 support)
49+
* Activation Quantization: W8A8 (int8 and fp8), MXFP8 (experimental)
50+
* Mixed Precision: W4A16, W8A16, MXFP8A16 (experimental), NVFP4 (W4A4 and W4A16 support)
5051
* 2:4 Semi-structured and Unstructured Sparsity
5152

5253
### Supported Algorithms
@@ -75,6 +76,8 @@ pip install llmcompressor
7576
Applying quantization with `llmcompressor`:
7677
* [Activation quantization to `int8`](examples/quantization_w8a8_int8/README.md)
7778
* [Activation quantization to `fp8`](examples/quantization_w8a8_fp8/README.md)
79+
* [Activation quantization to MXFP8 (experimental)](experimental/mxfp8/qwen3_example_w8a8_mxfp8.py)
80+
* [Weight-only quantization to MXFP8A16 (experimental)](experimental/mxfp8/qwen3_example_w8a16_mxfp8.py)
7881
* [Activation quantization to `fp4`](examples/quantization_w4a4_fp4/llama3_example.py)
7982
* [Activation quantization to `fp4` using AutoRound](examples/autoround/quantization_w4a4_fp4/README.md)
8083
* [Activation quantization to `fp8` and weight quantization to `int4`](examples/quantization_w4a8_fp8/)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
7+
MODEL_ID = "Qwen/Qwen3-8B"
8+
9+
# Load model.
10+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto")
11+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12+
13+
# Configure the quantization algorithm and scheme.
14+
# In this case, we:
15+
# * quantize the weights to mxfp8 via ptq
16+
recipe = QuantizationModifier(
17+
targets="Linear", scheme="MXFP8A16", ignore=["lm_head"]
18+
)
19+
20+
# Apply quantization.
21+
oneshot(model=model, recipe=recipe)
22+
23+
# Confirm generations of the quantized model look sane.
24+
print("========== SAMPLE GENERATION ==============")
25+
dispatch_model(model)
26+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
27+
model.device
28+
)
29+
output = model.generate(input_ids, max_new_tokens=20)
30+
print(tokenizer.decode(output[0]))
31+
print("==========================================")
32+
33+
# Save to disk in compressed-tensors format.
34+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-MXFP8A16"
35+
model.save_pretrained(SAVE_DIR)
36+
tokenizer.save_pretrained(SAVE_DIR)
37+
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
7+
MODEL_ID = "Qwen/Qwen3-8B"
8+
9+
# Load model.
10+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto")
11+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12+
13+
# Configure the quantization algorithm and scheme.
14+
# In this case, we:
15+
# * quantize the weights and activations to mxfp8 via ptq
16+
recipe = QuantizationModifier(
17+
targets="Linear", scheme="MXFP8", ignore=["lm_head"]
18+
)
19+
20+
# Apply quantization.
21+
oneshot(model=model, recipe=recipe)
22+
23+
# Confirm generations of the quantized model look sane.
24+
print("========== SAMPLE GENERATION ==============")
25+
dispatch_model(model)
26+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
27+
model.device
28+
)
29+
output = model.generate(input_ids, max_new_tokens=20)
30+
print(tokenizer.decode(output[0]))
31+
print("==========================================")
32+
33+
# Save to disk in compressed-tensors format.
34+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-MXFP8"
35+
model.save_pretrained(SAVE_DIR)
36+
tokenizer.save_pretrained(SAVE_DIR)
37+

0 commit comments

Comments
 (0)