|
| 1 | +# `AutoRound` Quantization |
| 2 | + |
| 3 | +`llm-compressor` supports [AutoRound](https://aclanthology.org/2024.findings-emnlp.662.pdf), an advanced quantization technique that delivers **high-accuracy**, **low-bit quantization**. The quantized results are fully compatible with `compressed-tensors` and can be served directly with vLLM. |
| 4 | + |
| 5 | +AutoRound introduces three trainable parameters (V, α, and β) to optimize rounding values and clipping ranges during quantization. The method processes each decoder layer sequentially, using block-wise output reconstruction error as the training objective to fine-tune these parameters. This approach combines the efficiency of post-training quantization with the adaptability of parameter tuning, delivering robust compression for large language models while maintaining strong performance. |
| 6 | + |
| 7 | +## Installation |
| 8 | + |
| 9 | +To get started, install: |
| 10 | + |
| 11 | +```bash |
| 12 | +git clone https://github.com/vllm-project/llm-compressor.git |
| 13 | +cd llm-compressor |
| 14 | +pip install -e . |
| 15 | +``` |
| 16 | + |
| 17 | +## Quickstart |
| 18 | + |
| 19 | +The example includes an end-to-end script for applying the AutoRound quantization algorithm. |
| 20 | + |
| 21 | +```bash |
| 22 | +python3 llama3_example.py |
| 23 | +``` |
| 24 | + |
| 25 | +The resulting model `Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound` is ready to be loaded into vLLM. |
| 26 | + |
| 27 | +## Code Walkthrough |
| 28 | + |
| 29 | +Now, we will step through the code in the example. There are four steps: |
| 30 | +1) Load model |
| 31 | +2) Prepare calibration data |
| 32 | +3) Apply quantization |
| 33 | +4) Evaluate accuracy in vLLM |
| 34 | + |
| 35 | +### 1) Load Model |
| 36 | + |
| 37 | +Load the model using `AutoModelForCausalLM` for handling quantized saving and loading. |
| 38 | + |
| 39 | +```python |
| 40 | +from transformers import AutoTokenizer, AutoModelForCausalLM |
| 41 | + |
| 42 | +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" |
| 43 | +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") |
| 44 | +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| 45 | +``` |
| 46 | + |
| 47 | +### 2) Prepare Calibration Data |
| 48 | + |
| 49 | +When quantizing model weights with AutoRound, you’ll need a small set of sample data to run the algorithm. By default, we are using [NeelNanda/pile-10k](https://huggingface.co/datasets/NeelNanda/pile-10k) as our calibration dataset. |
| 50 | +Recommended starting points: |
| 51 | +- 128 samples — typically sufficient for stable calibration (increase if accuracy degrades). |
| 52 | +- 2048 sequence length — a good baseline for most LLMs. |
| 53 | +- 200 tuning steps — usually enough to converge (increase if accuracy drops). |
| 54 | + |
| 55 | +```python |
| 56 | +# Select calibration dataset. |
| 57 | +from auto_round.calib_dataset import get_dataset |
| 58 | + |
| 59 | +NUM_CALIBRATION_SAMPLES = 128 |
| 60 | +MAX_SEQUENCE_LENGTH = 2048 |
| 61 | + |
| 62 | +# Get aligned calibration dataset. |
| 63 | +ds = get_dataset( |
| 64 | + tokenizer=tokenizer, |
| 65 | + seqlen=MAX_SEQUENCE_LENGTH, |
| 66 | + nsamples=NUM_CALIBRATION_SAMPLES, |
| 67 | +) |
| 68 | +``` |
| 69 | + |
| 70 | +### 3) Apply Quantization |
| 71 | + |
| 72 | +With the dataset ready, we will now apply AutoRound quantization to the model. |
| 73 | + |
| 74 | +```python |
| 75 | +from llmcompressor import oneshot |
| 76 | +from llmcompressor.modifiers.autoround import AutoRoundModifier |
| 77 | + |
| 78 | +# Configure the quantization algorithm to run. |
| 79 | +recipe = AutoRoundModifier( |
| 80 | + targets="Linear", scheme="W4A16", ignore=["lm_head"], iters=200 |
| 81 | +) |
| 82 | + |
| 83 | +# Apply quantization. |
| 84 | +oneshot( |
| 85 | + model=model, |
| 86 | + dataset=ds, |
| 87 | + recipe=recipe, |
| 88 | + max_seq_length=MAX_SEQUENCE_LENGTH, |
| 89 | + num_calibration_samples=NUM_CALIBRATION_SAMPLES, |
| 90 | + # disable shuffling to get slightly better mmlu score |
| 91 | + shuffle_calibration_samples=False, |
| 92 | +) |
| 93 | + |
| 94 | + |
| 95 | +# Save to disk compressed. |
| 96 | +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound" |
| 97 | +model.save_pretrained(SAVE_DIR, save_compressed=True) |
| 98 | +tokenizer.save_pretrained(SAVE_DIR) |
| 99 | +``` |
| 100 | + |
| 101 | +We have successfully created an `int4` model! |
| 102 | + |
| 103 | +### 4) Evaluate Accuracy |
| 104 | + |
| 105 | +With the model created, we can now load and run in vLLM (after installing). |
| 106 | + |
| 107 | +```python |
| 108 | +from vllm import LLM |
| 109 | +model = LLM("./Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound") |
| 110 | +``` |
| 111 | + |
| 112 | +We can evaluate accuracy with `lm_eval` (`pip install lm-eval==0.4.9.1`): |
| 113 | +> Note: quantized models can be sensitive to the presence of the `bos` token. `lm_eval` does not add a `bos` token by default, so make sure to include the `add_bos_token=True` argument when running your evaluations. |
| 114 | +
|
| 115 | +Run the following to test accuracy on GSM-8K: |
| 116 | + |
| 117 | +```bash |
| 118 | +lm_eval --model vllm \ |
| 119 | + --model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound",add_bos_token=true \ |
| 120 | + --tasks gsm8k \ |
| 121 | + --num_fewshot 5 \ |
| 122 | + --limit 1000 \ |
| 123 | + --batch_size 'auto' |
| 124 | +``` |
| 125 | + |
| 126 | +We can see the resulting scores look good! |
| 127 | + |
| 128 | +```bash |
| 129 | +| Tasks | Version | Filter | n-shot | Metric | | Value | | Stderr | |
| 130 | +| ----- | ------: | ---------------- | -----: | ----------- | --- | ----: | --- | -----: | |
| 131 | +| gsm8k | 3 | flexible-extract | 5 | exact_match | ↑ | 0.737 | ± | 0.0139 | |
| 132 | +| | | strict-match | 5 | exact_match | ↑ | 0.736 | ± | 0.0139 | |
| 133 | +``` |
| 134 | +> Note: quantized model accuracy may vary slightly due to nondeterminism. |
| 135 | +
|
| 136 | +### Known Issues |
| 137 | +Currently, `llm-compressor` supports applying AutoRound only on the `wNa16` quantization schemes. Support for additional schemes is planned. You can follow progress in the [RFC](https://github.com/vllm-project/llm-compressor/issues/1968). |
| 138 | + |
| 139 | +### Questions or Feature Request? |
| 140 | + |
| 141 | +Please open up an issue on [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) or [intel/auto-round](https://github.com/intel/auto-round). |
0 commit comments