|
| 1 | +# `fp8` Weight and Activation Quantization for Granite 4 |
| 2 | + |
| 3 | +`llmcompressor` supports quantizing weights and activations to `fp8` for memory savings and inference acceleration with `vllm` |
| 4 | + |
| 5 | +For Granite 4, in addition to typical `nn.Linear` layers in `mamba` or `mlp` modules, there are three "Linear-like" layers in `GraniteMoeHybridMoe` (`moe` module) that could be quantized as well. Among the three layers, usually `router` should be kept in high precision for accuracy reason. Therefore, users could choose to quantize the other two layers, `input_linear` and `output_linear`, for better model compression. |
| 6 | + |
| 7 | +Note that input_linear and output_linear are `GraniteMoeHybridParallelExperts`, which subclasses `nn.Module` instead of `nn.Linear`, for it needs to store weights in 3D, i.e. [num_experts, out_feat, in_feat]. Because llm-compressor can only handle `nn.Linear` at the moment, our simple workaround would be: |
| 8 | +1. **Swap `GraniteMoeHybridParallelExperts` with `GraniteMoeHybridParallelExpertsLinear`** |
| 9 | + |
| 10 | + The custom class is equivalent to the original one, except it subclasses nn.Linear and stores 2D weights. Moe expert weight tensors will be converted from 3D to 2D, i.e. from [num_experts, out_feat, in_feat] to [num_experts * out_feat, in_feat]. |
| 11 | +2. **Perform dynamic fp8 quantization** |
| 12 | + |
| 13 | + The new class is compatible with typical per-channel weight quantization, llm-compressor will be able to identify those layers and process them normally. The resulting scales will have shape of [num_experts * out_feat, 1] |
| 14 | +3. **Reshape weights and scales back to 3D before saving the checkpoint** |
| 15 | + |
| 16 | +> `fp8` compuation is supported on Nvidia GPUs with compute capability > 8.9 (Ada Lovelace, Hopper). |
| 17 | +
|
| 18 | +## Installation |
| 19 | + |
| 20 | +To get started, install: |
| 21 | + |
| 22 | +```bash |
| 23 | +pip install llmcompressor |
| 24 | +``` |
| 25 | + |
| 26 | +This checkpoint format will need the latest vllm (ver >= 0.10.1.1) to run correctly. Additional dependencies and environment variables needed are: |
| 27 | +1. Dependencies: `vllm>=0.10.1.1, lm_eval>=0.4.9.1, flash-attn=2.7.3, torch>=2.7.1` |
| 28 | +2. ENV VAR: `VLLM_USE_V1=0, VLLM_WORKER_MULTIPROC_METHOD=spawn` |
| 29 | + |
| 30 | +## Quickstart |
| 31 | + |
| 32 | +`granite4_example.py` demonstrates the quantization of `mamba`, `mlp`, and those |
| 33 | +"Linear-like" input/output layers with minimal changes to `llm-compressor`. |
| 34 | + |
| 35 | + |
| 36 | +```bash |
| 37 | +python3 granite4_example.py |
| 38 | +``` |
| 39 | + |
| 40 | +The resulting model `ibm-granite-4-tiny-fp8-dynamic-skipMoeRouter` is ready to be loaded into vLLM. |
| 41 | + |
| 42 | +## Code Walkthough |
| 43 | + |
| 44 | +Now, we will step though the code in the example. There are three steps: |
| 45 | +1) Load model |
| 46 | +2) Apply quantization |
| 47 | +3) Evaluate accuracy in vLLM |
| 48 | + |
| 49 | +### 1) Load Model |
| 50 | + |
| 51 | +Load the model using `AutoModelForCausalLM` |
| 52 | + |
| 53 | +```python |
| 54 | +from transformers import AutoTokenizer, AutoModelForCausalLM |
| 55 | + |
| 56 | +MODEL_ID = "ibm-granite/granite-4.0-tiny-preview" |
| 57 | + |
| 58 | +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") |
| 59 | +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| 60 | +``` |
| 61 | + |
| 62 | +### 2) Apply Quantization |
| 63 | + |
| 64 | +We recommend targeting all `Linear` layers using the `FP8_DYNAMIC` scheme, which uses: |
| 65 | +- Static, per-channel quantization on the weights |
| 66 | +- Dynamic, per-token quantization on the activations |
| 67 | + |
| 68 | +Since simple PTQ does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow. |
| 69 | + |
| 70 | +Note that we replace the 3D moe expert layers with their 2D equivalent counterpart before quantization and convert them back to 3D before model saving. |
| 71 | + |
| 72 | +```python |
| 73 | +from compressed_tensors.utils import replace_module |
| 74 | +from llmcompressor import oneshot |
| 75 | +from llmcompressor.modifiers.quantization import QuantizationModifier |
| 76 | + |
| 77 | +skip_router_only = True # assume we want to quantize input/output moe layers |
| 78 | + |
| 79 | +ignore_lay = ["lm_head",] |
| 80 | +if skip_router_only: |
| 81 | + # swap moe linears to a custom class |
| 82 | + for n, m in model.named_modules(): |
| 83 | + if isinstance(m, GraniteMoeHybridParallelExperts): |
| 84 | + new_mod = GraniteMoeHybridParallelExpertsLinear.from_3d_expert(m) |
| 85 | + replace_module(model, n, new_mod) |
| 86 | + ignore_lay += ["re:.*block_sparse_moe.router"] |
| 87 | + SAVE_DIR = "ibm-granite-4-tiny-fp8-dynamic-skipMoeRouter" |
| 88 | + |
| 89 | +# Configure the simple PTQ quantization |
| 90 | +recipe = QuantizationModifier( |
| 91 | + targets=["Linear", "GraniteMoeHybridParallelExpertsLinear"], |
| 92 | + scheme="FP8_DYNAMIC", |
| 93 | + ignore=ignore_lay, |
| 94 | +) |
| 95 | + |
| 96 | +# Apply the quantization algorithm. |
| 97 | +oneshot(model=model, recipe=recipe) |
| 98 | + |
| 99 | +# Revert weights of MoE experts to 3D format (num_experts, output_size, input_size) |
| 100 | +for n, m in model.named_modules(): |
| 101 | + if isinstance(m, GraniteMoeHybridParallelExpertsLinear): |
| 102 | + m.to_3d_expert() |
| 103 | + |
| 104 | +# Save the model. |
| 105 | +model.save_pretrained(SAVE_DIR) |
| 106 | +tokenizer.save_pretrained(SAVE_DIR) |
| 107 | +``` |
| 108 | + |
| 109 | +We have successfully created an `fp8` model! |
| 110 | + |
| 111 | +### 3) Evaluate Accuracy |
| 112 | + |
| 113 | +Install `vllm` and `lm-evaluation-harness`: |
| 114 | + |
| 115 | +```bash |
| 116 | +pip install vllm lm_eval |
| 117 | +``` |
| 118 | + |
| 119 | +Load and run the model in `vllm` and evaluate accuracy with `lm_eval` on `gsm8k`: |
| 120 | + |
| 121 | +1. **Base model** |
| 122 | +```bash |
| 123 | +export MODEL=ibm-granite/granite-4.0-tiny-preview |
| 124 | +export OPT_FLAGS=tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,enable_prefix_caching=False,max_model_len=8192 |
| 125 | +lm_eval --model vllm \ |
| 126 | + --model_args pretrained=$MODEL,$OPT_FLAGS,add_bos_token=True \ |
| 127 | + --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k |
| 128 | +``` |
| 129 | +> Note: quantized models can be sensitive to the presence of the `bos` token. `lm_eval` does not add a `bos` token by default, so make sure to include the `add_bos_token=True` argument when running your evaluations. |
| 130 | +
|
| 131 | + |
| 132 | +|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr| |
| 133 | +|-----|------:|----------------|-----:|-----------|---|----:|---|-----:| |
| 134 | +|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.602|± |0.0135| |
| 135 | +| | |strict-match | 5|exact_match|↑ |0.583|± |0.0136| |
| 136 | + |
| 137 | +2. **FP8 model** |
| 138 | +```bash |
| 139 | +export MODEL=$PWD/ibm-granite-4-tiny-fp8-dynamic-skipMoeRouter |
| 140 | +lm_eval --model vllm \ |
| 141 | + --model_args pretrained=$MODEL,$OPT_FLAGS,add_bos_token=True \ |
| 142 | + --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k |
| 143 | +``` |
| 144 | + |
| 145 | +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |
| 146 | +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |
| 147 | +|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.6164|± |0.0134| |
| 148 | +| | |strict-match | 5|exact_match|↑ |0.5974|± |0.0135| |
| 149 | + |
| 150 | +We can see the resulting FP8 model look comparable with (and sometimes slightly better than) the baseline. |
| 151 | + |
| 152 | +> NOTE: If running with hf instead of vllm, such as the command below, there will be an error |
| 153 | +related to the `weight_scale` when the FP8 ckpt is being used. |
| 154 | +`lm_eval --model hf --model_args pretrained=$MODEL --batch_size 16 --trust_remote_code --tasks gsm8k` |
| 155 | + |
| 156 | + |
| 157 | +### Questions or Feature Request? |
| 158 | + |
| 159 | +Please open up an issue on `vllm-project/llm-compressor` |
0 commit comments