|
4 | 4 |
|
5 | 5 | > `2:4 sparisty + int4/int8` mixed precision computation is supported in vLLM on Nvidia capability > 8.0 (Ampere, Ada Lovelace, Hopper).
|
6 | 6 |
|
| 7 | +## NOTE: |
| 8 | +Fine tuning can require more steps than is shown in the example. |
| 9 | +See the Axolotl integration blog post for best fine tuning practices |
| 10 | +https://developers.redhat.com/articles/2025/06/17/axolotl-meets-llm-compressor-fast-sparse-open |
| 11 | + |
| 12 | + |
7 | 13 | ## Installation
|
8 | 14 |
|
9 | 15 | To get started, install:
|
@@ -45,49 +51,117 @@ It contains instructions to prune the model to 2:4 sparsity, run one epoch of re
|
45 | 51 | and quantize to 4 bits in one show using GPTQ.
|
46 | 52 |
|
47 | 53 | ```python
|
| 54 | +from pathlib import Path |
| 55 | + |
48 | 56 | import torch
|
49 |
| -from transformers import AutoModelForCausalLM |
| 57 | +from loguru import logger |
| 58 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
50 | 59 |
|
| 60 | +from llmcompressor import oneshot, train |
| 61 | + |
| 62 | +# load the model in as bfloat16 to save on memory and compute |
51 | 63 | model_stub = "neuralmagic/Llama-2-7b-ultrachat200k"
|
52 | 64 | model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.bfloat16)
|
| 65 | +tokenizer = AutoTokenizer.from_pretrained(model_stub) |
53 | 66 |
|
| 67 | +# uses LLM Compressor's built-in preprocessing for ultra chat |
54 | 68 | dataset = "ultrachat-200k"
|
55 |
| -splits = {"calibration": "train_gen[:5%]", "train": "train_gen"} |
56 | 69 |
|
| 70 | +# Select the recipe for 2 of 4 sparsity and 4-bit activation quantization |
57 | 71 | recipe = "2of4_w4a16_recipe.yaml"
|
| 72 | + |
| 73 | +# save location of quantized model |
| 74 | +output_dir = "output_llama7b_2of4_w4a16_channel" |
| 75 | +output_path = Path(output_dir) |
| 76 | + |
| 77 | +# set dataset config parameters |
| 78 | +splits = {"calibration": "train_gen[:5%]", "train": "train_gen"} |
| 79 | +max_seq_length = 512 |
| 80 | +num_calibration_samples = 512 |
| 81 | + |
| 82 | +# set training parameters for finetuning |
| 83 | +# increase num_train_epochs for longer training |
| 84 | +num_train_epochs = 0.01 |
| 85 | +logging_steps = 500 |
| 86 | +save_steps = 5000 |
| 87 | +gradient_checkpointing = True # saves memory during training |
| 88 | +learning_rate = 0.0001 |
| 89 | +bf16 = False # using full precision for training |
| 90 | +lr_scheduler_type = "cosine" |
| 91 | +warmup_ratio = 0.1 |
| 92 | +preprocessing_num_workers = 8 |
58 | 93 | ```
|
59 | 94 |
|
60 |
| -## Step 2: Run sparsification using `apply` |
61 |
| -The `apply` function applies the given recipe to our model and dataset. |
62 |
| -The hardcoded kwargs may be altered based on each model's needs. |
63 |
| -After running, the sparsified model will be saved to `output_llama7b_2of4_w4a16_channel`. |
| 95 | +## Step 2: Run `sparsification`, `fine-tuning`, and `quantization` |
| 96 | +The compression process now runs in three stages: sparsification, fine-tuning, and quantization. |
| 97 | +Each stage saves the intermediate model outputs to the `output_llama7b_2of4_w4a16_channel` directory. |
64 | 98 |
|
65 | 99 | ```python
|
66 |
| -from llmcompressor.transformers import apply |
| 100 | +from llmcompressor import oneshot, train |
| 101 | +from pathlib import Path |
67 | 102 |
|
68 | 103 | output_dir = "output_llama7b_2of4_w4a16_channel"
|
| 104 | +output_path = Path(output_dir) |
69 | 105 |
|
70 |
| -apply( |
| 106 | +# 1. Oneshot sparsification: apply pruning |
| 107 | +oneshot( |
71 | 108 | model=model,
|
72 | 109 | dataset=dataset,
|
73 | 110 | recipe=recipe,
|
74 |
| - bf16=False, # use full precision for training |
| 111 | + splits=splits, |
| 112 | + num_calibration_samples=num_calibration_samples, |
| 113 | + preprocessing_num_workers=preprocessing_num_workers, |
| 114 | + output_dir=output_dir, |
| 115 | + stage="sparsity_stage", |
| 116 | +) |
| 117 | + |
| 118 | +# 2. Sparse fine-tuning: improve accuracy on pruned model |
| 119 | +train( |
| 120 | + model=output_path / "sparsity_stage", |
| 121 | + dataset=dataset, |
| 122 | + recipe=recipe, |
| 123 | + splits=splits, |
| 124 | + num_calibration_samples=num_calibration_samples, |
| 125 | + preprocessing_num_workers=preprocessing_num_workers, |
| 126 | + bf16=bf16, |
| 127 | + max_seq_length=max_seq_length, |
| 128 | + num_train_epochs=num_train_epochs, |
| 129 | + logging_steps=logging_steps, |
| 130 | + save_steps=save_steps, |
| 131 | + gradient_checkpointing=gradient_checkpointing, |
| 132 | + learning_rate=learning_rate, |
| 133 | + lr_scheduler_type=lr_scheduler_type, |
| 134 | + warmup_ratio=warmup_ratio, |
75 | 135 | output_dir=output_dir,
|
| 136 | + stage="finetuning_stage", |
| 137 | +) |
| 138 | + |
| 139 | +# 3. Oneshot quantization: compress model weights to lower precision |
| 140 | +quantized_model = oneshot( |
| 141 | + model=output_path / "finetuning_stage", |
| 142 | + dataset=dataset, |
| 143 | + recipe=recipe, |
76 | 144 | splits=splits,
|
77 |
| - max_seq_length=512, |
78 |
| - num_calibration_samples=512, |
79 |
| - num_train_epochs=0.5, |
80 |
| - logging_steps=500, |
81 |
| - save_steps=5000, |
82 |
| - gradient_checkpointing=True, |
83 |
| - learning_rate=0.0001, |
84 |
| - lr_scheduler_type="cosine", |
85 |
| - warmup_ratio=0.1, |
| 145 | + num_calibration_samples=num_calibration_samples, |
| 146 | + preprocessing_num_workers=preprocessing_num_workers, |
| 147 | + output_dir=output_dir, |
| 148 | + stage="quantization_stage", |
86 | 149 | )
|
| 150 | +# skip_sparsity_compression_stats is set to False |
| 151 | +# to account for sparsity in the model when compressing |
| 152 | +quantized_model.save_pretrained( |
| 153 | + f"{output_dir}/quantization_stage", skip_sparsity_compression_stats=False |
| 154 | +) |
| 155 | +tokenizer.save_pretrained(f"{output_dir}/quantization_stage") |
87 | 156 |
|
88 | 157 | ```
|
89 | 158 |
|
90 | 159 | ### Custom Quantization
|
91 |
| -The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are `tensor`, `group` and `channel`. |
92 |
| -The above recipe (`2of4_w4a16_recipe.yaml`) uses channel-wise quantization specified by `strategy: "channel"` in its config group. |
93 |
| -To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. A group size quantization example is shown in `2of4_w4a16_group-128_recipe.yaml`. |
| 160 | +The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are tensor, group, and channel. |
| 161 | + |
| 162 | +The recipe (`2of4_w4a16_recipe.yaml`) uses channel-wise quantization (`strategy: "channel"`). |
| 163 | +To change the quantization strategy, edit the recipe file accordingly: |
| 164 | + |
| 165 | +Use `tensor` for per-tensor quantization |
| 166 | +Use `group` for group-wise quantization and specify the group_size parameter (e.g., 128) |
| 167 | +See `2of4_w4a16_group-128_recipe.yaml` for a group-size example |
0 commit comments