Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ Some of the exciting new features include:
* **Distributed GPTQ Support**: GPTQ now supports Distributed Data Parallel (DDP) functionality to significantly improve calibration runtime. An example using DDP with GPTQ can be found [here](examples/quantization_w4a16/llama3_ddp_example.py).
* **Updated FP4 Microscale Support**: GPTQ now supports FP4 quantization schemes, including both [MXFP4](examples/quantization_w4a16_fp4/mxfp4/llama3_example.py) and [NVFP4](examples/quantization_w4a4_fp4/llama3_gptq_example.py). MXFP4 support has also been improved with updated weight scale generation. Models with weight-only quantization in the MXFP4 format can now run in vLLM as of vLLM v0.14.0. MXFP4 models with activation quantization are not yet supported in vLLM for compressed-tensors models
* **New Model-Free PTQ Pathway**: A new model-free PTQ pathway has been added to LLM Compressor, called [`model_free_ptq`](src/llmcompressor/entrypoints/model_free/__init__.py#L36). This pathway allows you to quantize your model without the requirement of Hugging Face model definition and is especially useful in cases where `oneshot` may fail. This pathway is currently supported for data-free pathways only i.e FP8 quantization and was leveraged to quantize the [Mistral Large 3 model](https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512). Additional [examples](examples/model_free_ptq) have been added illustrating how LLM Compressor can be used for Kimi K2
* **MXFP8 Microscale Support (Experimental)**: LLM Compressor now supports MXFP8 quantization for both weights and activations via PTQ. An example script can be found in the [experimental folder](experimental/mxfp8/qwen3_example_w8a8_mxfp8.py)
* **Extended KV Cache and Attention Quantization Support**: LLM Compressor now supports attention quantization. KV Cache quantization, which previously only supported per-tensor scales, has been extended to support any quantization scheme including a new `per-head` quantization scheme. Support for these checkpoints is on-going in vLLM and scripts to get started have been added to the [experimental folder](experimental/attention)


### Supported Formats
* Activation Quantization: W8A8 (int8 and fp8)
* Activation Quantization: W8A8 (int8 and fp8), MXFP8 (experimental)
* Mixed Precision: W4A16, W8A16, NVFP4 (W4A4 and W4A16 support)
* 2:4 Semi-structured and Unstructured Sparsity

Expand Down Expand Up @@ -75,6 +76,7 @@ pip install llmcompressor
Applying quantization with `llmcompressor`:
* [Activation quantization to `int8`](examples/quantization_w8a8_int8/README.md)
* [Activation quantization to `fp8`](examples/quantization_w8a8_fp8/README.md)
* [Activation quantization to MXFP8 (experimental)](experimental/mxfp8/qwen3_example_w8a8_mxfp8.py)
* [Activation quantization to `fp4`](examples/quantization_w4a4_fp4/llama3_example.py)
* [Activation quantization to `fp4` using AutoRound](examples/autoround/quantization_w4a4_fp4/README.md)
* [Activation quantization to `fp8` and weight quantization to `int4`](examples/quantization_w4a8_fp8/)
Expand Down
37 changes: 37 additions & 0 deletions experimental/mxfp8/qwen3_example_w8a8_mxfp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from compressed_tensors.offload import dispatch_model
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier

MODEL_ID = "Qwen/Qwen3-8B"

# Load model.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights and activations to mxfp8 via ptq
recipe = QuantizationModifier(
targets="Linear", scheme="MXFP8", ignore=["lm_head"]
)

# Apply quantization.
oneshot(model=model, recipe=recipe)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_model(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
model.device
)
output = model.generate(input_ids, max_new_tokens=20)
print(tokenizer.decode(output[0]))
print("==========================================")

# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-MXFP8"
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

Loading