Skip to content

Commit 4c95fd2

Browse files
[Transforms] Update examples for R4 and transform_block_size option (#1870)
SUMMARY: Prerequisites: - [x] neuralmagic/compressed-tensors#472 This PR updates the SpinQuant and Quip examples to include `transform_block_size` and the latest R4 feature in SpinQuant. It also reverts the `TransformScheme.block_size` changes previously introduced into CT, and updated in Pr linked above. While `block_size` is a more appropriate name, `head_dim` has already landed in vllm, and it would be too much of a pain to change. Users will rarely create their own `TransformScheme` anyway. TEST PLAN: - [x] Both examples run and the saved model can be run in vllm, output is meaningful. - [x] with prints, confirmed hadacore is used for `QuIPModifier(rotations=["v", "u"], transform_block_size=64, transform_type="hadamard")` - [x] and dense gemm is used for `QuIPModifier(rotations=["v", "u"], transform_block_size=64, transform_type="random-hadamard")` --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 4c8c0a7 commit 4c95fd2

File tree

4 files changed

+14
-11
lines changed

4 files changed

+14
-11
lines changed

examples/transform/quip_example.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
# * apply quip transforms to model in order to make quantization easier
2222
# * quantize the weights to 4 bit with a group size 128
2323
recipe = [
24-
QuIPModifier(rotations=["v", "u"], transform_type="random-hadamard"),
24+
QuIPModifier(
25+
rotations=["v", "u"], transform_block_size=128, transform_type="random-hadamard"
26+
),
2527
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
2628
]
2729

examples/transform/spinquant_example.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
1212
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
1313

14-
# NOTE: currently only fused rotations (R1 & R2) are available
15-
# Learned rotations and online rotations (R3 & R4) will be added
16-
# in a future release.
14+
# NOTE: currently only rotations R1, R2, and R4 are available
15+
# R3 and learned R1/R2 rotations will be added in a future release.
1716
# Configure the quantization algorithm to run.
1817
# * apply spinquant transforms to model to reduce quantization loss
1918
# * quantize the weights to 4 bit with group size 128
2019
recipe = [
21-
SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"),
20+
SpinQuantModifier(
21+
rotations=["R1", "R2", "R4"], transform_block_size=64, transform_type="hadamard"
22+
),
2223
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
2324
]
2425

@@ -37,6 +38,6 @@
3738
print("==========================================\n\n")
3839

3940
# Save to disk compressed.
40-
SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR1R2-w4a16"
41+
SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR1R2R4-w4a16"
4142
model.save_pretrained(SAVE_DIR, save_compressed=True)
4243
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _create_config(self) -> TransformConfig:
135135
def _create_v_scheme(self) -> TransformScheme:
136136
return TransformScheme(
137137
type=self.transform_type,
138-
block_size=self.transform_block_size,
138+
head_dim=self.transform_block_size,
139139
apply=[
140140
TransformArgs(
141141
targets=self.targets,
@@ -157,7 +157,7 @@ def _create_v_scheme(self) -> TransformScheme:
157157
def _create_u_scheme(self) -> TransformScheme:
158158
return TransformScheme(
159159
type=self.transform_type,
160-
block_size=self.transform_block_size,
160+
head_dim=self.transform_block_size,
161161
apply=[
162162
TransformArgs(
163163
targets=self.targets,

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _create_r1_scheme(self) -> TransformScheme:
193193
randomize=self.randomize,
194194
requires_grad=self.learnable,
195195
precision=self.precision,
196-
block_size=self.transform_block_size,
196+
head_dim=self.transform_block_size,
197197
apply=[
198198
TransformArgs(
199199
targets=[
@@ -240,7 +240,7 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
240240
randomize=self.randomize,
241241
requires_grad=self.learnable,
242242
precision=self.precision,
243-
block_size=head_dim,
243+
head_dim=head_dim,
244244
apply=[
245245
TransformArgs(targets=[self.mappings.attn_v], location="weight_output"),
246246
TransformArgs(
@@ -262,7 +262,7 @@ def _create_r4_scheme(self) -> TransformScheme:
262262
randomize=self.randomize,
263263
requires_grad=self.learnable,
264264
precision=self.precision,
265-
block_size=self.transform_block_size,
265+
head_dim=self.transform_block_size,
266266
apply=[
267267
TransformArgs(
268268
targets=[*self.mappings.mlp_out],

0 commit comments

Comments
 (0)