Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions examples/autoround/quantization_w4a4_fp4/README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ pip install -e .

## Quickstart

The example includes an end-to-end script for applying the AutoRound quantization algorithm.
The example includes end-to-end scripts for applying the AutoRound quantization algorithm.

### Llama 3.1 Example

```bash
python3 llama3.1_example.py
```

The resulting model `Meta-Llama-3.1-8B-Instruct-NVFP4-AutoRound` is ready to be loaded into vLLM.

### Evaluate Accuracy
#### Evaluate Accuracy

With the model created, we can now load and run in vLLM (after installing).

Expand All @@ -46,33 +48,68 @@ lm_eval --model vllm \
--batch_size 'auto'
```

#### meta-llama/Meta-Llama-3.1-8B-Instruct
##### meta-llama/Meta-Llama-3.1-8B-Instruct
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7710|± |0.0116|
| | |strict-match | 5|exact_match|↑ |0.7043|± |0.0126|

#### Meta-Llama-3.1-8B-Instruct-NVFP4 (QuantizationModifier)
##### Meta-Llama-3.1-8B-Instruct-NVFP4 (QuantizationModifier)
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7248|± |0.0123|
| | |strict-match | 5|exact_match|↑ |0.6611|± |0.0130|


#### Meta-Llama-3.1-8B-Instruct-NVFP4-AutoRound (AutoRoundModifier, iters=0)
##### Meta-Llama-3.1-8B-Instruct-NVFP4-AutoRound (AutoRoundModifier, iters=0)
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7362|± |0.0121|
| | |strict-match | 5|exact_match|↑ |0.6702|± |0.0129|

#### Meta-Llama-3.1-8B-Instruct-NVFP4-AutoRound (AutoRoundModifier, iters=200)
##### Meta-Llama-3.1-8B-Instruct-NVFP4-AutoRound (AutoRoundModifier, iters=200)
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7210|± |0.0124|
| | |strict-match | 5|exact_match|↑ |0.6945|± |0.0127|

> Note: quantized model accuracy may vary slightly due to nondeterminism.

### Qwen3-VL Example

```bash
python3 qwen3_vl_example.py
```

The resulting model `Qwen3-VL-8B-Instruct-NVFP4-AutoRound` is ready to be loaded into vLLM.

#### Evaluate Accuracy

Run the following to test accuracy on GSM-8K:

```bash
lm_eval --model vllm-vlm \
--model_args pretrained="./Qwen3-VL-8B-Instruct-NVFP4-AutoRound",add_bos_token=true \
--tasks gsm8k \
--num_fewshot 5 \
--batch_size 'auto'
```

##### Qwen3-VL-8B-Instruct (Baseline)
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8628|± |0.0095|
| | |strict-match | 5|exact_match|↑ |0.8453|± |0.0100|


##### Qwen3-VL-8B-Instruct-NVFP4-AutoRound (AutoRoundModifier, iters=200)
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8415|± |0.0101|
| | |strict-match | 5|exact_match|↑ |0.8408|± |0.0101|

> Note: quantized model accuracy may vary slightly due to nondeterminism.

### Questions or Feature Request?

Please open up an issue on [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) or [intel/auto-round](https://github.com/intel/auto-round).
60 changes: 60 additions & 0 deletions examples/autoround/quantization_w4a4_fp4/qwen3_vl_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from auto_round.calib_dataset import get_dataset
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier
from llmcompressor.utils import dispatch_for_generation

# Load model.
MODEL_ID = "Qwen/Qwen3-VL-8B-Instruct"
model = Qwen3VLForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(MODEL_ID)
tokenizer = processor.tokenizer

# Select calibration dataset.
NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 2048
# Get aligned calibration dataset.

ds = get_dataset(
tokenizer=tokenizer,
seqlen=MAX_SEQUENCE_LENGTH,
nsamples=NUM_CALIBRATION_SAMPLES,
)


# Configure the quantization algorithm to run.
# * quantize the weights to 4 bit with AutoRound with a group size 128
recipe = AutoRoundModifier(
targets="Linear",
scheme="NVFP4",
ignore=["re:.*lm_head", "re:.*visual.*"],
iters=200,
)

# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
# disable shuffling to get slightly better mmlu score
shuffle_calibration_samples=False,
)

print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
model.device
)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")


# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4-AutoRound"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)