Skip to content

Commit 84575f2

Browse files
chichun-charlie-liuandrea-fasolidsikka
authored
Add FP8 quantization example for Granite4 (#1814)
SUMMARY: Create an example for Granite4 FP8 quantization. Mainly to handle the two "Linear-like" layers in MOE block, which llm-compressor had problem to identify and quantize. TEST PLAN: 1. This example was tested with no errors. 2. The resulting checkpoint was tested with vllm as well. See details in `granite4_example.py`'s docstring. 3. new codes were formatted as suggested by `make quality` 4. Only 3 new files were added, no other files were changed or impacted. --------- Signed-off-by: cliu-us <[email protected]> Signed-off-by: Andrea Fasoli <[email protected]> Co-authored-by: Andrea Fasoli <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent b78b052 commit 84575f2

File tree

3 files changed

+314
-0
lines changed

3 files changed

+314
-0
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# `fp8` Weight and Activation Quantization for Granite 4
2+
3+
`llmcompressor` supports quantizing weights and activations to `fp8` for memory savings and inference acceleration with `vllm`
4+
5+
For Granite 4, in addition to typical `nn.Linear` layers in `mamba` or `mlp` modules, there are three "Linear-like" layers in `GraniteMoeHybridMoe` (`moe` module) that could be quantized as well. Among the three layers, usually `router` should be kept in high precision for accuracy reason. Therefore, users could choose to quantize the other two layers, `input_linear` and `output_linear`, for better model compression.
6+
7+
Note that input_linear and output_linear are `GraniteMoeHybridParallelExperts`, which subclasses `nn.Module` instead of `nn.Linear`, for it needs to store weights in 3D, i.e. [num_experts, out_feat, in_feat]. Because llm-compressor can only handle `nn.Linear` at the moment, our simple workaround would be:
8+
1. **Swap `GraniteMoeHybridParallelExperts` with `GraniteMoeHybridParallelExpertsLinear`**
9+
10+
The custom class is equivalent to the original one, except it subclasses nn.Linear and stores 2D weights. Moe expert weight tensors will be converted from 3D to 2D, i.e. from [num_experts, out_feat, in_feat] to [num_experts * out_feat, in_feat].
11+
2. **Perform dynamic fp8 quantization**
12+
13+
The new class is compatible with typical per-channel weight quantization, llm-compressor will be able to identify those layers and process them normally. The resulting scales will have shape of [num_experts * out_feat, 1]
14+
3. **Reshape weights and scales back to 3D before saving the checkpoint**
15+
16+
> `fp8` compuation is supported on Nvidia GPUs with compute capability > 8.9 (Ada Lovelace, Hopper).
17+
18+
## Installation
19+
20+
To get started, install:
21+
22+
```bash
23+
pip install llmcompressor
24+
```
25+
26+
This checkpoint format will need the latest vllm (ver >= 0.10.1.1) to run correctly. Additional dependencies and environment variables needed are:
27+
1. Dependencies: `vllm>=0.10.1.1, lm_eval>=0.4.9.1, flash-attn=2.7.3, torch>=2.7.1`
28+
2. ENV VAR: `VLLM_USE_V1=0, VLLM_WORKER_MULTIPROC_METHOD=spawn`
29+
30+
## Quickstart
31+
32+
`granite4_example.py` demonstrates the quantization of `mamba`, `mlp`, and those
33+
"Linear-like" input/output layers with minimal changes to `llm-compressor`.
34+
35+
36+
```bash
37+
python3 granite4_example.py
38+
```
39+
40+
The resulting model `ibm-granite-4-tiny-fp8-dynamic-skipMoeRouter` is ready to be loaded into vLLM.
41+
42+
## Code Walkthough
43+
44+
Now, we will step though the code in the example. There are three steps:
45+
1) Load model
46+
2) Apply quantization
47+
3) Evaluate accuracy in vLLM
48+
49+
### 1) Load Model
50+
51+
Load the model using `AutoModelForCausalLM`
52+
53+
```python
54+
from transformers import AutoTokenizer, AutoModelForCausalLM
55+
56+
MODEL_ID = "ibm-granite/granite-4.0-tiny-preview"
57+
58+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
59+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
60+
```
61+
62+
### 2) Apply Quantization
63+
64+
We recommend targeting all `Linear` layers using the `FP8_DYNAMIC` scheme, which uses:
65+
- Static, per-channel quantization on the weights
66+
- Dynamic, per-token quantization on the activations
67+
68+
Since simple PTQ does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow.
69+
70+
Note that we replace the 3D moe expert layers with their 2D equivalent counterpart before quantization and convert them back to 3D before model saving.
71+
72+
```python
73+
from compressed_tensors.utils import replace_module
74+
from llmcompressor import oneshot
75+
from llmcompressor.modifiers.quantization import QuantizationModifier
76+
77+
skip_router_only = True # assume we want to quantize input/output moe layers
78+
79+
ignore_lay = ["lm_head",]
80+
if skip_router_only:
81+
# swap moe linears to a custom class
82+
for n, m in model.named_modules():
83+
if isinstance(m, GraniteMoeHybridParallelExperts):
84+
new_mod = GraniteMoeHybridParallelExpertsLinear.from_3d_expert(m)
85+
replace_module(model, n, new_mod)
86+
ignore_lay += ["re:.*block_sparse_moe.router"]
87+
SAVE_DIR = "ibm-granite-4-tiny-fp8-dynamic-skipMoeRouter"
88+
89+
# Configure the simple PTQ quantization
90+
recipe = QuantizationModifier(
91+
targets=["Linear", "GraniteMoeHybridParallelExpertsLinear"],
92+
scheme="FP8_DYNAMIC",
93+
ignore=ignore_lay,
94+
)
95+
96+
# Apply the quantization algorithm.
97+
oneshot(model=model, recipe=recipe)
98+
99+
# Revert weights of MoE experts to 3D format (num_experts, output_size, input_size)
100+
for n, m in model.named_modules():
101+
if isinstance(m, GraniteMoeHybridParallelExpertsLinear):
102+
m.to_3d_expert()
103+
104+
# Save the model.
105+
model.save_pretrained(SAVE_DIR)
106+
tokenizer.save_pretrained(SAVE_DIR)
107+
```
108+
109+
We have successfully created an `fp8` model!
110+
111+
### 3) Evaluate Accuracy
112+
113+
Install `vllm` and `lm-evaluation-harness`:
114+
115+
```bash
116+
pip install vllm lm_eval
117+
```
118+
119+
Load and run the model in `vllm` and evaluate accuracy with `lm_eval` on `gsm8k`:
120+
121+
1. **Base model**
122+
```bash
123+
export MODEL=ibm-granite/granite-4.0-tiny-preview
124+
export OPT_FLAGS=tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,enable_prefix_caching=False,max_model_len=8192
125+
lm_eval --model vllm \
126+
--model_args pretrained=$MODEL,$OPT_FLAGS,add_bos_token=True \
127+
--batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k
128+
```
129+
> Note: quantized models can be sensitive to the presence of the `bos` token. `lm_eval` does not add a `bos` token by default, so make sure to include the `add_bos_token=True` argument when running your evaluations.
130+
131+
132+
|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr|
133+
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
134+
|gsm8k| 3|flexible-extract| 5|exact_match||0.602|± |0.0135|
135+
| | |strict-match | 5|exact_match||0.583|± |0.0136|
136+
137+
2. **FP8 model**
138+
```bash
139+
export MODEL=$PWD/ibm-granite-4-tiny-fp8-dynamic-skipMoeRouter
140+
lm_eval --model vllm \
141+
--model_args pretrained=$MODEL,$OPT_FLAGS,add_bos_token=True \
142+
--batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k
143+
```
144+
145+
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
146+
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
147+
|gsm8k| 3|flexible-extract| 5|exact_match||0.6164|± |0.0134|
148+
| | |strict-match | 5|exact_match||0.5974|± |0.0135|
149+
150+
We can see the resulting FP8 model look comparable with (and sometimes slightly better than) the baseline.
151+
152+
> NOTE: If running with hf instead of vllm, such as the command below, there will be an error
153+
related to the `weight_scale` when the FP8 ckpt is being used.
154+
`lm_eval --model hf --model_args pretrained=$MODEL --batch_size 16 --trust_remote_code --tasks gsm8k`
155+
156+
157+
### Questions or Feature Request?
158+
159+
Please open up an issue on `vllm-project/llm-compressor`
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from compressed_tensors.utils import replace_module
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
4+
GraniteMoeHybridParallelExperts,
5+
)
6+
7+
from llmcompressor import oneshot
8+
from llmcompressor.modeling.granite4 import GraniteMoeHybridParallelExpertsLinear
9+
from llmcompressor.modifiers.quantization import QuantizationModifier
10+
from llmcompressor.utils import dispatch_for_generation
11+
12+
"""Please see details in `README_granite4.md`."""
13+
14+
MODEL_ID = "ibm-granite/granite-4.0-tiny-preview"
15+
16+
# Load model.
17+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
18+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
19+
20+
skip_router_only = True # assume we want to quantize input/output moe layers
21+
ignore_lay = [
22+
"lm_head",
23+
]
24+
if skip_router_only:
25+
# swap moe linears to a custom class
26+
for n, m in model.named_modules():
27+
if isinstance(m, GraniteMoeHybridParallelExperts):
28+
new_mod = GraniteMoeHybridParallelExpertsLinear.from_3d_expert(m)
29+
replace_module(model, n, new_mod)
30+
ignore_lay += ["re:.*block_sparse_moe.router"]
31+
SAVE_DIR = "ibm-granite-4-tiny-fp8-dynamic-skipMoeRouter"
32+
else:
33+
# Skip all .input_linear, .output-linear, and router layers.
34+
ignore_lay += ["re:.*block_sparse_moe"]
35+
SAVE_DIR = "ibm-granite-4-tiny-fp8-dynamic-skipMoe"
36+
37+
recipe = QuantizationModifier(
38+
targets=["Linear", "GraniteMoeHybridParallelExpertsLinear"],
39+
scheme="FP8_DYNAMIC",
40+
ignore=ignore_lay,
41+
)
42+
43+
# Apply quantization.
44+
oneshot(model=model, recipe=recipe)
45+
46+
# Confirm generations of the quantized model look sane.
47+
print("========== SAMPLE GENERATION ==============")
48+
dispatch_for_generation(model)
49+
input_ids = tokenizer(
50+
"What is your favorite TV show?", return_tensors="pt"
51+
).input_ids.to("cuda")
52+
output = model.generate(input_ids, max_new_tokens=20)
53+
print(tokenizer.decode(output[0]))
54+
print("==========================================")
55+
56+
# Revert weights of MoE experts to 3D format (num_experts, output_size, input_size)
57+
for n, m in model.named_modules():
58+
if isinstance(m, GraniteMoeHybridParallelExpertsLinear):
59+
# NOTE: can assert type != "meta" instead, which is sign of offloading
60+
assert m.weight.device.type == "cuda", (
61+
"Found some offloaded weights. This is not compatible with reshaping "
62+
"experts to 3D prior model save. Ensure the model is fully on cuda."
63+
)
64+
m.to_3d_expert()
65+
66+
model.save_pretrained(SAVE_DIR)
67+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
3+
GraniteMoeHybridParallelExperts,
4+
)
5+
6+
7+
class GraniteMoeHybridParallelExpertsLinear(torch.nn.Linear):
8+
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
9+
"""Use a real Linear so that llmcompressor and vllm can handle it easier.
10+
1. Change .weight from 3D [num_experts, output_size, input_size] to 2D
11+
[num_experts * output_size, input_size] before calling llm-compressor
12+
2. Change it back to 3D before saving ckpt
13+
"""
14+
super().__init__(
15+
input_size, output_size * num_experts, bias=False, device="meta"
16+
)
17+
self.num_experts = num_experts
18+
self.input_size = input_size
19+
self.output_size = output_size
20+
self.is_2d: bool = True
21+
22+
@classmethod
23+
def from_3d_expert(cls, original: GraniteMoeHybridParallelExperts):
24+
"""Reshape weights of GraniteMoeHybridParallelExperts module into 2D and store
25+
them as weights of this "Linear" module.
26+
"""
27+
newMoeLin = cls(original.num_experts, original.input_size, original.output_size)
28+
newMoeLin.weight = torch.nn.Parameter(
29+
original.weight.view(-1, original.input_size).clone(),
30+
requires_grad=False,
31+
)
32+
original.to("cpu")
33+
newMoeLin.is_2d = True
34+
return newMoeLin
35+
36+
def to_3d_expert(self) -> None:
37+
"""Convert weights and quantization parameters from 2D to 3D shape."""
38+
dim0_mul = self.num_experts * self.output_size
39+
assert (
40+
self.weight.shape == torch.Size((dim0_mul, self.input_size))
41+
and hasattr(self, "weight_scale")
42+
and self.weight_scale.shape == torch.Size((dim0_mul, 1))
43+
), "Shape mismatch, please check."
44+
45+
self.weight = torch.nn.Parameter(
46+
self.weight.view(
47+
self.num_experts, self.output_size, self.input_size
48+
).clone(),
49+
requires_grad=False,
50+
)
51+
self.weight_scale = torch.nn.Parameter(
52+
self.weight_scale.view(self.num_experts, self.output_size, 1).clone(),
53+
requires_grad=False,
54+
)
55+
if hasattr(self, "weight_zero_point"):
56+
assert self.weight_zero_point.shape == torch.Size((dim0_mul, 1))
57+
self.weight_zero_point = torch.nn.Parameter(
58+
self.weight_zero_point.view(
59+
self.num_experts, self.output_size, 1
60+
).clone(),
61+
requires_grad=False,
62+
)
63+
self.is_2d = False
64+
65+
def forward(self, inputs, expert_size):
66+
"""Modified from original forward()"""
67+
68+
input_list = inputs.split(expert_size, dim=0)
69+
70+
weight_3d = self.weight.view(
71+
self.num_experts, self.output_size, self.input_size
72+
)
73+
output_list = []
74+
for i in range(self.num_experts):
75+
output_list.append(torch.nn.functional.linear(input_list[i], weight_3d[i]))
76+
77+
results = torch.cat(output_list, dim=0)
78+
return results
79+
80+
def __repr__(self):
81+
if self.is_2d:
82+
sizes_str = f"(out={self.weight.shape[0]},in={self.weight.shape[1]})"
83+
else:
84+
sizes_str = (
85+
f"(exp={self.weight.shape[0]},out={self.weight.shape[1]},"
86+
f"in={self.weight.shape[2]})"
87+
)
88+
return f"{self.__class__.__name__}{sizes_str}"

0 commit comments

Comments
 (0)